Skip to content

Overview (for developers)

The main goal of the design is to extend TorchGeo's existing tasks to be able to handle Prithvi backbones with appropriate decoders and heads. At the same time, we wish to keep the existing TorchGeo functionality intact so it can be leveraged with pretrained models that are already included.

We achieve this by making new tasks that accept model factory classes, containing a build_model method.. This strategy in principle allows arbitrary models to be trained for these tasks, given they respect some reasonable minimal interface.

Additionally, we extend TorchGeo with generic datasets and datamodules which can be defined at runtime, rather than requiring classes to be defined beforehand.

The glue that holds everything together is LightningCLI, allowing the model, datamodule and Lightning Trainer to be instantiated from a config file or from the CLI. We make extensive use of for training and inference.

Initial reading for a full understanding of the platform includes:

Tasks

Tasks are the main coordinators for training and inference for specific tasks. They are LightningModules that contain a model and abstract away all the logic for training steps, metric computation and inference.

One of the most important design decisions was delegating the model construction to a model factory. This has a few advantages:

  • Avoids code repetition among tasks - different tasks can use the same factory
  • Prefers composition over inheritance
  • Allows new models to be easily added by introducing new factories

Models are expected to be torch.nn.Modules and implement the Model interface, providing:

  • freeze_encoder()
  • freeze_decoder()
  • forward()

terratorch.models.model.Model

Bases: ABC

Source code in terratorch/models/model.py
class Model(ABC):
    @abstractmethod
    def freeze_encoder(self):
        pass

    @abstractmethod
    def freeze_decoder(self):
        pass

    @abstractmethod
    def forward(self) -> ModelOutput:
        pass

Additionally, the forward() method is expected to return an object of type ModelOutput, containing the main head's output, as well as any additional auxiliary outputs. The names of these auxiliary heads are matched with the names of the provided auxiliary losses.

terratorch.models.model.ModelOutput dataclass

Source code in terratorch/models/model.py
@dataclass
class ModelOutput:
    output: Tensor
    auxiliary_heads: dict[str, Tensor] = None

Models

In the currently existing model implementations, we explicitly divide the models into backbones, decoders and heads. This structure is provided by the PixelWiseModel and ScalarOutputModel classes.

However, as long as models implement the Model interface, and return ModelOutput in their forward method, they can take on any structure.

terratorch.models.pixel_wise_model.PixelWiseModel

Bases: Model, SegmentationModel

Model that encapsulates encoder and decoder and heads Expects decoder to have a "forward_features" method, an embed_dims property and optionally a "prepare_features_for_image_model" method.

Source code in terratorch/models/pixel_wise_model.py
class PixelWiseModel(Model, SegmentationModel):
    """Model that encapsulates encoder and decoder and heads
    Expects decoder to have a "forward_features" method, an embed_dims property
    and optionally a "prepare_features_for_image_model" method.
    """

    def __init__(
        self,
        task: str,
        encoder: nn.Module,
        decoder: nn.Module,
        head_kwargs: dict,
        auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
        prepare_features_for_image_model: Callable | None = None,
        rescale: bool = True,  # noqa: FBT002, FBT001
    ) -> None:
        """Constructor

        Args:
            task (str): Task to be performed. One of segmentation or regression.
            encoder (nn.Module): Encoder to be used
            decoder (nn.Module): Decoder to be used
            head_kwargs (dict): Arguments to be passed at instantiation of the head.
            auxiliary_heads (list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None, optional): List of
                AuxiliaryHeads with heads to be instantiated. Defaults to None.
            prepare_features_for_image_model (Callable | None, optional): Function applied to encoder outputs.
                Defaults to None.
            rescale (bool, optional): Rescale the output of the model if it has a different size than the ground truth.
                Uses bilinear interpolation. Defaults to True.
        """
        super().__init__()

        if "multiple_embed" in head_kwargs:
            self.multiple_embed = head_kwargs.pop("multiple_embed")
        else:
            self.multiple_embed = False

        self.task = task
        self.encoder = encoder
        self.decoder = decoder
        self.head = self._get_head(task, decoder.output_embed_dim, head_kwargs)

        if auxiliary_heads is not None:
            aux_heads = {}
            for aux_head_to_be_instantiated in auxiliary_heads:
                aux_head: nn.Module = self._get_head(
                    task, aux_head_to_be_instantiated.decoder.output_embed_dim, head_kwargs
                )
                aux_head = nn.Sequential(aux_head_to_be_instantiated.decoder, aux_head)
                aux_heads[aux_head_to_be_instantiated.name] = aux_head
        else:
            aux_heads = {}
        self.aux_heads = nn.ModuleDict(aux_heads)

        self.prepare_features_for_image_model = prepare_features_for_image_model
        self.rescale = rescale

        # Defining the method for dealing withe the encoder embedding
        if self.multiple_embed:
            self.embed_handler = self._multiple_embedding_outputs
        else:
            self.embed_handler = self._single_embedding_output

    def freeze_encoder(self):
        freeze_module(self.encoder)

    def freeze_decoder(self):
        freeze_module(self.encoder)
        freeze_module(self.head)

    def _single_embedding_output(self, features: torch.Tensor) -> torch.Tensor:
        decoder_output = self.decoder([f.clone() for f in features])

        return decoder_output

    def _multiple_embedding_outputs(self, features: tuple[torch.Tensor]) -> torch.Tensor:
        decoder_output = self.decoder(*features)

        return decoder_output

    # TODO: do this properly
    def check_input_shape(self, x: torch.Tensor) -> bool:  # noqa: ARG002
        return True

    @staticmethod
    def _check_for_single_channel_and_squeeze(x):
        if x.shape[1] == 1:
            x = x.squeeze(1)
        return x

    def forward(self, x: torch.Tensor) -> ModelOutput:
        """Sequentially pass `x` through model`s encoder, decoder and heads"""
        self.check_input_shape(x)
        input_size = x.shape[-2:]
        features = self.encoder(x)

        # some models need their features reshaped

        if self.prepare_features_for_image_model:
            prepare = self.prepare_features_for_image_model
        else:
            prepare = getattr(self.encoder, "prepare_features_for_image_model", lambda x: x)

        # Dealing with cases in which the encoder returns more than one
        # output
        features = prepare(features)

        decoder_output = self.embed_handler(features=features)
        mask = self.head(decoder_output)
        if self.rescale and mask.shape[-2:] != input_size:
            mask = F.interpolate(mask, size=input_size, mode="bilinear")
        mask = self._check_for_single_channel_and_squeeze(mask)
        aux_outputs = {}
        for name, decoder in self.aux_heads.items():
            aux_output = decoder([f.clone() for f in features])
            if self.rescale and aux_output.shape[-2:] != input_size:
                aux_output = F.interpolate(aux_output, size=input_size, mode="bilinear")
            aux_output = self._check_for_single_channel_and_squeeze(aux_output)
            aux_outputs[name] = aux_output
        return ModelOutput(output=mask, auxiliary_heads=aux_outputs)

    def _get_head(self, task: str, input_embed_dim: int, head_kwargs):
        if task == "segmentation":
            if "num_classes" not in head_kwargs:
                msg = "num_classes must be defined for segmentation task"
                raise Exception(msg)
            return SegmentationHead(input_embed_dim, **head_kwargs)
        if task == "regression":
            return RegressionHead(input_embed_dim, **head_kwargs)
        msg = "Task must be one of segmentation or regression."
        raise Exception(msg)
__init__(task, encoder, decoder, head_kwargs, auxiliary_heads=None, prepare_features_for_image_model=None, rescale=True)

Constructor

Parameters:

Name Type Description Default
task str

Task to be performed. One of segmentation or regression.

required
encoder Module

Encoder to be used

required
decoder Module

Decoder to be used

required
head_kwargs dict

Arguments to be passed at instantiation of the head.

required
auxiliary_heads list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None

List of AuxiliaryHeads with heads to be instantiated. Defaults to None.

None
prepare_features_for_image_model Callable | None

Function applied to encoder outputs. Defaults to None.

None
rescale bool

Rescale the output of the model if it has a different size than the ground truth. Uses bilinear interpolation. Defaults to True.

True
Source code in terratorch/models/pixel_wise_model.py
def __init__(
    self,
    task: str,
    encoder: nn.Module,
    decoder: nn.Module,
    head_kwargs: dict,
    auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
    prepare_features_for_image_model: Callable | None = None,
    rescale: bool = True,  # noqa: FBT002, FBT001
) -> None:
    """Constructor

    Args:
        task (str): Task to be performed. One of segmentation or regression.
        encoder (nn.Module): Encoder to be used
        decoder (nn.Module): Decoder to be used
        head_kwargs (dict): Arguments to be passed at instantiation of the head.
        auxiliary_heads (list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None, optional): List of
            AuxiliaryHeads with heads to be instantiated. Defaults to None.
        prepare_features_for_image_model (Callable | None, optional): Function applied to encoder outputs.
            Defaults to None.
        rescale (bool, optional): Rescale the output of the model if it has a different size than the ground truth.
            Uses bilinear interpolation. Defaults to True.
    """
    super().__init__()

    if "multiple_embed" in head_kwargs:
        self.multiple_embed = head_kwargs.pop("multiple_embed")
    else:
        self.multiple_embed = False

    self.task = task
    self.encoder = encoder
    self.decoder = decoder
    self.head = self._get_head(task, decoder.output_embed_dim, head_kwargs)

    if auxiliary_heads is not None:
        aux_heads = {}
        for aux_head_to_be_instantiated in auxiliary_heads:
            aux_head: nn.Module = self._get_head(
                task, aux_head_to_be_instantiated.decoder.output_embed_dim, head_kwargs
            )
            aux_head = nn.Sequential(aux_head_to_be_instantiated.decoder, aux_head)
            aux_heads[aux_head_to_be_instantiated.name] = aux_head
    else:
        aux_heads = {}
    self.aux_heads = nn.ModuleDict(aux_heads)

    self.prepare_features_for_image_model = prepare_features_for_image_model
    self.rescale = rescale

    # Defining the method for dealing withe the encoder embedding
    if self.multiple_embed:
        self.embed_handler = self._multiple_embedding_outputs
    else:
        self.embed_handler = self._single_embedding_output
forward(x)

Sequentially pass x through model`s encoder, decoder and heads

Source code in terratorch/models/pixel_wise_model.py
def forward(self, x: torch.Tensor) -> ModelOutput:
    """Sequentially pass `x` through model`s encoder, decoder and heads"""
    self.check_input_shape(x)
    input_size = x.shape[-2:]
    features = self.encoder(x)

    # some models need their features reshaped

    if self.prepare_features_for_image_model:
        prepare = self.prepare_features_for_image_model
    else:
        prepare = getattr(self.encoder, "prepare_features_for_image_model", lambda x: x)

    # Dealing with cases in which the encoder returns more than one
    # output
    features = prepare(features)

    decoder_output = self.embed_handler(features=features)
    mask = self.head(decoder_output)
    if self.rescale and mask.shape[-2:] != input_size:
        mask = F.interpolate(mask, size=input_size, mode="bilinear")
    mask = self._check_for_single_channel_and_squeeze(mask)
    aux_outputs = {}
    for name, decoder in self.aux_heads.items():
        aux_output = decoder([f.clone() for f in features])
        if self.rescale and aux_output.shape[-2:] != input_size:
            aux_output = F.interpolate(aux_output, size=input_size, mode="bilinear")
        aux_output = self._check_for_single_channel_and_squeeze(aux_output)
        aux_outputs[name] = aux_output
    return ModelOutput(output=mask, auxiliary_heads=aux_outputs)

terratorch.models.scalar_output_model.ScalarOutputModel

Bases: Model, SegmentationModel

Model that encapsulates encoder and decoder and heads for a scalar output Expects decoder to have a "forward_features" method, an embed_dims property and optionally a "prepare_features_for_image_model" method.

Source code in terratorch/models/scalar_output_model.py
class ScalarOutputModel(Model, SegmentationModel):
    """Model that encapsulates encoder and decoder and heads for a scalar output
    Expects decoder to have a "forward_features" method, an embed_dims property
    and optionally a "prepare_features_for_image_model" method.
    """

    def __init__(
        self,
        task: str,
        encoder: nn.Module,
        decoder: nn.Module,
        head_kwargs: dict,
        auxiliary_heads: dict[str, nn.Module] | None = None,
        prepare_features_for_image_model: Callable | None = None,
    ) -> None:
        """Constructor

        Args:
            task (str): Task to be performed. Must be "classification".
            encoder (nn.Module): Encoder to be used
            decoder (nn.Module): Decoder to be used
            head_kwargs (dict): Arguments to be passed at instantiation of the head.
            auxiliary_heads (dict[str, nn.Module] | None, optional): Names mapped to auxiliary heads. Defaults to None.
            prepare_features_for_image_model (Callable | None, optional): Function applied to encoder outputs.
                Defaults to None.
        """
        super().__init__()
        self.task = task
        self.encoder = encoder
        self.decoder = decoder
        self.head = self._get_head(task, decoder.output_embed_dim, head_kwargs)

        if auxiliary_heads is not None:
            aux_heads = {}
            for aux_head_to_be_instantiated in auxiliary_heads:
                aux_head: nn.Module = self._get_head(
                    task, aux_head_to_be_instantiated.decoder.output_embed_dim, **head_kwargs
                )
                aux_head = nn.Sequential(aux_head_to_be_instantiated.decoder, aux_head)
                aux_heads[aux_head_to_be_instantiated.name] = aux_head
        else:
            aux_heads = {}
        self.aux_heads = nn.ModuleDict(aux_heads)

        self.prepare_features_for_image_model = prepare_features_for_image_model

    def freeze_encoder(self):
        freeze_module(self.encoder)

    def freeze_decoder(self):
        freeze_module(self.encoder)
        freeze_module(self.head)

    # TODO: do this properly
    def check_input_shape(self, x: torch.Tensor) -> bool:  # noqa: ARG002
        return True

    def forward(self, x: torch.Tensor) -> ModelOutput:
        """Sequentially pass `x` through model`s encoder, decoder and heads"""

        self.check_input_shape(x)
        features = self.encoder(x)

        # some models need their features reshaped

        if self.prepare_features_for_image_model:
            prepare = self.prepare_features_for_image_model
        else:
            prepare = getattr(self.encoder, "prepare_features_for_image_model", lambda x: x)
        features = prepare(features)

        decoder_output = self.decoder([f.clone() for f in features])
        mask = self.head(decoder_output)
        aux_outputs = {}
        for name, decoder in self.aux_heads.items():
            aux_output = decoder([f.clone() for f in features])
            aux_outputs[name] = aux_output
        return ModelOutput(output=mask, auxiliary_heads=aux_outputs)

    def _get_head(self, task: str, input_embed_dim: int, head_kwargs: dict):
        if task == "classification":
            if "num_classes" not in head_kwargs:
                msg = "num_classes must be defined for classification task"
                raise Exception(msg)
            return ClassificationHead(input_embed_dim, **head_kwargs)
        msg = "Task must be classification."
        raise Exception(msg)
__init__(task, encoder, decoder, head_kwargs, auxiliary_heads=None, prepare_features_for_image_model=None)

Constructor

Parameters:

Name Type Description Default
task str

Task to be performed. Must be "classification".

required
encoder Module

Encoder to be used

required
decoder Module

Decoder to be used

required
head_kwargs dict

Arguments to be passed at instantiation of the head.

required
auxiliary_heads dict[str, Module] | None

Names mapped to auxiliary heads. Defaults to None.

None
prepare_features_for_image_model Callable | None

Function applied to encoder outputs. Defaults to None.

None
Source code in terratorch/models/scalar_output_model.py
def __init__(
    self,
    task: str,
    encoder: nn.Module,
    decoder: nn.Module,
    head_kwargs: dict,
    auxiliary_heads: dict[str, nn.Module] | None = None,
    prepare_features_for_image_model: Callable | None = None,
) -> None:
    """Constructor

    Args:
        task (str): Task to be performed. Must be "classification".
        encoder (nn.Module): Encoder to be used
        decoder (nn.Module): Decoder to be used
        head_kwargs (dict): Arguments to be passed at instantiation of the head.
        auxiliary_heads (dict[str, nn.Module] | None, optional): Names mapped to auxiliary heads. Defaults to None.
        prepare_features_for_image_model (Callable | None, optional): Function applied to encoder outputs.
            Defaults to None.
    """
    super().__init__()
    self.task = task
    self.encoder = encoder
    self.decoder = decoder
    self.head = self._get_head(task, decoder.output_embed_dim, head_kwargs)

    if auxiliary_heads is not None:
        aux_heads = {}
        for aux_head_to_be_instantiated in auxiliary_heads:
            aux_head: nn.Module = self._get_head(
                task, aux_head_to_be_instantiated.decoder.output_embed_dim, **head_kwargs
            )
            aux_head = nn.Sequential(aux_head_to_be_instantiated.decoder, aux_head)
            aux_heads[aux_head_to_be_instantiated.name] = aux_head
    else:
        aux_heads = {}
    self.aux_heads = nn.ModuleDict(aux_heads)

    self.prepare_features_for_image_model = prepare_features_for_image_model
forward(x)

Sequentially pass x through model`s encoder, decoder and heads

Source code in terratorch/models/scalar_output_model.py
def forward(self, x: torch.Tensor) -> ModelOutput:
    """Sequentially pass `x` through model`s encoder, decoder and heads"""

    self.check_input_shape(x)
    features = self.encoder(x)

    # some models need their features reshaped

    if self.prepare_features_for_image_model:
        prepare = self.prepare_features_for_image_model
    else:
        prepare = getattr(self.encoder, "prepare_features_for_image_model", lambda x: x)
    features = prepare(features)

    decoder_output = self.decoder([f.clone() for f in features])
    mask = self.head(decoder_output)
    aux_outputs = {}
    for name, decoder in self.aux_heads.items():
        aux_output = decoder([f.clone() for f in features])
        aux_outputs[name] = aux_output
    return ModelOutput(output=mask, auxiliary_heads=aux_outputs)

Backbones

We decide to leverage timm for handling backbones. It is important to understand the advantages and disadvantages of this decision:

Advantages

  1. timm provides an incredibly rich variety of already implemented and validated architectures.
  2. timm provides an API for creating backbones directly as feature extractors, with the features_only=True argument.
  3. timm provides an existing and powerful model registry and factory.

Disadvantages

  1. The documentation on using timm is reasonable, but the documentation on how to develop with timm, particularly for adding new models, is not great.
  2. A few hacks had to be put in place to make our existing model architectures play nicely with the features_only=True functionality.
    1. In particular, we allow models to define a prepare_features_for_image_model function, which is called on the model features just before they are passed to the decoder. This function can be defined on the backbone code itself, or passed to the PixelWiseModel constructor.

Decoders

Currently, we have implemented a simple Fully Convolutional Decoder as well as an UperNetDecoder, which exactly match the definitions in the MMSeg framework. This was mostly done to ensure we could replicate the results from that framework.

However, libraries such as pytorch-segmentation or segmentation_models.pytorch provide a large set of already implemented decoders that can be leveraged.

This is probably a reasonable next step in the implementation. In order to do this, a new factory can simply be created which leverages these libraries. See as an example this section

Heads

In the current implementation, the heads perform the final step in going from the output of the decoder to the final desired output. Often this can just be e.g. a single convolutional head going from the final decoder depth to the number of classes, in the case of segmentation, or to a depth of 1, in the case of regression.

Loss

For convenience, we provide a loss handler that can be used to compute the full loss (from the main head and auxiliary heads as well).

terratorch.tasks.loss_handler

LossHandler

Class to help handle the computation and logging of loss

Source code in terratorch/tasks/loss_handler.py
class LossHandler:
    """Class to help handle the computation and logging of loss"""

    def __init__(self, loss_prefix: str) -> None:
        """Constructor

        Args:
            loss_prefix (str): Prefix to be prepended to all the metrics (e.g. training).
        """
        self.loss_prefix = loss_prefix

    def compute_loss(
        self,
        model_output: ModelOutput,
        ground_truth: Tensor,
        criterion: Callable,
        aux_loss_weights: dict[str, float] | None,
    ) -> dict[str, Tensor]:
        """Compute the loss for the mean decode head as well as other heads

        Args:
            model_output (ModelOutput): Output from the model
            ground_truth (Tensor): Tensor with labels
            criterion (Callable): Loss function to be applied
            aux_loss_weights (Union[dict[str, float], None]): Dictionary of names of model auxiliary
                heads and their weights

        Raises:
            Exception: If the keys in aux_loss_weights and the model output do not match, will raise an exception.

        Returns:
            dict[str, Tensor]: Dictionary of computed losses. Total loss is returned under the key "loss".
                If there are auxiliary heads, the main decode head is returned under the key "decode_head".
                All other heads are returned with the same key as their name.
        """
        loss = self._compute_loss(model_output.output, ground_truth, criterion)
        if not model_output.auxiliary_heads:
            return {"loss": loss}

        if aux_loss_weights is None:
            msg = "Auxiliary heads given with no aux_loss_weights"
            raise Exception(msg)
        all_losses = {}
        all_losses["decode_head"] = loss
        total_loss = loss.clone()
        # incorporate aux heads
        model_output_names = set(model_output.auxiliary_heads.keys())
        aux_loss_names = set(aux_loss_weights.keys())
        if aux_loss_names != model_output_names:
            msg = f"Found difference in declared auxiliary losses and model outputs.\n \
                Found in declared losses but not in model output: {aux_loss_names - model_output_names}. \n \
                Found in model output but not in delcared losses: {model_output_names - aux_loss_names}"
            raise Exception(msg)

        for loss_name, loss_weight in aux_loss_weights.items():
            output = model_output.auxiliary_heads[loss_name]
            loss_value: Tensor = self._compute_loss(output, ground_truth, criterion)
            all_losses[loss_name] = loss_value
            total_loss = total_loss + loss_value * loss_weight

        all_losses["loss"] = total_loss
        return all_losses

    def _compute_loss(self, y_hat: Tensor, ground_truth: Tensor, criterion: Callable):
        loss: Tensor = criterion(y_hat, ground_truth)
        return loss

    def log_loss(
        self, log_function: Callable, loss_dict: dict[str, Tensor] | None = None, batch_size: int | None = None
    ) -> None:
        """Log the loss. If auxiliary heads exist, log the full loss suffix "loss", and then all other losses.

        Args:
            log_function (Callable): _description_
            loss_dict (dict[str, Tensor], optional): _description_. Defaults to None.
        """

        # dont alter passed dict
        all_losses = dict(loss_dict)
        full_loss = all_losses.pop("loss")
        log_function(f"{self.loss_prefix}loss", full_loss, sync_dist=True, batch_size=batch_size)

        for loss_name, loss_value in all_losses.items():
            log_function(
                f"{self.loss_prefix}{loss_name}",
                loss_value,
                on_epoch=True,
                on_step=True,
                sync_dist=True,
                batch_size=batch_size,
            )

__init__(loss_prefix)

Constructor

Parameters:

Name Type Description Default
loss_prefix str

Prefix to be prepended to all the metrics (e.g. training).

required
Source code in terratorch/tasks/loss_handler.py
def __init__(self, loss_prefix: str) -> None:
    """Constructor

    Args:
        loss_prefix (str): Prefix to be prepended to all the metrics (e.g. training).
    """
    self.loss_prefix = loss_prefix

compute_loss(model_output, ground_truth, criterion, aux_loss_weights)

Compute the loss for the mean decode head as well as other heads

Parameters:

Name Type Description Default
model_output ModelOutput

Output from the model

required
ground_truth Tensor

Tensor with labels

required
criterion Callable

Loss function to be applied

required
aux_loss_weights Union[dict[str, float], None]

Dictionary of names of model auxiliary heads and their weights

required

Raises:

Type Description
Exception

If the keys in aux_loss_weights and the model output do not match, will raise an exception.

Returns:

Type Description
dict[str, Tensor]

dict[str, Tensor]: Dictionary of computed losses. Total loss is returned under the key "loss". If there are auxiliary heads, the main decode head is returned under the key "decode_head". All other heads are returned with the same key as their name.

Source code in terratorch/tasks/loss_handler.py
def compute_loss(
    self,
    model_output: ModelOutput,
    ground_truth: Tensor,
    criterion: Callable,
    aux_loss_weights: dict[str, float] | None,
) -> dict[str, Tensor]:
    """Compute the loss for the mean decode head as well as other heads

    Args:
        model_output (ModelOutput): Output from the model
        ground_truth (Tensor): Tensor with labels
        criterion (Callable): Loss function to be applied
        aux_loss_weights (Union[dict[str, float], None]): Dictionary of names of model auxiliary
            heads and their weights

    Raises:
        Exception: If the keys in aux_loss_weights and the model output do not match, will raise an exception.

    Returns:
        dict[str, Tensor]: Dictionary of computed losses. Total loss is returned under the key "loss".
            If there are auxiliary heads, the main decode head is returned under the key "decode_head".
            All other heads are returned with the same key as their name.
    """
    loss = self._compute_loss(model_output.output, ground_truth, criterion)
    if not model_output.auxiliary_heads:
        return {"loss": loss}

    if aux_loss_weights is None:
        msg = "Auxiliary heads given with no aux_loss_weights"
        raise Exception(msg)
    all_losses = {}
    all_losses["decode_head"] = loss
    total_loss = loss.clone()
    # incorporate aux heads
    model_output_names = set(model_output.auxiliary_heads.keys())
    aux_loss_names = set(aux_loss_weights.keys())
    if aux_loss_names != model_output_names:
        msg = f"Found difference in declared auxiliary losses and model outputs.\n \
            Found in declared losses but not in model output: {aux_loss_names - model_output_names}. \n \
            Found in model output but not in delcared losses: {model_output_names - aux_loss_names}"
        raise Exception(msg)

    for loss_name, loss_weight in aux_loss_weights.items():
        output = model_output.auxiliary_heads[loss_name]
        loss_value: Tensor = self._compute_loss(output, ground_truth, criterion)
        all_losses[loss_name] = loss_value
        total_loss = total_loss + loss_value * loss_weight

    all_losses["loss"] = total_loss
    return all_losses

log_loss(log_function, loss_dict=None, batch_size=None)

Log the loss. If auxiliary heads exist, log the full loss suffix "loss", and then all other losses.

Parameters:

Name Type Description Default
log_function Callable

description

required
loss_dict dict[str, Tensor]

description. Defaults to None.

None
Source code in terratorch/tasks/loss_handler.py
def log_loss(
    self, log_function: Callable, loss_dict: dict[str, Tensor] | None = None, batch_size: int | None = None
) -> None:
    """Log the loss. If auxiliary heads exist, log the full loss suffix "loss", and then all other losses.

    Args:
        log_function (Callable): _description_
        loss_dict (dict[str, Tensor], optional): _description_. Defaults to None.
    """

    # dont alter passed dict
    all_losses = dict(loss_dict)
    full_loss = all_losses.pop("loss")
    log_function(f"{self.loss_prefix}loss", full_loss, sync_dist=True, batch_size=batch_size)

    for loss_name, loss_value in all_losses.items():
        log_function(
            f"{self.loss_prefix}{loss_name}",
            loss_value,
            on_epoch=True,
            on_step=True,
            sync_dist=True,
            batch_size=batch_size,
        )

Generic datasets / datamodules

Refer to the section on data

Exporting models

A future feature would be the possibility to save models in ONNX format, and export them that way. This would bring all the benefits of onnx.