Skip to content

Loss

terratorch.tasks.loss_handler #

LossHandler #

Class to help handle the computation and logging of loss

__init__(loss_prefix) #

Constructor

Parameters:

Name Type Description Default
loss_prefix str

Prefix to be prepended to all the metrics (e.g. training).

required
compute_loss(model_output, ground_truth, criterion, aux_loss_weights) #

Compute the loss for the mean decode head as well as other heads

Parameters:

Name Type Description Default
model_output ModelOutput

Output from the model

required
ground_truth Tensor

Tensor with labels

required
criterion Callable

Loss function to be applied

required
aux_loss_weights Union[dict[str, float], None]

Dictionary of names of model auxiliary heads and their weights

required

Raises:

Type Description
Exception

If the keys in aux_loss_weights and the model output do not match, will raise an exception.

Returns:

Type Description
dict[str, Tensor]

dict[str, Tensor]: Dictionary of computed losses. Total loss is returned under the key "loss". If there are auxiliary heads, the main decode head is returned under the key "decode_head". All other heads are returned with the same key as their name.

log_loss(log_function, loss_dict=None, batch_size=None) #

Log the loss. If auxiliary heads exist, log the full loss suffix "loss", and then all other losses.

Parameters:

Name Type Description Default
log_function Callable

description

required
loss_dict dict[str, Tensor]

description. Defaults to None.

None

Last update: March 24, 2025