Skip to content

Tasks#

Tasks provide a convenient abstraction over the training of a model for a specific downstream task. They encapsulate the model, optimizer, metrics, loss as well as training, validation and testing steps. The task expects to be passed a model_factory, to which the model_args arguments are passed to instantiate the model that will be trained. The models produced by this model factory should output ModelOutput instances and conform to the Model ABC. Tasks are best leveraged using config files, where they are specified in the model section under class_path. You can check out some examples of config files here. Below are the details of the tasks currently implemented in TerraTorch (Pixelwise Regression, Semantic Segmentation and Classification).

terratorch.tasks.SemanticSegmentationTask #

Bases: TerraTorchTask

Semantic Segmentation Task that accepts models from a range of sources.

This class is analog in functionality to class SemanticSegmentationTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - Allows to evaluate on multiple test dataloaders

Source code in terratorch/tasks/segmentation_tasks.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
class SemanticSegmentationTask(TerraTorchTask):
    """Semantic Segmentation Task that accepts models from a range of sources.

    This class is analog in functionality to class SemanticSegmentationTask defined by torchgeo.
    However, it has some important differences:
        - Accepts the specification of a model factory
        - Logs metrics per class
        - Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
        - Allows the setting of optimizers in the constructor
        - Allows to evaluate on multiple test dataloaders
    """

    def __init__(
        self,
        model_args: dict,
        model_factory: str | None = None,
        model: torch.nn.Module | None = None,
        loss: str = "ce",
        aux_heads: list[AuxiliaryHead] | None = None,
        aux_loss: dict[str, float] | None = None,
        class_weights: list[float] | None = None,
        ignore_index: int | None = None,
        lr: float = 0.001,
        # the following are optional so CLI doesnt need to pass them
        optimizer: str | None = None,
        optimizer_hparams: dict | None = None,
        scheduler: str | None = None,
        scheduler_hparams: dict | None = None,
        freeze_backbone: bool = False,  # noqa: FBT001, FBT002
        freeze_decoder: bool = False,  # noqa: FBT002, FBT001
        freeze_head: bool = False,
        plot_on_val: bool | int = 10,
        class_names: list[str] | None = None,
        tiled_inference_parameters: dict = None,
        test_dataloaders_names: list[str] | None = None,
        lr_overrides: dict[str, float] | None = None,
        output_on_inference: str | list[str] = "prediction",
        output_most_probable: bool = True,
        path_to_record_metrics: str = None,
        tiled_inference_on_testing: bool = False,
        tiled_inference_on_validation: bool = False,
    ) -> None:
        """Constructor

        Args:
            Defaults to None.
            model_args (Dict): Arguments passed to the model factory.
            model_factory (str, optional): ModelFactory class to be used to instantiate the model.
                Is ignored when model is provided.
            model (torch.nn.Module, optional): Custom model.
            loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
                Defaults to "ce".
            aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
                Should be a dictionary where the key is the name given to the loss
                and the value is the weight to be applied to that loss.
                The name of the loss should match the key in the dictionary output by the model's forward
                method containing that output. Defaults to None.
            class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
            class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
                Defaults to None.
            ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
            lr (float, optional): Learning rate to be used. Defaults to 0.001.
            optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
            If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
            optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
                Overriden by config / cli specification through LightningCLI.
            scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
                to be used (e.g. ReduceLROnPlateau). Defaults to None.
                Overriden by config / cli specification through LightningCLI.
            scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
                Overriden by config / cli specification through LightningCLI.
            freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
            freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
            freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False.
            plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
            If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
            class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
                Defaults to numeric ordering.
            tiled_inference_parameters (dict | None, optional): Inference parameters
                used to determine if inference is done on the whole image or through tiling.
            test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
                multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
                which assumes only one test dataloader is used.
            lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
                parameters. The key should be a substring of the parameter names (it will check the substring is
                contained in the parameter name) and the value should be the new lr. Defaults to None.
            output_on_inference (str | list[str]): A string or a list defining the kind of output to be saved to file during the inference, for example,
                it can be "prediction", to save just the most probable class, or ["prediction", "probabilities"] to save both prediction and probabilities.
            output_most_probable (bool): A boolean to define if the prediction step will output just the most probable logit or all of them.
                This argument has been deprecated and will be replaced with `output_on_inference`.
            tiled_inference_on_testing (bool): A boolean to define if tiled inference will be used during the test step.
            tiled_inference_on_validation (bool): A boolean to define if tiled inference will be used during the val step.
            tiled_inference_on_testing (bool): A boolean to define if tiled inference will be used when full inference
                fails during the test step.
            path_to_record_metrics (str): A path to save the file containing the metrics log.
        """

        self.tiled_inference_parameters = tiled_inference_parameters
        self.aux_loss = aux_loss
        self.aux_heads = aux_heads

        if model is not None and model_factory is not None:
            logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
        if model is None and model_factory is None:
            raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

        if model_factory and model is None:
            self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

        super().__init__(
            task="segmentation",
            tiled_inference_on_testing=tiled_inference_on_testing,
            tiled_inference_on_validation=tiled_inference_on_validation,
            path_to_record_metrics=path_to_record_metrics,
        )

        if model is not None:
            # Custom model
            self.model = model

        self.train_loss_handler = LossHandler(self.train_metrics.prefix)
        self.test_loss_handler: list[LossHandler] = []
        for metrics in self.test_metrics:
            self.test_loss_handler.append(LossHandler(metrics.prefix))
        self.val_loss_handler = LossHandler(self.val_metrics.prefix)
        self.monitor = f"{self.val_metrics.prefix}loss"
        self.plot_on_val = int(plot_on_val)
        self.output_on_inference = output_on_inference

        # When the user decides to use `output_most_probable` as `False` in
        # order to output the probabilities instead of the prediction.
        if not output_most_probable:
            warnings.warn(
                "The argument `output_most_probable` is deprecated and will be replaced with `output_on_inference='probabilities'`.",
                stacklevel=1,
            )
            output_on_inference = "probabilities"

        # Processing the `output_on_inference` argument.
        self.output_prediction = lambda y: (y.argmax(dim=1), "pred")
        self.output_logits = lambda y: (y, "logits")
        self.output_probabilities = lambda y: (torch.nn.Softmax()(y), "probabilities")

        # The possible methods to define outputs.
        self.operation_map = {
            "prediction": self.output_prediction,
            "logits": self.output_logits,
            "probabilities": self.output_probabilities,
        }

        # `output_on_inference` can be a list or a string.
        if isinstance(output_on_inference, list):
            list_of_selectors = ()
            for var in output_on_inference:
                if var in self.operation_map:
                    list_of_selectors += (self.operation_map[var],)
                else:
                    raise ValueError(
                        f"Option {var} is not supported. It must be in ['prediction', 'logits', 'probabilities']"
                    )

            if not len(list_of_selectors):
                raise ValueError(
                    "The list of selectors for the output is empty, please, provide a valid value for `output_on_inference`"
                )

            self.select_classes = lambda y: [op(y) for op in list_of_selectors]
        elif isinstance(output_on_inference, str):
            self.select_classes = self.operation_map[output_on_inference]

        else:
            raise ValueError(f"The value {output_on_inference} isn't supported for `output_on_inference`.")

    def squeeze_ground_truth(self, x):
        return torch.squeeze(x, 1)

    def configure_losses(self) -> None:
        """Initialize the loss criterion.

        Raises:
            ValueError: If *loss* is invalid.
        """
        loss: str = self.hparams["loss"]
        ignore_index = self.hparams["ignore_index"]

        class_weights = (
            torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
        )
        if loss == "ce":
            ignore_value = -100 if ignore_index is None else ignore_index
            self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
        elif loss == "jaccard":
            if ignore_index is not None:
                exception_message = (
                    f"Jaccard loss does not support ignore_index, but found non-None value of {ignore_index}."
                )
                raise RuntimeError(exception_message)
            self.criterion = smp.losses.JaccardLoss(mode="multiclass")
        elif loss == "focal":
            self.criterion = smp.losses.FocalLoss("multiclass", ignore_index=ignore_index, normalized=True)
        elif loss == "dice":
            self.criterion = smp.losses.DiceLoss("multiclass", ignore_index=ignore_index)
        else:
            exception_message = (
                f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss."
            )
            raise ValueError(exception_message)

    def configure_metrics(self) -> None:
        """Initialize the performance metrics."""
        num_classes: int = self.hparams["model_args"]["num_classes"]
        ignore_index: int = self.hparams["ignore_index"]
        class_names = self.hparams["class_names"]
        metrics = MetricCollection(
            {
                "mIoU": MulticlassJaccardIndex(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="macro",
                ),
                "mIoU_Micro": MulticlassJaccardIndex(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="micro",
                ),
                "F1_Score": MulticlassF1Score(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="macro",
                ),
                "Accuracy": MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="macro",
                ),
                "Pixel_Accuracy": MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="micro",
                ),
                "IoU": ClasswiseWrapper(
                    MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None),
                    labels=class_names,
                    prefix="IoU_",
                ),
                "Class_Accuracy": ClasswiseWrapper(
                    MulticlassAccuracy(
                        num_classes=num_classes,
                        ignore_index=ignore_index,
                        average=None,
                    ),
                    labels=class_names,
                    prefix="Class_Accuracy_",
                ),
            }
        )
        self.train_metrics = metrics.clone(prefix="train/")
        self.val_metrics = metrics.clone(prefix="val/")
        if self.hparams["test_dataloaders_names"] is not None:
            self.test_metrics = nn.ModuleList(
                [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]]
            )
        else:
            self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")])

    def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the train loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        # Testing because of failures.
        x = batch["image"]
        y = self.squeeze_ground_truth(batch["mask"])
        other_keys = batch.keys() - {"image", "mask", "filename"}

        rest = {k: batch[k] for k in other_keys}
        model_output: ModelOutput = self(x, **rest)
        loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
        y_hat_hard = to_segmentation_prediction(model_output)
        self.train_metrics.update(y_hat_hard, y)

        return loss["loss"]

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the test loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = self.squeeze_ground_truth(batch["mask"])
        other_keys = batch.keys() - {"image", "mask", "filename"}

        rest = {k: batch[k] for k in other_keys}

        model_output = self.handle_full_or_tiled_inference(x, self.tiled_inference_on_testing, **rest)

        if dataloader_idx >= len(self.test_loss_handler):
            msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
            raise ValueError(msg)
        loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.test_loss_handler[dataloader_idx].log_loss(
            partial(self.log, add_dataloader_idx=False),  # We don't need the dataloader idx as prefixes are different
            loss_dict=loss,
            batch_size=y.shape[0],
        )
        y_hat_hard = to_segmentation_prediction(model_output)
        self.test_metrics[dataloader_idx].update(y_hat_hard, y)

        self.record_metrics(dataloader_idx, y_hat_hard, y)

    def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the validation loss and additional metrics.
        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = self.squeeze_ground_truth(batch["mask"])

        other_keys = batch.keys() - {"image", "mask", "filename"}
        rest = {k: batch[k] for k in other_keys}
        # model_output: ModelOutput = self(x, **rest)
        model_output = self.handle_full_or_tiled_inference(x, self.tiled_inference_on_validation, **rest)

        loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
        y_hat_hard = to_segmentation_prediction(model_output)
        self.val_metrics.update(y_hat_hard, y)

        if self._do_plot_samples(batch_idx):
            try:
                datamodule = self.trainer.datamodule
                batch["prediction"] = y_hat_hard

                if isinstance(batch["image"], dict):
                    rgb_modality = getattr(datamodule, "rgb_modality", None) or list(batch["image"].keys())[0]
                    batch["image"] = batch["image"][rgb_modality]

                for key in ["image", "mask", "prediction"]:
                    batch[key] = batch[key].cpu()
                sample = unbind_samples(batch)[0]
                fig = datamodule.val_dataset.plot(sample)
                if fig:
                    summary_writer = self.logger.experiment
                    if hasattr(summary_writer, "add_figure"):
                        summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
                    elif hasattr(summary_writer, "log_figure"):
                        summary_writer.log_figure(
                            self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png"
                        )
            except ValueError:
                pass
            finally:
                plt.close()

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the predicted class probabilities.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.

        Returns:
            Output predicted probabilities.
        """
        x = batch["image"]
        file_names = batch["filename"] if "filename" in batch else None
        other_keys = batch.keys() - {"image", "mask", "filename"}

        rest = {k: batch[k] for k in other_keys}

        def model_forward(x, **kwargs):
            return self(x, **kwargs).output

        if self.tiled_inference_parameters:
            y_hat: Tensor = tiled_inference(
                model_forward,
                x,
                **self.tiled_inference_parameters,
                **rest,
            )
        else:
            y_hat: Tensor = self(x, **rest).output

        y_hat_ = self.select_classes(y_hat)

        return y_hat_, file_names

__init__(model_args, model_factory=None, model=None, loss='ce', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, freeze_head=False, plot_on_val=10, class_names=None, tiled_inference_parameters=None, test_dataloaders_names=None, lr_overrides=None, output_on_inference='prediction', output_most_probable=True, path_to_record_metrics=None, tiled_inference_on_testing=False, tiled_inference_on_validation=False) #

Constructor

Parameters:

Name Type Description Default
model_args Dict

Arguments passed to the model factory.

required
model_factory str

ModelFactory class to be used to instantiate the model. Is ignored when model is provided.

None
model Module

Custom model.

None
loss str

Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss. Defaults to "ce".

'ce'
aux_loss dict[str, float] | None

Auxiliary loss weights. Should be a dictionary where the key is the name given to the loss and the value is the weight to be applied to that loss. The name of the loss should match the key in the dictionary output by the model's forward method containing that output. Defaults to None.

None
class_weights Union[list[float], None]

List of class weights to be applied to the loss.

None
class_weights list[float] | None

List of class weights to be applied to the loss. Defaults to None.

None
ignore_index int | None

Label to ignore in the loss computation. Defaults to None.

None
lr float

Learning rate to be used. Defaults to 0.001.

0.001
optimizer str | None

Name of optimizer class from torch.optim to be used.

None
optimizer_hparams dict | None

Parameters to be passed for instantiation of the optimizer. Overriden by config / cli specification through LightningCLI.

None
scheduler str

Name of Torch scheduler class from torch.optim.lr_scheduler to be used (e.g. ReduceLROnPlateau). Defaults to None. Overriden by config / cli specification through LightningCLI.

None
scheduler_hparams dict | None

Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI.

None
freeze_backbone bool

Whether to freeze the backbone. Defaults to False.

False
freeze_decoder bool

Whether to freeze the decoder. Defaults to False.

False
freeze_head bool

Whether to freeze the segmentation head. Defaults to False.

False
plot_on_val bool | int

Whether to plot visualizations on validation.

10
class_names list[str] | None

List of class names passed to metrics for better naming. Defaults to numeric ordering.

None
tiled_inference_parameters dict | None

Inference parameters used to determine if inference is done on the whole image or through tiling.

None
test_dataloaders_names list[str] | None

Names used to differentiate metrics when multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, which assumes only one test dataloader is used.

None
lr_overrides dict[str, float] | None

Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name) and the value should be the new lr. Defaults to None.

None
output_on_inference str | list[str]

A string or a list defining the kind of output to be saved to file during the inference, for example, it can be "prediction", to save just the most probable class, or ["prediction", "probabilities"] to save both prediction and probabilities.

'prediction'
output_most_probable bool

A boolean to define if the prediction step will output just the most probable logit or all of them. This argument has been deprecated and will be replaced with output_on_inference.

True
tiled_inference_on_testing bool

A boolean to define if tiled inference will be used during the test step.

False
tiled_inference_on_validation bool

A boolean to define if tiled inference will be used during the val step.

False
tiled_inference_on_testing bool

A boolean to define if tiled inference will be used when full inference fails during the test step.

False
path_to_record_metrics str

A path to save the file containing the metrics log.

None
Source code in terratorch/tasks/segmentation_tasks.py
def __init__(
    self,
    model_args: dict,
    model_factory: str | None = None,
    model: torch.nn.Module | None = None,
    loss: str = "ce",
    aux_heads: list[AuxiliaryHead] | None = None,
    aux_loss: dict[str, float] | None = None,
    class_weights: list[float] | None = None,
    ignore_index: int | None = None,
    lr: float = 0.001,
    # the following are optional so CLI doesnt need to pass them
    optimizer: str | None = None,
    optimizer_hparams: dict | None = None,
    scheduler: str | None = None,
    scheduler_hparams: dict | None = None,
    freeze_backbone: bool = False,  # noqa: FBT001, FBT002
    freeze_decoder: bool = False,  # noqa: FBT002, FBT001
    freeze_head: bool = False,
    plot_on_val: bool | int = 10,
    class_names: list[str] | None = None,
    tiled_inference_parameters: dict = None,
    test_dataloaders_names: list[str] | None = None,
    lr_overrides: dict[str, float] | None = None,
    output_on_inference: str | list[str] = "prediction",
    output_most_probable: bool = True,
    path_to_record_metrics: str = None,
    tiled_inference_on_testing: bool = False,
    tiled_inference_on_validation: bool = False,
) -> None:
    """Constructor

    Args:
        Defaults to None.
        model_args (Dict): Arguments passed to the model factory.
        model_factory (str, optional): ModelFactory class to be used to instantiate the model.
            Is ignored when model is provided.
        model (torch.nn.Module, optional): Custom model.
        loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
            Defaults to "ce".
        aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
            Should be a dictionary where the key is the name given to the loss
            and the value is the weight to be applied to that loss.
            The name of the loss should match the key in the dictionary output by the model's forward
            method containing that output. Defaults to None.
        class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
        class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
            Defaults to None.
        ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
        lr (float, optional): Learning rate to be used. Defaults to 0.001.
        optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
        If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
        optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
            Overriden by config / cli specification through LightningCLI.
        scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
            to be used (e.g. ReduceLROnPlateau). Defaults to None.
            Overriden by config / cli specification through LightningCLI.
        scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
            Overriden by config / cli specification through LightningCLI.
        freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
        freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
        freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False.
        plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
        If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
        class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
            Defaults to numeric ordering.
        tiled_inference_parameters (dict | None, optional): Inference parameters
            used to determine if inference is done on the whole image or through tiling.
        test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
            multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
            which assumes only one test dataloader is used.
        lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
            parameters. The key should be a substring of the parameter names (it will check the substring is
            contained in the parameter name) and the value should be the new lr. Defaults to None.
        output_on_inference (str | list[str]): A string or a list defining the kind of output to be saved to file during the inference, for example,
            it can be "prediction", to save just the most probable class, or ["prediction", "probabilities"] to save both prediction and probabilities.
        output_most_probable (bool): A boolean to define if the prediction step will output just the most probable logit or all of them.
            This argument has been deprecated and will be replaced with `output_on_inference`.
        tiled_inference_on_testing (bool): A boolean to define if tiled inference will be used during the test step.
        tiled_inference_on_validation (bool): A boolean to define if tiled inference will be used during the val step.
        tiled_inference_on_testing (bool): A boolean to define if tiled inference will be used when full inference
            fails during the test step.
        path_to_record_metrics (str): A path to save the file containing the metrics log.
    """

    self.tiled_inference_parameters = tiled_inference_parameters
    self.aux_loss = aux_loss
    self.aux_heads = aux_heads

    if model is not None and model_factory is not None:
        logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
    if model is None and model_factory is None:
        raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

    if model_factory and model is None:
        self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

    super().__init__(
        task="segmentation",
        tiled_inference_on_testing=tiled_inference_on_testing,
        tiled_inference_on_validation=tiled_inference_on_validation,
        path_to_record_metrics=path_to_record_metrics,
    )

    if model is not None:
        # Custom model
        self.model = model

    self.train_loss_handler = LossHandler(self.train_metrics.prefix)
    self.test_loss_handler: list[LossHandler] = []
    for metrics in self.test_metrics:
        self.test_loss_handler.append(LossHandler(metrics.prefix))
    self.val_loss_handler = LossHandler(self.val_metrics.prefix)
    self.monitor = f"{self.val_metrics.prefix}loss"
    self.plot_on_val = int(plot_on_val)
    self.output_on_inference = output_on_inference

    # When the user decides to use `output_most_probable` as `False` in
    # order to output the probabilities instead of the prediction.
    if not output_most_probable:
        warnings.warn(
            "The argument `output_most_probable` is deprecated and will be replaced with `output_on_inference='probabilities'`.",
            stacklevel=1,
        )
        output_on_inference = "probabilities"

    # Processing the `output_on_inference` argument.
    self.output_prediction = lambda y: (y.argmax(dim=1), "pred")
    self.output_logits = lambda y: (y, "logits")
    self.output_probabilities = lambda y: (torch.nn.Softmax()(y), "probabilities")

    # The possible methods to define outputs.
    self.operation_map = {
        "prediction": self.output_prediction,
        "logits": self.output_logits,
        "probabilities": self.output_probabilities,
    }

    # `output_on_inference` can be a list or a string.
    if isinstance(output_on_inference, list):
        list_of_selectors = ()
        for var in output_on_inference:
            if var in self.operation_map:
                list_of_selectors += (self.operation_map[var],)
            else:
                raise ValueError(
                    f"Option {var} is not supported. It must be in ['prediction', 'logits', 'probabilities']"
                )

        if not len(list_of_selectors):
            raise ValueError(
                "The list of selectors for the output is empty, please, provide a valid value for `output_on_inference`"
            )

        self.select_classes = lambda y: [op(y) for op in list_of_selectors]
    elif isinstance(output_on_inference, str):
        self.select_classes = self.operation_map[output_on_inference]

    else:
        raise ValueError(f"The value {output_on_inference} isn't supported for `output_on_inference`.")

configure_losses() #

Initialize the loss criterion.

Raises:

Type Description
ValueError

If loss is invalid.

Source code in terratorch/tasks/segmentation_tasks.py
def configure_losses(self) -> None:
    """Initialize the loss criterion.

    Raises:
        ValueError: If *loss* is invalid.
    """
    loss: str = self.hparams["loss"]
    ignore_index = self.hparams["ignore_index"]

    class_weights = (
        torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
    )
    if loss == "ce":
        ignore_value = -100 if ignore_index is None else ignore_index
        self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
    elif loss == "jaccard":
        if ignore_index is not None:
            exception_message = (
                f"Jaccard loss does not support ignore_index, but found non-None value of {ignore_index}."
            )
            raise RuntimeError(exception_message)
        self.criterion = smp.losses.JaccardLoss(mode="multiclass")
    elif loss == "focal":
        self.criterion = smp.losses.FocalLoss("multiclass", ignore_index=ignore_index, normalized=True)
    elif loss == "dice":
        self.criterion = smp.losses.DiceLoss("multiclass", ignore_index=ignore_index)
    else:
        exception_message = (
            f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss."
        )
        raise ValueError(exception_message)

configure_metrics() #

Initialize the performance metrics.

Source code in terratorch/tasks/segmentation_tasks.py
def configure_metrics(self) -> None:
    """Initialize the performance metrics."""
    num_classes: int = self.hparams["model_args"]["num_classes"]
    ignore_index: int = self.hparams["ignore_index"]
    class_names = self.hparams["class_names"]
    metrics = MetricCollection(
        {
            "mIoU": MulticlassJaccardIndex(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="macro",
            ),
            "mIoU_Micro": MulticlassJaccardIndex(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="micro",
            ),
            "F1_Score": MulticlassF1Score(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="macro",
            ),
            "Accuracy": MulticlassAccuracy(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="macro",
            ),
            "Pixel_Accuracy": MulticlassAccuracy(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="micro",
            ),
            "IoU": ClasswiseWrapper(
                MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None),
                labels=class_names,
                prefix="IoU_",
            ),
            "Class_Accuracy": ClasswiseWrapper(
                MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average=None,
                ),
                labels=class_names,
                prefix="Class_Accuracy_",
            ),
        }
    )
    self.train_metrics = metrics.clone(prefix="train/")
    self.val_metrics = metrics.clone(prefix="val/")
    if self.hparams["test_dataloaders_names"] is not None:
        self.test_metrics = nn.ModuleList(
            [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]]
        )
    else:
        self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")])

predict_step(batch, batch_idx, dataloader_idx=0) #

Compute the predicted class probabilities.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

Returns:

Type Description
Tensor

Output predicted probabilities.

Source code in terratorch/tasks/segmentation_tasks.py
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the predicted class probabilities.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.

    Returns:
        Output predicted probabilities.
    """
    x = batch["image"]
    file_names = batch["filename"] if "filename" in batch else None
    other_keys = batch.keys() - {"image", "mask", "filename"}

    rest = {k: batch[k] for k in other_keys}

    def model_forward(x, **kwargs):
        return self(x, **kwargs).output

    if self.tiled_inference_parameters:
        y_hat: Tensor = tiled_inference(
            model_forward,
            x,
            **self.tiled_inference_parameters,
            **rest,
        )
    else:
        y_hat: Tensor = self(x, **rest).output

    y_hat_ = self.select_classes(y_hat)

    return y_hat_, file_names

test_step(batch, batch_idx, dataloader_idx=0) #

Compute the test loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/segmentation_tasks.py
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the test loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = self.squeeze_ground_truth(batch["mask"])
    other_keys = batch.keys() - {"image", "mask", "filename"}

    rest = {k: batch[k] for k in other_keys}

    model_output = self.handle_full_or_tiled_inference(x, self.tiled_inference_on_testing, **rest)

    if dataloader_idx >= len(self.test_loss_handler):
        msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
        raise ValueError(msg)
    loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.test_loss_handler[dataloader_idx].log_loss(
        partial(self.log, add_dataloader_idx=False),  # We don't need the dataloader idx as prefixes are different
        loss_dict=loss,
        batch_size=y.shape[0],
    )
    y_hat_hard = to_segmentation_prediction(model_output)
    self.test_metrics[dataloader_idx].update(y_hat_hard, y)

    self.record_metrics(dataloader_idx, y_hat_hard, y)

training_step(batch, batch_idx, dataloader_idx=0) #

Compute the train loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/segmentation_tasks.py
def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the train loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    # Testing because of failures.
    x = batch["image"]
    y = self.squeeze_ground_truth(batch["mask"])
    other_keys = batch.keys() - {"image", "mask", "filename"}

    rest = {k: batch[k] for k in other_keys}
    model_output: ModelOutput = self(x, **rest)
    loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
    y_hat_hard = to_segmentation_prediction(model_output)
    self.train_metrics.update(y_hat_hard, y)

    return loss["loss"]

validation_step(batch, batch_idx, dataloader_idx=0) #

Compute the validation loss and additional metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader.

Source code in terratorch/tasks/segmentation_tasks.py
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the validation loss and additional metrics.
    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = self.squeeze_ground_truth(batch["mask"])

    other_keys = batch.keys() - {"image", "mask", "filename"}
    rest = {k: batch[k] for k in other_keys}
    # model_output: ModelOutput = self(x, **rest)
    model_output = self.handle_full_or_tiled_inference(x, self.tiled_inference_on_validation, **rest)

    loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
    y_hat_hard = to_segmentation_prediction(model_output)
    self.val_metrics.update(y_hat_hard, y)

    if self._do_plot_samples(batch_idx):
        try:
            datamodule = self.trainer.datamodule
            batch["prediction"] = y_hat_hard

            if isinstance(batch["image"], dict):
                rgb_modality = getattr(datamodule, "rgb_modality", None) or list(batch["image"].keys())[0]
                batch["image"] = batch["image"][rgb_modality]

            for key in ["image", "mask", "prediction"]:
                batch[key] = batch[key].cpu()
            sample = unbind_samples(batch)[0]
            fig = datamodule.val_dataset.plot(sample)
            if fig:
                summary_writer = self.logger.experiment
                if hasattr(summary_writer, "add_figure"):
                    summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
                elif hasattr(summary_writer, "log_figure"):
                    summary_writer.log_figure(
                        self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png"
                    )
        except ValueError:
            pass
        finally:
            plt.close()

terratorch.tasks.PixelwiseRegressionTask #

Bases: TerraTorchTask

Pixelwise Regression Task that accepts models from a range of sources.

This class is analog in functionality to PixelwiseRegressionTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - Allows to evaluate on multiple test dataloaders

Source code in terratorch/tasks/regression_tasks.py
class PixelwiseRegressionTask(TerraTorchTask):
    """Pixelwise Regression Task that accepts models from a range of sources.

    This class is analog in functionality to PixelwiseRegressionTask defined by torchgeo.
    However, it has some important differences:
        - Accepts the specification of a model factory
        - Logs metrics per class
        - Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
        - Allows the setting of optimizers in the constructor
        - Allows to evaluate on multiple test dataloaders"""

    def __init__(
        self,
        model_args: dict,
        model_factory: str | None = None,
        model: torch.nn.Module | None = None,
        loss: str = "mse",
        aux_heads: list[AuxiliaryHead] | None = None,
        aux_loss: dict[str, float] | None = None,
        class_weights: list[float] | None = None,
        ignore_index: int | None = None,
        lr: float = 0.001,
        # the following are optional so CLI doesnt need to pass them
        optimizer: str | None = None,
        optimizer_hparams: dict | None = None,
        scheduler: str | None = None,
        scheduler_hparams: dict | None = None,
        #
        freeze_backbone: bool = False,  # noqa: FBT001, FBT002
        freeze_decoder: bool = False,  # noqa: FBT001, FBT002
        freeze_head: bool = False,  # noqa: FBT001, FBT002
        plot_on_val: bool | int = 10,
        tiled_inference_parameters: dict | None = None,
        test_dataloaders_names: list[str] | None = None,
        lr_overrides: dict[str, float] | None = None,
        tiled_inference_on_testing: bool = False,
        tiled_inference_on_validation: bool = False,
        path_to_record_metrics: str = None,
    ) -> None:
        """Constructor

        Args:
            model_args (Dict): Arguments passed to the model factory.
            model_factory (str, optional): Name of ModelFactory class to be used to instantiate the model.
                Is ignored when model is provided.
            model (torch.nn.Module, optional): Custom model.
            loss (str, optional): Loss to be used. Currently, supports 'mse', 'rmse', 'mae' or 'huber' loss.
                Defaults to "mse".
            aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
                Should be a dictionary where the key is the name given to the loss
                and the value is the weight to be applied to that loss.
                The name of the loss should match the key in the dictionary output by the model's forward
                method containing that output. Defaults to None.
            class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
                Defaults to None.
            ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
            lr (float, optional): Learning rate to be used. Defaults to 0.001.
            optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
                If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
            optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
                Overriden by config / cli specification through LightningCLI.
            scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
                to be used (e.g. ReduceLROnPlateau). Defaults to None.
                Overriden by config / cli specification through LightningCLI.
            scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
                Overriden by config / cli specification through LightningCLI.
            freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
            freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
            freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False.
            plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
                If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
            tiled_inference_parameters (dict | None, optional): Inference parameters
                used to determine if inference is done on the whole image or through tiling.
            test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
                multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
                which assumes only one test dataloader is used.
            lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
                parameters. The key should be a substring of the parameter names (it will check the substring is
                contained in the parameter name)and the value should be the new lr. Defaults to None.
            tiled_inference_on_testing (bool): A boolean to define if tiled inference will be used during the test step. 
            tiled_inference_on_validation (bool): A boolean to define if tiled inference will be used during the val step. 
            path_to_record_metrics (str): A path to save the file containing the metrics log.
        """

        self.tiled_inference_parameters = tiled_inference_parameters
        self.aux_loss = aux_loss
        self.aux_heads = aux_heads

        if model is not None and model_factory is not None:
            logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
        if model is None and model_factory is None:
            raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

        if model_factory and model is None:
            self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

        super().__init__(
            task="regression", 
            tiled_inference_on_testing=tiled_inference_on_testing,
            tiled_inference_on_validation=tiled_inference_on_validation,
            path_to_record_metrics=path_to_record_metrics
            )

        if model:
            # Custom_model
            self.model = model

        self.train_loss_handler = LossHandler(self.train_metrics.prefix)
        self.test_loss_handler: list[LossHandler] = []
        for metrics in self.test_metrics:
            self.test_loss_handler.append(LossHandler(metrics.prefix))
        self.val_loss_handler = LossHandler(self.val_metrics.prefix)
        self.monitor = f"{self.val_metrics.prefix}loss"
        self.plot_on_val = int(plot_on_val)

    def configure_losses(self) -> None:
        """Initialize the loss criterion.

        Raises:
            ValueError: If *loss* is invalid.
        """
        loss: str = self.hparams["loss"].lower()
        if loss == "mse":
            self.criterion: nn.Module = IgnoreIndexLossWrapper(
                nn.MSELoss(reduction="none"), self.hparams["ignore_index"]
            )
        elif loss == "mae":
            self.criterion = IgnoreIndexLossWrapper(nn.L1Loss(reduction="none"), self.hparams["ignore_index"])
        elif loss == "rmse":
            # IMPORTANT! Root is done only after ignore index! Otherwise the mean taken is incorrect
            self.criterion = RootLossWrapper(
                IgnoreIndexLossWrapper(nn.MSELoss(reduction="none"), self.hparams["ignore_index"]), reduction=None
            )
        elif loss == "huber":
            self.criterion = IgnoreIndexLossWrapper(nn.HuberLoss(reduction="none"), self.hparams["ignore_index"])
        else:
            exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'mse', 'rmse' or 'mae' loss."
            raise ValueError(exception_message)

    def configure_metrics(self) -> None:
        """Initialize the performance metrics."""

        def instantiate_metrics():
            return {
                "RMSE": MeanSquaredError(squared=False),
                "MSE": MeanSquaredError(squared=True),
                "MAE": MeanAbsoluteError(),
                "R2_Score": R2Score(),
            }

        def wrap_metrics_with_ignore_index(metrics):
            return {
                name: IgnoreIndexMetricWrapper(metric, ignore_index=self.hparams["ignore_index"])
                for name, metric in metrics.items()
            }

        self.train_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="train/")
        self.val_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="val/")
        if self.hparams["test_dataloaders_names"] is not None:
            self.test_metrics = nn.ModuleList(
                [
                    MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix=f"test/{dl_name}/")
                    for dl_name in self.hparams["test_dataloaders_names"]
                ]
            )
        else:
            self.test_metrics = nn.ModuleList(
                [MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/")]
            )

    def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the train loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]
        other_keys = batch.keys() - {"image", "mask", "filename"}
        rest = {k: batch[k] for k in other_keys}
        model_output: ModelOutput = self(x, **rest)
        loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
        y_hat = model_output.output
        self.train_metrics.update(y_hat, y)

        return loss["loss"]

    def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the validation loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]
        other_keys = batch.keys() - {"image", "mask", "filename"}
        rest = {k: batch[k] for k in other_keys}
        #model_output: ModelOutput = self(x, **rest)
        model_output = self.handle_full_or_tiled_inference(x, self.tiled_inference_on_validation, **rest)

        loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
        y_hat = model_output.output
        self.val_metrics.update(y_hat, y)

        if self._do_plot_samples(batch_idx):
            try:
                datamodule = self.trainer.datamodule
                batch["prediction"] = y_hat
                if isinstance(batch["image"], dict):
                    rgb_modality = getattr(datamodule, 'rgb_modality', None) or list(batch["image"].keys())[0]
                    batch["image"] = batch["image"][rgb_modality]
                for key in ["image", "mask", "prediction"]:
                    batch[key] = batch[key].cpu()
                sample = unbind_samples(batch)[0]
                fig = datamodule.val_dataset.plot(sample)
                if fig:
                    summary_writer = self.logger.experiment
                    if hasattr(summary_writer, "add_figure"):
                        summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
                    elif hasattr(summary_writer, "log_figure"):
                        summary_writer.log_figure(
                            self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png"
                        )
            except ValueError:
                pass
            finally:
                plt.close()

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the test loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]
        other_keys = batch.keys() - {"image", "mask", "filename"}
        rest = {k: batch[k] for k in other_keys}

        model_output = self.handle_full_or_tiled_inference(x, self.tiled_inference_on_testing, **rest)

        if dataloader_idx >= len(self.test_loss_handler):
            msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
            raise ValueError(msg)
        loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.test_loss_handler[dataloader_idx].log_loss(
            partial(self.log, add_dataloader_idx=False),  # We don't need the dataloader idx as prefixes are different
            loss_dict=loss,
            batch_size=y.shape[0],
        )
        y_hat = model_output.output
        self.test_metrics[dataloader_idx].update(y_hat, y)

        self.record_metrics(dataloader_idx, y_hat, y)

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the predicted class probabilities.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.

        Returns:
            Output predicted probabilities.
        """
        x = batch["image"]
        file_names = batch["filename"] if "filename" in batch else None
        other_keys = batch.keys() - {"image", "mask", "filename"}
        rest = {k: batch[k] for k in other_keys}

        def model_forward(x, **kwargs):
            return self(x, **kwargs).output

        if self.tiled_inference_parameters:
            y_hat: Tensor = tiled_inference(model_forward, x, **self.tiled_inference_parameters, **rest)
        else:
            y_hat: Tensor = self(x, **rest).output
        return y_hat, file_names

__init__(model_args, model_factory=None, model=None, loss='mse', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, freeze_head=False, plot_on_val=10, tiled_inference_parameters=None, test_dataloaders_names=None, lr_overrides=None, tiled_inference_on_testing=False, tiled_inference_on_validation=False, path_to_record_metrics=None) #

Constructor

Parameters:

Name Type Description Default
model_args Dict

Arguments passed to the model factory.

required
model_factory str

Name of ModelFactory class to be used to instantiate the model. Is ignored when model is provided.

None
model Module

Custom model.

None
loss str

Loss to be used. Currently, supports 'mse', 'rmse', 'mae' or 'huber' loss. Defaults to "mse".

'mse'
aux_loss dict[str, float] | None

Auxiliary loss weights. Should be a dictionary where the key is the name given to the loss and the value is the weight to be applied to that loss. The name of the loss should match the key in the dictionary output by the model's forward method containing that output. Defaults to None.

None
class_weights list[float] | None

List of class weights to be applied to the loss. Defaults to None.

None
ignore_index int | None

Label to ignore in the loss computation. Defaults to None.

None
lr float

Learning rate to be used. Defaults to 0.001.

0.001
optimizer str | None

Name of optimizer class from torch.optim to be used. If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.

None
optimizer_hparams dict | None

Parameters to be passed for instantiation of the optimizer. Overriden by config / cli specification through LightningCLI.

None
scheduler str

Name of Torch scheduler class from torch.optim.lr_scheduler to be used (e.g. ReduceLROnPlateau). Defaults to None. Overriden by config / cli specification through LightningCLI.

None
scheduler_hparams dict | None

Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI.

None
freeze_backbone bool

Whether to freeze the backbone. Defaults to False.

False
freeze_decoder bool

Whether to freeze the decoder. Defaults to False.

False
freeze_head bool

Whether to freeze the segmentation head. Defaults to False.

False
plot_on_val bool | int

Whether to plot visualizations on validation. If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.

10
tiled_inference_parameters dict | None

Inference parameters used to determine if inference is done on the whole image or through tiling.

None
test_dataloaders_names list[str] | None

Names used to differentiate metrics when multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, which assumes only one test dataloader is used.

None
lr_overrides dict[str, float] | None

Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None.

None
tiled_inference_on_testing bool

A boolean to define if tiled inference will be used during the test step.

False
tiled_inference_on_validation bool

A boolean to define if tiled inference will be used during the val step.

False
path_to_record_metrics str

A path to save the file containing the metrics log.

None
Source code in terratorch/tasks/regression_tasks.py
def __init__(
    self,
    model_args: dict,
    model_factory: str | None = None,
    model: torch.nn.Module | None = None,
    loss: str = "mse",
    aux_heads: list[AuxiliaryHead] | None = None,
    aux_loss: dict[str, float] | None = None,
    class_weights: list[float] | None = None,
    ignore_index: int | None = None,
    lr: float = 0.001,
    # the following are optional so CLI doesnt need to pass them
    optimizer: str | None = None,
    optimizer_hparams: dict | None = None,
    scheduler: str | None = None,
    scheduler_hparams: dict | None = None,
    #
    freeze_backbone: bool = False,  # noqa: FBT001, FBT002
    freeze_decoder: bool = False,  # noqa: FBT001, FBT002
    freeze_head: bool = False,  # noqa: FBT001, FBT002
    plot_on_val: bool | int = 10,
    tiled_inference_parameters: dict | None = None,
    test_dataloaders_names: list[str] | None = None,
    lr_overrides: dict[str, float] | None = None,
    tiled_inference_on_testing: bool = False,
    tiled_inference_on_validation: bool = False,
    path_to_record_metrics: str = None,
) -> None:
    """Constructor

    Args:
        model_args (Dict): Arguments passed to the model factory.
        model_factory (str, optional): Name of ModelFactory class to be used to instantiate the model.
            Is ignored when model is provided.
        model (torch.nn.Module, optional): Custom model.
        loss (str, optional): Loss to be used. Currently, supports 'mse', 'rmse', 'mae' or 'huber' loss.
            Defaults to "mse".
        aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
            Should be a dictionary where the key is the name given to the loss
            and the value is the weight to be applied to that loss.
            The name of the loss should match the key in the dictionary output by the model's forward
            method containing that output. Defaults to None.
        class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
            Defaults to None.
        ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
        lr (float, optional): Learning rate to be used. Defaults to 0.001.
        optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
            If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
        optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
            Overriden by config / cli specification through LightningCLI.
        scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
            to be used (e.g. ReduceLROnPlateau). Defaults to None.
            Overriden by config / cli specification through LightningCLI.
        scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
            Overriden by config / cli specification through LightningCLI.
        freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
        freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
        freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False.
        plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
            If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
        tiled_inference_parameters (dict | None, optional): Inference parameters
            used to determine if inference is done on the whole image or through tiling.
        test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
            multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
            which assumes only one test dataloader is used.
        lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
            parameters. The key should be a substring of the parameter names (it will check the substring is
            contained in the parameter name)and the value should be the new lr. Defaults to None.
        tiled_inference_on_testing (bool): A boolean to define if tiled inference will be used during the test step. 
        tiled_inference_on_validation (bool): A boolean to define if tiled inference will be used during the val step. 
        path_to_record_metrics (str): A path to save the file containing the metrics log.
    """

    self.tiled_inference_parameters = tiled_inference_parameters
    self.aux_loss = aux_loss
    self.aux_heads = aux_heads

    if model is not None and model_factory is not None:
        logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
    if model is None and model_factory is None:
        raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

    if model_factory and model is None:
        self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

    super().__init__(
        task="regression", 
        tiled_inference_on_testing=tiled_inference_on_testing,
        tiled_inference_on_validation=tiled_inference_on_validation,
        path_to_record_metrics=path_to_record_metrics
        )

    if model:
        # Custom_model
        self.model = model

    self.train_loss_handler = LossHandler(self.train_metrics.prefix)
    self.test_loss_handler: list[LossHandler] = []
    for metrics in self.test_metrics:
        self.test_loss_handler.append(LossHandler(metrics.prefix))
    self.val_loss_handler = LossHandler(self.val_metrics.prefix)
    self.monitor = f"{self.val_metrics.prefix}loss"
    self.plot_on_val = int(plot_on_val)

configure_losses() #

Initialize the loss criterion.

Raises:

Type Description
ValueError

If loss is invalid.

Source code in terratorch/tasks/regression_tasks.py
def configure_losses(self) -> None:
    """Initialize the loss criterion.

    Raises:
        ValueError: If *loss* is invalid.
    """
    loss: str = self.hparams["loss"].lower()
    if loss == "mse":
        self.criterion: nn.Module = IgnoreIndexLossWrapper(
            nn.MSELoss(reduction="none"), self.hparams["ignore_index"]
        )
    elif loss == "mae":
        self.criterion = IgnoreIndexLossWrapper(nn.L1Loss(reduction="none"), self.hparams["ignore_index"])
    elif loss == "rmse":
        # IMPORTANT! Root is done only after ignore index! Otherwise the mean taken is incorrect
        self.criterion = RootLossWrapper(
            IgnoreIndexLossWrapper(nn.MSELoss(reduction="none"), self.hparams["ignore_index"]), reduction=None
        )
    elif loss == "huber":
        self.criterion = IgnoreIndexLossWrapper(nn.HuberLoss(reduction="none"), self.hparams["ignore_index"])
    else:
        exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'mse', 'rmse' or 'mae' loss."
        raise ValueError(exception_message)

configure_metrics() #

Initialize the performance metrics.

Source code in terratorch/tasks/regression_tasks.py
def configure_metrics(self) -> None:
    """Initialize the performance metrics."""

    def instantiate_metrics():
        return {
            "RMSE": MeanSquaredError(squared=False),
            "MSE": MeanSquaredError(squared=True),
            "MAE": MeanAbsoluteError(),
            "R2_Score": R2Score(),
        }

    def wrap_metrics_with_ignore_index(metrics):
        return {
            name: IgnoreIndexMetricWrapper(metric, ignore_index=self.hparams["ignore_index"])
            for name, metric in metrics.items()
        }

    self.train_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="train/")
    self.val_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="val/")
    if self.hparams["test_dataloaders_names"] is not None:
        self.test_metrics = nn.ModuleList(
            [
                MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix=f"test/{dl_name}/")
                for dl_name in self.hparams["test_dataloaders_names"]
            ]
        )
    else:
        self.test_metrics = nn.ModuleList(
            [MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/")]
        )

predict_step(batch, batch_idx, dataloader_idx=0) #

Compute the predicted class probabilities.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

Returns:

Type Description
Tensor

Output predicted probabilities.

Source code in terratorch/tasks/regression_tasks.py
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the predicted class probabilities.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.

    Returns:
        Output predicted probabilities.
    """
    x = batch["image"]
    file_names = batch["filename"] if "filename" in batch else None
    other_keys = batch.keys() - {"image", "mask", "filename"}
    rest = {k: batch[k] for k in other_keys}

    def model_forward(x, **kwargs):
        return self(x, **kwargs).output

    if self.tiled_inference_parameters:
        y_hat: Tensor = tiled_inference(model_forward, x, **self.tiled_inference_parameters, **rest)
    else:
        y_hat: Tensor = self(x, **rest).output
    return y_hat, file_names

test_step(batch, batch_idx, dataloader_idx=0) #

Compute the test loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/regression_tasks.py
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the test loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]
    other_keys = batch.keys() - {"image", "mask", "filename"}
    rest = {k: batch[k] for k in other_keys}

    model_output = self.handle_full_or_tiled_inference(x, self.tiled_inference_on_testing, **rest)

    if dataloader_idx >= len(self.test_loss_handler):
        msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
        raise ValueError(msg)
    loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.test_loss_handler[dataloader_idx].log_loss(
        partial(self.log, add_dataloader_idx=False),  # We don't need the dataloader idx as prefixes are different
        loss_dict=loss,
        batch_size=y.shape[0],
    )
    y_hat = model_output.output
    self.test_metrics[dataloader_idx].update(y_hat, y)

    self.record_metrics(dataloader_idx, y_hat, y)

training_step(batch, batch_idx, dataloader_idx=0) #

Compute the train loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/regression_tasks.py
def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the train loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]
    other_keys = batch.keys() - {"image", "mask", "filename"}
    rest = {k: batch[k] for k in other_keys}
    model_output: ModelOutput = self(x, **rest)
    loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
    y_hat = model_output.output
    self.train_metrics.update(y_hat, y)

    return loss["loss"]

validation_step(batch, batch_idx, dataloader_idx=0) #

Compute the validation loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/regression_tasks.py
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the validation loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]
    other_keys = batch.keys() - {"image", "mask", "filename"}
    rest = {k: batch[k] for k in other_keys}
    #model_output: ModelOutput = self(x, **rest)
    model_output = self.handle_full_or_tiled_inference(x, self.tiled_inference_on_validation, **rest)

    loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
    y_hat = model_output.output
    self.val_metrics.update(y_hat, y)

    if self._do_plot_samples(batch_idx):
        try:
            datamodule = self.trainer.datamodule
            batch["prediction"] = y_hat
            if isinstance(batch["image"], dict):
                rgb_modality = getattr(datamodule, 'rgb_modality', None) or list(batch["image"].keys())[0]
                batch["image"] = batch["image"][rgb_modality]
            for key in ["image", "mask", "prediction"]:
                batch[key] = batch[key].cpu()
            sample = unbind_samples(batch)[0]
            fig = datamodule.val_dataset.plot(sample)
            if fig:
                summary_writer = self.logger.experiment
                if hasattr(summary_writer, "add_figure"):
                    summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
                elif hasattr(summary_writer, "log_figure"):
                    summary_writer.log_figure(
                        self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png"
                    )
        except ValueError:
            pass
        finally:
            plt.close()

terratorch.tasks.ClassificationTask #

Bases: TerraTorchTask

Classification Task that accepts models from a range of sources.

This class is analog in functionality to the class ClassificationTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - It provides mIoU with both Micro and Macro averaging - Allows to evaluate on multiple test dataloaders

.. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect minority class accuracy. * 'Macro' averaging gives equal weight to each class, useful for balanced performance assessment across imbalanced classes.

Source code in terratorch/tasks/classification_tasks.py
class ClassificationTask(TerraTorchTask):
    """Classification Task that accepts models from a range of sources.

    This class is analog in functionality to the class ClassificationTask defined by torchgeo.
    However, it has some important differences:
        - Accepts the specification of a model factory
        - Logs metrics per class
        - Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
        - Allows the setting of optimizers in the constructor
        - It provides mIoU with both Micro and Macro averaging
        - Allows to evaluate on multiple test dataloaders

    .. note::
           * 'Micro' averaging suits overall performance evaluation but may not reflect
             minority class accuracy.
           * 'Macro' averaging gives equal weight to each class, useful
             for balanced performance assessment across imbalanced classes.
    """

    def __init__(
        self,
        model_args: dict,
        model_factory: str | None = None,
        model: torch.nn.Module | None = None,
        loss: str = "ce",
        aux_heads: list[AuxiliaryHead] | None = None,
        aux_loss: dict[str, float] | None = None,
        class_weights: list[float] | None = None,
        ignore_index: int | None = None,
        lr: float = 0.001,
        # the following are optional so CLI doesnt need to pass them
        optimizer: str | None = None,
        optimizer_hparams: dict | None = None,
        scheduler: str | None = None,
        scheduler_hparams: dict | None = None,
        #
        #
        freeze_backbone: bool = False,  # noqa: FBT001, FBT002
        freeze_decoder: bool = False,  # noqa: FBT002, FBT001
        freeze_head: bool = False,  # noqa: FBT002, FBT001
        class_names: list[str] | None = None,
        test_dataloaders_names: list[str] | None = None,
        lr_overrides: dict[str, float] | None = None,
        path_to_record_metrics: str = None,
    ) -> None:
        """Constructor

        Args:
            Defaults to None.
            model_args (Dict): Arguments passed to the model factory.
            model_factory (str, optional): ModelFactory class to be used to instantiate the model.
                Is ignored when model is provided.
            model (torch.nn.Module, optional): Custom model.
            loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
                Defaults to "ce".
            aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
                Should be a dictionary where the key is the name given to the loss
                and the value is the weight to be applied to that loss.
                The name of the loss should match the key in the dictionary output by the model's forward
                method containing that output. Defaults to None.
            class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
            class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
                Defaults to None.
            ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
            lr (float, optional): Learning rate to be used. Defaults to 0.001.
            optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
                If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
            optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
                Overriden by config / cli specification through LightningCLI.
            scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
                to be used (e.g. ReduceLROnPlateau). Defaults to None.
                Overriden by config / cli specification through LightningCLI.
            scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
                Overriden by config / cli specification through LightningCLI.
            freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
            freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
            freeze_head (bool, optional): Whether to freeze the segmentation_head. Defaults to False.
            class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
                Defaults to numeric ordering.
            test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
                multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
                which assumes only one test dataloader is used.
            lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
                parameters. The key should be a substring of the parameter names (it will check the substring is
                contained in the parameter name)and the value should be the new lr. Defaults to None.
            path_to_record_metrics (str): A path to save the file containing the metrics log. 
        """

        self.aux_loss = aux_loss
        self.aux_heads = aux_heads

        if model is not None and model_factory is not None:
            logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
        if model is None and model_factory is None:
            raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

        if model_factory and model is None:
            self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

        super().__init__(task="classification", path_to_record_metrics=path_to_record_metrics)

        if model:
            # Custom model
            self.model = model

        self.train_loss_handler = LossHandler(self.train_metrics.prefix)
        self.test_loss_handler: list[LossHandler] = []
        for metrics in self.test_metrics:
            self.test_loss_handler.append(LossHandler(metrics.prefix))
        self.val_loss_handler = LossHandler(self.val_metrics.prefix)
        self.monitor = f"{self.val_metrics.prefix}loss"

    def configure_losses(self) -> None:
        """Initialize the loss criterion.

        Raises:
            ValueError: If *loss* is invalid.
        """
        loss: str = self.hparams["loss"]
        ignore_index = self.hparams["ignore_index"]

        class_weights = (
            torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
        )
        if loss == "ce":
            ignore_value = -100 if ignore_index is None else ignore_index
            self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
        elif loss == "bce":
            self.criterion = nn.BCEWithLogitsLoss()
        elif loss == "jaccard":
            self.criterion = JaccardLoss(mode="multiclass")
        elif loss == "focal":
            self.criterion = FocalLoss(mode="multiclass", normalized=True)
        else:
            msg = f"Loss type '{loss}' is not valid."
            raise ValueError(msg)

    def configure_metrics(self) -> None:
        """Initialize the performance metrics."""
        num_classes: int = self.hparams["model_args"]["num_classes"]
        ignore_index: int = self.hparams["ignore_index"]
        class_names = self.hparams["class_names"]
        metrics = MetricCollection(
            {
                "Accuracy": MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="macro",
                ),
                "Accuracy_Micro": MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="micro",
                ),
                "F1_Score": MulticlassF1Score(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="macro",
                ),
                "Precision": MulticlassPrecision(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="macro",
                ),
                "Recall": MulticlassRecall(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="macro",
                ),
                "Class_Accuracy": ClasswiseWrapper(
                    MulticlassAccuracy(
                        num_classes=num_classes,
                        ignore_index=ignore_index,
                        average=None,
                    ),
                    labels=class_names,
                    prefix='Class_Accuracy_'
                ),
                "Class_F1": ClasswiseWrapper(
                    MulticlassF1Score(
                        num_classes=num_classes,
                        ignore_index=ignore_index,
                        average=None,
                    ),
                    labels=class_names,
                    prefix='Class_F1_'
                ),
            }
        )
        self.train_metrics = metrics.clone(prefix="train/")
        self.val_metrics = metrics.clone(prefix="val/")
        if self.hparams["test_dataloaders_names"] is not None:
            self.test_metrics = nn.ModuleList(
                [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]]
            )
        else:
            self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")])

    def training_step(self, batch: object, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the train loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["label"] 
        other_keys = batch.keys() - {"image", "label", "filename"}
        rest = {k: batch[k] for k in other_keys}
        model_output: ModelOutput = self(x, **rest)
        loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
        y_hat_hard = to_class_prediction(model_output)
        self.train_metrics.update(y_hat_hard, y)

        return loss["loss"]

    def validation_step(self, batch: object, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the validation loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["label"]
        other_keys = batch.keys() - {"image", "label", "filename"}
        rest = {k: batch[k] for k in other_keys}
        model_output: ModelOutput = self(x, **rest)
        loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
        y_hat_hard = to_class_prediction(model_output)
        self.val_metrics.update(y_hat_hard, y)

    def test_step(self, batch: object, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the test loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["label"]
        other_keys = batch.keys() - {"image", "label", "filename"}
        rest = {k: batch[k] for k in other_keys}
        model_output: ModelOutput = self(x, **rest)
        if dataloader_idx >= len(self.test_loss_handler):
            msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
            raise ValueError(msg)
        loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.test_loss_handler[dataloader_idx].log_loss(
            partial(self.log, add_dataloader_idx=False),  # We don't need the dataloader idx as prefixes are different
            loss_dict=loss,
            batch_size=y.shape[0],
        )
        y_hat_hard = to_class_prediction(model_output)
        self.test_metrics[dataloader_idx].update(y_hat_hard, y)

        self.record_metrics(dataloader_idx, y_hat_hard, y)

    def predict_step(self, batch: object, batch_idx: int, dataloader_idx: int = 0) -> Tensor:

        """Compute the predicted class probabilities.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.

        Returns:
            Output predicted probabilities.
        """
        x = batch["image"]
        file_names = batch["filename"] if "filename" in batch else None
        other_keys = batch.keys() - {"image", "label", "filename"}
        rest = {k: batch[k] for k in other_keys}
        y_hat = self(x, **rest).output
        y_hat = y_hat.argmax(dim=1)
        return y_hat, file_names

__init__(model_args, model_factory=None, model=None, loss='ce', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, freeze_head=False, class_names=None, test_dataloaders_names=None, lr_overrides=None, path_to_record_metrics=None) #

Constructor

Parameters:

Name Type Description Default
model_args Dict

Arguments passed to the model factory.

required
model_factory str

ModelFactory class to be used to instantiate the model. Is ignored when model is provided.

None
model Module

Custom model.

None
loss str

Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss. Defaults to "ce".

'ce'
aux_loss dict[str, float] | None

Auxiliary loss weights. Should be a dictionary where the key is the name given to the loss and the value is the weight to be applied to that loss. The name of the loss should match the key in the dictionary output by the model's forward method containing that output. Defaults to None.

None
class_weights Union[list[float], None]

List of class weights to be applied to the loss.

None
class_weights list[float] | None

List of class weights to be applied to the loss. Defaults to None.

None
ignore_index int | None

Label to ignore in the loss computation. Defaults to None.

None
lr float

Learning rate to be used. Defaults to 0.001.

0.001
optimizer str | None

Name of optimizer class from torch.optim to be used. If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.

None
optimizer_hparams dict | None

Parameters to be passed for instantiation of the optimizer. Overriden by config / cli specification through LightningCLI.

None
scheduler str

Name of Torch scheduler class from torch.optim.lr_scheduler to be used (e.g. ReduceLROnPlateau). Defaults to None. Overriden by config / cli specification through LightningCLI.

None
scheduler_hparams dict | None

Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI.

None
freeze_backbone bool

Whether to freeze the backbone. Defaults to False.

False
freeze_decoder bool

Whether to freeze the decoder. Defaults to False.

False
freeze_head bool

Whether to freeze the segmentation_head. Defaults to False.

False
class_names list[str] | None

List of class names passed to metrics for better naming. Defaults to numeric ordering.

None
test_dataloaders_names list[str] | None

Names used to differentiate metrics when multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, which assumes only one test dataloader is used.

None
lr_overrides dict[str, float] | None

Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None.

None
path_to_record_metrics str

A path to save the file containing the metrics log.

None
Source code in terratorch/tasks/classification_tasks.py
def __init__(
    self,
    model_args: dict,
    model_factory: str | None = None,
    model: torch.nn.Module | None = None,
    loss: str = "ce",
    aux_heads: list[AuxiliaryHead] | None = None,
    aux_loss: dict[str, float] | None = None,
    class_weights: list[float] | None = None,
    ignore_index: int | None = None,
    lr: float = 0.001,
    # the following are optional so CLI doesnt need to pass them
    optimizer: str | None = None,
    optimizer_hparams: dict | None = None,
    scheduler: str | None = None,
    scheduler_hparams: dict | None = None,
    #
    #
    freeze_backbone: bool = False,  # noqa: FBT001, FBT002
    freeze_decoder: bool = False,  # noqa: FBT002, FBT001
    freeze_head: bool = False,  # noqa: FBT002, FBT001
    class_names: list[str] | None = None,
    test_dataloaders_names: list[str] | None = None,
    lr_overrides: dict[str, float] | None = None,
    path_to_record_metrics: str = None,
) -> None:
    """Constructor

    Args:
        Defaults to None.
        model_args (Dict): Arguments passed to the model factory.
        model_factory (str, optional): ModelFactory class to be used to instantiate the model.
            Is ignored when model is provided.
        model (torch.nn.Module, optional): Custom model.
        loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
            Defaults to "ce".
        aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
            Should be a dictionary where the key is the name given to the loss
            and the value is the weight to be applied to that loss.
            The name of the loss should match the key in the dictionary output by the model's forward
            method containing that output. Defaults to None.
        class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
        class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
            Defaults to None.
        ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
        lr (float, optional): Learning rate to be used. Defaults to 0.001.
        optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
            If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
        optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
            Overriden by config / cli specification through LightningCLI.
        scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
            to be used (e.g. ReduceLROnPlateau). Defaults to None.
            Overriden by config / cli specification through LightningCLI.
        scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
            Overriden by config / cli specification through LightningCLI.
        freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
        freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
        freeze_head (bool, optional): Whether to freeze the segmentation_head. Defaults to False.
        class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
            Defaults to numeric ordering.
        test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
            multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
            which assumes only one test dataloader is used.
        lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
            parameters. The key should be a substring of the parameter names (it will check the substring is
            contained in the parameter name)and the value should be the new lr. Defaults to None.
        path_to_record_metrics (str): A path to save the file containing the metrics log. 
    """

    self.aux_loss = aux_loss
    self.aux_heads = aux_heads

    if model is not None and model_factory is not None:
        logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
    if model is None and model_factory is None:
        raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

    if model_factory and model is None:
        self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

    super().__init__(task="classification", path_to_record_metrics=path_to_record_metrics)

    if model:
        # Custom model
        self.model = model

    self.train_loss_handler = LossHandler(self.train_metrics.prefix)
    self.test_loss_handler: list[LossHandler] = []
    for metrics in self.test_metrics:
        self.test_loss_handler.append(LossHandler(metrics.prefix))
    self.val_loss_handler = LossHandler(self.val_metrics.prefix)
    self.monitor = f"{self.val_metrics.prefix}loss"

configure_losses() #

Initialize the loss criterion.

Raises:

Type Description
ValueError

If loss is invalid.

Source code in terratorch/tasks/classification_tasks.py
def configure_losses(self) -> None:
    """Initialize the loss criterion.

    Raises:
        ValueError: If *loss* is invalid.
    """
    loss: str = self.hparams["loss"]
    ignore_index = self.hparams["ignore_index"]

    class_weights = (
        torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
    )
    if loss == "ce":
        ignore_value = -100 if ignore_index is None else ignore_index
        self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
    elif loss == "bce":
        self.criterion = nn.BCEWithLogitsLoss()
    elif loss == "jaccard":
        self.criterion = JaccardLoss(mode="multiclass")
    elif loss == "focal":
        self.criterion = FocalLoss(mode="multiclass", normalized=True)
    else:
        msg = f"Loss type '{loss}' is not valid."
        raise ValueError(msg)

configure_metrics() #

Initialize the performance metrics.

Source code in terratorch/tasks/classification_tasks.py
def configure_metrics(self) -> None:
    """Initialize the performance metrics."""
    num_classes: int = self.hparams["model_args"]["num_classes"]
    ignore_index: int = self.hparams["ignore_index"]
    class_names = self.hparams["class_names"]
    metrics = MetricCollection(
        {
            "Accuracy": MulticlassAccuracy(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="macro",
            ),
            "Accuracy_Micro": MulticlassAccuracy(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="micro",
            ),
            "F1_Score": MulticlassF1Score(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="macro",
            ),
            "Precision": MulticlassPrecision(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="macro",
            ),
            "Recall": MulticlassRecall(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="macro",
            ),
            "Class_Accuracy": ClasswiseWrapper(
                MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average=None,
                ),
                labels=class_names,
                prefix='Class_Accuracy_'
            ),
            "Class_F1": ClasswiseWrapper(
                MulticlassF1Score(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average=None,
                ),
                labels=class_names,
                prefix='Class_F1_'
            ),
        }
    )
    self.train_metrics = metrics.clone(prefix="train/")
    self.val_metrics = metrics.clone(prefix="val/")
    if self.hparams["test_dataloaders_names"] is not None:
        self.test_metrics = nn.ModuleList(
            [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]]
        )
    else:
        self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")])

predict_step(batch, batch_idx, dataloader_idx=0) #

Compute the predicted class probabilities.

Parameters:

Name Type Description Default
batch object

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

Returns:

Type Description
Tensor

Output predicted probabilities.

Source code in terratorch/tasks/classification_tasks.py
def predict_step(self, batch: object, batch_idx: int, dataloader_idx: int = 0) -> Tensor:

    """Compute the predicted class probabilities.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.

    Returns:
        Output predicted probabilities.
    """
    x = batch["image"]
    file_names = batch["filename"] if "filename" in batch else None
    other_keys = batch.keys() - {"image", "label", "filename"}
    rest = {k: batch[k] for k in other_keys}
    y_hat = self(x, **rest).output
    y_hat = y_hat.argmax(dim=1)
    return y_hat, file_names

test_step(batch, batch_idx, dataloader_idx=0) #

Compute the test loss and additional metrics.

Parameters:

Name Type Description Default
batch object

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/classification_tasks.py
def test_step(self, batch: object, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the test loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["label"]
    other_keys = batch.keys() - {"image", "label", "filename"}
    rest = {k: batch[k] for k in other_keys}
    model_output: ModelOutput = self(x, **rest)
    if dataloader_idx >= len(self.test_loss_handler):
        msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
        raise ValueError(msg)
    loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.test_loss_handler[dataloader_idx].log_loss(
        partial(self.log, add_dataloader_idx=False),  # We don't need the dataloader idx as prefixes are different
        loss_dict=loss,
        batch_size=y.shape[0],
    )
    y_hat_hard = to_class_prediction(model_output)
    self.test_metrics[dataloader_idx].update(y_hat_hard, y)

    self.record_metrics(dataloader_idx, y_hat_hard, y)

training_step(batch, batch_idx, dataloader_idx=0) #

Compute the train loss and additional metrics.

Parameters:

Name Type Description Default
batch object

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/classification_tasks.py
def training_step(self, batch: object, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the train loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["label"] 
    other_keys = batch.keys() - {"image", "label", "filename"}
    rest = {k: batch[k] for k in other_keys}
    model_output: ModelOutput = self(x, **rest)
    loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
    y_hat_hard = to_class_prediction(model_output)
    self.train_metrics.update(y_hat_hard, y)

    return loss["loss"]

validation_step(batch, batch_idx, dataloader_idx=0) #

Compute the validation loss and additional metrics.

Parameters:

Name Type Description Default
batch object

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/classification_tasks.py
def validation_step(self, batch: object, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the validation loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["label"]
    other_keys = batch.keys() - {"image", "label", "filename"}
    rest = {k: batch[k] for k in other_keys}
    model_output: ModelOutput = self(x, **rest)
    loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
    y_hat_hard = to_class_prediction(model_output)
    self.val_metrics.update(y_hat_hard, y)

terratorch.tasks.MultiLabelClassificationTask #

Bases: ClassificationTask

Source code in terratorch/tasks/multilabel_classification_tasks.py
class MultiLabelClassificationTask(ClassificationTask):
    def configure_losses(self) -> None:
        if self.hparams["loss"] == "bce":
            self.criterion: nn.Module = nn.BCEWithLogitsLoss()
        elif self.hparams["loss"] == "balanced_bce":
            self.criterion = _balanced_binary_cross_entropy_with_logits
        else:
            super().configure_losses()

    def configure_metrics(self) -> None:
        """Initialize the performance metrics."""
        num_classes: int = self.hparams["model_args"]["num_classes"]
        ignore_index: int = self.hparams["ignore_index"]
        class_names = self.hparams["class_names"]
        metrics = MetricCollection(
            {
                "Multilabel_Accuracy": MultilabelAccuracy(
                    num_labels=num_classes, ignore_index=ignore_index, average="macro"
                ),
                "Multilabel_Accuracy_Micro": MultilabelAccuracy(
                    num_labels=num_classes, ignore_index=ignore_index, average="micro"
                ),
                "Multilabel_F1_Score": MultilabelF1Score(
                    num_labels=num_classes, ignore_index=ignore_index, average="macro"
                ),
                "Multilabel_Precision": MultilabelPrecision(
                    num_labels=num_classes,
                    ignore_index=ignore_index,
                    average="macro",
                ),
                "Multilabel_Recall": MultilabelRecall(
                    num_labels=num_classes,
                    ignore_index=ignore_index,
                    average="macro",
                ),
                "Multilabel_AUROC": MultilabelAUROC(
                    num_labels=num_classes,
                    ignore_index=ignore_index,
                    average="macro",
                ),
                "Class_Accuracy": ClasswiseWrapper(
                    MultilabelAccuracy(
                        num_labels=num_classes,
                        ignore_index=ignore_index,
                        average=None,
                    ),
                    labels=class_names,
                    prefix="Class_Accuracy_",
                ),
                "Class_F1": ClasswiseWrapper(
                    MultilabelF1Score(
                        num_labels=num_classes,
                        ignore_index=ignore_index,
                        average=None,
                    ),
                    labels=class_names,
                    prefix="Class_F1_",
                ),
            }
        )

        self.train_metrics = metrics.clone(prefix="train/")
        self.val_metrics = metrics.clone(prefix="val/")
        if self.hparams["test_dataloaders_names"] is not None:
            self.test_metrics = nn.ModuleList(
                [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]]
            )
        else:
            self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")])

    @staticmethod
    def to_multilabel_prediction(y: ModelOutput) -> Tensor:
        y_hat = y.output
        return torch.sigmoid(y_hat)

    def training_step(self, batch: object, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        x = batch["image"]
        y = batch["label"].to(torch.float32)
        other_keys = batch.keys() - {"image", "label", "filename"}
        rest = {k: batch[k] for k in other_keys}

        model_output: ModelOutput = self(x, **rest)
        loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
        y_hat = self.to_multilabel_prediction(model_output)
        self.train_metrics.update(y_hat, y.to(torch.int))

        return loss["loss"]

    def validation_step(self, batch: object, batch_idx: int, dataloader_idx: int = 0) -> None:
        x = batch["image"]
        y = batch["label"].to(torch.float32)
        other_keys = batch.keys() - {"image", "label", "filename"}
        rest = {k: batch[k] for k in other_keys}
        model_output: ModelOutput = self(x, **rest)
        loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
        y_hat = self.to_multilabel_prediction(model_output)
        self.val_metrics.update(y_hat, y.to(torch.int))

    def test_step(self, batch: object, batch_idx: int, dataloader_idx: int = 0) -> None:
        x = batch["image"]
        y = batch["label"].to(torch.float32)
        other_keys = batch.keys() - {"image", "label", "filename"}
        rest = {k: batch[k] for k in other_keys}
        model_output: ModelOutput = self(x, **rest)
        if dataloader_idx >= len(self.test_loss_handler):
            msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
            raise ValueError(msg)
        loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.test_loss_handler[dataloader_idx].log_loss(
            partial(self.log, add_dataloader_idx=False),  # We don't need the dataloader idx as prefixes are different
            loss_dict=loss,
            batch_size=y.shape[0],
        )
        y_hat = self.to_multilabel_prediction(model_output)
        self.test_metrics[dataloader_idx].update(y_hat, y.to(torch.int))

    def predict_step(self, batch: object, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the predicted class probabilities.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.

        Returns:
            Output predicted probabilities.
        """
        x = batch["image"]
        file_names = batch["filename"] if "filename" in batch else None
        other_keys = batch.keys() - {"image", "label", "filename"}
        rest = {k: batch[k] for k in other_keys}
        model_output = self(x, **rest)
        y_hat = self.to_multilabel_prediction(model_output)
        return y_hat, file_names

configure_metrics() #

Initialize the performance metrics.

Source code in terratorch/tasks/multilabel_classification_tasks.py
def configure_metrics(self) -> None:
    """Initialize the performance metrics."""
    num_classes: int = self.hparams["model_args"]["num_classes"]
    ignore_index: int = self.hparams["ignore_index"]
    class_names = self.hparams["class_names"]
    metrics = MetricCollection(
        {
            "Multilabel_Accuracy": MultilabelAccuracy(
                num_labels=num_classes, ignore_index=ignore_index, average="macro"
            ),
            "Multilabel_Accuracy_Micro": MultilabelAccuracy(
                num_labels=num_classes, ignore_index=ignore_index, average="micro"
            ),
            "Multilabel_F1_Score": MultilabelF1Score(
                num_labels=num_classes, ignore_index=ignore_index, average="macro"
            ),
            "Multilabel_Precision": MultilabelPrecision(
                num_labels=num_classes,
                ignore_index=ignore_index,
                average="macro",
            ),
            "Multilabel_Recall": MultilabelRecall(
                num_labels=num_classes,
                ignore_index=ignore_index,
                average="macro",
            ),
            "Multilabel_AUROC": MultilabelAUROC(
                num_labels=num_classes,
                ignore_index=ignore_index,
                average="macro",
            ),
            "Class_Accuracy": ClasswiseWrapper(
                MultilabelAccuracy(
                    num_labels=num_classes,
                    ignore_index=ignore_index,
                    average=None,
                ),
                labels=class_names,
                prefix="Class_Accuracy_",
            ),
            "Class_F1": ClasswiseWrapper(
                MultilabelF1Score(
                    num_labels=num_classes,
                    ignore_index=ignore_index,
                    average=None,
                ),
                labels=class_names,
                prefix="Class_F1_",
            ),
        }
    )

    self.train_metrics = metrics.clone(prefix="train/")
    self.val_metrics = metrics.clone(prefix="val/")
    if self.hparams["test_dataloaders_names"] is not None:
        self.test_metrics = nn.ModuleList(
            [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]]
        )
    else:
        self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")])

predict_step(batch, batch_idx, dataloader_idx=0) #

Compute the predicted class probabilities.

Parameters:

Name Type Description Default
batch object

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

Returns:

Type Description
Tensor

Output predicted probabilities.

Source code in terratorch/tasks/multilabel_classification_tasks.py
def predict_step(self, batch: object, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the predicted class probabilities.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.

    Returns:
        Output predicted probabilities.
    """
    x = batch["image"]
    file_names = batch["filename"] if "filename" in batch else None
    other_keys = batch.keys() - {"image", "label", "filename"}
    rest = {k: batch[k] for k in other_keys}
    model_output = self(x, **rest)
    y_hat = self.to_multilabel_prediction(model_output)
    return y_hat, file_names

terratorch.tasks.ObjectDetectionTask #

Bases: BaseTask

Source code in terratorch/tasks/object_detection_task.py
class ObjectDetectionTask(BaseTask):

    ignore = None
    monitor = 'val_map'
    mode = 'max'

    def __init__(
        self,
        model_factory: str,
        model_args: dict,

        lr: float = 0.001,

        optimizer: str | None = None,
        optimizer_hparams: dict | None = None,
        scheduler: str | None = None,
        scheduler_hparams: dict | None = None,

        freeze_backbone: bool = False,
        freeze_decoder: bool = False,
        class_names: list[str] | None = None,

        iou_threshold: float = 0.5,
        score_threshold: float = 0.5,

    ) -> None:

        """
        Initialize a new ObjectDetectionTask instance.

        Args:
            model_factory (str): Name of the model factory to use.
            model_args (dict): Arguments for the model factory.
            lr (float, optional): Learning rate for optimizer. Defaults to 0.001.
            optimizer (str | None, optional): Name of the optimizer to use. Defaults to None.
            optimizer_hparams (dict | None, optional): Hyperparameters for the optimizer. Defaults to None.
            scheduler (str | None, optional): Name of the scheduler to use. Defaults to None.
            scheduler_hparams (dict | None, optional): Hyperparameters for the scheduler. Defaults to None.
            freeze_backbone (bool, optional): Freeze the backbone network to fine-tune the detection head. Defaults to False.
            freeze_decoder (bool, optional): Freeze the decoder network to fine-tune the detection head. Defaults to False.
            class_names (list[str] | None, optional): List of class names. Defaults to None.
            iou_threshold (float, optional): Intersection over union threshold for evaluation. Defaults to 0.5.
            score_threshold (float, optional): Score threshold for evaluation. Defaults to 0.5.

        Returns:
            None
        """
        warnings.warn("The Object Detection Task has to be considered experimental. This is less mature than the other tasks and being further improved.")

        self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)
        self.framework = model_args['framework']
        self.monitor = 'val_segm_map' if self.framework == 'mask-rcnn' else self.monitor

        super().__init__()
        self.train_loss_handler = LossHandler(self.train_metrics.prefix)
        self.test_loss_handler = LossHandler(self.test_metrics.prefix)
        self.val_loss_handler = LossHandler(self.val_metrics.prefix)
        self.iou_threshold = iou_threshold
        self.score_threshold = score_threshold
        self.lr = lr
        if optimizer_hparams is not None:
            if "lr" in self.hparams["optimizer_hparams"].keys():
                self.lr = float(self.hparams["optimizer_hparams"]["lr"])
                del self.hparams["optimizer_hparams"]["lr"]



    def configure_models(self) -> None:
        """
        It instantiates the model and freezes/unfreezes the backbone and decoder networks.
        """

        self.model: Model = self.model_factory.build_model(
            "object_detection", **self.hparams["model_args"]
        )
        if self.hparams["freeze_backbone"]:
            self.model.freeze_encoder()
        if self.hparams["freeze_decoder"]:
            self.model.freeze_decoder()

    def configure_metrics(self) -> None:
        """
        Configure metrics for the task.
        """
        if self.framework == 'mask-rcnn':
            metrics = MetricCollection({
                "mAP": MeanAveragePrecision(
                    iou_type=('bbox', 'segm'),
                    average='macro'
                )
            })
        else:
            metrics = MetricCollection({
                "mAP": MeanAveragePrecision(
                    iou_type=('bbox'),
                    average='macro'
                )
            })

        self.train_metrics = metrics.clone(prefix='train_')
        self.val_metrics = metrics.clone(prefix='val_')
        self.test_metrics = metrics.clone(prefix='test_')

    def configure_optimizers(
        self,
    ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
        """
        Configure optimiser for the task.
        """
        optimizer = self.hparams["optimizer"]
        if optimizer is None:
            optimizer = "Adam"
        return optimizer_factory(
            optimizer,
            self.lr,
            self.parameters(),
            self.hparams["optimizer_hparams"],
            self.hparams["scheduler"],
            self.monitor,
            self.hparams["scheduler_hparams"],
        )

    def reformat_batch(self, batch: Any, batch_size: int):
        """
        Reformat batch to calculate loss and metrics.

        Args:
            batch: The output of your DataLoader.
            batch_size: Size of your batch
        Returns:
            Reformated batch
        """

        if 'masks' in batch.keys():
            y = [
                {'boxes': batch['boxes'][i], 'labels': batch['labels'][i], 'masks': torch.cat([x[None].to(torch.uint8) for x in batch['masks'][i]])}
                for i in range(batch_size)
            ]
        else:

            y = [
                {'boxes': batch['boxes'][i], 'labels': batch['labels'][i]}
                for i in range(batch_size)
            ]

        return y

    def apply_nms_sample(self, y_hat, iou_threshold=0.5, score_threshold=0.5):
        """
        It applies nms to a sample predictions of the model.

        Args:
            y_hat: Predictions dictionary.
            iou_threshold: IoU threshold for evaluation.
            score_threshold: Score threshold for evaluation.
        Returns:
            fintered predictions for a sample after applying nms batch
        """

        boxes, scores, labels = y_hat['boxes'], y_hat['scores'], y_hat['labels']
        masks = y_hat['masks'] if "masks" in y_hat.keys() else None

        # Filter based on score threshold
        keep_score = scores > score_threshold
        boxes, scores, labels = boxes[keep_score], scores[keep_score], labels[keep_score]
        if masks is not None:
            masks = masks[keep_score]

        # Apply NMS
        keep_nms = nms(boxes, scores, iou_threshold)

        y_hat['boxes'], y_hat['scores'], y_hat['labels'] = boxes[keep_nms], scores[keep_nms], labels[keep_nms]

        if masks is not None:
            y_hat['masks'] = masks[keep_nms]

        return y_hat

    def apply_nms_batch(self, y_hat: Any, batch_size: int):
        """
        It applies nms to a batch predictions of the model.

        Args:
            y_hat: List of predictions dictionaries.
            iou_threshold: IoU threshold for evaluation.
            score_threshold: Score threshold for evaluation.
        Returns:
            fintered predictions for a batch after applying nms batch
        """

        for i in range(batch_size):
            y_hat[i] = self.apply_nms_sample(y_hat[i], iou_threshold=self.iou_threshold, score_threshold=self.score_threshold)

        return y_hat

    def training_step(
        self, batch: Any, batch_idx: int, dataloader_idx: int = 0
    ) -> Tensor:
        """
        Compute the training loss.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.

        Returns:
            The loss dictionary.
        """

        x = batch['image']
        batch_size = get_batch_size(x)
        y = self.reformat_batch(batch, batch_size)
        loss_dict = self(x, y)
        if isinstance(loss_dict, dict) is False:
            loss_dict = loss_dict.output
        train_loss: Tensor = sum(loss_dict.values())
        self.log_dict(loss_dict, batch_size=batch_size)
        self.log("train_loss", train_loss)
        return train_loss

    def validation_step(
        self, batch: Any, batch_idx: int, dataloader_idx: int = 0
    ) -> None:
        """
        Compute the validation metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """

        x = batch['image']
        batch_size = get_batch_size(x)
        y = self.reformat_batch(batch, batch_size)
        y_hat = self(x)
        if isinstance(y_hat, dict) is False:
            y_hat = y_hat.output

        y_hat = self.apply_nms_batch(y_hat, batch_size)

        if self.framework == 'mask-rcnn':

            for i in range(len(y_hat)):
                if y_hat[i]['masks'].shape[0] > 0:

                    y_hat[i]['masks']= (y_hat[i]['masks'] > 0.5).squeeze(1).to(torch.uint8)

        metrics = self.val_metrics(y_hat, y) 

        # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714
        metrics.pop('val_classes', None)


        self.log_dict(metrics, batch_size=batch_size)

        if (
            batch_idx < 10
            and hasattr(self.trainer, 'datamodule')
            and hasattr(self.trainer.datamodule, 'plot')
            and self.logger
            and hasattr(self.logger, 'experiment')
            and hasattr(self.logger.experiment, 'add_figure')
        ):

            dataset = self.trainer.datamodule.val_dataset
            batch['prediction_boxes'] = [b['boxes'].cpu() for b in y_hat]
            batch['prediction_labels'] = [b['labels'].cpu() for b in y_hat]
            batch['prediction_scores'] = [b['scores'].cpu() for b in y_hat]

            if "masks" in y_hat[0].keys():
                batch['prediction_masks'] = [b['masks'].cpu() for b in y_hat]
                if self.framework == 'mask-rcnn':
                    batch['prediction_masks'] = [b.unsqueeze(1) for b in batch['prediction_masks']]

            batch['image'] = batch['image'].cpu()
            sample = unbind_samples(batch)[0]
            fig: Figure | None = None
            try:
                fig = dataset.plot(sample)
            except RGBBandsMissingError:
                pass

            if fig:
                summary_writer = self.logger.experiment
                summary_writer.add_figure(
                    f'image/{batch_idx}', fig, global_step=self.global_step
                )
                plt.close()

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """
        Compute the test metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """

        x = batch['image']
        batch_size = get_batch_size(x)
        y = self.reformat_batch(batch, batch_size)
        y_hat = self(x)
        if isinstance(y_hat, dict) is False:
            y_hat = y_hat.output

        y_hat = self.apply_nms_batch(y_hat, batch_size)

        if self.framework == 'mask-rcnn':

            for i in range(len(y_hat)):
                if y_hat[i]['masks'].shape[0] > 0:
                    y_hat[i]['masks']= (y_hat[i]['masks'] > 0.5).squeeze(1).to(torch.uint8)


        metrics = self.test_metrics(y_hat, y)

        # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714
        metrics.pop('test_classes', None)

        self.log_dict(metrics, batch_size=batch_size)

    def predict_step(
        self, batch: Any, batch_idx: int, dataloader_idx: int = 0
    ) -> list[dict[str, Tensor]]:
        """
        Output predicted bounding boxes, classes and masks.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.

        Returns:
            Output predicted bounding boxes, classes and masks.
        """
        x = batch['image']
        batch_size = get_batch_size(x)
        y_hat: list[dict[str, Tensor]] = self(x)
        if isinstance(y_hat, dict) is False:
            y_hat = y_hat.output

        y_hat = self.apply_nms_batch(y_hat, batch_size)

        return y_hat

__init__(model_factory, model_args, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, class_names=None, iou_threshold=0.5, score_threshold=0.5) #

Initialize a new ObjectDetectionTask instance.

Parameters:

Name Type Description Default
model_factory str

Name of the model factory to use.

required
model_args dict

Arguments for the model factory.

required
lr float

Learning rate for optimizer. Defaults to 0.001.

0.001
optimizer str | None

Name of the optimizer to use. Defaults to None.

None
optimizer_hparams dict | None

Hyperparameters for the optimizer. Defaults to None.

None
scheduler str | None

Name of the scheduler to use. Defaults to None.

None
scheduler_hparams dict | None

Hyperparameters for the scheduler. Defaults to None.

None
freeze_backbone bool

Freeze the backbone network to fine-tune the detection head. Defaults to False.

False
freeze_decoder bool

Freeze the decoder network to fine-tune the detection head. Defaults to False.

False
class_names list[str] | None

List of class names. Defaults to None.

None
iou_threshold float

Intersection over union threshold for evaluation. Defaults to 0.5.

0.5
score_threshold float

Score threshold for evaluation. Defaults to 0.5.

0.5

Returns:

Type Description
None

None

Source code in terratorch/tasks/object_detection_task.py
def __init__(
    self,
    model_factory: str,
    model_args: dict,

    lr: float = 0.001,

    optimizer: str | None = None,
    optimizer_hparams: dict | None = None,
    scheduler: str | None = None,
    scheduler_hparams: dict | None = None,

    freeze_backbone: bool = False,
    freeze_decoder: bool = False,
    class_names: list[str] | None = None,

    iou_threshold: float = 0.5,
    score_threshold: float = 0.5,

) -> None:

    """
    Initialize a new ObjectDetectionTask instance.

    Args:
        model_factory (str): Name of the model factory to use.
        model_args (dict): Arguments for the model factory.
        lr (float, optional): Learning rate for optimizer. Defaults to 0.001.
        optimizer (str | None, optional): Name of the optimizer to use. Defaults to None.
        optimizer_hparams (dict | None, optional): Hyperparameters for the optimizer. Defaults to None.
        scheduler (str | None, optional): Name of the scheduler to use. Defaults to None.
        scheduler_hparams (dict | None, optional): Hyperparameters for the scheduler. Defaults to None.
        freeze_backbone (bool, optional): Freeze the backbone network to fine-tune the detection head. Defaults to False.
        freeze_decoder (bool, optional): Freeze the decoder network to fine-tune the detection head. Defaults to False.
        class_names (list[str] | None, optional): List of class names. Defaults to None.
        iou_threshold (float, optional): Intersection over union threshold for evaluation. Defaults to 0.5.
        score_threshold (float, optional): Score threshold for evaluation. Defaults to 0.5.

    Returns:
        None
    """
    warnings.warn("The Object Detection Task has to be considered experimental. This is less mature than the other tasks and being further improved.")

    self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)
    self.framework = model_args['framework']
    self.monitor = 'val_segm_map' if self.framework == 'mask-rcnn' else self.monitor

    super().__init__()
    self.train_loss_handler = LossHandler(self.train_metrics.prefix)
    self.test_loss_handler = LossHandler(self.test_metrics.prefix)
    self.val_loss_handler = LossHandler(self.val_metrics.prefix)
    self.iou_threshold = iou_threshold
    self.score_threshold = score_threshold
    self.lr = lr
    if optimizer_hparams is not None:
        if "lr" in self.hparams["optimizer_hparams"].keys():
            self.lr = float(self.hparams["optimizer_hparams"]["lr"])
            del self.hparams["optimizer_hparams"]["lr"]

apply_nms_batch(y_hat, batch_size) #

It applies nms to a batch predictions of the model.

Parameters:

Name Type Description Default
y_hat Any

List of predictions dictionaries.

required
iou_threshold

IoU threshold for evaluation.

required
score_threshold

Score threshold for evaluation.

required

Returns: fintered predictions for a batch after applying nms batch

Source code in terratorch/tasks/object_detection_task.py
def apply_nms_batch(self, y_hat: Any, batch_size: int):
    """
    It applies nms to a batch predictions of the model.

    Args:
        y_hat: List of predictions dictionaries.
        iou_threshold: IoU threshold for evaluation.
        score_threshold: Score threshold for evaluation.
    Returns:
        fintered predictions for a batch after applying nms batch
    """

    for i in range(batch_size):
        y_hat[i] = self.apply_nms_sample(y_hat[i], iou_threshold=self.iou_threshold, score_threshold=self.score_threshold)

    return y_hat

apply_nms_sample(y_hat, iou_threshold=0.5, score_threshold=0.5) #

It applies nms to a sample predictions of the model.

Parameters:

Name Type Description Default
y_hat

Predictions dictionary.

required
iou_threshold

IoU threshold for evaluation.

0.5
score_threshold

Score threshold for evaluation.

0.5

Returns: fintered predictions for a sample after applying nms batch

Source code in terratorch/tasks/object_detection_task.py
def apply_nms_sample(self, y_hat, iou_threshold=0.5, score_threshold=0.5):
    """
    It applies nms to a sample predictions of the model.

    Args:
        y_hat: Predictions dictionary.
        iou_threshold: IoU threshold for evaluation.
        score_threshold: Score threshold for evaluation.
    Returns:
        fintered predictions for a sample after applying nms batch
    """

    boxes, scores, labels = y_hat['boxes'], y_hat['scores'], y_hat['labels']
    masks = y_hat['masks'] if "masks" in y_hat.keys() else None

    # Filter based on score threshold
    keep_score = scores > score_threshold
    boxes, scores, labels = boxes[keep_score], scores[keep_score], labels[keep_score]
    if masks is not None:
        masks = masks[keep_score]

    # Apply NMS
    keep_nms = nms(boxes, scores, iou_threshold)

    y_hat['boxes'], y_hat['scores'], y_hat['labels'] = boxes[keep_nms], scores[keep_nms], labels[keep_nms]

    if masks is not None:
        y_hat['masks'] = masks[keep_nms]

    return y_hat

configure_metrics() #

Configure metrics for the task.

Source code in terratorch/tasks/object_detection_task.py
def configure_metrics(self) -> None:
    """
    Configure metrics for the task.
    """
    if self.framework == 'mask-rcnn':
        metrics = MetricCollection({
            "mAP": MeanAveragePrecision(
                iou_type=('bbox', 'segm'),
                average='macro'
            )
        })
    else:
        metrics = MetricCollection({
            "mAP": MeanAveragePrecision(
                iou_type=('bbox'),
                average='macro'
            )
        })

    self.train_metrics = metrics.clone(prefix='train_')
    self.val_metrics = metrics.clone(prefix='val_')
    self.test_metrics = metrics.clone(prefix='test_')

configure_models() #

It instantiates the model and freezes/unfreezes the backbone and decoder networks.

Source code in terratorch/tasks/object_detection_task.py
def configure_models(self) -> None:
    """
    It instantiates the model and freezes/unfreezes the backbone and decoder networks.
    """

    self.model: Model = self.model_factory.build_model(
        "object_detection", **self.hparams["model_args"]
    )
    if self.hparams["freeze_backbone"]:
        self.model.freeze_encoder()
    if self.hparams["freeze_decoder"]:
        self.model.freeze_decoder()

configure_optimizers() #

Configure optimiser for the task.

Source code in terratorch/tasks/object_detection_task.py
def configure_optimizers(
    self,
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
    """
    Configure optimiser for the task.
    """
    optimizer = self.hparams["optimizer"]
    if optimizer is None:
        optimizer = "Adam"
    return optimizer_factory(
        optimizer,
        self.lr,
        self.parameters(),
        self.hparams["optimizer_hparams"],
        self.hparams["scheduler"],
        self.monitor,
        self.hparams["scheduler_hparams"],
    )

predict_step(batch, batch_idx, dataloader_idx=0) #

Output predicted bounding boxes, classes and masks.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

Returns:

Type Description
list[dict[str, Tensor]]

Output predicted bounding boxes, classes and masks.

Source code in terratorch/tasks/object_detection_task.py
def predict_step(
    self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> list[dict[str, Tensor]]:
    """
    Output predicted bounding boxes, classes and masks.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.

    Returns:
        Output predicted bounding boxes, classes and masks.
    """
    x = batch['image']
    batch_size = get_batch_size(x)
    y_hat: list[dict[str, Tensor]] = self(x)
    if isinstance(y_hat, dict) is False:
        y_hat = y_hat.output

    y_hat = self.apply_nms_batch(y_hat, batch_size)

    return y_hat

reformat_batch(batch, batch_size) #

Reformat batch to calculate loss and metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_size int

Size of your batch

required

Returns: Reformated batch

Source code in terratorch/tasks/object_detection_task.py
def reformat_batch(self, batch: Any, batch_size: int):
    """
    Reformat batch to calculate loss and metrics.

    Args:
        batch: The output of your DataLoader.
        batch_size: Size of your batch
    Returns:
        Reformated batch
    """

    if 'masks' in batch.keys():
        y = [
            {'boxes': batch['boxes'][i], 'labels': batch['labels'][i], 'masks': torch.cat([x[None].to(torch.uint8) for x in batch['masks'][i]])}
            for i in range(batch_size)
        ]
    else:

        y = [
            {'boxes': batch['boxes'][i], 'labels': batch['labels'][i]}
            for i in range(batch_size)
        ]

    return y

test_step(batch, batch_idx, dataloader_idx=0) #

Compute the test metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/object_detection_task.py
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """
    Compute the test metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """

    x = batch['image']
    batch_size = get_batch_size(x)
    y = self.reformat_batch(batch, batch_size)
    y_hat = self(x)
    if isinstance(y_hat, dict) is False:
        y_hat = y_hat.output

    y_hat = self.apply_nms_batch(y_hat, batch_size)

    if self.framework == 'mask-rcnn':

        for i in range(len(y_hat)):
            if y_hat[i]['masks'].shape[0] > 0:
                y_hat[i]['masks']= (y_hat[i]['masks'] > 0.5).squeeze(1).to(torch.uint8)


    metrics = self.test_metrics(y_hat, y)

    # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714
    metrics.pop('test_classes', None)

    self.log_dict(metrics, batch_size=batch_size)

training_step(batch, batch_idx, dataloader_idx=0) #

Compute the training loss.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

Returns:

Type Description
Tensor

The loss dictionary.

Source code in terratorch/tasks/object_detection_task.py
def training_step(
    self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
    """
    Compute the training loss.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.

    Returns:
        The loss dictionary.
    """

    x = batch['image']
    batch_size = get_batch_size(x)
    y = self.reformat_batch(batch, batch_size)
    loss_dict = self(x, y)
    if isinstance(loss_dict, dict) is False:
        loss_dict = loss_dict.output
    train_loss: Tensor = sum(loss_dict.values())
    self.log_dict(loss_dict, batch_size=batch_size)
    self.log("train_loss", train_loss)
    return train_loss

validation_step(batch, batch_idx, dataloader_idx=0) #

Compute the validation metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/object_detection_task.py
def validation_step(
    self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
    """
    Compute the validation metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """

    x = batch['image']
    batch_size = get_batch_size(x)
    y = self.reformat_batch(batch, batch_size)
    y_hat = self(x)
    if isinstance(y_hat, dict) is False:
        y_hat = y_hat.output

    y_hat = self.apply_nms_batch(y_hat, batch_size)

    if self.framework == 'mask-rcnn':

        for i in range(len(y_hat)):
            if y_hat[i]['masks'].shape[0] > 0:

                y_hat[i]['masks']= (y_hat[i]['masks'] > 0.5).squeeze(1).to(torch.uint8)

    metrics = self.val_metrics(y_hat, y) 

    # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714
    metrics.pop('val_classes', None)


    self.log_dict(metrics, batch_size=batch_size)

    if (
        batch_idx < 10
        and hasattr(self.trainer, 'datamodule')
        and hasattr(self.trainer.datamodule, 'plot')
        and self.logger
        and hasattr(self.logger, 'experiment')
        and hasattr(self.logger.experiment, 'add_figure')
    ):

        dataset = self.trainer.datamodule.val_dataset
        batch['prediction_boxes'] = [b['boxes'].cpu() for b in y_hat]
        batch['prediction_labels'] = [b['labels'].cpu() for b in y_hat]
        batch['prediction_scores'] = [b['scores'].cpu() for b in y_hat]

        if "masks" in y_hat[0].keys():
            batch['prediction_masks'] = [b['masks'].cpu() for b in y_hat]
            if self.framework == 'mask-rcnn':
                batch['prediction_masks'] = [b.unsqueeze(1) for b in batch['prediction_masks']]

        batch['image'] = batch['image'].cpu()
        sample = unbind_samples(batch)[0]
        fig: Figure | None = None
        try:
            fig = dataset.plot(sample)
        except RGBBandsMissingError:
            pass

        if fig:
            summary_writer = self.logger.experiment
            summary_writer.add_figure(
                f'image/{batch_idx}', fig, global_step=self.global_step
            )
            plt.close()