Tasks
Tasks provide a convenient abstraction over the training of a model for a specific downstream task.
They encapsulate the model, optimizer, metrics, loss as well as training, validation and testing steps.
The task expects to be passed a model factory, to which the model_args arguments are passed to instantiate the model that will be trained.
The models produced by this model factory should output ModelOutput instances and conform to the Model ABC.
Tasks are best leveraged using config files, where they are specified in the model
section under class_path
. You can check out some examples of config files here.
Below are the details of the tasks currently implemented in TerraTorch (Pixelwise Regression, Semantic Segmentation and Classification).
terratorch.tasks.segmentation_tasks.SemanticSegmentationTask
#
Bases: TerraTorchTask
Semantic Segmentation Task that accepts models from a range of sources.
This class is analog in functionality to class SemanticSegmentationTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - Allows to evaluate on multiple test dataloaders
__init__(model_args, model_factory=None, model=None, loss='ce', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, freeze_head=False, plot_on_val=10, class_names=None, tiled_inference_parameters=None, test_dataloaders_names=None, lr_overrides=None, output_on_inference='prediction', output_most_probable=True, tiled_inference_on_testing=False)
#
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_args
|
Dict
|
Arguments passed to the model factory. |
required |
model_factory
|
str
|
ModelFactory class to be used to instantiate the model. Is ignored when model is provided. |
None
|
model
|
Module
|
Custom model. |
None
|
loss
|
str
|
Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss. Defaults to "ce". |
'ce'
|
aux_loss
|
dict[str, float] | None
|
Auxiliary loss weights. Should be a dictionary where the key is the name given to the loss and the value is the weight to be applied to that loss. The name of the loss should match the key in the dictionary output by the model's forward method containing that output. Defaults to None. |
None
|
class_weights
|
Union[list[float], None]
|
List of class weights to be applied to the loss. |
None
|
class_weights
|
list[float] | None
|
List of class weights to be applied to the loss. Defaults to None. |
None
|
ignore_index
|
int | None
|
Label to ignore in the loss computation. Defaults to None. |
None
|
lr
|
float
|
Learning rate to be used. Defaults to 0.001. |
0.001
|
optimizer
|
str | None
|
Name of optimizer class from torch.optim to be used. |
None
|
optimizer_hparams
|
dict | None
|
Parameters to be passed for instantiation of the optimizer. Overriden by config / cli specification through LightningCLI. |
None
|
scheduler
|
str
|
Name of Torch scheduler class from torch.optim.lr_scheduler to be used (e.g. ReduceLROnPlateau). Defaults to None. Overriden by config / cli specification through LightningCLI. |
None
|
scheduler_hparams
|
dict | None
|
Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI. |
None
|
freeze_backbone
|
bool
|
Whether to freeze the backbone. Defaults to False. |
False
|
freeze_decoder
|
bool
|
Whether to freeze the decoder. Defaults to False. |
False
|
freeze_head
|
bool
|
Whether to freeze the segmentation head. Defaults to False. |
False
|
plot_on_val
|
bool | int
|
Whether to plot visualizations on validation. |
10
|
class_names
|
list[str] | None
|
List of class names passed to metrics for better naming. Defaults to numeric ordering. |
None
|
tiled_inference_parameters
|
TiledInferenceParameters | None
|
Inference parameters used to determine if inference is done on the whole image or through tiling. |
None
|
test_dataloaders_names
|
list[str] | None
|
Names used to differentiate metrics when multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, which assumes only one test dataloader is used. |
None
|
lr_overrides
|
dict[str, float] | None
|
Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None. |
None
|
output_on_inference
|
str
|
A string defining the kind of output to be saved to file during the inference, it can be "prediction", |
'prediction'
|
output_most_probable
|
bool
|
A boolean to define if the prediction step will output just the most probable |
True
|
tiled_inference_on_testing
|
bool
|
A boolean to the fine if tiled inference will be used when full inference fails during the test step. |
False
|
configure_losses()
#
Initialize the loss criterion.
Raises:
Type | Description |
---|---|
ValueError
|
If loss is invalid. |
configure_metrics()
#
Initialize the performance metrics.
predict_step(batch, batch_idx, dataloader_idx=0)
#
Compute the predicted class probabilities.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Any
|
The output of your DataLoader. |
required |
batch_idx
|
int
|
Integer displaying index of this batch. |
required |
dataloader_idx
|
int
|
Index of the current dataloader. |
0
|
Returns:
Type | Description |
---|---|
Tensor
|
Output predicted probabilities. |
test_step(batch, batch_idx, dataloader_idx=0)
#
Compute the test loss and additional metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Any
|
The output of your DataLoader. |
required |
batch_idx
|
int
|
Integer displaying index of this batch. |
required |
dataloader_idx
|
int
|
Index of the current dataloader. |
0
|
training_step(batch, batch_idx, dataloader_idx=0)
#
Compute the train loss and additional metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Any
|
The output of your DataLoader. |
required |
batch_idx
|
int
|
Integer displaying index of this batch. |
required |
dataloader_idx
|
int
|
Index of the current dataloader. |
0
|
validation_step(batch, batch_idx, dataloader_idx=0)
#
Compute the validation loss and additional metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader.
terratorch.tasks.regression_tasks.PixelwiseRegressionTask
#
Bases: TerraTorchTask
Pixelwise Regression Task that accepts models from a range of sources.
This class is analog in functionality to PixelwiseRegressionTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - Allows to evaluate on multiple test dataloaders
__init__(model_args, model_factory=None, model=None, loss='mse', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, freeze_head=False, plot_on_val=10, tiled_inference_parameters=None, test_dataloaders_names=None, lr_overrides=None, tiled_inference_on_testing=None)
#
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_args
|
Dict
|
Arguments passed to the model factory. |
required |
model_factory
|
str
|
Name of ModelFactory class to be used to instantiate the model. Is ignored when model is provided. |
None
|
model
|
Module
|
Custom model. |
None
|
loss
|
str
|
Loss to be used. Currently, supports 'mse', 'rmse', 'mae' or 'huber' loss. Defaults to "mse". |
'mse'
|
aux_loss
|
dict[str, float] | None
|
Auxiliary loss weights. Should be a dictionary where the key is the name given to the loss and the value is the weight to be applied to that loss. The name of the loss should match the key in the dictionary output by the model's forward method containing that output. Defaults to None. |
None
|
class_weights
|
list[float] | None
|
List of class weights to be applied to the loss. Defaults to None. |
None
|
ignore_index
|
int | None
|
Label to ignore in the loss computation. Defaults to None. |
None
|
lr
|
float
|
Learning rate to be used. Defaults to 0.001. |
0.001
|
optimizer
|
str | None
|
Name of optimizer class from torch.optim to be used. If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI. |
None
|
optimizer_hparams
|
dict | None
|
Parameters to be passed for instantiation of the optimizer. Overriden by config / cli specification through LightningCLI. |
None
|
scheduler
|
str
|
Name of Torch scheduler class from torch.optim.lr_scheduler to be used (e.g. ReduceLROnPlateau). Defaults to None. Overriden by config / cli specification through LightningCLI. |
None
|
scheduler_hparams
|
dict | None
|
Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI. |
None
|
freeze_backbone
|
bool
|
Whether to freeze the backbone. Defaults to False. |
False
|
freeze_decoder
|
bool
|
Whether to freeze the decoder. Defaults to False. |
False
|
freeze_head
|
bool
|
Whether to freeze the segmentation head. Defaults to False. |
False
|
plot_on_val
|
bool | int
|
Whether to plot visualizations on validation. If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs. |
10
|
tiled_inference_parameters
|
TiledInferenceParameters | None
|
Inference parameters used to determine if inference is done on the whole image or through tiling. |
None
|
test_dataloaders_names
|
list[str] | None
|
Names used to differentiate metrics when multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, which assumes only one test dataloader is used. |
None
|
lr_overrides
|
dict[str, float] | None
|
Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None. |
None
|
tiled_inference_on_testing
|
bool
|
A boolean to the fine if tiled inference will be used when full inference fails during the test step. |
None
|
configure_losses()
#
Initialize the loss criterion.
Raises:
Type | Description |
---|---|
ValueError
|
If loss is invalid. |
configure_metrics()
#
Initialize the performance metrics.
predict_step(batch, batch_idx, dataloader_idx=0)
#
Compute the predicted class probabilities.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Any
|
The output of your DataLoader. |
required |
batch_idx
|
int
|
Integer displaying index of this batch. |
required |
dataloader_idx
|
int
|
Index of the current dataloader. |
0
|
Returns:
Type | Description |
---|---|
Tensor
|
Output predicted probabilities. |
test_step(batch, batch_idx, dataloader_idx=0)
#
Compute the test loss and additional metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Any
|
The output of your DataLoader. |
required |
batch_idx
|
int
|
Integer displaying index of this batch. |
required |
dataloader_idx
|
int
|
Index of the current dataloader. |
0
|
training_step(batch, batch_idx, dataloader_idx=0)
#
Compute the train loss and additional metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Any
|
The output of your DataLoader. |
required |
batch_idx
|
int
|
Integer displaying index of this batch. |
required |
dataloader_idx
|
int
|
Index of the current dataloader. |
0
|
validation_step(batch, batch_idx, dataloader_idx=0)
#
Compute the validation loss and additional metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Any
|
The output of your DataLoader. |
required |
batch_idx
|
int
|
Integer displaying index of this batch. |
required |
dataloader_idx
|
int
|
Index of the current dataloader. |
0
|
terratorch.tasks.classification_tasks.ClassificationTask
#
Bases: TerraTorchTask
Classification Task that accepts models from a range of sources.
This class is analog in functionality to the class ClassificationTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - It provides mIoU with both Micro and Macro averaging - Allows to evaluate on multiple test dataloaders
.. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect minority class accuracy. * 'Macro' averaging gives equal weight to each class, useful for balanced performance assessment across imbalanced classes.
__init__(model_args, model_factory=None, model=None, loss='ce', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, freeze_head=False, class_names=None, test_dataloaders_names=None, lr_overrides=None)
#
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_args
|
Dict
|
Arguments passed to the model factory. |
required |
model_factory
|
str
|
ModelFactory class to be used to instantiate the model. Is ignored when model is provided. |
None
|
model
|
Module
|
Custom model. |
None
|
loss
|
str
|
Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss. Defaults to "ce". |
'ce'
|
aux_loss
|
dict[str, float] | None
|
Auxiliary loss weights. Should be a dictionary where the key is the name given to the loss and the value is the weight to be applied to that loss. The name of the loss should match the key in the dictionary output by the model's forward method containing that output. Defaults to None. |
None
|
class_weights
|
Union[list[float], None]
|
List of class weights to be applied to the loss. |
None
|
class_weights
|
list[float] | None
|
List of class weights to be applied to the loss. Defaults to None. |
None
|
ignore_index
|
int | None
|
Label to ignore in the loss computation. Defaults to None. |
None
|
lr
|
float
|
Learning rate to be used. Defaults to 0.001. |
0.001
|
optimizer
|
str | None
|
Name of optimizer class from torch.optim to be used. If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI. |
None
|
optimizer_hparams
|
dict | None
|
Parameters to be passed for instantiation of the optimizer. Overriden by config / cli specification through LightningCLI. |
None
|
scheduler
|
str
|
Name of Torch scheduler class from torch.optim.lr_scheduler to be used (e.g. ReduceLROnPlateau). Defaults to None. Overriden by config / cli specification through LightningCLI. |
None
|
scheduler_hparams
|
dict | None
|
Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI. |
None
|
freeze_backbone
|
bool
|
Whether to freeze the backbone. Defaults to False. |
False
|
freeze_decoder
|
bool
|
Whether to freeze the decoder. Defaults to False. |
False
|
freeze_head
|
bool
|
Whether to freeze the segmentation_head. Defaults to False. |
False
|
class_names
|
list[str] | None
|
List of class names passed to metrics for better naming. Defaults to numeric ordering. |
None
|
test_dataloaders_names
|
list[str] | None
|
Names used to differentiate metrics when multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, which assumes only one test dataloader is used. |
None
|
lr_overrides
|
dict[str, float] | None
|
Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None. |
None
|
configure_losses()
#
Initialize the loss criterion.
Raises:
Type | Description |
---|---|
ValueError
|
If loss is invalid. |
configure_metrics()
#
Initialize the performance metrics.
predict_step(batch, batch_idx, dataloader_idx=0)
#
Compute the predicted class probabilities.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Any
|
The output of your DataLoader. |
required |
batch_idx
|
int
|
Integer displaying index of this batch. |
required |
dataloader_idx
|
int
|
Index of the current dataloader. |
0
|
Returns:
Type | Description |
---|---|
Tensor
|
Output predicted probabilities. |
test_step(batch, batch_idx, dataloader_idx=0)
#
Compute the test loss and additional metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Any
|
The output of your DataLoader. |
required |
batch_idx
|
int
|
Integer displaying index of this batch. |
required |
dataloader_idx
|
int
|
Index of the current dataloader. |
0
|
training_step(batch, batch_idx, dataloader_idx=0)
#
Compute the train loss and additional metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Any
|
The output of your DataLoader. |
required |
batch_idx
|
int
|
Integer displaying index of this batch. |
required |
dataloader_idx
|
int
|
Index of the current dataloader. |
0
|
validation_step(batch, batch_idx, dataloader_idx=0)
#
Compute the validation loss and additional metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Any
|
The output of your DataLoader. |
required |
batch_idx
|
int
|
Integer displaying index of this batch. |
required |
dataloader_idx
|
int
|
Index of the current dataloader. |
0
|