Skip to content

Overview (for developers)

The main goal of the design is to extend TorchGeo's existing tasks to be able to handle Prithvi backbones with appropriate decoders and heads. At the same time, we wish to keep the existing TorchGeo functionality intact so it can be leveraged with pretrained models that are already included.

We achieve this by making new tasks that accept model factory classes, containing a build_model method. This strategy in principle allows arbitrary models to be trained for these tasks, given they respect some reasonable minimal interface. Together with this, we provide the EncoderDecoderFactory, which should enable users to plug together different Encoders and Decoders, with the aid of Necks for intermediate operations.

Additionally, we extend TorchGeo with generic datasets and datamodules which can be defined at runtime, rather than requiring classes to be defined beforehand.

The glue that holds everything together is LightningCLI, allowing the model, datamodule and Lightning Trainer to be instantiated from a config file or from the CLI. We make extensive use of for training and inference.

Initial reading for a full understanding of the platform includes:

Tasks

Tasks are the main coordinators for training and inference for specific tasks. They are LightningModules that contain a model and abstract away all the logic for training steps, metric computation and inference.

One of the most important design decisions was delegating the model construction to a model factory. This has a few advantages:

  • Avoids code repetition among tasks - different tasks can use the same factory
  • Prefers composition over inheritance
  • Allows new models to be easily added by introducing new factories

Models are expected to be torch.nn.Modules and implement the Model interface, providing:

  • freeze_encoder()
  • freeze_decoder()
  • forward()

Additionally, the forward() method is expected to return an object of type ModelOutput, containing the main head's output, as well as any additional auxiliary outputs. The names of these auxiliary heads are matched with the names of the provided auxiliary losses.

Models

Models constructed by the EncoderDecoderFactory have an internal structure explicitly divided into backbones, necks, decoders and heads. This structure is provided by the PixelWiseModel and ScalarOutputModel classes.

However, as long as models implement the Model interface, and return ModelOutput in their forward method, they can take on any structure.

terratorch.models.pixel_wise_model.PixelWiseModel

Bases: Model, SegmentationModel

Model that encapsulates encoder and decoder and heads Expects decoder to have a "forward_features" method, an embed_dims property and optionally a "prepare_features_for_image_model" method.

Source code in terratorch/models/pixel_wise_model.py
class PixelWiseModel(Model, SegmentationModel):
    """Model that encapsulates encoder and decoder and heads
    Expects decoder to have a "forward_features" method, an embed_dims property
    and optionally a "prepare_features_for_image_model" method.
    """

    def __init__(
        self,
        task: str,
        encoder: nn.Module,
        decoder: nn.Module,
        head_kwargs: dict,
        decoder_includes_head: bool = False,
        auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
        neck: nn.Module | None = None,
        rescale: bool = True,  # noqa: FBT002, FBT001
    ) -> None:
        """Constructor

        Args:
            task (str): Task to be performed. One of segmentation or regression.
            encoder (nn.Module): Encoder to be used
            decoder (nn.Module): Decoder to be used
            head_kwargs (dict): Arguments to be passed at instantiation of the head.
            decoder_includes_head (bool): Whether the decoder already incldes a head. If true, a head will not be added. Defaults to False.
            auxiliary_heads (list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None, optional): List of
                AuxiliaryHeads with heads to be instantiated. Defaults to None.
            neck (nn.Module | None): Module applied between backbone and decoder.
                Defaults to None, which applies the identity.
            rescale (bool, optional): Rescale the output of the model if it has a different size than the ground truth.
                Uses bilinear interpolation. Defaults to True.
        """
        super().__init__()

        self.task = task
        self.encoder = encoder
        self.decoder = decoder
        self.head = (
            self._get_head(task, decoder.out_channels, head_kwargs) if not decoder_includes_head else nn.Identity()
        )

        if auxiliary_heads is not None:
            aux_heads = {}
            for aux_head_to_be_instantiated in auxiliary_heads:
                aux_head: nn.Module = self._get_head(
                    task, aux_head_to_be_instantiated.decoder.out_channels, head_kwargs
                ) if not aux_head_to_be_instantiated.decoder_includes_head else nn.Identity()
                aux_head = nn.Sequential(aux_head_to_be_instantiated.decoder, aux_head)
                aux_heads[aux_head_to_be_instantiated.name] = aux_head
        else:
            aux_heads = {}
        self.aux_heads = nn.ModuleDict(aux_heads)

        self.neck = neck
        self.rescale = rescale

    def freeze_encoder(self):
        freeze_module(self.encoder)

    def freeze_decoder(self):
        freeze_module(self.decoder)
        freeze_module(self.head)

    # TODO: do this properly
    def check_input_shape(self, x: torch.Tensor) -> bool:  # noqa: ARG002
        return True

    @staticmethod
    def _check_for_single_channel_and_squeeze(x):
        if x.shape[1] == 1:
            x = x.squeeze(1)
        return x

    def forward(self, x: torch.Tensor) -> ModelOutput:
        """Sequentially pass `x` through model`s encoder, decoder and heads"""
        self.check_input_shape(x)
        input_size = x.shape[-2:]
        features = self.encoder(x)

        ## only for backwards compatibility with pre-neck times.
        if self.neck:
            prepare = self.neck
        else:
            # for backwards compatibility, if this is defined in the encoder, use it
            prepare = getattr(self.encoder, "prepare_features_for_image_model", lambda x: x)

        features = prepare(features)
        decoder_output = self.decoder([f.clone() for f in features])
        mask = self.head(decoder_output)
        if self.rescale and mask.shape[-2:] != input_size:
            mask = F.interpolate(mask, size=input_size, mode="bilinear")
        mask = self._check_for_single_channel_and_squeeze(mask)
        aux_outputs = {}
        for name, decoder in self.aux_heads.items():
            aux_output = decoder([f.clone() for f in features])
            if self.rescale and aux_output.shape[-2:] != input_size:
                aux_output = F.interpolate(aux_output, size=input_size, mode="bilinear")
            aux_output = self._check_for_single_channel_and_squeeze(aux_output)
            aux_outputs[name] = aux_output
        return ModelOutput(output=mask, auxiliary_heads=aux_outputs)

    def _get_head(self, task: str, input_embed_dim: int, head_kwargs):
        if task == "segmentation":
            if "num_classes" not in head_kwargs:
                msg = "num_classes must be defined for segmentation task"
                raise Exception(msg)
            return SegmentationHead(input_embed_dim, **head_kwargs)
        if task == "regression":
            return RegressionHead(input_embed_dim, **head_kwargs)
        msg = "Task must be one of segmentation or regression."
        raise Exception(msg)
__init__(task, encoder, decoder, head_kwargs, decoder_includes_head=False, auxiliary_heads=None, neck=None, rescale=True)

Constructor

Parameters:

Name Type Description Default
task str

Task to be performed. One of segmentation or regression.

required
encoder Module

Encoder to be used

required
decoder Module

Decoder to be used

required
head_kwargs dict

Arguments to be passed at instantiation of the head.

required
decoder_includes_head bool

Whether the decoder already incldes a head. If true, a head will not be added. Defaults to False.

False
auxiliary_heads list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None

List of AuxiliaryHeads with heads to be instantiated. Defaults to None.

None
neck Module | None

Module applied between backbone and decoder. Defaults to None, which applies the identity.

None
rescale bool

Rescale the output of the model if it has a different size than the ground truth. Uses bilinear interpolation. Defaults to True.

True
Source code in terratorch/models/pixel_wise_model.py
def __init__(
    self,
    task: str,
    encoder: nn.Module,
    decoder: nn.Module,
    head_kwargs: dict,
    decoder_includes_head: bool = False,
    auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
    neck: nn.Module | None = None,
    rescale: bool = True,  # noqa: FBT002, FBT001
) -> None:
    """Constructor

    Args:
        task (str): Task to be performed. One of segmentation or regression.
        encoder (nn.Module): Encoder to be used
        decoder (nn.Module): Decoder to be used
        head_kwargs (dict): Arguments to be passed at instantiation of the head.
        decoder_includes_head (bool): Whether the decoder already incldes a head. If true, a head will not be added. Defaults to False.
        auxiliary_heads (list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None, optional): List of
            AuxiliaryHeads with heads to be instantiated. Defaults to None.
        neck (nn.Module | None): Module applied between backbone and decoder.
            Defaults to None, which applies the identity.
        rescale (bool, optional): Rescale the output of the model if it has a different size than the ground truth.
            Uses bilinear interpolation. Defaults to True.
    """
    super().__init__()

    self.task = task
    self.encoder = encoder
    self.decoder = decoder
    self.head = (
        self._get_head(task, decoder.out_channels, head_kwargs) if not decoder_includes_head else nn.Identity()
    )

    if auxiliary_heads is not None:
        aux_heads = {}
        for aux_head_to_be_instantiated in auxiliary_heads:
            aux_head: nn.Module = self._get_head(
                task, aux_head_to_be_instantiated.decoder.out_channels, head_kwargs
            ) if not aux_head_to_be_instantiated.decoder_includes_head else nn.Identity()
            aux_head = nn.Sequential(aux_head_to_be_instantiated.decoder, aux_head)
            aux_heads[aux_head_to_be_instantiated.name] = aux_head
    else:
        aux_heads = {}
    self.aux_heads = nn.ModuleDict(aux_heads)

    self.neck = neck
    self.rescale = rescale
forward(x)

Sequentially pass x through model`s encoder, decoder and heads

Source code in terratorch/models/pixel_wise_model.py
def forward(self, x: torch.Tensor) -> ModelOutput:
    """Sequentially pass `x` through model`s encoder, decoder and heads"""
    self.check_input_shape(x)
    input_size = x.shape[-2:]
    features = self.encoder(x)

    ## only for backwards compatibility with pre-neck times.
    if self.neck:
        prepare = self.neck
    else:
        # for backwards compatibility, if this is defined in the encoder, use it
        prepare = getattr(self.encoder, "prepare_features_for_image_model", lambda x: x)

    features = prepare(features)
    decoder_output = self.decoder([f.clone() for f in features])
    mask = self.head(decoder_output)
    if self.rescale and mask.shape[-2:] != input_size:
        mask = F.interpolate(mask, size=input_size, mode="bilinear")
    mask = self._check_for_single_channel_and_squeeze(mask)
    aux_outputs = {}
    for name, decoder in self.aux_heads.items():
        aux_output = decoder([f.clone() for f in features])
        if self.rescale and aux_output.shape[-2:] != input_size:
            aux_output = F.interpolate(aux_output, size=input_size, mode="bilinear")
        aux_output = self._check_for_single_channel_and_squeeze(aux_output)
        aux_outputs[name] = aux_output
    return ModelOutput(output=mask, auxiliary_heads=aux_outputs)

terratorch.models.scalar_output_model.ScalarOutputModel

Bases: Model, SegmentationModel

Model that encapsulates encoder and decoder and heads for a scalar output Expects decoder to have a "forward_features" method, an embed_dims property and optionally a "prepare_features_for_image_model" method.

Source code in terratorch/models/scalar_output_model.py
class ScalarOutputModel(Model, SegmentationModel):
    """Model that encapsulates encoder and decoder and heads for a scalar output
    Expects decoder to have a "forward_features" method, an embed_dims property
    and optionally a "prepare_features_for_image_model" method.
    """

    def __init__(
        self,
        task: str,
        encoder: nn.Module,
        decoder: nn.Module,
        head_kwargs: dict,
        decoder_includes_head: bool = False,
        auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
        neck: nn.Module | None = None,
    ) -> None:
        """Constructor

        Args:
            task (str): Task to be performed. Must be "classification".
            encoder (nn.Module): Encoder to be used
            decoder (nn.Module): Decoder to be used
            head_kwargs (dict): Arguments to be passed at instantiation of the head.
            decoder_includes_head (bool): Whether the decoder already incldes a head. If true, a head will not be added. Defaults to False.
            auxiliary_heads (list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None, optional): List of
                AuxiliaryHeads with heads to be instantiated. Defaults to None.
            neck (nn.Module | None): Module applied between backbone and decoder.
                Defaults to None, which applies the identity.
        """
        super().__init__()
        self.task = task
        self.encoder = encoder
        self.decoder = decoder
        self.head = (
            self._get_head(task, decoder.out_channels, head_kwargs) if not decoder_includes_head else nn.Identity()
        )

        if auxiliary_heads is not None:
            aux_heads = {}
            for aux_head_to_be_instantiated in auxiliary_heads:
                aux_head: nn.Module = self._get_head(
                    task, aux_head_to_be_instantiated.decoder.out_channels, head_kwargs
                ) if not aux_head_to_be_instantiated.decoder_includes_head else nn.Identity()
                aux_head = nn.Sequential(aux_head_to_be_instantiated.decoder, aux_head)
                aux_heads[aux_head_to_be_instantiated.name] = aux_head
        else:
            aux_heads = {}
        self.aux_heads = nn.ModuleDict(aux_heads)

        self.neck = neck

    def freeze_encoder(self):
        freeze_module(self.encoder)

    def freeze_decoder(self):
        freeze_module(self.decoder)
        freeze_module(self.head)

    # TODO: do this properly
    def check_input_shape(self, x: torch.Tensor) -> bool:  # noqa: ARG002
        return True

    def forward(self, x: torch.Tensor) -> ModelOutput:
        """Sequentially pass `x` through model`s encoder, decoder and heads"""

        self.check_input_shape(x)
        features = self.encoder(x)

        ## only for backwards compatibility with pre-neck times.
        if self.neck:
            prepare = self.neck
        else:
            # for backwards compatibility, if this is defined in the encoder, use it
            prepare = getattr(self.encoder, "prepare_features_for_image_model", lambda x: x)

        features = prepare(features)

        decoder_output = self.decoder([f.clone() for f in features])
        mask = self.head(decoder_output)
        aux_outputs = {}
        for name, decoder in self.aux_heads.items():
            aux_output = decoder([f.clone() for f in features])
            aux_outputs[name] = aux_output
        return ModelOutput(output=mask, auxiliary_heads=aux_outputs)

    def _get_head(self, task: str, input_embed_dim: int, head_kwargs: dict):
        if task == "classification":
            if "num_classes" not in head_kwargs:
                msg = "num_classes must be defined for classification task"
                raise Exception(msg)
            return ClassificationHead(input_embed_dim, **head_kwargs)
        msg = "Task must be classification."
        raise Exception(msg)
__init__(task, encoder, decoder, head_kwargs, decoder_includes_head=False, auxiliary_heads=None, neck=None)

Constructor

Parameters:

Name Type Description Default
task str

Task to be performed. Must be "classification".

required
encoder Module

Encoder to be used

required
decoder Module

Decoder to be used

required
head_kwargs dict

Arguments to be passed at instantiation of the head.

required
decoder_includes_head bool

Whether the decoder already incldes a head. If true, a head will not be added. Defaults to False.

False
auxiliary_heads list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None

List of AuxiliaryHeads with heads to be instantiated. Defaults to None.

None
neck Module | None

Module applied between backbone and decoder. Defaults to None, which applies the identity.

None
Source code in terratorch/models/scalar_output_model.py
def __init__(
    self,
    task: str,
    encoder: nn.Module,
    decoder: nn.Module,
    head_kwargs: dict,
    decoder_includes_head: bool = False,
    auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
    neck: nn.Module | None = None,
) -> None:
    """Constructor

    Args:
        task (str): Task to be performed. Must be "classification".
        encoder (nn.Module): Encoder to be used
        decoder (nn.Module): Decoder to be used
        head_kwargs (dict): Arguments to be passed at instantiation of the head.
        decoder_includes_head (bool): Whether the decoder already incldes a head. If true, a head will not be added. Defaults to False.
        auxiliary_heads (list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None, optional): List of
            AuxiliaryHeads with heads to be instantiated. Defaults to None.
        neck (nn.Module | None): Module applied between backbone and decoder.
            Defaults to None, which applies the identity.
    """
    super().__init__()
    self.task = task
    self.encoder = encoder
    self.decoder = decoder
    self.head = (
        self._get_head(task, decoder.out_channels, head_kwargs) if not decoder_includes_head else nn.Identity()
    )

    if auxiliary_heads is not None:
        aux_heads = {}
        for aux_head_to_be_instantiated in auxiliary_heads:
            aux_head: nn.Module = self._get_head(
                task, aux_head_to_be_instantiated.decoder.out_channels, head_kwargs
            ) if not aux_head_to_be_instantiated.decoder_includes_head else nn.Identity()
            aux_head = nn.Sequential(aux_head_to_be_instantiated.decoder, aux_head)
            aux_heads[aux_head_to_be_instantiated.name] = aux_head
    else:
        aux_heads = {}
    self.aux_heads = nn.ModuleDict(aux_heads)

    self.neck = neck
forward(x)

Sequentially pass x through model`s encoder, decoder and heads

Source code in terratorch/models/scalar_output_model.py
def forward(self, x: torch.Tensor) -> ModelOutput:
    """Sequentially pass `x` through model`s encoder, decoder and heads"""

    self.check_input_shape(x)
    features = self.encoder(x)

    ## only for backwards compatibility with pre-neck times.
    if self.neck:
        prepare = self.neck
    else:
        # for backwards compatibility, if this is defined in the encoder, use it
        prepare = getattr(self.encoder, "prepare_features_for_image_model", lambda x: x)

    features = prepare(features)

    decoder_output = self.decoder([f.clone() for f in features])
    mask = self.head(decoder_output)
    aux_outputs = {}
    for name, decoder in self.aux_heads.items():
        aux_output = decoder([f.clone() for f in features])
        aux_outputs[name] = aux_output
    return ModelOutput(output=mask, auxiliary_heads=aux_outputs)

EncoderDecoderFactory

We expect this factory to be widely employed by users. With that in mind, we dive deeper into it here.

Loss

For convenience, we provide a loss handler that can be used to compute the full loss (from the main head and auxiliary heads as well).

terratorch.tasks.loss_handler

LossHandler

Class to help handle the computation and logging of loss

Source code in terratorch/tasks/loss_handler.py
class LossHandler:
    """Class to help handle the computation and logging of loss"""

    def __init__(self, loss_prefix: str) -> None:
        """Constructor

        Args:
            loss_prefix (str): Prefix to be prepended to all the metrics (e.g. training).
        """
        self.loss_prefix = loss_prefix

    def compute_loss(
        self,
        model_output: ModelOutput,
        ground_truth: Tensor,
        criterion: Callable,
        aux_loss_weights: dict[str, float] | None,
    ) -> dict[str, Tensor]:
        """Compute the loss for the mean decode head as well as other heads

        Args:
            model_output (ModelOutput): Output from the model
            ground_truth (Tensor): Tensor with labels
            criterion (Callable): Loss function to be applied
            aux_loss_weights (Union[dict[str, float], None]): Dictionary of names of model auxiliary
                heads and their weights

        Raises:
            Exception: If the keys in aux_loss_weights and the model output do not match, will raise an exception.

        Returns:
            dict[str, Tensor]: Dictionary of computed losses. Total loss is returned under the key "loss".
                If there are auxiliary heads, the main decode head is returned under the key "decode_head".
                All other heads are returned with the same key as their name.
        """

        loss = self._compute_loss(model_output.output, ground_truth, criterion)
        if not model_output.auxiliary_heads:
            return {"loss": loss}

        if aux_loss_weights is None:
            msg = "Auxiliary heads given with no aux_loss_weights"
            raise Exception(msg)
        all_losses = {}
        all_losses["decode_head"] = loss
        total_loss = loss.clone()
        # incorporate aux heads
        model_output_names = set(model_output.auxiliary_heads.keys())
        aux_loss_names = set(aux_loss_weights.keys())
        if aux_loss_names != model_output_names:
            msg = f"Found difference in declared auxiliary losses and model outputs.\n \
                Found in declared losses but not in model output: {aux_loss_names - model_output_names}. \n \
                Found in model output but not in delcared losses: {model_output_names - aux_loss_names}"
            raise Exception(msg)

        for loss_name, loss_weight in aux_loss_weights.items():
            output = model_output.auxiliary_heads[loss_name]
            loss_value: Tensor = self._compute_loss(output, ground_truth, criterion)
            all_losses[loss_name] = loss_value
            total_loss = total_loss + loss_value * loss_weight

        all_losses["loss"] = total_loss
        return all_losses

    def _compute_loss(self, y_hat: Tensor, ground_truth: Tensor, criterion: Callable):
        loss: Tensor = criterion(y_hat, ground_truth)
        return loss

    def log_loss(
        self, log_function: Callable, loss_dict: dict[str, Tensor] | None = None, batch_size: int | None = None
    ) -> None:
        """Log the loss. If auxiliary heads exist, log the full loss suffix "loss", and then all other losses.

        Args:
            log_function (Callable): _description_
            loss_dict (dict[str, Tensor], optional): _description_. Defaults to None.
        """

        # dont alter passed dict
        all_losses = dict(loss_dict)
        full_loss = all_losses.pop("loss")
        log_function(f"{self.loss_prefix}loss", full_loss, sync_dist=True, batch_size=batch_size)

        for loss_name, loss_value in all_losses.items():
            log_function(
                f"{self.loss_prefix}{loss_name}",
                loss_value,
                on_epoch=True,
                on_step=True,
                sync_dist=True,
                batch_size=batch_size,
            )

__init__(loss_prefix)

Constructor

Parameters:

Name Type Description Default
loss_prefix str

Prefix to be prepended to all the metrics (e.g. training).

required
Source code in terratorch/tasks/loss_handler.py
def __init__(self, loss_prefix: str) -> None:
    """Constructor

    Args:
        loss_prefix (str): Prefix to be prepended to all the metrics (e.g. training).
    """
    self.loss_prefix = loss_prefix

compute_loss(model_output, ground_truth, criterion, aux_loss_weights)

Compute the loss for the mean decode head as well as other heads

Parameters:

Name Type Description Default
model_output ModelOutput

Output from the model

required
ground_truth Tensor

Tensor with labels

required
criterion Callable

Loss function to be applied

required
aux_loss_weights Union[dict[str, float], None]

Dictionary of names of model auxiliary heads and their weights

required

Raises:

Type Description
Exception

If the keys in aux_loss_weights and the model output do not match, will raise an exception.

Returns:

Type Description
dict[str, Tensor]

dict[str, Tensor]: Dictionary of computed losses. Total loss is returned under the key "loss". If there are auxiliary heads, the main decode head is returned under the key "decode_head". All other heads are returned with the same key as their name.

Source code in terratorch/tasks/loss_handler.py
def compute_loss(
    self,
    model_output: ModelOutput,
    ground_truth: Tensor,
    criterion: Callable,
    aux_loss_weights: dict[str, float] | None,
) -> dict[str, Tensor]:
    """Compute the loss for the mean decode head as well as other heads

    Args:
        model_output (ModelOutput): Output from the model
        ground_truth (Tensor): Tensor with labels
        criterion (Callable): Loss function to be applied
        aux_loss_weights (Union[dict[str, float], None]): Dictionary of names of model auxiliary
            heads and their weights

    Raises:
        Exception: If the keys in aux_loss_weights and the model output do not match, will raise an exception.

    Returns:
        dict[str, Tensor]: Dictionary of computed losses. Total loss is returned under the key "loss".
            If there are auxiliary heads, the main decode head is returned under the key "decode_head".
            All other heads are returned with the same key as their name.
    """

    loss = self._compute_loss(model_output.output, ground_truth, criterion)
    if not model_output.auxiliary_heads:
        return {"loss": loss}

    if aux_loss_weights is None:
        msg = "Auxiliary heads given with no aux_loss_weights"
        raise Exception(msg)
    all_losses = {}
    all_losses["decode_head"] = loss
    total_loss = loss.clone()
    # incorporate aux heads
    model_output_names = set(model_output.auxiliary_heads.keys())
    aux_loss_names = set(aux_loss_weights.keys())
    if aux_loss_names != model_output_names:
        msg = f"Found difference in declared auxiliary losses and model outputs.\n \
            Found in declared losses but not in model output: {aux_loss_names - model_output_names}. \n \
            Found in model output but not in delcared losses: {model_output_names - aux_loss_names}"
        raise Exception(msg)

    for loss_name, loss_weight in aux_loss_weights.items():
        output = model_output.auxiliary_heads[loss_name]
        loss_value: Tensor = self._compute_loss(output, ground_truth, criterion)
        all_losses[loss_name] = loss_value
        total_loss = total_loss + loss_value * loss_weight

    all_losses["loss"] = total_loss
    return all_losses

log_loss(log_function, loss_dict=None, batch_size=None)

Log the loss. If auxiliary heads exist, log the full loss suffix "loss", and then all other losses.

Parameters:

Name Type Description Default
log_function Callable

description

required
loss_dict dict[str, Tensor]

description. Defaults to None.

None
Source code in terratorch/tasks/loss_handler.py
def log_loss(
    self, log_function: Callable, loss_dict: dict[str, Tensor] | None = None, batch_size: int | None = None
) -> None:
    """Log the loss. If auxiliary heads exist, log the full loss suffix "loss", and then all other losses.

    Args:
        log_function (Callable): _description_
        loss_dict (dict[str, Tensor], optional): _description_. Defaults to None.
    """

    # dont alter passed dict
    all_losses = dict(loss_dict)
    full_loss = all_losses.pop("loss")
    log_function(f"{self.loss_prefix}loss", full_loss, sync_dist=True, batch_size=batch_size)

    for loss_name, loss_value in all_losses.items():
        log_function(
            f"{self.loss_prefix}{loss_name}",
            loss_value,
            on_epoch=True,
            on_step=True,
            sync_dist=True,
            batch_size=batch_size,
        )

Generic datasets / datamodules

Refer to the section on data

Exporting models

A future feature would be the possibility to save models in ONNX format, and export them that way. This would bring all the benefits of onnx.