Inference#
You can run inference with TerraTorch by providing the path to an input folder and output directory. You can do this directly via the CLI with:
terratorch predict -c config.yaml --ckpt_path path/to/model/checkpoint.ckpt --data.init_args.predict_data_root input/folder/ --predict_output_dir output/folder/
This approach works only for supported data modules like the TerraTorch GenericNonGeoSegmentationDataModule
.
E.g., the generic multimodal datamodule expects a dictionary for predict_data_root
.
Therefore, one can define the parameters in the config file as well:
data:
class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule
init_args:
...
predict_data_root: path/to/input/files/
We provide two tutorials for basic inferrence and a simplified inference.
Tiled inference via CLI#
TerraTorch supports a tiled inference that splits up a tile into smaller chips. With this approach, you can run a model on very large tiles like a 10k x 10k pixel Sentinel-2 tile.
Define the tiled inference parameters in the yaml config like the following:
model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
...
tiled_inference_parameters:
crop: 224
stride: 192
Next, you can run:
terratorch predict -c config.yaml --ckpt_path path/to/model/checkpoint.ckpt --data.init_args.predict_data_root input/folder/ --predict_output_dir output/folder/
Warning
The Lightning CLI load each input tile automatically to the GPU before passing it to tiled_inference
.
This can result in CUDA out-of-memory errors for very large tiles like 100k x 100k pixels.
In this case, run the tiled inference via a python script and do not load the full tile into the GPU.
By default, tiled inference adds some padding around the tile and removes the edge pixels of each chip before merging which are both defined by the parameter delta
.
The predictions of overlapping patches are using blend masks to edges in the predictions. This can be deactivated with blend_overlaps=False
.
Here is a comparison between both with their respective predictions counts per pixel.
Pixels along the right side and bottom can have more predictions if the tile is not fully divisible by the define crop
and stride
size.
TerraTorch maximises the overlap to fully leverage the compute while generating at least one prediction per pixel.
-
Without blending, "patchy" predictions with visible lines along the chip edges can appear.
-
By default, a cosine-based blend mask is applied to each chip which smooths the generations.
Tiled inference via Python#
You can use TerraTorch to run tiled inference in a python script like the following:
import torch
import rioxarray as rxr
from terratorch.tasks import SemanticSegmentationTask
from terratorch.tasks.tiled_inference import tiled_inference
from terratorch.cli_tools import LightningInferenceModel
# Init an TerraTorch task, e.g. for semantic segmentation
model = SemanticSegmentationTask.load_from_checkpoint(
ckpt_path, # Pass the checkpoint path
model_factory="EncoderDecoderFactory",
model_args=model_args, # Pass your model args
)
# Alternatively build the model from a config file
model = LightningInferenceModel.from_config(
config_file,
ckpt_path,
# additional params, e.g. predict_dataset_bands
)
# Load your data
input = rxr.open_rasterio("input.tif")
# Apply your standardization values to the input tile
input = (input - means[:, None, None]) / stds[:, None, None]
# Create input tensor with shape [B, C, H, W] on CPU
input = torch.tensor(input, dtype=torch.float, device='cpu').unsqueeze(0)
# Inference wrapper for TerraTorch task model
def model_forward(x, **kwargs):
# Retuns a torch Tensor
return model(x, **kwargs).output
# Run tiled inference (data is loaded automatically to GPU)
pred = tiled_inference(
model_forward,
input,
crop=256,
stride=240,
batch_size=16,
verbose=True
)
# Remove batch dim and compute segmentation map
pred = pred.squeeze(0).argmax(dim=0)
Tip
You can easily modify the script by adjusting the parameters or using a custom PyTorch model instead of model_forward
for tiled_inference
.
It is just important, that the output of the passed forward function returns a torch tensor.
Function reference#
terratorch.tasks.tiled_inference.tiled_inference(model_forward, input_batch, out_channels=None, inference_parameters=None, crop=224, stride=192, delta=8, h_crop=None, w_crop=None, h_stride=None, w_stride=None, average_patches=True, blend_overlaps=True, batch_size=16, verbose=False, padding='reflect', **kwargs)
#
Divide an image into (potentially) overlapping chips and perform inference on them. Additionally, re-batch for varibale GPU utilization defined by crop size and batch_size. The overlap between chips is defined with: crop - stride - 2 * delta.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_forward
|
Callable
|
Callable that return the output of the model. |
required |
input_batch
|
Tensor
|
Input batch to be processed |
required |
out_channels
|
int
|
Number of output channels |
None
|
inference_parameters
|
TiledInferenceParameters
|
Parameters to be used for inference. Deprecated, please us directly pass the parameters to tiled_inference. |
None
|
crop
|
int
|
height and width of the smaller chips. Ignored if h_crop or w_crop is provided. Defaults to 224. |
224
|
stride
|
int
|
size of the stride. Ignored if h_stride or w_stride is provided. Defaults to 192. |
192
|
delta
|
int
|
size of the border cropped from each chip. Defaults to 8. |
8
|
h_crop
|
int
|
height of the smaller chips. |
None
|
w_crop
|
int
|
width of the smaller chips. |
None
|
h_stride
|
int
|
size of the stride on the y-axis. |
None
|
w_stride
|
int
|
size of the stride on the x-axis. |
None
|
average_patches
|
bool
|
Whether to average the overlapping regions. Defaults to True. |
True
|
batch_size
|
int
|
Number of chips per forward pass. Defaults to 16. |
16
|
padding
|
str | bool
|
Padding mode for input image to reduce artefacts on edges. Deactivate padding with False. Defaults to reflect. |
'reflect'
|
Returns: torch.Tensor: The result of the inference