Skip to content

Tasks

Tasks provide a convenient abstraction over the training of a model for a specific downstream task. They encapsulate the model, optimizer, metrics, loss as well as training, validation and testing steps. The task expects to be passed a model factory, to which the model_args arguments are passed to instantiate the model that will be trained. The models produced by this model factory should output ModelOutput instances and conform to the Model ABC. Tasks are best leveraged using config files, where they are specified in the model section under class_path. You can check out some examples of config files here. Below are the details of the tasks currently implemented in TerraTorch (Pixelwise Regression, Semantic Segmentation and Classification).

terratorch.tasks.segmentation_tasks.SemanticSegmentationTask #

Bases: TerraTorchTask

Semantic Segmentation Task that accepts models from a range of sources.

This class is analog in functionality to class SemanticSegmentationTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - Allows to evaluate on multiple test dataloaders

__init__(model_args, model_factory=None, model=None, loss='ce', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, freeze_head=False, plot_on_val=10, class_names=None, tiled_inference_parameters=None, test_dataloaders_names=None, lr_overrides=None, output_on_inference='prediction', output_most_probable=True, tiled_inference_on_testing=False) #

Constructor

Parameters:

Name Type Description Default
model_args Dict

Arguments passed to the model factory.

required
model_factory str

ModelFactory class to be used to instantiate the model. Is ignored when model is provided.

None
model Module

Custom model.

None
loss str

Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss. Defaults to "ce".

'ce'
aux_loss dict[str, float] | None

Auxiliary loss weights. Should be a dictionary where the key is the name given to the loss and the value is the weight to be applied to that loss. The name of the loss should match the key in the dictionary output by the model's forward method containing that output. Defaults to None.

None
class_weights Union[list[float], None]

List of class weights to be applied to the loss.

None
class_weights list[float] | None

List of class weights to be applied to the loss. Defaults to None.

None
ignore_index int | None

Label to ignore in the loss computation. Defaults to None.

None
lr float

Learning rate to be used. Defaults to 0.001.

0.001
optimizer str | None

Name of optimizer class from torch.optim to be used.

None
optimizer_hparams dict | None

Parameters to be passed for instantiation of the optimizer. Overriden by config / cli specification through LightningCLI.

None
scheduler str

Name of Torch scheduler class from torch.optim.lr_scheduler to be used (e.g. ReduceLROnPlateau). Defaults to None. Overriden by config / cli specification through LightningCLI.

None
scheduler_hparams dict | None

Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI.

None
freeze_backbone bool

Whether to freeze the backbone. Defaults to False.

False
freeze_decoder bool

Whether to freeze the decoder. Defaults to False.

False
freeze_head bool

Whether to freeze the segmentation head. Defaults to False.

False
plot_on_val bool | int

Whether to plot visualizations on validation.

10
class_names list[str] | None

List of class names passed to metrics for better naming. Defaults to numeric ordering.

None
tiled_inference_parameters TiledInferenceParameters | None

Inference parameters used to determine if inference is done on the whole image or through tiling.

None
test_dataloaders_names list[str] | None

Names used to differentiate metrics when multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, which assumes only one test dataloader is used.

None
lr_overrides dict[str, float] | None

Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None.

None
output_on_inference str

A string defining the kind of output to be saved to file during the inference, it can be "prediction",

'prediction'
output_most_probable bool

A boolean to define if the prediction step will output just the most probable

True
tiled_inference_on_testing bool

A boolean to the fine if tiled inference will be used when full inference fails during the test step.

False

configure_losses() #

Initialize the loss criterion.

Raises:

Type Description
ValueError

If loss is invalid.

configure_metrics() #

Initialize the performance metrics.

predict_step(batch, batch_idx, dataloader_idx=0) #

Compute the predicted class probabilities.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

Returns:

Type Description
Tensor

Output predicted probabilities.

test_step(batch, batch_idx, dataloader_idx=0) #

Compute the test loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

training_step(batch, batch_idx, dataloader_idx=0) #

Compute the train loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

validation_step(batch, batch_idx, dataloader_idx=0) #

Compute the validation loss and additional metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader.

terratorch.tasks.regression_tasks.PixelwiseRegressionTask #

Bases: TerraTorchTask

Pixelwise Regression Task that accepts models from a range of sources.

This class is analog in functionality to PixelwiseRegressionTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - Allows to evaluate on multiple test dataloaders

__init__(model_args, model_factory=None, model=None, loss='mse', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, freeze_head=False, plot_on_val=10, tiled_inference_parameters=None, test_dataloaders_names=None, lr_overrides=None, tiled_inference_on_testing=None) #

Constructor

Parameters:

Name Type Description Default
model_args Dict

Arguments passed to the model factory.

required
model_factory str

Name of ModelFactory class to be used to instantiate the model. Is ignored when model is provided.

None
model Module

Custom model.

None
loss str

Loss to be used. Currently, supports 'mse', 'rmse', 'mae' or 'huber' loss. Defaults to "mse".

'mse'
aux_loss dict[str, float] | None

Auxiliary loss weights. Should be a dictionary where the key is the name given to the loss and the value is the weight to be applied to that loss. The name of the loss should match the key in the dictionary output by the model's forward method containing that output. Defaults to None.

None
class_weights list[float] | None

List of class weights to be applied to the loss. Defaults to None.

None
ignore_index int | None

Label to ignore in the loss computation. Defaults to None.

None
lr float

Learning rate to be used. Defaults to 0.001.

0.001
optimizer str | None

Name of optimizer class from torch.optim to be used. If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.

None
optimizer_hparams dict | None

Parameters to be passed for instantiation of the optimizer. Overriden by config / cli specification through LightningCLI.

None
scheduler str

Name of Torch scheduler class from torch.optim.lr_scheduler to be used (e.g. ReduceLROnPlateau). Defaults to None. Overriden by config / cli specification through LightningCLI.

None
scheduler_hparams dict | None

Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI.

None
freeze_backbone bool

Whether to freeze the backbone. Defaults to False.

False
freeze_decoder bool

Whether to freeze the decoder. Defaults to False.

False
freeze_head bool

Whether to freeze the segmentation head. Defaults to False.

False
plot_on_val bool | int

Whether to plot visualizations on validation. If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.

10
tiled_inference_parameters TiledInferenceParameters | None

Inference parameters used to determine if inference is done on the whole image or through tiling.

None
test_dataloaders_names list[str] | None

Names used to differentiate metrics when multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, which assumes only one test dataloader is used.

None
lr_overrides dict[str, float] | None

Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None.

None
tiled_inference_on_testing bool

A boolean to the fine if tiled inference will be used when full inference fails during the test step.

None

configure_losses() #

Initialize the loss criterion.

Raises:

Type Description
ValueError

If loss is invalid.

configure_metrics() #

Initialize the performance metrics.

predict_step(batch, batch_idx, dataloader_idx=0) #

Compute the predicted class probabilities.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

Returns:

Type Description
Tensor

Output predicted probabilities.

test_step(batch, batch_idx, dataloader_idx=0) #

Compute the test loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

training_step(batch, batch_idx, dataloader_idx=0) #

Compute the train loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

validation_step(batch, batch_idx, dataloader_idx=0) #

Compute the validation loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

terratorch.tasks.classification_tasks.ClassificationTask #

Bases: TerraTorchTask

Classification Task that accepts models from a range of sources.

This class is analog in functionality to the class ClassificationTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - It provides mIoU with both Micro and Macro averaging - Allows to evaluate on multiple test dataloaders

.. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect minority class accuracy. * 'Macro' averaging gives equal weight to each class, useful for balanced performance assessment across imbalanced classes.

__init__(model_args, model_factory=None, model=None, loss='ce', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, freeze_head=False, class_names=None, test_dataloaders_names=None, lr_overrides=None) #

Constructor

Parameters:

Name Type Description Default
model_args Dict

Arguments passed to the model factory.

required
model_factory str

ModelFactory class to be used to instantiate the model. Is ignored when model is provided.

None
model Module

Custom model.

None
loss str

Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss. Defaults to "ce".

'ce'
aux_loss dict[str, float] | None

Auxiliary loss weights. Should be a dictionary where the key is the name given to the loss and the value is the weight to be applied to that loss. The name of the loss should match the key in the dictionary output by the model's forward method containing that output. Defaults to None.

None
class_weights Union[list[float], None]

List of class weights to be applied to the loss.

None
class_weights list[float] | None

List of class weights to be applied to the loss. Defaults to None.

None
ignore_index int | None

Label to ignore in the loss computation. Defaults to None.

None
lr float

Learning rate to be used. Defaults to 0.001.

0.001
optimizer str | None

Name of optimizer class from torch.optim to be used. If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.

None
optimizer_hparams dict | None

Parameters to be passed for instantiation of the optimizer. Overriden by config / cli specification through LightningCLI.

None
scheduler str

Name of Torch scheduler class from torch.optim.lr_scheduler to be used (e.g. ReduceLROnPlateau). Defaults to None. Overriden by config / cli specification through LightningCLI.

None
scheduler_hparams dict | None

Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI.

None
freeze_backbone bool

Whether to freeze the backbone. Defaults to False.

False
freeze_decoder bool

Whether to freeze the decoder. Defaults to False.

False
freeze_head bool

Whether to freeze the segmentation_head. Defaults to False.

False
class_names list[str] | None

List of class names passed to metrics for better naming. Defaults to numeric ordering.

None
test_dataloaders_names list[str] | None

Names used to differentiate metrics when multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, which assumes only one test dataloader is used.

None
lr_overrides dict[str, float] | None

Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None.

None

configure_losses() #

Initialize the loss criterion.

Raises:

Type Description
ValueError

If loss is invalid.

configure_metrics() #

Initialize the performance metrics.

predict_step(batch, batch_idx, dataloader_idx=0) #

Compute the predicted class probabilities.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

Returns:

Type Description
Tensor

Output predicted probabilities.

test_step(batch, batch_idx, dataloader_idx=0) #

Compute the test loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

training_step(batch, batch_idx, dataloader_idx=0) #

Compute the train loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

validation_step(batch, batch_idx, dataloader_idx=0) #

Compute the validation loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

Last update: March 25, 2025