Tasks#

Tasks provide a convenient abstraction over the training of a model for a specific downstream task.

They encapsulate the model, optimizer, metrics, loss as well as training, validation and testing steps.

The task expects to be passed a model factory, to which the model_args arguments are passed to instantiate the model that will be trained. The models produced by this model factory should output ModelOutput instances and conform to the Model ABC.

Tasks are best leveraged using config files, where they are specified in the model section under class_path. You can check out some examples of config files here.

Below are the details of the tasks currently implemented in TerraTorch (Pixelwise Regression, Semantic Segmentation and Classification).


terratorch.tasks.segmentation_tasks.SemanticSegmentationTask #

Bases: TerraTorchTask

Semantic Segmentation Task that accepts models from a range of sources.

This class is analog in functionality to class SemanticSegmentationTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - Allows to evaluate on multiple test dataloaders

Source code in terratorch/tasks/segmentation_tasks.py
class SemanticSegmentationTask(TerraTorchTask):
    """Semantic Segmentation Task that accepts models from a range of sources.

    This class is analog in functionality to class SemanticSegmentationTask defined by torchgeo.
    However, it has some important differences:
        - Accepts the specification of a model factory
        - Logs metrics per class
        - Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
        - Allows the setting of optimizers in the constructor
        - Allows to evaluate on multiple test dataloaders
    """

    def __init__(
        self,
        model_args: dict,
        model_factory: str | None = None,
        model: torch.nn.Module | None = None,
        loss: str = "ce",
        aux_heads: list[AuxiliaryHead] | None = None,
        aux_loss: dict[str, float] | None = None,
        class_weights: list[float] | None = None,
        ignore_index: int | None = None,
        lr: float = 0.001,
        # the following are optional so CLI doesnt need to pass them
        optimizer: str | None = None,
        optimizer_hparams: dict | None = None,
        scheduler: str | None = None,
        scheduler_hparams: dict | None = None,
        #
        freeze_backbone: bool = False,  # noqa: FBT001, FBT002
        freeze_decoder: bool = False,  # noqa: FBT002, FBT001
        freeze_head: bool = False, 
        plot_on_val: bool | int = 10,
        class_names: list[str] | None = None,
        tiled_inference_parameters: TiledInferenceParameters = None,
        test_dataloaders_names: list[str] | None = None,
        lr_overrides: dict[str, float] | None = None,
        output_most_probable: bool = True,
    ) -> None:
        """Constructor

        Args:
            Defaults to None.
            model_args (Dict): Arguments passed to the model factory.
            model_factory (str, optional): ModelFactory class to be used to instantiate the model.
                Is ignored when model is provided.
            model (torch.nn.Module, optional): Custom model.
            loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
                Defaults to "ce".
            aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
                Should be a dictionary where the key is the name given to the loss
                and the value is the weight to be applied to that loss.
                The name of the loss should match the key in the dictionary output by the model's forward
                method containing that output. Defaults to None.
            class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
            class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
                Defaults to None.
            ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
            lr (float, optional): Learning rate to be used. Defaults to 0.001.
            optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
            If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
            optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
                Overriden by config / cli specification through LightningCLI.
            scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
                to be used (e.g. ReduceLROnPlateau). Defaults to None.
                Overriden by config / cli specification through LightningCLI.
            scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
                Overriden by config / cli specification through LightningCLI.
            freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
            freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
            freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False.
            plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
            If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
            class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
                Defaults to numeric ordering.
            tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
                used to determine if inference is done on the whole image or through tiling.
            test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
                multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
                which assumes only one test dataloader is used.
            lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
                parameters. The key should be a substring of the parameter names (it will check the substring is
                contained in the parameter name)and the value should be the new lr. Defaults to None.
            output_most_probable (bool): A boolean to define if the output during the inference will be just
                for the most probable class or if it will include all of them. 
        """
        self.tiled_inference_parameters = tiled_inference_parameters
        self.aux_loss = aux_loss
        self.aux_heads = aux_heads

        if model is not None and model_factory is not None:
            logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
        if model is None and model_factory is None:
            raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

        if model_factory and model is None:
            self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

        super().__init__(task="segmentation")

        if model is not None:
            # Custom model
            self.model = model

        self.train_loss_handler = LossHandler(self.train_metrics.prefix)
        self.test_loss_handler: list[LossHandler] = []
        for metrics in self.test_metrics:
            self.test_loss_handler.append(LossHandler(metrics.prefix))
        self.val_loss_handler = LossHandler(self.val_metrics.prefix)
        self.monitor = f"{self.val_metrics.prefix}loss"
        self.plot_on_val = int(plot_on_val)
        self.output_most_probable = output_most_probable

        if output_most_probable:
            self.select_classes = lambda y: y.argmax(dim=1) 
        else:
            self.select_classes = lambda y: y

    def configure_losses(self) -> None:
        """Initialize the loss criterion.

        Raises:
            ValueError: If *loss* is invalid.
        """
        loss: str = self.hparams["loss"]
        ignore_index = self.hparams["ignore_index"]

        class_weights = (
            torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
        )
        if loss == "ce":
            ignore_value = -100 if ignore_index is None else ignore_index
            self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
        elif loss == "jaccard":
            if ignore_index is not None:
                exception_message = (
                    f"Jaccard loss does not support ignore_index, but found non-None value of {ignore_index}."
                )
                raise RuntimeError(exception_message)
            self.criterion = smp.losses.JaccardLoss(mode="multiclass")
        elif loss == "focal":
            self.criterion = smp.losses.FocalLoss("multiclass", ignore_index=ignore_index, normalized=True)
        elif loss == "dice":
            self.criterion = smp.losses.DiceLoss("multiclass", ignore_index=ignore_index)
        else:
            exception_message = (
                f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss."
            )
            raise ValueError(exception_message)

    def configure_metrics(self) -> None:
        """Initialize the performance metrics."""
        num_classes: int = self.hparams["model_args"]["num_classes"]
        ignore_index: int = self.hparams["ignore_index"]
        class_names = self.hparams["class_names"]
        metrics = MetricCollection(
            {
                "Multiclass_Accuracy": MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    multidim_average="global",
                    average="micro",
                ),
                "Multiclass_Accuracy_Class": ClasswiseWrapper(
                    MulticlassAccuracy(
                        num_classes=num_classes,
                        ignore_index=ignore_index,
                        multidim_average="global",
                        average=None,
                    ),
                    labels=class_names,
                ),
                "Multiclass_Jaccard_Index_Micro": MulticlassJaccardIndex(
                    num_classes=num_classes, ignore_index=ignore_index, average="micro"
                ),
                "Multiclass_Jaccard_Index": MulticlassJaccardIndex(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                ),
                "Multiclass_Jaccard_Index_Class": ClasswiseWrapper(
                    MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None),
                    labels=class_names,
                ),
                "Multiclass_F1_Score": MulticlassF1Score(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    multidim_average="global",
                    average="micro",
                ),
            }
        )
        self.train_metrics = metrics.clone(prefix="train/")
        self.val_metrics = metrics.clone(prefix="val/")
        if self.hparams["test_dataloaders_names"] is not None:
            self.test_metrics = nn.ModuleList(
                [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]]
            )
        else:
            self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")])

    def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the train loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]
        other_keys = batch.keys() - {"image", "mask", "filename"}

        rest = {k: batch[k] for k in other_keys}

        model_output: ModelOutput = self(x, **rest)
        loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
        y_hat_hard = to_segmentation_prediction(model_output)
        self.train_metrics.update(y_hat_hard, y)

        return loss["loss"]

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the test loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]
        other_keys = batch.keys() - {"image", "mask", "filename"}

        rest = {k: batch[k] for k in other_keys}

        model_output: ModelOutput = self(x, **rest)
        if dataloader_idx >= len(self.test_loss_handler):
            msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
            raise ValueError(msg)
        loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.test_loss_handler[dataloader_idx].log_loss(
            partial(self.log, add_dataloader_idx=False),  # We don't need the dataloader idx as prefixes are different
            loss_dict=loss,
            batch_size=y.shape[0],
        )
        y_hat_hard = to_segmentation_prediction(model_output)
        self.test_metrics[dataloader_idx].update(y_hat_hard, y)

    def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the validation loss and additional metrics.
        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]

        other_keys = batch.keys() - {"image", "mask", "filename"}
        rest = {k: batch[k] for k in other_keys}
        model_output: ModelOutput = self(x, **rest)

        loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
        y_hat_hard = to_segmentation_prediction(model_output)
        self.val_metrics.update(y_hat_hard, y)

        if self._do_plot_samples(batch_idx):
            try:
                datamodule = self.trainer.datamodule
                batch["prediction"] = y_hat_hard

                if isinstance(batch["image"], dict):
                    if hasattr(datamodule, "rgb_modality"):
                        # Generic multimodal dataset
                        batch["image"] = batch["image"][datamodule.rgb_modality]
                    else:
                        # Multimodal dataset. Assuming first item to be the modality to visualize.
                        batch["image"] = batch["image"][list(batch["image"].keys())[0]]

                for key in ["image", "mask", "prediction"]:
                    batch[key] = batch[key].cpu()
                sample = unbind_samples(batch)[0]
                fig = datamodule.val_dataset.plot(sample)
                if fig:
                    summary_writer = self.logger.experiment
                    if hasattr(summary_writer, "add_figure"):
                        summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
                    elif hasattr(summary_writer, "log_figure"):
                        summary_writer.log_figure(
                            self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png"
                        )
            except ValueError:
                pass
            finally:
                plt.close()

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the predicted class probabilities.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.

        Returns:
            Output predicted probabilities.
        """
        x = batch["image"]
        file_names = batch["filename"] if "filename" in batch else None
        other_keys = batch.keys() - {"image", "mask", "filename"}

        rest = {k: batch[k] for k in other_keys}

        model_output: ModelOutput = self(x, **rest)

        def model_forward(x):
            return self(x).output

        if self.tiled_inference_parameters:
            y_hat: Tensor = tiled_inference(
                # TODO: tiled inference does not work with additional input data (**rest)
                model_forward,
                x,
                self.hparams["model_args"]["num_classes"],
                self.tiled_inference_parameters,
            )
        else:
            y_hat: Tensor = self(x, **rest).output

        y_hat = self.select_classes(y_hat)

        return y_hat, file_names

__init__(model_args, model_factory=None, model=None, loss='ce', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, freeze_head=False, plot_on_val=10, class_names=None, tiled_inference_parameters=None, test_dataloaders_names=None, lr_overrides=None, output_most_probable=True) #

Constructor

Parameters:
  • model_args (Dict) –

    Arguments passed to the model factory.

  • model_factory (str, default: None ) –

    ModelFactory class to be used to instantiate the model. Is ignored when model is provided.

  • model (Module, default: None ) –

    Custom model.

  • loss (str, default: 'ce' ) –

    Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss. Defaults to "ce".

  • aux_loss (dict[str, float] | None, default: None ) –

    Auxiliary loss weights. Should be a dictionary where the key is the name given to the loss and the value is the weight to be applied to that loss. The name of the loss should match the key in the dictionary output by the model's forward method containing that output. Defaults to None.

  • class_weights (Union[list[float], None], default: None ) –

    List of class weights to be applied to the loss.

  • class_weights (list[float] | None, default: None ) –

    List of class weights to be applied to the loss. Defaults to None.

  • ignore_index (int | None, default: None ) –

    Label to ignore in the loss computation. Defaults to None.

  • lr (float, default: 0.001 ) –

    Learning rate to be used. Defaults to 0.001.

  • optimizer (str | None, default: None ) –

    Name of optimizer class from torch.optim to be used.

  • optimizer_hparams (dict | None, default: None ) –

    Parameters to be passed for instantiation of the optimizer. Overriden by config / cli specification through LightningCLI.

  • scheduler (str, default: None ) –

    Name of Torch scheduler class from torch.optim.lr_scheduler to be used (e.g. ReduceLROnPlateau). Defaults to None. Overriden by config / cli specification through LightningCLI.

  • scheduler_hparams (dict | None, default: None ) –

    Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI.

  • freeze_backbone (bool, default: False ) –

    Whether to freeze the backbone. Defaults to False.

  • freeze_decoder (bool, default: False ) –

    Whether to freeze the decoder. Defaults to False.

  • freeze_head (bool, default: False ) –

    Whether to freeze the segmentation head. Defaults to False.

  • plot_on_val (bool | int, default: 10 ) –

    Whether to plot visualizations on validation.

  • class_names (list[str] | None, default: None ) –

    List of class names passed to metrics for better naming. Defaults to numeric ordering.

  • tiled_inference_parameters (TiledInferenceParameters | None, default: None ) –

    Inference parameters used to determine if inference is done on the whole image or through tiling.

  • test_dataloaders_names (list[str] | None, default: None ) –

    Names used to differentiate metrics when multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, which assumes only one test dataloader is used.

  • lr_overrides (dict[str, float] | None, default: None ) –

    Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None.

  • output_most_probable (bool, default: True ) –

    A boolean to define if the output during the inference will be just for the most probable class or if it will include all of them.

Source code in terratorch/tasks/segmentation_tasks.py
def __init__(
    self,
    model_args: dict,
    model_factory: str | None = None,
    model: torch.nn.Module | None = None,
    loss: str = "ce",
    aux_heads: list[AuxiliaryHead] | None = None,
    aux_loss: dict[str, float] | None = None,
    class_weights: list[float] | None = None,
    ignore_index: int | None = None,
    lr: float = 0.001,
    # the following are optional so CLI doesnt need to pass them
    optimizer: str | None = None,
    optimizer_hparams: dict | None = None,
    scheduler: str | None = None,
    scheduler_hparams: dict | None = None,
    #
    freeze_backbone: bool = False,  # noqa: FBT001, FBT002
    freeze_decoder: bool = False,  # noqa: FBT002, FBT001
    freeze_head: bool = False, 
    plot_on_val: bool | int = 10,
    class_names: list[str] | None = None,
    tiled_inference_parameters: TiledInferenceParameters = None,
    test_dataloaders_names: list[str] | None = None,
    lr_overrides: dict[str, float] | None = None,
    output_most_probable: bool = True,
) -> None:
    """Constructor

    Args:
        Defaults to None.
        model_args (Dict): Arguments passed to the model factory.
        model_factory (str, optional): ModelFactory class to be used to instantiate the model.
            Is ignored when model is provided.
        model (torch.nn.Module, optional): Custom model.
        loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
            Defaults to "ce".
        aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
            Should be a dictionary where the key is the name given to the loss
            and the value is the weight to be applied to that loss.
            The name of the loss should match the key in the dictionary output by the model's forward
            method containing that output. Defaults to None.
        class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
        class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
            Defaults to None.
        ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
        lr (float, optional): Learning rate to be used. Defaults to 0.001.
        optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
        If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
        optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
            Overriden by config / cli specification through LightningCLI.
        scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
            to be used (e.g. ReduceLROnPlateau). Defaults to None.
            Overriden by config / cli specification through LightningCLI.
        scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
            Overriden by config / cli specification through LightningCLI.
        freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
        freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
        freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False.
        plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
        If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
        class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
            Defaults to numeric ordering.
        tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
            used to determine if inference is done on the whole image or through tiling.
        test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
            multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
            which assumes only one test dataloader is used.
        lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
            parameters. The key should be a substring of the parameter names (it will check the substring is
            contained in the parameter name)and the value should be the new lr. Defaults to None.
        output_most_probable (bool): A boolean to define if the output during the inference will be just
            for the most probable class or if it will include all of them. 
    """
    self.tiled_inference_parameters = tiled_inference_parameters
    self.aux_loss = aux_loss
    self.aux_heads = aux_heads

    if model is not None and model_factory is not None:
        logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
    if model is None and model_factory is None:
        raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

    if model_factory and model is None:
        self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

    super().__init__(task="segmentation")

    if model is not None:
        # Custom model
        self.model = model

    self.train_loss_handler = LossHandler(self.train_metrics.prefix)
    self.test_loss_handler: list[LossHandler] = []
    for metrics in self.test_metrics:
        self.test_loss_handler.append(LossHandler(metrics.prefix))
    self.val_loss_handler = LossHandler(self.val_metrics.prefix)
    self.monitor = f"{self.val_metrics.prefix}loss"
    self.plot_on_val = int(plot_on_val)
    self.output_most_probable = output_most_probable

    if output_most_probable:
        self.select_classes = lambda y: y.argmax(dim=1) 
    else:
        self.select_classes = lambda y: y

configure_losses() #

Initialize the loss criterion.

Raises:
  • ValueError

    If loss is invalid.

Source code in terratorch/tasks/segmentation_tasks.py
def configure_losses(self) -> None:
    """Initialize the loss criterion.

    Raises:
        ValueError: If *loss* is invalid.
    """
    loss: str = self.hparams["loss"]
    ignore_index = self.hparams["ignore_index"]

    class_weights = (
        torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
    )
    if loss == "ce":
        ignore_value = -100 if ignore_index is None else ignore_index
        self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
    elif loss == "jaccard":
        if ignore_index is not None:
            exception_message = (
                f"Jaccard loss does not support ignore_index, but found non-None value of {ignore_index}."
            )
            raise RuntimeError(exception_message)
        self.criterion = smp.losses.JaccardLoss(mode="multiclass")
    elif loss == "focal":
        self.criterion = smp.losses.FocalLoss("multiclass", ignore_index=ignore_index, normalized=True)
    elif loss == "dice":
        self.criterion = smp.losses.DiceLoss("multiclass", ignore_index=ignore_index)
    else:
        exception_message = (
            f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss."
        )
        raise ValueError(exception_message)

configure_metrics() #

Initialize the performance metrics.

Source code in terratorch/tasks/segmentation_tasks.py
def configure_metrics(self) -> None:
    """Initialize the performance metrics."""
    num_classes: int = self.hparams["model_args"]["num_classes"]
    ignore_index: int = self.hparams["ignore_index"]
    class_names = self.hparams["class_names"]
    metrics = MetricCollection(
        {
            "Multiclass_Accuracy": MulticlassAccuracy(
                num_classes=num_classes,
                ignore_index=ignore_index,
                multidim_average="global",
                average="micro",
            ),
            "Multiclass_Accuracy_Class": ClasswiseWrapper(
                MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    multidim_average="global",
                    average=None,
                ),
                labels=class_names,
            ),
            "Multiclass_Jaccard_Index_Micro": MulticlassJaccardIndex(
                num_classes=num_classes, ignore_index=ignore_index, average="micro"
            ),
            "Multiclass_Jaccard_Index": MulticlassJaccardIndex(
                num_classes=num_classes,
                ignore_index=ignore_index,
            ),
            "Multiclass_Jaccard_Index_Class": ClasswiseWrapper(
                MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None),
                labels=class_names,
            ),
            "Multiclass_F1_Score": MulticlassF1Score(
                num_classes=num_classes,
                ignore_index=ignore_index,
                multidim_average="global",
                average="micro",
            ),
        }
    )
    self.train_metrics = metrics.clone(prefix="train/")
    self.val_metrics = metrics.clone(prefix="val/")
    if self.hparams["test_dataloaders_names"] is not None:
        self.test_metrics = nn.ModuleList(
            [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]]
        )
    else:
        self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")])

predict_step(batch, batch_idx, dataloader_idx=0) #

Compute the predicted class probabilities.

Parameters:
  • batch (Any) –

    The output of your DataLoader.

  • batch_idx (int) –

    Integer displaying index of this batch.

  • dataloader_idx (int, default: 0 ) –

    Index of the current dataloader.

Returns:
  • Tensor

    Output predicted probabilities.

Source code in terratorch/tasks/segmentation_tasks.py
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the predicted class probabilities.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.

    Returns:
        Output predicted probabilities.
    """
    x = batch["image"]
    file_names = batch["filename"] if "filename" in batch else None
    other_keys = batch.keys() - {"image", "mask", "filename"}

    rest = {k: batch[k] for k in other_keys}

    model_output: ModelOutput = self(x, **rest)

    def model_forward(x):
        return self(x).output

    if self.tiled_inference_parameters:
        y_hat: Tensor = tiled_inference(
            # TODO: tiled inference does not work with additional input data (**rest)
            model_forward,
            x,
            self.hparams["model_args"]["num_classes"],
            self.tiled_inference_parameters,
        )
    else:
        y_hat: Tensor = self(x, **rest).output

    y_hat = self.select_classes(y_hat)

    return y_hat, file_names

test_step(batch, batch_idx, dataloader_idx=0) #

Compute the test loss and additional metrics.

Parameters:
  • batch (Any) –

    The output of your DataLoader.

  • batch_idx (int) –

    Integer displaying index of this batch.

  • dataloader_idx (int, default: 0 ) –

    Index of the current dataloader.

Source code in terratorch/tasks/segmentation_tasks.py
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the test loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]
    other_keys = batch.keys() - {"image", "mask", "filename"}

    rest = {k: batch[k] for k in other_keys}

    model_output: ModelOutput = self(x, **rest)
    if dataloader_idx >= len(self.test_loss_handler):
        msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
        raise ValueError(msg)
    loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.test_loss_handler[dataloader_idx].log_loss(
        partial(self.log, add_dataloader_idx=False),  # We don't need the dataloader idx as prefixes are different
        loss_dict=loss,
        batch_size=y.shape[0],
    )
    y_hat_hard = to_segmentation_prediction(model_output)
    self.test_metrics[dataloader_idx].update(y_hat_hard, y)

training_step(batch, batch_idx, dataloader_idx=0) #

Compute the train loss and additional metrics.

Parameters:
  • batch (Any) –

    The output of your DataLoader.

  • batch_idx (int) –

    Integer displaying index of this batch.

  • dataloader_idx (int, default: 0 ) –

    Index of the current dataloader.

Source code in terratorch/tasks/segmentation_tasks.py
def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the train loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]
    other_keys = batch.keys() - {"image", "mask", "filename"}

    rest = {k: batch[k] for k in other_keys}

    model_output: ModelOutput = self(x, **rest)
    loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
    y_hat_hard = to_segmentation_prediction(model_output)
    self.train_metrics.update(y_hat_hard, y)

    return loss["loss"]

validation_step(batch, batch_idx, dataloader_idx=0) #

Compute the validation loss and additional metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader.

Source code in terratorch/tasks/segmentation_tasks.py
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the validation loss and additional metrics.
    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]

    other_keys = batch.keys() - {"image", "mask", "filename"}
    rest = {k: batch[k] for k in other_keys}
    model_output: ModelOutput = self(x, **rest)

    loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
    y_hat_hard = to_segmentation_prediction(model_output)
    self.val_metrics.update(y_hat_hard, y)

    if self._do_plot_samples(batch_idx):
        try:
            datamodule = self.trainer.datamodule
            batch["prediction"] = y_hat_hard

            if isinstance(batch["image"], dict):
                if hasattr(datamodule, "rgb_modality"):
                    # Generic multimodal dataset
                    batch["image"] = batch["image"][datamodule.rgb_modality]
                else:
                    # Multimodal dataset. Assuming first item to be the modality to visualize.
                    batch["image"] = batch["image"][list(batch["image"].keys())[0]]

            for key in ["image", "mask", "prediction"]:
                batch[key] = batch[key].cpu()
            sample = unbind_samples(batch)[0]
            fig = datamodule.val_dataset.plot(sample)
            if fig:
                summary_writer = self.logger.experiment
                if hasattr(summary_writer, "add_figure"):
                    summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
                elif hasattr(summary_writer, "log_figure"):
                    summary_writer.log_figure(
                        self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png"
                    )
        except ValueError:
            pass
        finally:
            plt.close()

terratorch.tasks.regression_tasks.PixelwiseRegressionTask #

Bases: TerraTorchTask

Pixelwise Regression Task that accepts models from a range of sources.

This class is analog in functionality to PixelwiseRegressionTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - Allows to evaluate on multiple test dataloaders

Source code in terratorch/tasks/regression_tasks.py
class PixelwiseRegressionTask(TerraTorchTask):
    """Pixelwise Regression Task that accepts models from a range of sources.

    This class is analog in functionality to PixelwiseRegressionTask defined by torchgeo.
    However, it has some important differences:
        - Accepts the specification of a model factory
        - Logs metrics per class
        - Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
        - Allows the setting of optimizers in the constructor
        - Allows to evaluate on multiple test dataloaders"""

    def __init__(
        self,
        model_args: dict,
        model_factory: str | None = None,
        model: torch.nn.Module | None = None,
        loss: str = "mse",
        aux_heads: list[AuxiliaryHead] | None = None,
        aux_loss: dict[str, float] | None = None,
        class_weights: list[float] | None = None,
        ignore_index: int | None = None,
        lr: float = 0.001,
        # the following are optional so CLI doesnt need to pass them
        optimizer: str | None = None,
        optimizer_hparams: dict | None = None,
        scheduler: str | None = None,
        scheduler_hparams: dict | None = None,
        #
        freeze_backbone: bool = False,  # noqa: FBT001, FBT002
        freeze_decoder: bool = False,  # noqa: FBT001, FBT002
        freeze_head: bool = False,  # noqa: FBT001, FBT002
        plot_on_val: bool | int = 10,
        tiled_inference_parameters: TiledInferenceParameters | None = None,
        test_dataloaders_names: list[str] | None = None,
        lr_overrides: dict[str, float] | None = None,
    ) -> None:
        """Constructor

        Args:
            model_args (Dict): Arguments passed to the model factory.
            model_factory (str, optional): Name of ModelFactory class to be used to instantiate the model.
                Is ignored when model is provided.
            model (torch.nn.Module, optional): Custom model.
            loss (str, optional): Loss to be used. Currently, supports 'mse', 'rmse', 'mae' or 'huber' loss.
                Defaults to "mse".
            aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
                Should be a dictionary where the key is the name given to the loss
                and the value is the weight to be applied to that loss.
                The name of the loss should match the key in the dictionary output by the model's forward
                method containing that output. Defaults to None.
            class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
                Defaults to None.
            ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
            lr (float, optional): Learning rate to be used. Defaults to 0.001.
            optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
                If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
            optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
                Overriden by config / cli specification through LightningCLI.
            scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
                to be used (e.g. ReduceLROnPlateau). Defaults to None.
                Overriden by config / cli specification through LightningCLI.
            scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
                Overriden by config / cli specification through LightningCLI.
            freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
            freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
            freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False.
            plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
                If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
            tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
                used to determine if inference is done on the whole image or through tiling.
            test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
                multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
                which assumes only one test dataloader is used.
            lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
                parameters. The key should be a substring of the parameter names (it will check the substring is
                contained in the parameter name)and the value should be the new lr. Defaults to None.
        """
        self.tiled_inference_parameters = tiled_inference_parameters
        self.aux_loss = aux_loss
        self.aux_heads = aux_heads

        if model is not None and model_factory is not None:
            logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
        if model is None and model_factory is None:
            raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

        if model_factory and model is None:
            self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

        super().__init__(task="regression")

        if model:
            # Custom_model
            self.model = model

        self.train_loss_handler = LossHandler(self.train_metrics.prefix)
        self.test_loss_handler: list[LossHandler] = []
        for metrics in self.test_metrics:
            self.test_loss_handler.append(LossHandler(metrics.prefix))
        self.val_loss_handler = LossHandler(self.val_metrics.prefix)
        self.monitor = f"{self.val_metrics.prefix}loss"
        self.plot_on_val = int(plot_on_val)

    def configure_losses(self) -> None:
        """Initialize the loss criterion.

        Raises:
            ValueError: If *loss* is invalid.
        """
        loss: str = self.hparams["loss"].lower()
        if loss == "mse":
            self.criterion: nn.Module = IgnoreIndexLossWrapper(
                nn.MSELoss(reduction="none"), self.hparams["ignore_index"]
            )
        elif loss == "mae":
            self.criterion = IgnoreIndexLossWrapper(nn.L1Loss(reduction="none"), self.hparams["ignore_index"])
        elif loss == "rmse":
            # IMPORTANT! Root is done only after ignore index! Otherwise the mean taken is incorrect
            self.criterion = RootLossWrapper(
                IgnoreIndexLossWrapper(nn.MSELoss(reduction="none"), self.hparams["ignore_index"]), reduction=None
            )
        elif loss == "huber":
            self.criterion = IgnoreIndexLossWrapper(nn.HuberLoss(reduction="none"), self.hparams["ignore_index"])
        else:
            exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'mse', 'rmse' or 'mae' loss."
            raise ValueError(exception_message)

    def configure_metrics(self) -> None:
        """Initialize the performance metrics."""

        def instantiate_metrics():
            return {
                "RMSE": MeanSquaredError(squared=False),
                "MSE": MeanSquaredError(squared=True),
                "MAE": MeanAbsoluteError(),
            }

        def wrap_metrics_with_ignore_index(metrics):
            return {
                name: IgnoreIndexMetricWrapper(metric, ignore_index=self.hparams["ignore_index"])
                for name, metric in metrics.items()
            }

        self.train_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="train/")
        self.val_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="val/")
        if self.hparams["test_dataloaders_names"] is not None:
            self.test_metrics = nn.ModuleList(
                [
                    MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix=f"test/{dl_name}/")
                    for dl_name in self.hparams["test_dataloaders_names"]
                ]
            )
        else:
            self.test_metrics = nn.ModuleList(
                [MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/")]
            )

    def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the train loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]
        other_keys = batch.keys() - {"image", "mask", "filename"}
        rest = {k: batch[k] for k in other_keys}

        model_output: ModelOutput = self(x, **rest)
        loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
        y_hat = model_output.output
        self.train_metrics.update(y_hat, y)

        return loss["loss"]

    def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the validation loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]
        other_keys = batch.keys() - {"image", "mask", "filename"}
        rest = {k: batch[k] for k in other_keys}
        model_output: ModelOutput = self(x, **rest)
        loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
        y_hat = model_output.output
        self.val_metrics.update(y_hat, y)

        if self._do_plot_samples(batch_idx):
            try:
                datamodule = self.trainer.datamodule
                batch["prediction"] = y_hat
                if isinstance(batch["image"], dict):
                    # Multimodal input
                    batch["image"] = batch["image"][self.trainer.datamodule.rgb_modality]
                for key in ["image", "mask", "prediction"]:
                    batch[key] = batch[key].cpu()
                sample = unbind_samples(batch)[0]
                fig = datamodule.val_dataset.plot(sample)
                if fig:
                    summary_writer = self.logger.experiment
                    if hasattr(summary_writer, "add_figure"):
                        summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
                    elif hasattr(summary_writer, "log_figure"):
                        summary_writer.log_figure(
                            self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png"
                        )
            except ValueError:
                pass
            finally:
                plt.close()

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the test loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["mask"]
        other_keys = batch.keys() - {"image", "mask", "filename"}
        rest = {k: batch[k] for k in other_keys}
        model_output: ModelOutput = self(x, **rest)
        if dataloader_idx >= len(self.test_loss_handler):
            msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
            raise ValueError(msg)
        loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.test_loss_handler[dataloader_idx].log_loss(
            partial(self.log, add_dataloader_idx=False),  # We don't need the dataloader idx as prefixes are different
            loss_dict=loss,
            batch_size=x.shape[0],
        )
        y_hat = model_output.output
        self.test_metrics[dataloader_idx].update(y_hat, y)

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the predicted class probabilities.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.

        Returns:
            Output predicted probabilities.
        """
        x = batch["image"]
        file_names = batch["filename"] if "filename" in batch else None
        other_keys = batch.keys() - {"image", "mask", "filename"}
        rest = {k: batch[k] for k in other_keys}

        def model_forward(x):
            return self(x).output

        if self.tiled_inference_parameters:
            # TODO: tiled inference does not work with additional input data (**rest)
            y_hat: Tensor = tiled_inference(model_forward, x, 1, self.tiled_inference_parameters)
        else:
            y_hat: Tensor = self(x, **rest).output
        return y_hat, file_names

__init__(model_args, model_factory=None, model=None, loss='mse', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, freeze_head=False, plot_on_val=10, tiled_inference_parameters=None, test_dataloaders_names=None, lr_overrides=None) #

Constructor

Parameters:
  • model_args (Dict) –

    Arguments passed to the model factory.

  • model_factory (str, default: None ) –

    Name of ModelFactory class to be used to instantiate the model. Is ignored when model is provided.

  • model (Module, default: None ) –

    Custom model.

  • loss (str, default: 'mse' ) –

    Loss to be used. Currently, supports 'mse', 'rmse', 'mae' or 'huber' loss. Defaults to "mse".

  • aux_loss (dict[str, float] | None, default: None ) –

    Auxiliary loss weights. Should be a dictionary where the key is the name given to the loss and the value is the weight to be applied to that loss. The name of the loss should match the key in the dictionary output by the model's forward method containing that output. Defaults to None.

  • class_weights (list[float] | None, default: None ) –

    List of class weights to be applied to the loss. Defaults to None.

  • ignore_index (int | None, default: None ) –

    Label to ignore in the loss computation. Defaults to None.

  • lr (float, default: 0.001 ) –

    Learning rate to be used. Defaults to 0.001.

  • optimizer (str | None, default: None ) –

    Name of optimizer class from torch.optim to be used. If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.

  • optimizer_hparams (dict | None, default: None ) –

    Parameters to be passed for instantiation of the optimizer. Overriden by config / cli specification through LightningCLI.

  • scheduler (str, default: None ) –

    Name of Torch scheduler class from torch.optim.lr_scheduler to be used (e.g. ReduceLROnPlateau). Defaults to None. Overriden by config / cli specification through LightningCLI.

  • scheduler_hparams (dict | None, default: None ) –

    Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI.

  • freeze_backbone (bool, default: False ) –

    Whether to freeze the backbone. Defaults to False.

  • freeze_decoder (bool, default: False ) –

    Whether to freeze the decoder. Defaults to False.

  • freeze_head (bool, default: False ) –

    Whether to freeze the segmentation head. Defaults to False.

  • plot_on_val (bool | int, default: 10 ) –

    Whether to plot visualizations on validation. If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.

  • tiled_inference_parameters (TiledInferenceParameters | None, default: None ) –

    Inference parameters used to determine if inference is done on the whole image or through tiling.

  • test_dataloaders_names (list[str] | None, default: None ) –

    Names used to differentiate metrics when multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, which assumes only one test dataloader is used.

  • lr_overrides (dict[str, float] | None, default: None ) –

    Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None.

Source code in terratorch/tasks/regression_tasks.py
def __init__(
    self,
    model_args: dict,
    model_factory: str | None = None,
    model: torch.nn.Module | None = None,
    loss: str = "mse",
    aux_heads: list[AuxiliaryHead] | None = None,
    aux_loss: dict[str, float] | None = None,
    class_weights: list[float] | None = None,
    ignore_index: int | None = None,
    lr: float = 0.001,
    # the following are optional so CLI doesnt need to pass them
    optimizer: str | None = None,
    optimizer_hparams: dict | None = None,
    scheduler: str | None = None,
    scheduler_hparams: dict | None = None,
    #
    freeze_backbone: bool = False,  # noqa: FBT001, FBT002
    freeze_decoder: bool = False,  # noqa: FBT001, FBT002
    freeze_head: bool = False,  # noqa: FBT001, FBT002
    plot_on_val: bool | int = 10,
    tiled_inference_parameters: TiledInferenceParameters | None = None,
    test_dataloaders_names: list[str] | None = None,
    lr_overrides: dict[str, float] | None = None,
) -> None:
    """Constructor

    Args:
        model_args (Dict): Arguments passed to the model factory.
        model_factory (str, optional): Name of ModelFactory class to be used to instantiate the model.
            Is ignored when model is provided.
        model (torch.nn.Module, optional): Custom model.
        loss (str, optional): Loss to be used. Currently, supports 'mse', 'rmse', 'mae' or 'huber' loss.
            Defaults to "mse".
        aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
            Should be a dictionary where the key is the name given to the loss
            and the value is the weight to be applied to that loss.
            The name of the loss should match the key in the dictionary output by the model's forward
            method containing that output. Defaults to None.
        class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
            Defaults to None.
        ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
        lr (float, optional): Learning rate to be used. Defaults to 0.001.
        optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
            If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
        optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
            Overriden by config / cli specification through LightningCLI.
        scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
            to be used (e.g. ReduceLROnPlateau). Defaults to None.
            Overriden by config / cli specification through LightningCLI.
        scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
            Overriden by config / cli specification through LightningCLI.
        freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
        freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
        freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False.
        plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
            If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
        tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
            used to determine if inference is done on the whole image or through tiling.
        test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
            multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
            which assumes only one test dataloader is used.
        lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
            parameters. The key should be a substring of the parameter names (it will check the substring is
            contained in the parameter name)and the value should be the new lr. Defaults to None.
    """
    self.tiled_inference_parameters = tiled_inference_parameters
    self.aux_loss = aux_loss
    self.aux_heads = aux_heads

    if model is not None and model_factory is not None:
        logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
    if model is None and model_factory is None:
        raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

    if model_factory and model is None:
        self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

    super().__init__(task="regression")

    if model:
        # Custom_model
        self.model = model

    self.train_loss_handler = LossHandler(self.train_metrics.prefix)
    self.test_loss_handler: list[LossHandler] = []
    for metrics in self.test_metrics:
        self.test_loss_handler.append(LossHandler(metrics.prefix))
    self.val_loss_handler = LossHandler(self.val_metrics.prefix)
    self.monitor = f"{self.val_metrics.prefix}loss"
    self.plot_on_val = int(plot_on_val)

configure_losses() #

Initialize the loss criterion.

Raises:
  • ValueError

    If loss is invalid.

Source code in terratorch/tasks/regression_tasks.py
def configure_losses(self) -> None:
    """Initialize the loss criterion.

    Raises:
        ValueError: If *loss* is invalid.
    """
    loss: str = self.hparams["loss"].lower()
    if loss == "mse":
        self.criterion: nn.Module = IgnoreIndexLossWrapper(
            nn.MSELoss(reduction="none"), self.hparams["ignore_index"]
        )
    elif loss == "mae":
        self.criterion = IgnoreIndexLossWrapper(nn.L1Loss(reduction="none"), self.hparams["ignore_index"])
    elif loss == "rmse":
        # IMPORTANT! Root is done only after ignore index! Otherwise the mean taken is incorrect
        self.criterion = RootLossWrapper(
            IgnoreIndexLossWrapper(nn.MSELoss(reduction="none"), self.hparams["ignore_index"]), reduction=None
        )
    elif loss == "huber":
        self.criterion = IgnoreIndexLossWrapper(nn.HuberLoss(reduction="none"), self.hparams["ignore_index"])
    else:
        exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'mse', 'rmse' or 'mae' loss."
        raise ValueError(exception_message)

configure_metrics() #

Initialize the performance metrics.

Source code in terratorch/tasks/regression_tasks.py
def configure_metrics(self) -> None:
    """Initialize the performance metrics."""

    def instantiate_metrics():
        return {
            "RMSE": MeanSquaredError(squared=False),
            "MSE": MeanSquaredError(squared=True),
            "MAE": MeanAbsoluteError(),
        }

    def wrap_metrics_with_ignore_index(metrics):
        return {
            name: IgnoreIndexMetricWrapper(metric, ignore_index=self.hparams["ignore_index"])
            for name, metric in metrics.items()
        }

    self.train_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="train/")
    self.val_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="val/")
    if self.hparams["test_dataloaders_names"] is not None:
        self.test_metrics = nn.ModuleList(
            [
                MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix=f"test/{dl_name}/")
                for dl_name in self.hparams["test_dataloaders_names"]
            ]
        )
    else:
        self.test_metrics = nn.ModuleList(
            [MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/")]
        )

predict_step(batch, batch_idx, dataloader_idx=0) #

Compute the predicted class probabilities.

Parameters:
  • batch (Any) –

    The output of your DataLoader.

  • batch_idx (int) –

    Integer displaying index of this batch.

  • dataloader_idx (int, default: 0 ) –

    Index of the current dataloader.

Returns:
  • Tensor

    Output predicted probabilities.

Source code in terratorch/tasks/regression_tasks.py
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the predicted class probabilities.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.

    Returns:
        Output predicted probabilities.
    """
    x = batch["image"]
    file_names = batch["filename"] if "filename" in batch else None
    other_keys = batch.keys() - {"image", "mask", "filename"}
    rest = {k: batch[k] for k in other_keys}

    def model_forward(x):
        return self(x).output

    if self.tiled_inference_parameters:
        # TODO: tiled inference does not work with additional input data (**rest)
        y_hat: Tensor = tiled_inference(model_forward, x, 1, self.tiled_inference_parameters)
    else:
        y_hat: Tensor = self(x, **rest).output
    return y_hat, file_names

test_step(batch, batch_idx, dataloader_idx=0) #

Compute the test loss and additional metrics.

Parameters:
  • batch (Any) –

    The output of your DataLoader.

  • batch_idx (int) –

    Integer displaying index of this batch.

  • dataloader_idx (int, default: 0 ) –

    Index of the current dataloader.

Source code in terratorch/tasks/regression_tasks.py
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the test loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]
    other_keys = batch.keys() - {"image", "mask", "filename"}
    rest = {k: batch[k] for k in other_keys}
    model_output: ModelOutput = self(x, **rest)
    if dataloader_idx >= len(self.test_loss_handler):
        msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
        raise ValueError(msg)
    loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.test_loss_handler[dataloader_idx].log_loss(
        partial(self.log, add_dataloader_idx=False),  # We don't need the dataloader idx as prefixes are different
        loss_dict=loss,
        batch_size=x.shape[0],
    )
    y_hat = model_output.output
    self.test_metrics[dataloader_idx].update(y_hat, y)

training_step(batch, batch_idx, dataloader_idx=0) #

Compute the train loss and additional metrics.

Parameters:
  • batch (Any) –

    The output of your DataLoader.

  • batch_idx (int) –

    Integer displaying index of this batch.

  • dataloader_idx (int, default: 0 ) –

    Index of the current dataloader.

Source code in terratorch/tasks/regression_tasks.py
def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the train loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]
    other_keys = batch.keys() - {"image", "mask", "filename"}
    rest = {k: batch[k] for k in other_keys}

    model_output: ModelOutput = self(x, **rest)
    loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
    y_hat = model_output.output
    self.train_metrics.update(y_hat, y)

    return loss["loss"]

validation_step(batch, batch_idx, dataloader_idx=0) #

Compute the validation loss and additional metrics.

Parameters:
  • batch (Any) –

    The output of your DataLoader.

  • batch_idx (int) –

    Integer displaying index of this batch.

  • dataloader_idx (int, default: 0 ) –

    Index of the current dataloader.

Source code in terratorch/tasks/regression_tasks.py
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the validation loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["mask"]
    other_keys = batch.keys() - {"image", "mask", "filename"}
    rest = {k: batch[k] for k in other_keys}
    model_output: ModelOutput = self(x, **rest)
    loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
    y_hat = model_output.output
    self.val_metrics.update(y_hat, y)

    if self._do_plot_samples(batch_idx):
        try:
            datamodule = self.trainer.datamodule
            batch["prediction"] = y_hat
            if isinstance(batch["image"], dict):
                # Multimodal input
                batch["image"] = batch["image"][self.trainer.datamodule.rgb_modality]
            for key in ["image", "mask", "prediction"]:
                batch[key] = batch[key].cpu()
            sample = unbind_samples(batch)[0]
            fig = datamodule.val_dataset.plot(sample)
            if fig:
                summary_writer = self.logger.experiment
                if hasattr(summary_writer, "add_figure"):
                    summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
                elif hasattr(summary_writer, "log_figure"):
                    summary_writer.log_figure(
                        self.logger.run_id, fig, f"epoch_{self.current_epoch}_{batch_idx}.png"
                    )
        except ValueError:
            pass
        finally:
            plt.close()

terratorch.tasks.classification_tasks.ClassificationTask #

Bases: TerraTorchTask

Classification Task that accepts models from a range of sources.

This class is analog in functionality to the class ClassificationTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - It provides mIoU with both Micro and Macro averaging - Allows to evaluate on multiple test dataloaders

.. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect minority class accuracy. * 'Macro' averaging gives equal weight to each class, useful for balanced performance assessment across imbalanced classes.

Source code in terratorch/tasks/classification_tasks.py
class ClassificationTask(TerraTorchTask):
    """Classification Task that accepts models from a range of sources.

    This class is analog in functionality to the class ClassificationTask defined by torchgeo.
    However, it has some important differences:
        - Accepts the specification of a model factory
        - Logs metrics per class
        - Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
        - Allows the setting of optimizers in the constructor
        - It provides mIoU with both Micro and Macro averaging
        - Allows to evaluate on multiple test dataloaders

    .. note::
           * 'Micro' averaging suits overall performance evaluation but may not reflect
             minority class accuracy.
           * 'Macro' averaging gives equal weight to each class, useful
             for balanced performance assessment across imbalanced classes.
    """

    def __init__(
        self,
        model_args: dict,
        model_factory: str | None = None,
        model: torch.nn.Module | None = None,
        loss: str = "ce",
        aux_heads: list[AuxiliaryHead] | None = None,
        aux_loss: dict[str, float] | None = None,
        class_weights: list[float] | None = None,
        ignore_index: int | None = None,
        lr: float = 0.001,
        # the following are optional so CLI doesnt need to pass them
        optimizer: str | None = None,
        optimizer_hparams: dict | None = None,
        scheduler: str | None = None,
        scheduler_hparams: dict | None = None,
        #
        #
        freeze_backbone: bool = False,  # noqa: FBT001, FBT002
        freeze_decoder: bool = False,  # noqa: FBT002, FBT001
        freeze_head: bool = False,  # noqa: FBT002, FBT001
        class_names: list[str] | None = None,
        test_dataloaders_names: list[str] | None = None,
        lr_overrides: dict[str, float] | None = None,
    ) -> None:
        """Constructor

        Args:
            Defaults to None.
            model_args (Dict): Arguments passed to the model factory.
            model_factory (str, optional): ModelFactory class to be used to instantiate the model.
                Is ignored when model is provided.
            model (torch.nn.Module, optional): Custom model.
            loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
                Defaults to "ce".
            aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
                Should be a dictionary where the key is the name given to the loss
                and the value is the weight to be applied to that loss.
                The name of the loss should match the key in the dictionary output by the model's forward
                method containing that output. Defaults to None.
            class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
            class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
                Defaults to None.
            ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
            lr (float, optional): Learning rate to be used. Defaults to 0.001.
            optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
                If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
            optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
                Overriden by config / cli specification through LightningCLI.
            scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
                to be used (e.g. ReduceLROnPlateau). Defaults to None.
                Overriden by config / cli specification through LightningCLI.
            scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
                Overriden by config / cli specification through LightningCLI.
            freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
            freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
            freeze_head (bool, optional): Whether to freeze the segmentation_head. Defaults to False.
            class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
                Defaults to numeric ordering.
            test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
                multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
                which assumes only one test dataloader is used.
            lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
                parameters. The key should be a substring of the parameter names (it will check the substring is
                contained in the parameter name)and the value should be the new lr. Defaults to None.
        """
        self.aux_loss = aux_loss
        self.aux_heads = aux_heads

        if model is not None and model_factory is not None:
            logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
        if model is None and model_factory is None:
            raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

        if model_factory and model is None:
            self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

        super().__init__(task="classification")

        if model:
            # Custom model
            self.model = model

        self.train_loss_handler = LossHandler(self.train_metrics.prefix)
        self.test_loss_handler: list[LossHandler] = []
        for metrics in self.test_metrics:
            self.test_loss_handler.append(LossHandler(metrics.prefix))
        self.val_loss_handler = LossHandler(self.val_metrics.prefix)
        self.monitor = f"{self.val_metrics.prefix}loss"

    def configure_losses(self) -> None:
        """Initialize the loss criterion.

        Raises:
            ValueError: If *loss* is invalid.
        """
        loss: str = self.hparams["loss"]
        ignore_index = self.hparams["ignore_index"]

        class_weights = (
            torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
        )
        if loss == "ce":
            ignore_value = -100 if ignore_index is None else ignore_index
            self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
        elif loss == "bce":
            self.criterion = nn.BCEWithLogitsLoss()
        elif loss == "jaccard":
            self.criterion = JaccardLoss(mode="multiclass")
        elif loss == "focal":
            self.criterion = FocalLoss(mode="multiclass", normalized=True)
        else:
            msg = f"Loss type '{loss}' is not valid."
            raise ValueError(msg)

    def configure_metrics(self) -> None:
        """Initialize the performance metrics."""
        num_classes: int = self.hparams["model_args"]["num_classes"]
        ignore_index: int = self.hparams["ignore_index"]
        class_names = self.hparams["class_names"]
        metrics = MetricCollection(
            {
                "Overall_Accuracy": MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="micro",
                ),
                "Average_Accuracy": MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="macro",
                ),
                "Multiclass_Accuracy_Class": ClasswiseWrapper(
                    MulticlassAccuracy(
                        num_classes=num_classes,
                        ignore_index=ignore_index,
                        average=None,
                    ),
                    labels=class_names,
                ),
                "Multiclass_Jaccard_Index": MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index),
                "Multiclass_Jaccard_Index_Class": ClasswiseWrapper(
                    MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None),
                    labels=class_names,
                ),
                # why FBetaScore
                "Multiclass_F1_Score": MulticlassFBetaScore(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    beta=1.0,
                    average="micro",
                ),
            }
        )
        self.train_metrics = metrics.clone(prefix="train/")
        self.val_metrics = metrics.clone(prefix="val/")
        if self.hparams["test_dataloaders_names"] is not None:
            self.test_metrics = nn.ModuleList(
                [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]]
            )
        else:
            self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")])

    def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the train loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["label"]
        other_keys = batch.keys() - {"image", "label", "filename"}
        rest = {k: batch[k] for k in other_keys}

        model_output: ModelOutput = self(x, **rest)
        loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
        y_hat_hard = to_class_prediction(model_output)
        self.train_metrics.update(y_hat_hard, y)

        return loss["loss"]

    def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the validation loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["label"]
        other_keys = batch.keys() - {"image", "label", "filename"}
        rest = {k: batch[k] for k in other_keys}
        model_output: ModelOutput = self(x, **rest)
        loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
        y_hat_hard = to_class_prediction(model_output)
        self.val_metrics.update(y_hat_hard, y)

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        """Compute the test loss and additional metrics.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.
        """
        x = batch["image"]
        y = batch["label"]
        other_keys = batch.keys() - {"image", "label", "filename"}
        rest = {k: batch[k] for k in other_keys}
        model_output: ModelOutput = self(x, **rest)
        if dataloader_idx >= len(self.test_loss_handler):
            msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
            raise ValueError(msg)
        loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
        self.test_loss_handler[dataloader_idx].log_loss(
            partial(self.log, add_dataloader_idx=False),  # We don't need the dataloader idx as prefixes are different
            loss_dict=loss,
            batch_size=x.shape[0],
        )
        y_hat_hard = to_class_prediction(model_output)
        self.test_metrics[dataloader_idx].update(y_hat_hard, y)

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
        """Compute the predicted class probabilities.

        Args:
            batch: The output of your DataLoader.
            batch_idx: Integer displaying index of this batch.
            dataloader_idx: Index of the current dataloader.

        Returns:
            Output predicted probabilities.
        """
        x = batch["image"]
        file_names = batch["filename"] if "filename" in batch else None
        other_keys = batch.keys() - {"image", "label", "filename"}
        rest = {k: batch[k] for k in other_keys}
        model_output: ModelOutput = self(x, **rest)

        y_hat = self(x).output
        y_hat = y_hat.argmax(dim=1)
        return y_hat, file_names

__init__(model_args, model_factory=None, model=None, loss='ce', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=None, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, freeze_head=False, class_names=None, test_dataloaders_names=None, lr_overrides=None) #

Constructor

Parameters:
  • model_args (Dict) –

    Arguments passed to the model factory.

  • model_factory (str, default: None ) –

    ModelFactory class to be used to instantiate the model. Is ignored when model is provided.

  • model (Module, default: None ) –

    Custom model.

  • loss (str, default: 'ce' ) –

    Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss. Defaults to "ce".

  • aux_loss (dict[str, float] | None, default: None ) –

    Auxiliary loss weights. Should be a dictionary where the key is the name given to the loss and the value is the weight to be applied to that loss. The name of the loss should match the key in the dictionary output by the model's forward method containing that output. Defaults to None.

  • class_weights (Union[list[float], None], default: None ) –

    List of class weights to be applied to the loss.

  • class_weights (list[float] | None, default: None ) –

    List of class weights to be applied to the loss. Defaults to None.

  • ignore_index (int | None, default: None ) –

    Label to ignore in the loss computation. Defaults to None.

  • lr (float, default: 0.001 ) –

    Learning rate to be used. Defaults to 0.001.

  • optimizer (str | None, default: None ) –

    Name of optimizer class from torch.optim to be used. If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.

  • optimizer_hparams (dict | None, default: None ) –

    Parameters to be passed for instantiation of the optimizer. Overriden by config / cli specification through LightningCLI.

  • scheduler (str, default: None ) –

    Name of Torch scheduler class from torch.optim.lr_scheduler to be used (e.g. ReduceLROnPlateau). Defaults to None. Overriden by config / cli specification through LightningCLI.

  • scheduler_hparams (dict | None, default: None ) –

    Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI.

  • freeze_backbone (bool, default: False ) –

    Whether to freeze the backbone. Defaults to False.

  • freeze_decoder (bool, default: False ) –

    Whether to freeze the decoder. Defaults to False.

  • freeze_head (bool, default: False ) –

    Whether to freeze the segmentation_head. Defaults to False.

  • class_names (list[str] | None, default: None ) –

    List of class names passed to metrics for better naming. Defaults to numeric ordering.

  • test_dataloaders_names (list[str] | None, default: None ) –

    Names used to differentiate metrics when multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, which assumes only one test dataloader is used.

  • lr_overrides (dict[str, float] | None, default: None ) –

    Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None.

Source code in terratorch/tasks/classification_tasks.py
def __init__(
    self,
    model_args: dict,
    model_factory: str | None = None,
    model: torch.nn.Module | None = None,
    loss: str = "ce",
    aux_heads: list[AuxiliaryHead] | None = None,
    aux_loss: dict[str, float] | None = None,
    class_weights: list[float] | None = None,
    ignore_index: int | None = None,
    lr: float = 0.001,
    # the following are optional so CLI doesnt need to pass them
    optimizer: str | None = None,
    optimizer_hparams: dict | None = None,
    scheduler: str | None = None,
    scheduler_hparams: dict | None = None,
    #
    #
    freeze_backbone: bool = False,  # noqa: FBT001, FBT002
    freeze_decoder: bool = False,  # noqa: FBT002, FBT001
    freeze_head: bool = False,  # noqa: FBT002, FBT001
    class_names: list[str] | None = None,
    test_dataloaders_names: list[str] | None = None,
    lr_overrides: dict[str, float] | None = None,
) -> None:
    """Constructor

    Args:
        Defaults to None.
        model_args (Dict): Arguments passed to the model factory.
        model_factory (str, optional): ModelFactory class to be used to instantiate the model.
            Is ignored when model is provided.
        model (torch.nn.Module, optional): Custom model.
        loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
            Defaults to "ce".
        aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
            Should be a dictionary where the key is the name given to the loss
            and the value is the weight to be applied to that loss.
            The name of the loss should match the key in the dictionary output by the model's forward
            method containing that output. Defaults to None.
        class_weights (Union[list[float], None], optional): List of class weights to be applied to the loss.
        class_weights (list[float] | None, optional): List of class weights to be applied to the loss.
            Defaults to None.
        ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
        lr (float, optional): Learning rate to be used. Defaults to 0.001.
        optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
            If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
        optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
            Overriden by config / cli specification through LightningCLI.
        scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
            to be used (e.g. ReduceLROnPlateau). Defaults to None.
            Overriden by config / cli specification through LightningCLI.
        scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
            Overriden by config / cli specification through LightningCLI.
        freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
        freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
        freeze_head (bool, optional): Whether to freeze the segmentation_head. Defaults to False.
        class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
            Defaults to numeric ordering.
        test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
            multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
            which assumes only one test dataloader is used.
        lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
            parameters. The key should be a substring of the parameter names (it will check the substring is
            contained in the parameter name)and the value should be the new lr. Defaults to None.
    """
    self.aux_loss = aux_loss
    self.aux_heads = aux_heads

    if model is not None and model_factory is not None:
        logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
    if model is None and model_factory is None:
        raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

    if model_factory and model is None:
        self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

    super().__init__(task="classification")

    if model:
        # Custom model
        self.model = model

    self.train_loss_handler = LossHandler(self.train_metrics.prefix)
    self.test_loss_handler: list[LossHandler] = []
    for metrics in self.test_metrics:
        self.test_loss_handler.append(LossHandler(metrics.prefix))
    self.val_loss_handler = LossHandler(self.val_metrics.prefix)
    self.monitor = f"{self.val_metrics.prefix}loss"

configure_losses() #

Initialize the loss criterion.

Raises:
  • ValueError

    If loss is invalid.

Source code in terratorch/tasks/classification_tasks.py
def configure_losses(self) -> None:
    """Initialize the loss criterion.

    Raises:
        ValueError: If *loss* is invalid.
    """
    loss: str = self.hparams["loss"]
    ignore_index = self.hparams["ignore_index"]

    class_weights = (
        torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
    )
    if loss == "ce":
        ignore_value = -100 if ignore_index is None else ignore_index
        self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
    elif loss == "bce":
        self.criterion = nn.BCEWithLogitsLoss()
    elif loss == "jaccard":
        self.criterion = JaccardLoss(mode="multiclass")
    elif loss == "focal":
        self.criterion = FocalLoss(mode="multiclass", normalized=True)
    else:
        msg = f"Loss type '{loss}' is not valid."
        raise ValueError(msg)

configure_metrics() #

Initialize the performance metrics.

Source code in terratorch/tasks/classification_tasks.py
def configure_metrics(self) -> None:
    """Initialize the performance metrics."""
    num_classes: int = self.hparams["model_args"]["num_classes"]
    ignore_index: int = self.hparams["ignore_index"]
    class_names = self.hparams["class_names"]
    metrics = MetricCollection(
        {
            "Overall_Accuracy": MulticlassAccuracy(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="micro",
            ),
            "Average_Accuracy": MulticlassAccuracy(
                num_classes=num_classes,
                ignore_index=ignore_index,
                average="macro",
            ),
            "Multiclass_Accuracy_Class": ClasswiseWrapper(
                MulticlassAccuracy(
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average=None,
                ),
                labels=class_names,
            ),
            "Multiclass_Jaccard_Index": MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index),
            "Multiclass_Jaccard_Index_Class": ClasswiseWrapper(
                MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None),
                labels=class_names,
            ),
            # why FBetaScore
            "Multiclass_F1_Score": MulticlassFBetaScore(
                num_classes=num_classes,
                ignore_index=ignore_index,
                beta=1.0,
                average="micro",
            ),
        }
    )
    self.train_metrics = metrics.clone(prefix="train/")
    self.val_metrics = metrics.clone(prefix="val/")
    if self.hparams["test_dataloaders_names"] is not None:
        self.test_metrics = nn.ModuleList(
            [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]]
        )
    else:
        self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")])

predict_step(batch, batch_idx, dataloader_idx=0) #

Compute the predicted class probabilities.

Parameters:
  • batch (Any) –

    The output of your DataLoader.

  • batch_idx (int) –

    Integer displaying index of this batch.

  • dataloader_idx (int, default: 0 ) –

    Index of the current dataloader.

Returns:
  • Tensor

    Output predicted probabilities.

Source code in terratorch/tasks/classification_tasks.py
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the predicted class probabilities.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.

    Returns:
        Output predicted probabilities.
    """
    x = batch["image"]
    file_names = batch["filename"] if "filename" in batch else None
    other_keys = batch.keys() - {"image", "label", "filename"}
    rest = {k: batch[k] for k in other_keys}
    model_output: ModelOutput = self(x, **rest)

    y_hat = self(x).output
    y_hat = y_hat.argmax(dim=1)
    return y_hat, file_names

test_step(batch, batch_idx, dataloader_idx=0) #

Compute the test loss and additional metrics.

Parameters:
  • batch (Any) –

    The output of your DataLoader.

  • batch_idx (int) –

    Integer displaying index of this batch.

  • dataloader_idx (int, default: 0 ) –

    Index of the current dataloader.

Source code in terratorch/tasks/classification_tasks.py
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the test loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["label"]
    other_keys = batch.keys() - {"image", "label", "filename"}
    rest = {k: batch[k] for k in other_keys}
    model_output: ModelOutput = self(x, **rest)
    if dataloader_idx >= len(self.test_loss_handler):
        msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
        raise ValueError(msg)
    loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.test_loss_handler[dataloader_idx].log_loss(
        partial(self.log, add_dataloader_idx=False),  # We don't need the dataloader idx as prefixes are different
        loss_dict=loss,
        batch_size=x.shape[0],
    )
    y_hat_hard = to_class_prediction(model_output)
    self.test_metrics[dataloader_idx].update(y_hat_hard, y)

training_step(batch, batch_idx, dataloader_idx=0) #

Compute the train loss and additional metrics.

Parameters:
  • batch (Any) –

    The output of your DataLoader.

  • batch_idx (int) –

    Integer displaying index of this batch.

  • dataloader_idx (int, default: 0 ) –

    Index of the current dataloader.

Source code in terratorch/tasks/classification_tasks.py
def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    """Compute the train loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["label"]
    other_keys = batch.keys() - {"image", "label", "filename"}
    rest = {k: batch[k] for k in other_keys}

    model_output: ModelOutput = self(x, **rest)
    loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
    y_hat_hard = to_class_prediction(model_output)
    self.train_metrics.update(y_hat_hard, y)

    return loss["loss"]

validation_step(batch, batch_idx, dataloader_idx=0) #

Compute the validation loss and additional metrics.

Parameters:
  • batch (Any) –

    The output of your DataLoader.

  • batch_idx (int) –

    Integer displaying index of this batch.

  • dataloader_idx (int, default: 0 ) –

    Index of the current dataloader.

Source code in terratorch/tasks/classification_tasks.py
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    """Compute the validation loss and additional metrics.

    Args:
        batch: The output of your DataLoader.
        batch_idx: Integer displaying index of this batch.
        dataloader_idx: Index of the current dataloader.
    """
    x = batch["image"]
    y = batch["label"]
    other_keys = batch.keys() - {"image", "label", "filename"}
    rest = {k: batch[k] for k in other_keys}
    model_output: ModelOutput = self(x, **rest)
    loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
    self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
    y_hat_hard = to_class_prediction(model_output)
    self.val_metrics.update(y_hat_hard, y)