Skip to content

Tasks

Tasks provide a convenient abstraction over the training of a model for a specific downstream task.

They encapsulate the model, optimizer, metrics, loss as well as training, validation and testing steps.

The task expects to be passed a model factory, to which the model_args arguments are passed. The models produced by this model factory should output ModelOutput instances and conform to the Model ABC. This is an easy way to extend these tasks to models other than Pritvhi ones produced by the PrithviModelFactory.

Tasks are best leveraged using config files. Check out some examples here.

Argument parsing in configs

Argument parsing of configs relies on argument names and type hints in the code. To pass arguments that do not conform to this (e.g. for classes that make use of **kwargs) put those arguments in dict_args instead of init_args.


Multi Temporal Inputs

Multi temporal inputs are also supported! However, we leverage albumentations for augmentations, and it does not support multitemporal input. We currently get around this using the following strategy in the transform:

train_transform:
      - class_path: FlattenTemporalIntoChannels
      # your transforms here, wrapped by these other ones
      # e.g. a random flip
      - class_path: albumentations.Flip
      # end of your transforms
      - class_path: ToTensorV2
      - class_path: UnflattenTemporalFromChannels
        init_args:
          n_timesteps: 3 # your number of timesteps here
          # alternatively, n_channels can be specified
See an example of this here.


terratorch.tasks.regression_tasks.PixelwiseRegressionTask

Bases: BaseTask

Pixelwise Regression Task that accepts models from a range of sources.

This class is analog in functionality to [PixelwiseRegressionTask] (https://torchgeo.readthedocs.io/en/stable/api/trainers.html#torchgeo.trainers.PixelwiseRegressionTask) defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor

Source code in terratorch/tasks/regression_tasks.py
class PixelwiseRegressionTask(BaseTask):
    """Pixelwise Regression Task that accepts models from a range of sources.

    This class is analog in functionality to
    [PixelwiseRegressionTask]
    (https://torchgeo.readthedocs.io/en/stable/api/trainers.html#torchgeo.trainers.PixelwiseRegressionTask)
    defined by torchgeo.
    However, it has some important differences:
        - Accepts the specification of a model factory
        - Logs metrics per class
        - Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
        - Allows the setting of optimizers in the constructor"""

    def __init__(
        self,
        model_args: dict,
        model_factory: str,
        loss: str = "mse",
        aux_heads: list[AuxiliaryHead] | None = None,
        aux_loss: dict[str, float] | None = None,
        class_weights: list[float] | None = None,
        ignore_index: int | None = None,
        lr: float = 0.001,
        # the following are optional so CLI doesnt need to pass them
        optimizer: str | None = None,
        optimizer_hparams: dict | None = None,
        scheduler: str | None = None,
        scheduler_hparams: dict | None = None,
        #
        freeze_backbone: bool = False,  # noqa: FBT001, FBT002
        freeze_decoder: bool = False,  # noqa: FBT001, FBT002
        plot_on_val: bool | int = 10,
        tiled_inference_parameters: TiledInferenceParameters | None = None,
    ) -> None:
        """Constructor

        Args:
            model_args (Dict): Arguments passed to the model factory.
            model_factory (str): Name of ModelFactory class to be used to instantiate the model.
            loss (str, optional): Loss to be used. Currently, supports 'mse', 'rmse', 'mae' or 'huber' loss.
                Defaults to "mse".
            aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
                Should be a dictionary where the key is the name given to the loss
                and the value is the weight to be applied to that loss.
                The name of the loss should match the key in the dictionary output by the model's forward
                method containing that output. Defaults to None.
            class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
                Defaults to None.
            ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
            lr (float, optional): Learning rate to be used. Defaults to 0.001.
            optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
                If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
            optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
                Overriden by config / cli specification through LightningCLI.
            scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
                to be used (e.g. ReduceLROnPlateau). Defaults to None.
                Overriden by config / cli specification through LightningCLI.
            scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
                Overriden by config / cli specification through LightningCLI.
            freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
            freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
            plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
                If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
            tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
                used to determine if inference is done on the whole image or through tiling.
        """
        self.tiled_inference_parameters = tiled_inference_parameters
        self.aux_loss = aux_loss
        self.aux_heads = aux_heads
        self.model_factory = get_factory(model_factory)
        super().__init__()
        self.train_loss_handler = LossHandler(self.train_metrics.prefix)
        self.test_loss_handler = LossHandler(self.test_metrics.prefix)
        self.val_loss_handler = LossHandler(self.val_metrics.prefix)
        self.monitor = f"{self.val_metrics.prefix}loss"
        self.plot_on_val = int(plot_on_val)

    # overwrite early stopping
    def configure_callbacks(self) -> list[Callback]:
        return []

    def configure_models(self) -> None:
        self.model: Model = self.model_factory.build_model(
            "regression", aux_decoders=self.aux_heads, **self.hparams["model_args"]
        )
        if self.hparams["freeze_backbone"]:
            self.model.freeze_encoder()
        if self.hparams["freeze_decoder"]:
            self.model.freeze_decoder()

    def configure_optimizers(
        self,
    ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
        optimizer = self.hparams["optimizer"]
        if optimizer is None:
            optimizer = "Adam"
        return optimizer_factory(
            self.hparams["optimizer"],
            self.hparams["lr"],
            self.parameters(),
            self.hparams["optimizer_hparams"],
            self.hparams["scheduler"],
            self.monitor,
            self.hparams["scheduler_hparams"],
        )

    def configure_losses(self) -> None:
        """Initialize the loss criterion.

        Raises:
            ValueError: If *loss* is invalid.
        """
        loss: str = self.hparams["loss"].lower()
        if loss == "mse":
            self.criterion: nn.Module = IgnoreIndexLossWrapper(
                nn.MSELoss(reduction="none"), self.hparams["ignore_index"]
            )
        elif loss == "mae":
            self.criterion = IgnoreIndexLossWrapper(nn.L1Loss(reduction="none"), self.hparams["ignore_index"])
        elif loss == "rmse":
            # IMPORTANT! Root is done only after ignore index! Otherwise the mean taken is incorrect
            self.criterion = RootLossWrapper(
                IgnoreIndexLossWrapper(nn.MSELoss(reduction="none"), self.hparams["ignore_index"]), reduction=None
            )
        elif loss == "huber":
            self.criterion = IgnoreIndexLossWrapper(nn.HuberLoss(reduction="none"), self.hparams["ignore_index"])
        else:
            exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'mse', 'rmse' or 'mae' loss."
            raise ValueError(exception_message)

    def configure_metrics(self) -> None:
        """Initialize the performance metrics."""

        def instantiate_metrics():
            return {
                "RMSE": MeanSquaredError(squared=False),
                "MSE": MeanSquaredError(squared=True),
                "MAE": MeanAbsoluteError(),
            }

        def wrap_metrics_with_ignore_index(metrics):
            return {
                name: IgnoreIndexMetricWrapper(metric, ignore_index=self.hparams["ignore_index"])
                for name, metric in metrics.items()
            }

        self.train_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="train/")
        self.val_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="val/")
        self.test_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/")

    def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the train loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]
        model_output: ModelOutput = self(x)
        loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
        y_hat = model_output.output
        self.train_metrics(y_hat, y)
        self.log_dict(self.train_metrics, on_epoch=True)

        return loss["loss"]

    def _do_plot_samples(self, batch_index):
        if not self.plot_on_val:  # dont plot if self.plot_on_val is 0
            return False

        return (
            batch_index < BATCH_IDX_FOR_VALIDATION_PLOTTING
            and hasattr(self.trainer, "datamodule")
            and self.logger
            and not self.current_epoch % self.plot_on_val  # will be True every self.plot_on_val epochs
            and hasattr(self.logger, "experiment")
            and (hasattr(self.logger.experiment, "add_figure") or hasattr(self.logger.experiment, "log_figure"))
        )

    def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the validation loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]
        model_output: ModelOutput = self(x)
        loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
        y_hat = model_output.output
        out = y_hat[y != -1]
        mask = y[y != -1]
        self.val_metrics(out, mask)
        self.log_dict(self.val_metrics, on_epoch=True)

        if self._do_plot_samples(batch_idx):
            try:
                datamodule = self.trainer.datamodule
                batch["prediction"] = y_hat
                for key in ["image", "mask", "prediction"]:
                    batch[key] = batch[key].cpu()
                sample = unbind_samples(batch)[0]
                fig = datamodule.val_dataset.plot(sample)
                if fig:
                    summary_writer = self.logger.experiment
                    if hasattr(summary_writer, "add_figure"):
                        summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
                    elif hasattr(summary_writer, "log_figure"):
                        summary_writer.log_figure(
                            self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png"
                        )
            except ValueError:
                pass
            finally:
                plt.close()

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the test loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]
        model_output: ModelOutput = self(x)
        loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
        y_hat = model_output.output
        self.test_metrics(y_hat, y)
        self.log_dict(self.test_metrics, on_epoch=True)

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the predicted class probabilities.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.

        Returns:
            Output predicted probabilities.
        """
        x = batch["image"]
        file_names = batch["filename"]

        def model_forward(x):
            return self(x).output

        if self.tiled_inference_parameters:
            y_hat: Tensor = tiled_inference(model_forward, x, 1, self.tiled_inference_parameters)
        else:
            y_hat: Tensor = self(x).output
        return y_hat, file_names

__init__(model_args, model_factory, loss='mse', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, plot_on_val=10, tiled_inference_parameters=None)

Constructor

Parameters:

Name Type Description Default
model_args Dict

Arguments passed to the model factory.

required
model_factory str

Name of ModelFactory class to be used to instantiate the model.

required
loss str

Loss to be used. Currently, supports 'mse', 'rmse', 'mae' or 'huber' loss. Defaults to "mse".

'mse'
aux_loss dict[str, float] | None

Auxiliary loss weights. Should be a dictionary where the key is the name given to the loss and the value is the weight to be applied to that loss. The name of the loss should match the key in the dictionary output by the model's forward method containing that output. Defaults to None.

None
class_weights list[float] | None

List of class weights to be applied to the loss. Defaults to None.

None
ignore_index int | None

Label to ignore in the loss computation. Defaults to None.

None
lr float

Learning rate to be used. Defaults to 0.001.

0.001
optimizer str | None

Name of optimizer class from torch.optim to be used. If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.

None
optimizer_hparams dict | None

Parameters to be passed for instantiation of the optimizer. Overriden by config / cli specification through LightningCLI.

None
scheduler str

Name of Torch scheduler class from torch.optim.lr_scheduler to be used (e.g. ReduceLROnPlateau). Defaults to None. Overriden by config / cli specification through LightningCLI.

None
scheduler_hparams dict | None

Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI.

None
freeze_backbone bool

Whether to freeze the backbone. Defaults to False.

False
freeze_decoder bool

Whether to freeze the decoder and segmentation head. Defaults to False.

False
plot_on_val bool | int

Whether to plot visualizations on validation. If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.

10
tiled_inference_parameters TiledInferenceParameters | None

Inference parameters used to determine if inference is done on the whole image or through tiling.

None
Source code in terratorch/tasks/regression_tasks.py
def __init__(
    self,
    model_args: dict,
    model_factory: str,
    loss: str = "mse",
    aux_heads: list[AuxiliaryHead] | None = None,
    aux_loss: dict[str, float] | None = None,
    class_weights: list[float] | None = None,
    ignore_index: int | None = None,
    lr: float = 0.001,
    # the following are optional so CLI doesnt need to pass them
    optimizer: str | None = None,
    optimizer_hparams: dict | None = None,
    scheduler: str | None = None,
    scheduler_hparams: dict | None = None,
    #
    freeze_backbone: bool = False,  # noqa: FBT001, FBT002
    freeze_decoder: bool = False,  # noqa: FBT001, FBT002
    plot_on_val: bool | int = 10,
    tiled_inference_parameters: TiledInferenceParameters | None = None,
) -> None:
    """Constructor

    Args:
        model_args (Dict): Arguments passed to the model factory.
        model_factory (str): Name of ModelFactory class to be used to instantiate the model.
        loss (str, optional): Loss to be used. Currently, supports 'mse', 'rmse', 'mae' or 'huber' loss.
            Defaults to "mse".
        aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
            Should be a dictionary where the key is the name given to the loss
            and the value is the weight to be applied to that loss.
            The name of the loss should match the key in the dictionary output by the model's forward
            method containing that output. Defaults to None.
        class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
            Defaults to None.
        ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
        lr (float, optional): Learning rate to be used. Defaults to 0.001.
        optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
            If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
        optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
            Overriden by config / cli specification through LightningCLI.
        scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
            to be used (e.g. ReduceLROnPlateau). Defaults to None.
            Overriden by config / cli specification through LightningCLI.
        scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
            Overriden by config / cli specification through LightningCLI.
        freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
        freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
        plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
            If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
        tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
            used to determine if inference is done on the whole image or through tiling.
    """
    self.tiled_inference_parameters = tiled_inference_parameters
    self.aux_loss = aux_loss
    self.aux_heads = aux_heads
    self.model_factory = get_factory(model_factory)
    super().__init__()
    self.train_loss_handler = LossHandler(self.train_metrics.prefix)
    self.test_loss_handler = LossHandler(self.test_metrics.prefix)
    self.val_loss_handler = LossHandler(self.val_metrics.prefix)
    self.monitor = f"{self.val_metrics.prefix}loss"
    self.plot_on_val = int(plot_on_val)

configure_losses()

Initialize the loss criterion.

Raises:

Type Description
ValueError

If loss is invalid.

Source code in terratorch/tasks/regression_tasks.py
def configure_losses(self) -> None:
    """Initialize the loss criterion.

    Raises:
        ValueError: If *loss* is invalid.
    """
    loss: str = self.hparams["loss"].lower()
    if loss == "mse":
        self.criterion: nn.Module = IgnoreIndexLossWrapper(
            nn.MSELoss(reduction="none"), self.hparams["ignore_index"]
        )
    elif loss == "mae":
        self.criterion = IgnoreIndexLossWrapper(nn.L1Loss(reduction="none"), self.hparams["ignore_index"])
    elif loss == "rmse":
        # IMPORTANT! Root is done only after ignore index! Otherwise the mean taken is incorrect
        self.criterion = RootLossWrapper(
            IgnoreIndexLossWrapper(nn.MSELoss(reduction="none"), self.hparams["ignore_index"]), reduction=None
        )
    elif loss == "huber":
        self.criterion = IgnoreIndexLossWrapper(nn.HuberLoss(reduction="none"), self.hparams["ignore_index"])
    else:
        exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'mse', 'rmse' or 'mae' loss."
        raise ValueError(exception_message)

configure_metrics()

Initialize the performance metrics.

Source code in terratorch/tasks/regression_tasks.py
def configure_metrics(self) -> None:
    """Initialize the performance metrics."""

    def instantiate_metrics():
        return {
            "RMSE": MeanSquaredError(squared=False),
            "MSE": MeanSquaredError(squared=True),
            "MAE": MeanAbsoluteError(),
        }

    def wrap_metrics_with_ignore_index(metrics):
        return {
            name: IgnoreIndexMetricWrapper(metric, ignore_index=self.hparams["ignore_index"])
            for name, metric in metrics.items()
        }

    self.train_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="train/")
    self.val_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="val/")
    self.test_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/")

predict_step(batch, batch_idx, dataloader_idx=0)

Compute the predicted class probabilities.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

Returns:

Type Description
Tensor

Output predicted probabilities.

Source code in terratorch/tasks/regression_tasks.py
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the predicted class probabilities.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.

    Returns:
        Output predicted probabilities.
    """
    x = batch["image"]
    file_names = batch["filename"]

    def model_forward(x):
        return self(x).output

    if self.tiled_inference_parameters:
        y_hat: Tensor = tiled_inference(model_forward, x, 1, self.tiled_inference_parameters)
    else:
        y_hat: Tensor = self(x).output
    return y_hat, file_names

test_step(batch, batch_idx, dataloader_idx=0)

Compute the test loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/regression_tasks.py
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the test loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]
    model_output: ModelOutput = self(x)
    loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
    y_hat = model_output.output
    self.test_metrics(y_hat, y)
    self.log_dict(self.test_metrics, on_epoch=True)

training_step(batch, batch_idx, dataloader_idx=0)

Compute the train loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/regression_tasks.py
def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the train loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]
    model_output: ModelOutput = self(x)
    loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
    y_hat = model_output.output
    self.train_metrics(y_hat, y)
    self.log_dict(self.train_metrics, on_epoch=True)

    return loss["loss"]

validation_step(batch, batch_idx, dataloader_idx=0)

Compute the validation loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/regression_tasks.py
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the validation loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]
    model_output: ModelOutput = self(x)
    loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
    y_hat = model_output.output
    out = y_hat[y != -1]
    mask = y[y != -1]
    self.val_metrics(out, mask)
    self.log_dict(self.val_metrics, on_epoch=True)

    if self._do_plot_samples(batch_idx):
        try:
            datamodule = self.trainer.datamodule
            batch["prediction"] = y_hat
            for key in ["image", "mask", "prediction"]:
                batch[key] = batch[key].cpu()
            sample = unbind_samples(batch)[0]
            fig = datamodule.val_dataset.plot(sample)
            if fig:
                summary_writer = self.logger.experiment
                if hasattr(summary_writer, "add_figure"):
                    summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
                elif hasattr(summary_writer, "log_figure"):
                    summary_writer.log_figure(
                        self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png"
                    )
        except ValueError:
            pass
        finally:
            plt.close()

terratorch.tasks.segmentation_tasks.SemanticSegmentationTask

Bases: BaseTask

Semantic Segmentation Task that accepts models from a range of sources.

This class is analog in functionality to class:SemanticSegmentationTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor

Source code in terratorch/tasks/segmentation_tasks.py
class SemanticSegmentationTask(BaseTask):
    """Semantic Segmentation Task that accepts models from a range of sources.

    This class is analog in functionality to class:SemanticSegmentationTask defined by torchgeo.
    However, it has some important differences:
        - Accepts the specification of a model factory
        - Logs metrics per class
        - Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
        - Allows the setting of optimizers in the constructor
    """

    def __init__(
        self,
        model_args: dict,
        model_factory: str,
        loss: str = "ce",
        aux_heads: list[AuxiliaryHead] | None = None,
        aux_loss: dict[str, float] | None = None,
        class_weights: list[float] | None = None,
        ignore_index: int | None = None,
        lr: float = 0.001,
        # the following are optional so CLI doesnt need to pass them
        optimizer: str | None = None,
        optimizer_hparams: dict | None = None,
        scheduler: str | None = None,
        scheduler_hparams: dict | None = None,
        #
        freeze_backbone: bool = False,  # noqa: FBT001, FBT002
        freeze_decoder: bool = False,  # noqa: FBT002, FBT001
        plot_on_val: bool | int = 10,
        class_names: list[str] | None = None,
        tiled_inference_parameters: TiledInferenceParameters = None,
    ) -> None:
        """Constructor

        Args:

            Defaults to None.
            model_args (Dict): Arguments passed to the model factory.
            model_factory (str): ModelFactory class to be used to instantiate the model.
            loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
                Defaults to "ce".
            aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
                Should be a dictionary where the key is the name given to the loss
                and the value is the weight to be applied to that loss.
                The name of the loss should match the key in the dictionary output by the model's forward
                method containing that output. Defaults to None.
            class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
            class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
                Defaults to None.
            ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
            lr (float, optional): Learning rate to be used. Defaults to 0.001.
            optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
            If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
            optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
                Overriden by config / cli specification through LightningCLI.
            scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
                to be used (e.g. ReduceLROnPlateau). Defaults to None.
                Overriden by config / cli specification through LightningCLI.
            scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
                Overriden by config / cli specification through LightningCLI.
            freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
            freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
            plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
            If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
            class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
                Defaults to numeric ordering.
            tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
                used to determine if inference is done on the whole image or through tiling.
        """
        self.tiled_inference_parameters = tiled_inference_parameters
        self.aux_loss = aux_loss
        self.aux_heads = aux_heads
        self.model_factory = get_factory(model_factory)
        super().__init__()
        self.train_loss_handler = LossHandler(self.train_metrics.prefix)
        self.test_loss_handler = LossHandler(self.test_metrics.prefix)
        self.val_loss_handler = LossHandler(self.val_metrics.prefix)
        self.monitor = f"{self.val_metrics.prefix}loss"
        self.plot_on_val = int(plot_on_val)

    # overwrite early stopping
    def configure_callbacks(self) -> list[Callback]:
        return []

    def configure_models(self) -> None:
        self.model: Model = self.model_factory.build_model(
            "segmentation", aux_decoders=self.aux_heads, **self.hparams["model_args"]
        )
        if self.hparams["freeze_backbone"]:
            self.model.freeze_encoder()
        if self.hparams["freeze_decoder"]:
            self.model.freeze_decoder()

    def configure_optimizers(
        self,
    ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
        optimizer = self.hparams["optimizer"]
        if optimizer is None:
            optimizer = "Adam"
        return optimizer_factory(
            self.hparams["optimizer"],
            self.hparams["lr"],
            self.parameters(),
            self.hparams["optimizer_hparams"],
            self.hparams["scheduler"],
            self.monitor,
            self.hparams["scheduler_hparams"],
        )

    def configure_losses(self) -> None:
        """Initialize the loss criterion.

        Raises:
            ValueError: If *loss* is invalid.
        """
        loss: str = self.hparams["loss"]
        ignore_index = self.hparams["ignore_index"]

        class_weights = (
            torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
        )
        if loss == "ce":
            ignore_value = -100 if ignore_index is None else ignore_index
            self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
        elif loss == "jaccard":
            if ignore_index is not None:
                exception_message = (
                    f"Jaccard loss does not support ignore_index, but found non-None value of {ignore_index}."
                )
                raise RuntimeError(exception_message)
            self.criterion = smp.losses.JaccardLoss(mode="multiclass")
        elif loss == "focal":
            self.criterion = smp.losses.FocalLoss("multiclass", ignore_index=ignore_index, normalized=True)
        elif loss == "dice":
            self.criterion = smp.losses.DiceLoss("multiclass", ignore_index=ignore_index)
        else:
            exception_message = (
                f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss."
            )
            raise ValueError(exception_message)

    def configure_metrics(self) -> None:
        """Initialize the performance metrics."""
        num_classes: int = self.hparams["model_args"]["num_classes"]
        ignore_index: int = self.hparams["ignore_index"]
        class_names = self.hparams["class_names"]
        metrics = MetricCollection(
            {
                "Multiclass_Accuracy": MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    multidim_average="global",
                    average="micro",
                ),
                "Multiclass_Accuracy_Class": ClasswiseWrapper(
                    MulticlassAccuracy(
                        num_classes=num_classes,
                        ignore_index=ignore_index,
                        multidim_average="global",
                        average=None,
                    ),
                    labels=class_names,
                ),
                "Multiclass_Jaccard_Index_Micro": MulticlassJaccardIndex(
                    num_classes=num_classes, ignore_index=ignore_index, average="micro"
                ),
                "Multiclass_Jaccard_Index": MulticlassJaccardIndex(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                ),
                "Multiclass_Jaccard_Index_Class": ClasswiseWrapper(
                    MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None),
                    labels=class_names,
                ),
                "Multiclass_F1_Score": MulticlassF1Score(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    multidim_average="global",
                    average="micro",
                ),
            }
        )
        self.train_metrics = metrics.clone(prefix="train/")
        self.val_metrics = metrics.clone(prefix="val/")
        self.test_metrics = metrics.clone(prefix="test/")

    def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the train loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]

        model_output: ModelOutput = self(x)
        loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
        y_hat_hard = to_segmentation_prediction(model_output)
        self.train_metrics.update(y_hat_hard, y)

        return loss["loss"]

    def on_train_epoch_end(self) -> None:
        self.log_dict(self.train_metrics.compute(), sync_dist=True)
        self.train_metrics.reset()
        return super().on_train_epoch_end()

    def _do_plot_samples(self, batch_index):
        if not self.plot_on_val:  # dont plot if self.plot_on_val is 0
            return False

        return (
            batch_index < BATCH_IDX_FOR_VALIDATION_PLOTTING
            and hasattr(self.trainer, "datamodule")
            and self.logger
            and not self.current_epoch % self.plot_on_val  # will be True every self.plot_on_val epochs
            and hasattr(self.logger, "experiment")
            and (hasattr(self.logger.experiment, "add_figure") or hasattr(self.logger.experiment, "log_figure"))
        )

    def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the validation loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]
        model_output: ModelOutput = self(x)
        loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
        y_hat_hard = to_segmentation_prediction(model_output)
        self.val_metrics.update(y_hat_hard, y)

        if self._do_plot_samples(batch_idx):
            try:
                datamodule = self.trainer.datamodule
                batch["prediction"] = y_hat_hard
                for key in ["image", "mask", "prediction"]:
                    batch[key] = batch[key].cpu()
                sample = unbind_samples(batch)[0]
                fig = datamodule.val_dataset.plot(sample)
                if fig:
                    summary_writer = self.logger.experiment
                    if hasattr(summary_writer, "add_figure"):
                        summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
                    elif hasattr(summary_writer, "log_figure"):
                        summary_writer.log_figure(
                            self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png"
                        )
            except ValueError:
                pass
            finally:
                plt.close()

    def on_validation_epoch_end(self) -> None:
        self.log_dict(self.val_metrics.compute(), sync_dist=True)
        self.val_metrics.reset()
        return super().on_validation_epoch_end()

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the test loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]

        model_output: ModelOutput = self(x)
        loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
        y_hat_hard = to_segmentation_prediction(model_output)
        self.test_metrics.update(y_hat_hard, y)

    def on_test_epoch_end(self) -> None:
        self.log_dict(self.test_metrics.compute(), sync_dist=True)
        self.test_metrics.reset()
        return super().on_test_epoch_end()

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the predicted class probabilities.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.

        Returns:
            Output predicted probabilities.
        """
        x = batch["image"]
        file_names = batch["filename"]

        def model_forward(x):
            return self(x).output

        if self.tiled_inference_parameters:
            y_hat: Tensor = tiled_inference(
                model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters
            )
        else:
            y_hat: Tensor = self(x).output
        y_hat = y_hat.argmax(dim=1)
        return y_hat, file_names

__init__(model_args, model_factory, loss='ce', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, plot_on_val=10, class_names=None, tiled_inference_parameters=None)

Constructor

Args:

Defaults to None.
model_args (Dict): Arguments passed to the model factory.
model_factory (str): ModelFactory class to be used to instantiate the model.
loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
    Defaults to "ce".
aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
    Should be a dictionary where the key is the name given to the loss
    and the value is the weight to be applied to that loss.
    The name of the loss should match the key in the dictionary output by the model's forward
    method containing that output. Defaults to None.
class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
    Defaults to None.
ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
lr (float, optional): Learning rate to be used. Defaults to 0.001.
optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
    Overriden by config / cli specification through LightningCLI.
scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
    to be used (e.g. ReduceLROnPlateau). Defaults to None.
    Overriden by config / cli specification through LightningCLI.
scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
    Overriden by config / cli specification through LightningCLI.
freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
    Defaults to numeric ordering.
tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
    used to determine if inference is done on the whole image or through tiling.
Source code in terratorch/tasks/segmentation_tasks.py
def __init__(
    self,
    model_args: dict,
    model_factory: str,
    loss: str = "ce",
    aux_heads: list[AuxiliaryHead] | None = None,
    aux_loss: dict[str, float] | None = None,
    class_weights: list[float] | None = None,
    ignore_index: int | None = None,
    lr: float = 0.001,
    # the following are optional so CLI doesnt need to pass them
    optimizer: str | None = None,
    optimizer_hparams: dict | None = None,
    scheduler: str | None = None,
    scheduler_hparams: dict | None = None,
    #
    freeze_backbone: bool = False,  # noqa: FBT001, FBT002
    freeze_decoder: bool = False,  # noqa: FBT002, FBT001
    plot_on_val: bool | int = 10,
    class_names: list[str] | None = None,
    tiled_inference_parameters: TiledInferenceParameters = None,
) -> None:
    """Constructor

    Args:

        Defaults to None.
        model_args (Dict): Arguments passed to the model factory.
        model_factory (str): ModelFactory class to be used to instantiate the model.
        loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
            Defaults to "ce".
        aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
            Should be a dictionary where the key is the name given to the loss
            and the value is the weight to be applied to that loss.
            The name of the loss should match the key in the dictionary output by the model's forward
            method containing that output. Defaults to None.
        class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
        class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
            Defaults to None.
        ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
        lr (float, optional): Learning rate to be used. Defaults to 0.001.
        optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
        If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
        optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
            Overriden by config / cli specification through LightningCLI.
        scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
            to be used (e.g. ReduceLROnPlateau). Defaults to None.
            Overriden by config / cli specification through LightningCLI.
        scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
            Overriden by config / cli specification through LightningCLI.
        freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
        freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
        plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
        If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
        class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
            Defaults to numeric ordering.
        tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
            used to determine if inference is done on the whole image or through tiling.
    """
    self.tiled_inference_parameters = tiled_inference_parameters
    self.aux_loss = aux_loss
    self.aux_heads = aux_heads
    self.model_factory = get_factory(model_factory)
    super().__init__()
    self.train_loss_handler = LossHandler(self.train_metrics.prefix)
    self.test_loss_handler = LossHandler(self.test_metrics.prefix)
    self.val_loss_handler = LossHandler(self.val_metrics.prefix)
    self.monitor = f"{self.val_metrics.prefix}loss"
    self.plot_on_val = int(plot_on_val)

configure_losses()

Initialize the loss criterion.

Raises:

Type Description
ValueError

If loss is invalid.

Source code in terratorch/tasks/segmentation_tasks.py
def configure_losses(self) -> None:
    """Initialize the loss criterion.

    Raises:
        ValueError: If *loss* is invalid.
    """
    loss: str = self.hparams["loss"]
    ignore_index = self.hparams["ignore_index"]

    class_weights = (
        torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
    )
    if loss == "ce":
        ignore_value = -100 if ignore_index is None else ignore_index
        self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
    elif loss == "jaccard":
        if ignore_index is not None:
            exception_message = (
                f"Jaccard loss does not support ignore_index, but found non-None value of {ignore_index}."
            )
            raise RuntimeError(exception_message)
        self.criterion = smp.losses.JaccardLoss(mode="multiclass")
    elif loss == "focal":
        self.criterion = smp.losses.FocalLoss("multiclass", ignore_index=ignore_index, normalized=True)
    elif loss == "dice":
        self.criterion = smp.losses.DiceLoss("multiclass", ignore_index=ignore_index)
    else:
        exception_message = (
            f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss."
        )
        raise ValueError(exception_message)

configure_metrics()

Initialize the performance metrics.

Source code in terratorch/tasks/segmentation_tasks.py
def configure_metrics(self) -> None:
    """Initialize the performance metrics."""
    num_classes: int = self.hparams["model_args"]["num_classes"]
    ignore_index: int = self.hparams["ignore_index"]
    class_names = self.hparams["class_names"]
    metrics = MetricCollection(
        {
            "Multiclass_Accuracy": MulticlassAccuracy(
                num_classes=num_classes,
                ignore_index=ignore_index,
                multidim_average="global",
                average="micro",
            ),
            "Multiclass_Accuracy_Class": ClasswiseWrapper(
                MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    multidim_average="global",
                    average=None,
                ),
                labels=class_names,
            ),
            "Multiclass_Jaccard_Index_Micro": MulticlassJaccardIndex(
                num_classes=num_classes, ignore_index=ignore_index, average="micro"
            ),
            "Multiclass_Jaccard_Index": MulticlassJaccardIndex(
                num_classes=num_classes,
                ignore_index=ignore_index,
            ),
            "Multiclass_Jaccard_Index_Class": ClasswiseWrapper(
                MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None),
                labels=class_names,
            ),
            "Multiclass_F1_Score": MulticlassF1Score(
                num_classes=num_classes,
                ignore_index=ignore_index,
                multidim_average="global",
                average="micro",
            ),
        }
    )
    self.train_metrics = metrics.clone(prefix="train/")
    self.val_metrics = metrics.clone(prefix="val/")
    self.test_metrics = metrics.clone(prefix="test/")

predict_step(batch, batch_idx, dataloader_idx=0)

Compute the predicted class probabilities.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

Returns:

Type Description
Tensor

Output predicted probabilities.

Source code in terratorch/tasks/segmentation_tasks.py
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the predicted class probabilities.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.

    Returns:
        Output predicted probabilities.
    """
    x = batch["image"]
    file_names = batch["filename"]

    def model_forward(x):
        return self(x).output

    if self.tiled_inference_parameters:
        y_hat: Tensor = tiled_inference(
            model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters
        )
    else:
        y_hat: Tensor = self(x).output
    y_hat = y_hat.argmax(dim=1)
    return y_hat, file_names

test_step(batch, batch_idx, dataloader_idx=0)

Compute the test loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/segmentation_tasks.py
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the test loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]

    model_output: ModelOutput = self(x)
    loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
    y_hat_hard = to_segmentation_prediction(model_output)
    self.test_metrics.update(y_hat_hard, y)

training_step(batch, batch_idx, dataloader_idx=0)

Compute the train loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/segmentation_tasks.py
def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the train loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]

    model_output: ModelOutput = self(x)
    loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
    y_hat_hard = to_segmentation_prediction(model_output)
    self.train_metrics.update(y_hat_hard, y)

    return loss["loss"]

validation_step(batch, batch_idx, dataloader_idx=0)

Compute the validation loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/segmentation_tasks.py
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the validation loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]
    model_output: ModelOutput = self(x)
    loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
    y_hat_hard = to_segmentation_prediction(model_output)
    self.val_metrics.update(y_hat_hard, y)

    if self._do_plot_samples(batch_idx):
        try:
            datamodule = self.trainer.datamodule
            batch["prediction"] = y_hat_hard
            for key in ["image", "mask", "prediction"]:
                batch[key] = batch[key].cpu()
            sample = unbind_samples(batch)[0]
            fig = datamodule.val_dataset.plot(sample)
            if fig:
                summary_writer = self.logger.experiment
                if hasattr(summary_writer, "add_figure"):
                    summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
                elif hasattr(summary_writer, "log_figure"):
                    summary_writer.log_figure(
                        self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png"
                    )
        except ValueError:
            pass
        finally:
            plt.close()

terratorch.tasks.classification_tasks.ClassificationTask

Bases: BaseTask

Classification Task that accepts models from a range of sources.

This class is analog in functionality to class:ClassificationTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - It provides mIoU with both Micro and Macro averaging

.. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect minority class accuracy. * 'Macro' averaging gives equal weight to each class, useful for balanced performance assessment across imbalanced classes.

Source code in terratorch/tasks/classification_tasks.py
class ClassificationTask(BaseTask):
    """Classification Task that accepts models from a range of sources.

    This class is analog in functionality to class:ClassificationTask defined by torchgeo.
    However, it has some important differences:
        - Accepts the specification of a model factory
        - Logs metrics per class
        - Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
        - Allows the setting of optimizers in the constructor
        - It provides mIoU with both Micro and Macro averaging

    .. note::
           * 'Micro' averaging suits overall performance evaluation but may not reflect
             minority class accuracy.
           * 'Macro' averaging gives equal weight to each class, useful
             for balanced performance assessment across imbalanced classes.
    """

    def __init__(
        self,
        model_args: dict,
        model_factory: str,
        loss: str = "ce",
        aux_heads: list[AuxiliaryHead] | None = None,
        aux_loss: dict[str, float] | None = None,
        class_weights: list[float] | None = None,
        ignore_index: int | None = None,
        lr: float = 0.001,
        # the following are optional so CLI doesnt need to pass them
        optimizer: str | None = None,
        optimizer_hparams: dict | None = None,
        scheduler: str | None = None,
        scheduler_hparams: dict | None = None,
        #
        #
        freeze_backbone: bool = False,  # noqa: FBT001, FBT002
        freeze_decoder: bool = False,  # noqa: FBT002, FBT001
        class_names: list[str] | None = None,
    ) -> None:
        """Constructor

        Args:

            Defaults to None.
            model_args (Dict): Arguments passed to the model factory.
            model_factory (str): ModelFactory class to be used to instantiate the model.
            loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
                Defaults to "ce".
            aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
                Should be a dictionary where the key is the name given to the loss
                and the value is the weight to be applied to that loss.
                The name of the loss should match the key in the dictionary output by the model's forward
                method containing that output. Defaults to None.
            class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
            class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
                Defaults to None.
            ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
            lr (float, optional): Learning rate to be used. Defaults to 0.001.
            optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
                If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
            optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
                Overriden by config / cli specification through LightningCLI.
            scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
                to be used (e.g. ReduceLROnPlateau). Defaults to None.
                Overriden by config / cli specification through LightningCLI.
            scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
                Overriden by config / cli specification through LightningCLI.
            freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
            freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
            class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
                Defaults to numeric ordering.
        """
        self.aux_loss = aux_loss
        self.aux_heads = aux_heads
        self.model_factory = get_factory(model_factory)
        super().__init__()
        self.train_loss_handler = LossHandler(self.train_metrics.prefix)
        self.test_loss_handler = LossHandler(self.test_metrics.prefix)
        self.val_loss_handler = LossHandler(self.val_metrics.prefix)
        self.monitor = f"{self.val_metrics.prefix}loss"

    # overwrite early stopping
    def configure_callbacks(self) -> list[Callback]:
        return []

    def configure_models(self) -> None:
        self.model: Model = self.model_factory.build_model(
            "classification", aux_decoders=self.aux_heads, **self.hparams["model_args"]
        )
        if self.hparams["freeze_backbone"]:
            self.model.freeze_encoder()
        if self.hparams["freeze_decoder"]:
            self.model.freeze_decoder()

    def configure_optimizers(
        self,
    ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
        optimizer = self.hparams["optimizer"]
        if optimizer is None:
            optimizer = "Adam"
        return optimizer_factory(
            self.hparams["optimizer"],
            self.hparams["lr"],
            self.parameters(),
            self.hparams["optimizer_hparams"],
            self.hparams["scheduler"],
            self.monitor,
            self.hparams["scheduler_hparams"],
        )

    def configure_losses(self) -> None:
        """Initialize the loss criterion.

        Raises:
            ValueError: If *loss* is invalid.
        """
        loss: str = self.hparams["loss"]
        if loss == "ce":
            self.criterion: nn.Module = nn.CrossEntropyLoss(weight=self.hparams["class_weights"])
        elif loss == "bce":
            self.criterion = nn.BCEWithLogitsLoss()
        elif loss == "jaccard":
            self.criterion = JaccardLoss(mode="multiclass")
        elif loss == "focal":
            self.criterion = FocalLoss(mode="multiclass", normalized=True)
        else:
            msg = f"Loss type '{loss}' is not valid."
            raise ValueError(msg)

    def configure_metrics(self) -> None:
        """Initialize the performance metrics."""
        num_classes: int = self.hparams["model_args"]["num_classes"]
        ignore_index: int = self.hparams["ignore_index"]
        class_names = self.hparams["class_names"]
        metrics = MetricCollection(
            {
                "Overall_Accuracy": MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="micro",
                ),
                "Average_Accuracy": MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="macro",
                ),
                "Multiclass_Accuracy_Class": ClasswiseWrapper(
                    MulticlassAccuracy(
                        num_classes=num_classes,
                        ignore_index=ignore_index,
                        average=None,
                    ),
                    labels=class_names,
                ),
                "Multiclass_Jaccard_Index": MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index),
                "Multiclass_Jaccard_Index_Class": ClasswiseWrapper(
                    MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None),
                    labels=class_names,
                ),
                # why FBetaScore
                "Multiclass_F1_Score": MulticlassFBetaScore(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    beta=1.0,
                    average="micro",
                ),
            }
        )
        self.train_metrics = metrics.clone(prefix="train/")
        self.val_metrics = metrics.clone(prefix="val/")
        self.test_metrics = metrics.clone(prefix="test/")

    def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the train loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["label"]
        model_output: dict[str, Tensor] = self(x)
        loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
        y_hat_hard = to_class_prediction(model_output)
        self.train_metrics.update(y_hat_hard, y)

        return loss["loss"]

    def on_train_epoch_end(self) -> None:
        self.log_dict(self.train_metrics.compute(), sync_dist=True)
        self.train_metrics.reset()
        return super().on_train_epoch_end()

    def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the validation loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["label"]
        model_output: dict[str, Tensor] = self(x)
        loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
        y_hat_hard = to_class_prediction(model_output)
        self.val_metrics.update(y_hat_hard, y)

    def on_validation_epoch_end(self) -> None:
        self.log_dict(self.val_metrics.compute(), sync_dist=True)
        self.val_metrics.reset()
        return super().on_validation_epoch_end()

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the test loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["label"]
        model_output: dict[str, Tensor] = self(x)
        loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
        y_hat_hard = to_class_prediction(model_output)
        self.test_metrics.update(y_hat_hard, y)

    def on_test_epoch_end(self) -> None:
        self.log_dict(self.test_metrics.compute(), sync_dist=True)
        self.test_metrics.reset()
        return super().on_test_epoch_end()

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the predicted class probabilities.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.

        Returns:
            Output predicted probabilities.
        """
        x = batch["image"]
        file_names = batch["filename"]

        y_hat = self(x).output
        y_hat = y_hat.argmax(dim=1)
        return y_hat, file_names

__init__(model_args, model_factory, loss='ce', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, class_names=None)

Constructor

Args:

Defaults to None.
model_args (Dict): Arguments passed to the model factory.
model_factory (str): ModelFactory class to be used to instantiate the model.
loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
    Defaults to "ce".
aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
    Should be a dictionary where the key is the name given to the loss
    and the value is the weight to be applied to that loss.
    The name of the loss should match the key in the dictionary output by the model's forward
    method containing that output. Defaults to None.
class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
    Defaults to None.
ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
lr (float, optional): Learning rate to be used. Defaults to 0.001.
optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
    If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
    Overriden by config / cli specification through LightningCLI.
scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
    to be used (e.g. ReduceLROnPlateau). Defaults to None.
    Overriden by config / cli specification through LightningCLI.
scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
    Overriden by config / cli specification through LightningCLI.
freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
    Defaults to numeric ordering.
Source code in terratorch/tasks/classification_tasks.py
def __init__(
    self,
    model_args: dict,
    model_factory: str,
    loss: str = "ce",
    aux_heads: list[AuxiliaryHead] | None = None,
    aux_loss: dict[str, float] | None = None,
    class_weights: list[float] | None = None,
    ignore_index: int | None = None,
    lr: float = 0.001,
    # the following are optional so CLI doesnt need to pass them
    optimizer: str | None = None,
    optimizer_hparams: dict | None = None,
    scheduler: str | None = None,
    scheduler_hparams: dict | None = None,
    #
    #
    freeze_backbone: bool = False,  # noqa: FBT001, FBT002
    freeze_decoder: bool = False,  # noqa: FBT002, FBT001
    class_names: list[str] | None = None,
) -> None:
    """Constructor

    Args:

        Defaults to None.
        model_args (Dict): Arguments passed to the model factory.
        model_factory (str): ModelFactory class to be used to instantiate the model.
        loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
            Defaults to "ce".
        aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
            Should be a dictionary where the key is the name given to the loss
            and the value is the weight to be applied to that loss.
            The name of the loss should match the key in the dictionary output by the model's forward
            method containing that output. Defaults to None.
        class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
        class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
            Defaults to None.
        ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
        lr (float, optional): Learning rate to be used. Defaults to 0.001.
        optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
            If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
        optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
            Overriden by config / cli specification through LightningCLI.
        scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
            to be used (e.g. ReduceLROnPlateau). Defaults to None.
            Overriden by config / cli specification through LightningCLI.
        scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
            Overriden by config / cli specification through LightningCLI.
        freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
        freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
        class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
            Defaults to numeric ordering.
    """
    self.aux_loss = aux_loss
    self.aux_heads = aux_heads
    self.model_factory = get_factory(model_factory)
    super().__init__()
    self.train_loss_handler = LossHandler(self.train_metrics.prefix)
    self.test_loss_handler = LossHandler(self.test_metrics.prefix)
    self.val_loss_handler = LossHandler(self.val_metrics.prefix)
    self.monitor = f"{self.val_metrics.prefix}loss"

configure_losses()

Initialize the loss criterion.

Raises:

Type Description
ValueError

If loss is invalid.

Source code in terratorch/tasks/classification_tasks.py
def configure_losses(self) -> None:
    """Initialize the loss criterion.

    Raises:
        ValueError: If *loss* is invalid.
    """
    loss: str = self.hparams["loss"]
    if loss == "ce":
        self.criterion: nn.Module = nn.CrossEntropyLoss(weight=self.hparams["class_weights"])
    elif loss == "bce":
        self.criterion = nn.BCEWithLogitsLoss()
    elif loss == "jaccard":
        self.criterion = JaccardLoss(mode="multiclass")
    elif loss == "focal":
        self.criterion = FocalLoss(mode="multiclass", normalized=True)
    else:
        msg = f"Loss type '{loss}' is not valid."
        raise ValueError(msg)

configure_metrics()

Initialize the performance metrics.

Source code in terratorch/tasks/classification_tasks.py
def configure_metrics(self) -> None:
    """Initialize the performance metrics."""
    num_classes: int = self.hparams["model_args"]["num_classes"]
    ignore_index: int = self.hparams["ignore_index"]
    class_names = self.hparams["class_names"]
    metrics = MetricCollection(
        {
            "Overall_Accuracy": MulticlassAccuracy(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="micro",
            ),
            "Average_Accuracy": MulticlassAccuracy(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="macro",
            ),
            "Multiclass_Accuracy_Class": ClasswiseWrapper(
                MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average=None,
                ),
                labels=class_names,
            ),
            "Multiclass_Jaccard_Index": MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index),
            "Multiclass_Jaccard_Index_Class": ClasswiseWrapper(
                MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None),
                labels=class_names,
            ),
            # why FBetaScore
            "Multiclass_F1_Score": MulticlassFBetaScore(
                num_classes=num_classes,
                ignore_index=ignore_index,
                beta=1.0,
                average="micro",
            ),
        }
    )
    self.train_metrics = metrics.clone(prefix="train/")
    self.val_metrics = metrics.clone(prefix="val/")
    self.test_metrics = metrics.clone(prefix="test/")

predict_step(batch, batch_idx, dataloader_idx=0)

Compute the predicted class probabilities.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0

Returns:

Type Description
Tensor

Output predicted probabilities.

Source code in terratorch/tasks/classification_tasks.py
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the predicted class probabilities.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.

    Returns:
        Output predicted probabilities.
    """
    x = batch["image"]
    file_names = batch["filename"]

    y_hat = self(x).output
    y_hat = y_hat.argmax(dim=1)
    return y_hat, file_names

test_step(batch, batch_idx, dataloader_idx=0)

Compute the test loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/classification_tasks.py
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the test loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["label"]
    model_output: dict[str, Tensor] = self(x)
    loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
    y_hat_hard = to_class_prediction(model_output)
    self.test_metrics.update(y_hat_hard, y)

training_step(batch, batch_idx, dataloader_idx=0)

Compute the train loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/classification_tasks.py
def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the train loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["label"]
    model_output: dict[str, Tensor] = self(x)
    loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
    y_hat_hard = to_class_prediction(model_output)
    self.train_metrics.update(y_hat_hard, y)

    return loss["loss"]

validation_step(batch, batch_idx, dataloader_idx=0)

Compute the validation loss and additional metrics.

Parameters:

Name Type Description Default
batch Any

The output of your DataLoader.

required
batch_idx int

Integer displaying index of this batch.

required
dataloader_idx int

Index of the current dataloader.

0
Source code in terratorch/tasks/classification_tasks.py
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the validation loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["label"]
    model_output: dict[str, Tensor] = self(x)
    loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
    y_hat_hard = to_class_prediction(model_output)
    self.val_metrics.update(y_hat_hard, y)