Loss
terratorch.tasks.loss_handler
#
LossHandler
#
Class to help handle the computation and logging of loss
__init__(loss_prefix)
#
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_prefix
|
str
|
Prefix to be prepended to all the metrics (e.g. training). |
required |
compute_loss(model_output, ground_truth, criterion, aux_loss_weights)
#
Compute the loss for the mean decode head as well as other heads
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_output
|
ModelOutput
|
Output from the model |
required |
ground_truth
|
Tensor
|
Tensor with labels |
required |
criterion
|
Callable
|
Loss function to be applied |
required |
aux_loss_weights
|
Union[dict[str, float], None]
|
Dictionary of names of model auxiliary heads and their weights |
required |
Raises:
Type | Description |
---|---|
Exception
|
If the keys in aux_loss_weights and the model output do not match, will raise an exception. |
Returns:
Type | Description |
---|---|
dict[str, Tensor]
|
dict[str, Tensor]: Dictionary of computed losses. Total loss is returned under the key "loss". If there are auxiliary heads, the main decode head is returned under the key "decode_head". All other heads are returned with the same key as their name. |
log_loss(log_function, loss_dict=None, batch_size=None)
#
Log the loss. If auxiliary heads exist, log the full loss suffix "loss", and then all other losses.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_function
|
Callable
|
description |
required |
loss_dict
|
dict[str, Tensor]
|
description. Defaults to None. |
None
|