Overview (for developers)
The main goal of the design is to extend TorchGeo's existing tasks to be able to handle Prithvi backbones with appropriate decoders and heads.
At the same time, we wish to keep the existing TorchGeo functionality intact so it can be leveraged with pretrained models that are already included.
We achieve this by making new tasks that accept model factory classes, containing a build_model
method. This strategy in principle allows arbitrary models to be trained for these tasks, given they respect some reasonable minimal interface.
Together with this, we provide the EncoderDecoderFactory, which should enable users to plug together different Encoders and Decoders, with the aid of Necks for intermediate operations.
Additionally, we extend TorchGeo with generic datasets and datamodules which can be defined at runtime, rather than requiring classes to be defined beforehand.
The glue that holds everything together is LightningCLI, allowing the model, datamodule and Lightning Trainer to be instantiated from a config file or from the CLI. We make extensive use of for training and inference.
Initial reading for a full understanding of the platform includes:
Tasks
Tasks are the main coordinators for training and inference for specific tasks. They are LightningModules that contain a model and abstract away all the logic for training steps, metric computation and inference.
One of the most important design decisions was delegating the model construction to a model factory. This has a few advantages:
- Avoids code repetition among tasks - different tasks can use the same factory
- Prefers composition over inheritance
- Allows new models to be easily added by introducing new factories
Models are expected to be torch.nn.Module
s and implement the Model interface, providing:
freeze_encoder()
freeze_decoder()
forward()
Additionally, the forward()
method is expected to return an object of type ModelOutput, containing the main head's output, as well as any additional auxiliary outputs. The names of these auxiliary heads are matched with the names of the provided auxiliary losses.
Models
Models constructed by the EncoderDecoderFactory have an internal structure explicitly divided into backbones, necks, decoders and heads. This structure is provided by the PixelWiseModel and ScalarOutputModel classes.
However, as long as models implement the Model interface, and return ModelOutput in their forward method, they can take on any structure.
terratorch.models.pixel_wise_model.PixelWiseModel
Bases: Model
, SegmentationModel
Model that encapsulates encoder and decoder and heads
Expects decoder to have a "forward_features" method, an embed_dims property
and optionally a "prepare_features_for_image_model" method.
Source code in terratorch/models/pixel_wise_model.py
| class PixelWiseModel(Model, SegmentationModel):
"""Model that encapsulates encoder and decoder and heads
Expects decoder to have a "forward_features" method, an embed_dims property
and optionally a "prepare_features_for_image_model" method.
"""
def __init__(
self,
task: str,
encoder: nn.Module,
decoder: nn.Module,
head_kwargs: dict,
decoder_includes_head: bool = False,
auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
neck: nn.Module | None = None,
rescale: bool = True, # noqa: FBT002, FBT001
) -> None:
"""Constructor
Args:
task (str): Task to be performed. One of segmentation or regression.
encoder (nn.Module): Encoder to be used
decoder (nn.Module): Decoder to be used
head_kwargs (dict): Arguments to be passed at instantiation of the head.
decoder_includes_head (bool): Whether the decoder already incldes a head. If true, a head will not be added. Defaults to False.
auxiliary_heads (list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None, optional): List of
AuxiliaryHeads with heads to be instantiated. Defaults to None.
neck (nn.Module | None): Module applied between backbone and decoder.
Defaults to None, which applies the identity.
rescale (bool, optional): Rescale the output of the model if it has a different size than the ground truth.
Uses bilinear interpolation. Defaults to True.
"""
super().__init__()
self.task = task
self.encoder = encoder
self.decoder = decoder
self.head = (
self._get_head(task, decoder.out_channels, head_kwargs) if not decoder_includes_head else nn.Identity()
)
if auxiliary_heads is not None:
aux_heads = {}
for aux_head_to_be_instantiated in auxiliary_heads:
aux_head: nn.Module = self._get_head(
task, aux_head_to_be_instantiated.decoder.out_channels, head_kwargs
) if not aux_head_to_be_instantiated.decoder_includes_head else nn.Identity()
aux_head = nn.Sequential(aux_head_to_be_instantiated.decoder, aux_head)
aux_heads[aux_head_to_be_instantiated.name] = aux_head
else:
aux_heads = {}
self.aux_heads = nn.ModuleDict(aux_heads)
self.neck = neck
self.rescale = rescale
def freeze_encoder(self):
freeze_module(self.encoder)
def freeze_decoder(self):
freeze_module(self.decoder)
freeze_module(self.head)
# TODO: do this properly
def check_input_shape(self, x: torch.Tensor) -> bool: # noqa: ARG002
return True
@staticmethod
def _check_for_single_channel_and_squeeze(x):
if x.shape[1] == 1:
x = x.squeeze(1)
return x
def forward(self, x: torch.Tensor) -> ModelOutput:
"""Sequentially pass `x` through model`s encoder, decoder and heads"""
self.check_input_shape(x)
input_size = x.shape[-2:]
features = self.encoder(x)
## only for backwards compatibility with pre-neck times.
if self.neck:
prepare = self.neck
else:
# for backwards compatibility, if this is defined in the encoder, use it
prepare = getattr(self.encoder, "prepare_features_for_image_model", lambda x: x)
features = prepare(features)
decoder_output = self.decoder([f.clone() for f in features])
mask = self.head(decoder_output)
if self.rescale and mask.shape[-2:] != input_size:
mask = F.interpolate(mask, size=input_size, mode="bilinear")
mask = self._check_for_single_channel_and_squeeze(mask)
aux_outputs = {}
for name, decoder in self.aux_heads.items():
aux_output = decoder([f.clone() for f in features])
if self.rescale and aux_output.shape[-2:] != input_size:
aux_output = F.interpolate(aux_output, size=input_size, mode="bilinear")
aux_output = self._check_for_single_channel_and_squeeze(aux_output)
aux_outputs[name] = aux_output
return ModelOutput(output=mask, auxiliary_heads=aux_outputs)
def _get_head(self, task: str, input_embed_dim: int, head_kwargs):
if task == "segmentation":
if "num_classes" not in head_kwargs:
msg = "num_classes must be defined for segmentation task"
raise Exception(msg)
return SegmentationHead(input_embed_dim, **head_kwargs)
if task == "regression":
return RegressionHead(input_embed_dim, **head_kwargs)
msg = "Task must be one of segmentation or regression."
raise Exception(msg)
|
__init__(task, encoder, decoder, head_kwargs, decoder_includes_head=False, auxiliary_heads=None, neck=None, rescale=True)
Constructor
Parameters:
Name |
Type |
Description |
Default |
task
|
str
|
Task to be performed. One of segmentation or regression.
|
required
|
encoder
|
Module
|
|
required
|
decoder
|
Module
|
|
required
|
head_kwargs
|
dict
|
Arguments to be passed at instantiation of the head.
|
required
|
decoder_includes_head
|
bool
|
Whether the decoder already incldes a head. If true, a head will not be added. Defaults to False.
|
False
|
auxiliary_heads
|
list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None
|
List of
AuxiliaryHeads with heads to be instantiated. Defaults to None.
|
None
|
neck
|
Module | None
|
Module applied between backbone and decoder.
Defaults to None, which applies the identity.
|
None
|
rescale
|
bool
|
Rescale the output of the model if it has a different size than the ground truth.
Uses bilinear interpolation. Defaults to True.
|
True
|
Source code in terratorch/models/pixel_wise_model.py
| def __init__(
self,
task: str,
encoder: nn.Module,
decoder: nn.Module,
head_kwargs: dict,
decoder_includes_head: bool = False,
auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
neck: nn.Module | None = None,
rescale: bool = True, # noqa: FBT002, FBT001
) -> None:
"""Constructor
Args:
task (str): Task to be performed. One of segmentation or regression.
encoder (nn.Module): Encoder to be used
decoder (nn.Module): Decoder to be used
head_kwargs (dict): Arguments to be passed at instantiation of the head.
decoder_includes_head (bool): Whether the decoder already incldes a head. If true, a head will not be added. Defaults to False.
auxiliary_heads (list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None, optional): List of
AuxiliaryHeads with heads to be instantiated. Defaults to None.
neck (nn.Module | None): Module applied between backbone and decoder.
Defaults to None, which applies the identity.
rescale (bool, optional): Rescale the output of the model if it has a different size than the ground truth.
Uses bilinear interpolation. Defaults to True.
"""
super().__init__()
self.task = task
self.encoder = encoder
self.decoder = decoder
self.head = (
self._get_head(task, decoder.out_channels, head_kwargs) if not decoder_includes_head else nn.Identity()
)
if auxiliary_heads is not None:
aux_heads = {}
for aux_head_to_be_instantiated in auxiliary_heads:
aux_head: nn.Module = self._get_head(
task, aux_head_to_be_instantiated.decoder.out_channels, head_kwargs
) if not aux_head_to_be_instantiated.decoder_includes_head else nn.Identity()
aux_head = nn.Sequential(aux_head_to_be_instantiated.decoder, aux_head)
aux_heads[aux_head_to_be_instantiated.name] = aux_head
else:
aux_heads = {}
self.aux_heads = nn.ModuleDict(aux_heads)
self.neck = neck
self.rescale = rescale
|
forward(x)
Sequentially pass x
through model`s encoder, decoder and heads
Source code in terratorch/models/pixel_wise_model.py
| def forward(self, x: torch.Tensor) -> ModelOutput:
"""Sequentially pass `x` through model`s encoder, decoder and heads"""
self.check_input_shape(x)
input_size = x.shape[-2:]
features = self.encoder(x)
## only for backwards compatibility with pre-neck times.
if self.neck:
prepare = self.neck
else:
# for backwards compatibility, if this is defined in the encoder, use it
prepare = getattr(self.encoder, "prepare_features_for_image_model", lambda x: x)
features = prepare(features)
decoder_output = self.decoder([f.clone() for f in features])
mask = self.head(decoder_output)
if self.rescale and mask.shape[-2:] != input_size:
mask = F.interpolate(mask, size=input_size, mode="bilinear")
mask = self._check_for_single_channel_and_squeeze(mask)
aux_outputs = {}
for name, decoder in self.aux_heads.items():
aux_output = decoder([f.clone() for f in features])
if self.rescale and aux_output.shape[-2:] != input_size:
aux_output = F.interpolate(aux_output, size=input_size, mode="bilinear")
aux_output = self._check_for_single_channel_and_squeeze(aux_output)
aux_outputs[name] = aux_output
return ModelOutput(output=mask, auxiliary_heads=aux_outputs)
|
terratorch.models.scalar_output_model.ScalarOutputModel
Bases: Model
, SegmentationModel
Model that encapsulates encoder and decoder and heads for a scalar output
Expects decoder to have a "forward_features" method, an embed_dims property
and optionally a "prepare_features_for_image_model" method.
Source code in terratorch/models/scalar_output_model.py
| class ScalarOutputModel(Model, SegmentationModel):
"""Model that encapsulates encoder and decoder and heads for a scalar output
Expects decoder to have a "forward_features" method, an embed_dims property
and optionally a "prepare_features_for_image_model" method.
"""
def __init__(
self,
task: str,
encoder: nn.Module,
decoder: nn.Module,
head_kwargs: dict,
decoder_includes_head: bool = False,
auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
neck: nn.Module | None = None,
) -> None:
"""Constructor
Args:
task (str): Task to be performed. Must be "classification".
encoder (nn.Module): Encoder to be used
decoder (nn.Module): Decoder to be used
head_kwargs (dict): Arguments to be passed at instantiation of the head.
decoder_includes_head (bool): Whether the decoder already incldes a head. If true, a head will not be added. Defaults to False.
auxiliary_heads (list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None, optional): List of
AuxiliaryHeads with heads to be instantiated. Defaults to None.
neck (nn.Module | None): Module applied between backbone and decoder.
Defaults to None, which applies the identity.
"""
super().__init__()
self.task = task
self.encoder = encoder
self.decoder = decoder
self.head = (
self._get_head(task, decoder.out_channels, head_kwargs) if not decoder_includes_head else nn.Identity()
)
if auxiliary_heads is not None:
aux_heads = {}
for aux_head_to_be_instantiated in auxiliary_heads:
aux_head: nn.Module = self._get_head(
task, aux_head_to_be_instantiated.decoder.out_channels, head_kwargs
) if not aux_head_to_be_instantiated.decoder_includes_head else nn.Identity()
aux_head = nn.Sequential(aux_head_to_be_instantiated.decoder, aux_head)
aux_heads[aux_head_to_be_instantiated.name] = aux_head
else:
aux_heads = {}
self.aux_heads = nn.ModuleDict(aux_heads)
self.neck = neck
def freeze_encoder(self):
freeze_module(self.encoder)
def freeze_decoder(self):
freeze_module(self.decoder)
freeze_module(self.head)
# TODO: do this properly
def check_input_shape(self, x: torch.Tensor) -> bool: # noqa: ARG002
return True
def forward(self, x: torch.Tensor) -> ModelOutput:
"""Sequentially pass `x` through model`s encoder, decoder and heads"""
self.check_input_shape(x)
features = self.encoder(x)
## only for backwards compatibility with pre-neck times.
if self.neck:
prepare = self.neck
else:
# for backwards compatibility, if this is defined in the encoder, use it
prepare = getattr(self.encoder, "prepare_features_for_image_model", lambda x: x)
features = prepare(features)
decoder_output = self.decoder([f.clone() for f in features])
mask = self.head(decoder_output)
aux_outputs = {}
for name, decoder in self.aux_heads.items():
aux_output = decoder([f.clone() for f in features])
aux_outputs[name] = aux_output
return ModelOutput(output=mask, auxiliary_heads=aux_outputs)
def _get_head(self, task: str, input_embed_dim: int, head_kwargs: dict):
if task == "classification":
if "num_classes" not in head_kwargs:
msg = "num_classes must be defined for classification task"
raise Exception(msg)
return ClassificationHead(input_embed_dim, **head_kwargs)
msg = "Task must be classification."
raise Exception(msg)
|
__init__(task, encoder, decoder, head_kwargs, decoder_includes_head=False, auxiliary_heads=None, neck=None)
Constructor
Parameters:
Name |
Type |
Description |
Default |
task
|
str
|
Task to be performed. Must be "classification".
|
required
|
encoder
|
Module
|
|
required
|
decoder
|
Module
|
|
required
|
head_kwargs
|
dict
|
Arguments to be passed at instantiation of the head.
|
required
|
decoder_includes_head
|
bool
|
Whether the decoder already incldes a head. If true, a head will not be added. Defaults to False.
|
False
|
auxiliary_heads
|
list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None
|
List of
AuxiliaryHeads with heads to be instantiated. Defaults to None.
|
None
|
neck
|
Module | None
|
Module applied between backbone and decoder.
Defaults to None, which applies the identity.
|
None
|
Source code in terratorch/models/scalar_output_model.py
| def __init__(
self,
task: str,
encoder: nn.Module,
decoder: nn.Module,
head_kwargs: dict,
decoder_includes_head: bool = False,
auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
neck: nn.Module | None = None,
) -> None:
"""Constructor
Args:
task (str): Task to be performed. Must be "classification".
encoder (nn.Module): Encoder to be used
decoder (nn.Module): Decoder to be used
head_kwargs (dict): Arguments to be passed at instantiation of the head.
decoder_includes_head (bool): Whether the decoder already incldes a head. If true, a head will not be added. Defaults to False.
auxiliary_heads (list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None, optional): List of
AuxiliaryHeads with heads to be instantiated. Defaults to None.
neck (nn.Module | None): Module applied between backbone and decoder.
Defaults to None, which applies the identity.
"""
super().__init__()
self.task = task
self.encoder = encoder
self.decoder = decoder
self.head = (
self._get_head(task, decoder.out_channels, head_kwargs) if not decoder_includes_head else nn.Identity()
)
if auxiliary_heads is not None:
aux_heads = {}
for aux_head_to_be_instantiated in auxiliary_heads:
aux_head: nn.Module = self._get_head(
task, aux_head_to_be_instantiated.decoder.out_channels, head_kwargs
) if not aux_head_to_be_instantiated.decoder_includes_head else nn.Identity()
aux_head = nn.Sequential(aux_head_to_be_instantiated.decoder, aux_head)
aux_heads[aux_head_to_be_instantiated.name] = aux_head
else:
aux_heads = {}
self.aux_heads = nn.ModuleDict(aux_heads)
self.neck = neck
|
forward(x)
Sequentially pass x
through model`s encoder, decoder and heads
Source code in terratorch/models/scalar_output_model.py
| def forward(self, x: torch.Tensor) -> ModelOutput:
"""Sequentially pass `x` through model`s encoder, decoder and heads"""
self.check_input_shape(x)
features = self.encoder(x)
## only for backwards compatibility with pre-neck times.
if self.neck:
prepare = self.neck
else:
# for backwards compatibility, if this is defined in the encoder, use it
prepare = getattr(self.encoder, "prepare_features_for_image_model", lambda x: x)
features = prepare(features)
decoder_output = self.decoder([f.clone() for f in features])
mask = self.head(decoder_output)
aux_outputs = {}
for name, decoder in self.aux_heads.items():
aux_output = decoder([f.clone() for f in features])
aux_outputs[name] = aux_output
return ModelOutput(output=mask, auxiliary_heads=aux_outputs)
|
EncoderDecoderFactory
We expect this factory to be widely employed by users. With that in mind, we dive deeper into it here.
Loss
For convenience, we provide a loss handler that can be used to compute the full loss (from the main head and auxiliary heads as well).
terratorch.tasks.loss_handler
LossHandler
Class to help handle the computation and logging of loss
Source code in terratorch/tasks/loss_handler.py
| class LossHandler:
"""Class to help handle the computation and logging of loss"""
def __init__(self, loss_prefix: str) -> None:
"""Constructor
Args:
loss_prefix (str): Prefix to be prepended to all the metrics (e.g. training).
"""
self.loss_prefix = loss_prefix
def compute_loss(
self,
model_output: ModelOutput,
ground_truth: Tensor,
criterion: Callable,
aux_loss_weights: dict[str, float] | None,
) -> dict[str, Tensor]:
"""Compute the loss for the mean decode head as well as other heads
Args:
model_output (ModelOutput): Output from the model
ground_truth (Tensor): Tensor with labels
criterion (Callable): Loss function to be applied
aux_loss_weights (Union[dict[str, float], None]): Dictionary of names of model auxiliary
heads and their weights
Raises:
Exception: If the keys in aux_loss_weights and the model output do not match, will raise an exception.
Returns:
dict[str, Tensor]: Dictionary of computed losses. Total loss is returned under the key "loss".
If there are auxiliary heads, the main decode head is returned under the key "decode_head".
All other heads are returned with the same key as their name.
"""
loss = self._compute_loss(model_output.output, ground_truth, criterion)
if not model_output.auxiliary_heads:
return {"loss": loss}
if aux_loss_weights is None:
msg = "Auxiliary heads given with no aux_loss_weights"
raise Exception(msg)
all_losses = {}
all_losses["decode_head"] = loss
total_loss = loss.clone()
# incorporate aux heads
model_output_names = set(model_output.auxiliary_heads.keys())
aux_loss_names = set(aux_loss_weights.keys())
if aux_loss_names != model_output_names:
msg = f"Found difference in declared auxiliary losses and model outputs.\n \
Found in declared losses but not in model output: {aux_loss_names - model_output_names}. \n \
Found in model output but not in delcared losses: {model_output_names - aux_loss_names}"
raise Exception(msg)
for loss_name, loss_weight in aux_loss_weights.items():
output = model_output.auxiliary_heads[loss_name]
loss_value: Tensor = self._compute_loss(output, ground_truth, criterion)
all_losses[loss_name] = loss_value
total_loss = total_loss + loss_value * loss_weight
all_losses["loss"] = total_loss
return all_losses
def _compute_loss(self, y_hat: Tensor, ground_truth: Tensor, criterion: Callable):
loss: Tensor = criterion(y_hat, ground_truth)
return loss
def log_loss(
self, log_function: Callable, loss_dict: dict[str, Tensor] | None = None, batch_size: int | None = None
) -> None:
"""Log the loss. If auxiliary heads exist, log the full loss suffix "loss", and then all other losses.
Args:
log_function (Callable): _description_
loss_dict (dict[str, Tensor], optional): _description_. Defaults to None.
"""
# dont alter passed dict
all_losses = dict(loss_dict)
full_loss = all_losses.pop("loss")
log_function(f"{self.loss_prefix}loss", full_loss, sync_dist=True, batch_size=batch_size)
for loss_name, loss_value in all_losses.items():
log_function(
f"{self.loss_prefix}{loss_name}",
loss_value,
on_epoch=True,
on_step=True,
sync_dist=True,
batch_size=batch_size,
)
|
__init__(loss_prefix)
Constructor
Parameters:
Name |
Type |
Description |
Default |
loss_prefix
|
str
|
Prefix to be prepended to all the metrics (e.g. training).
|
required
|
Source code in terratorch/tasks/loss_handler.py
| def __init__(self, loss_prefix: str) -> None:
"""Constructor
Args:
loss_prefix (str): Prefix to be prepended to all the metrics (e.g. training).
"""
self.loss_prefix = loss_prefix
|
compute_loss(model_output, ground_truth, criterion, aux_loss_weights)
Compute the loss for the mean decode head as well as other heads
Parameters:
Name |
Type |
Description |
Default |
model_output
|
ModelOutput
|
|
required
|
ground_truth
|
Tensor
|
|
required
|
criterion
|
Callable
|
Loss function to be applied
|
required
|
aux_loss_weights
|
Union[dict[str, float], None]
|
Dictionary of names of model auxiliary
heads and their weights
|
required
|
Raises:
Type |
Description |
Exception
|
If the keys in aux_loss_weights and the model output do not match, will raise an exception.
|
Returns:
Type |
Description |
dict[str, Tensor]
|
dict[str, Tensor]: Dictionary of computed losses. Total loss is returned under the key "loss".
If there are auxiliary heads, the main decode head is returned under the key "decode_head".
All other heads are returned with the same key as their name.
|
Source code in terratorch/tasks/loss_handler.py
| def compute_loss(
self,
model_output: ModelOutput,
ground_truth: Tensor,
criterion: Callable,
aux_loss_weights: dict[str, float] | None,
) -> dict[str, Tensor]:
"""Compute the loss for the mean decode head as well as other heads
Args:
model_output (ModelOutput): Output from the model
ground_truth (Tensor): Tensor with labels
criterion (Callable): Loss function to be applied
aux_loss_weights (Union[dict[str, float], None]): Dictionary of names of model auxiliary
heads and their weights
Raises:
Exception: If the keys in aux_loss_weights and the model output do not match, will raise an exception.
Returns:
dict[str, Tensor]: Dictionary of computed losses. Total loss is returned under the key "loss".
If there are auxiliary heads, the main decode head is returned under the key "decode_head".
All other heads are returned with the same key as their name.
"""
loss = self._compute_loss(model_output.output, ground_truth, criterion)
if not model_output.auxiliary_heads:
return {"loss": loss}
if aux_loss_weights is None:
msg = "Auxiliary heads given with no aux_loss_weights"
raise Exception(msg)
all_losses = {}
all_losses["decode_head"] = loss
total_loss = loss.clone()
# incorporate aux heads
model_output_names = set(model_output.auxiliary_heads.keys())
aux_loss_names = set(aux_loss_weights.keys())
if aux_loss_names != model_output_names:
msg = f"Found difference in declared auxiliary losses and model outputs.\n \
Found in declared losses but not in model output: {aux_loss_names - model_output_names}. \n \
Found in model output but not in delcared losses: {model_output_names - aux_loss_names}"
raise Exception(msg)
for loss_name, loss_weight in aux_loss_weights.items():
output = model_output.auxiliary_heads[loss_name]
loss_value: Tensor = self._compute_loss(output, ground_truth, criterion)
all_losses[loss_name] = loss_value
total_loss = total_loss + loss_value * loss_weight
all_losses["loss"] = total_loss
return all_losses
|
log_loss(log_function, loss_dict=None, batch_size=None)
Log the loss. If auxiliary heads exist, log the full loss suffix "loss", and then all other losses.
Parameters:
Name |
Type |
Description |
Default |
log_function
|
Callable
|
|
required
|
loss_dict
|
dict[str, Tensor]
|
description. Defaults to None.
|
None
|
Source code in terratorch/tasks/loss_handler.py
| def log_loss(
self, log_function: Callable, loss_dict: dict[str, Tensor] | None = None, batch_size: int | None = None
) -> None:
"""Log the loss. If auxiliary heads exist, log the full loss suffix "loss", and then all other losses.
Args:
log_function (Callable): _description_
loss_dict (dict[str, Tensor], optional): _description_. Defaults to None.
"""
# dont alter passed dict
all_losses = dict(loss_dict)
full_loss = all_losses.pop("loss")
log_function(f"{self.loss_prefix}loss", full_loss, sync_dist=True, batch_size=batch_size)
for loss_name, loss_value in all_losses.items():
log_function(
f"{self.loss_prefix}{loss_name}",
loss_value,
on_epoch=True,
on_step=True,
sync_dist=True,
batch_size=batch_size,
)
|
Generic datasets / datamodules
Refer to the section on data
Exporting models
A future feature would be the possibility to save models in ONNX format, and export them that way. This would bring all the benefits of onnx.