Tasks provide a convenient abstraction over the training of a model for a specific downstream task.

They encapsulate the model, optimizer, metrics, loss as well as training, validation and testing steps.

The task expects to be passed a model factory, to which the model_args arguments are passed. The models produced by this model factory should output ModelOutput instances and conform to the Model ABC. This is an easy way to extend these tasks to models other than Pritvhi ones produced by the PrithviModelFactory.

Tasks are best leveraged using config files. Check out some examples here.

Argument parsing in configs

Argument parsing of configs relies on argument names and type hints in the code. To pass arguments that do not conform to this (e.g. for classes that make use of **kwargs) put those arguments in dict_args instead of init_args.

Multi Temporal Inputs

Multi temporal inputs are also supported! However, we leverage albumentations for augmentations, and it does not support multitemporal input. We currently get around this using the following strategy in the transform:

      - class_path: FlattenTemporalIntoChannels
      # your transforms here, wrapped by these other ones
      # e.g. a random flip
      - class_path: albumentations.Flip
      # end of your transforms
      - class_path: ToTensorV2
      - class_path: UnflattenTemporalFromChannels
          n_timesteps: 3 # your number of timesteps here
          # alternatively, n_channels can be specified
See an example of this here.


Bases: BaseTask

Pixelwise Regression Task that accepts models from a range of sources.

This class is analog in functionality to [PixelwiseRegressionTask] ( defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor

class PixelwiseRegressionTask(BaseTask):
    """Pixelwise Regression Task that accepts models from a range of sources.

    This class is analog in functionality to
    defined by torchgeo.
    However, it has some important differences:
        - Accepts the specification of a model factory
        - Logs metrics per class
        - Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
        - Allows the setting of optimizers in the constructor"""

    def __init__(
        model_args: dict,
        model_factory: str,
        loss: str = "mse",
        aux_heads: list[AuxiliaryHead] | None = None,
        aux_loss: dict[str, float] | None = None,
        class_weights: list[float] | None = None,
        ignore_index: int | None = None,
        lr: float = 0.001,
        # the following are optional so CLI doesnt need to pass them
        optimizer: str | None = None,
        optimizer_hparams: dict | None = None,
        scheduler: str | None = None,
        scheduler_hparams: dict | None = None,
        freeze_backbone: bool = False,  # noqa: FBT001, FBT002
        freeze_decoder: bool = False,  # noqa: FBT001, FBT002
        plot_on_val: bool | int = 10,
        tiled_inference_parameters: TiledInferenceParameters | None = None,
    ) -> None:

        self.aux_loss = aux_loss
        self.aux_heads = aux_heads
        self.model_factory = get_factory(model_factory)
        self.train_loss_handler = LossHandler(self.train_metrics.prefix)
        self.test_loss_handler = LossHandler(self.test_metrics.prefix)
        self.val_loss_handler = LossHandler(self.val_metrics.prefix)
        self.monitor = f"{self.val_metrics.prefix}loss"
        self.plot_on_val = int(plot_on_val)

    # overwrite early stopping
    def configure_callbacks(self) -> list[Callback]:
        return []

    def configure_models(self) -> None:
        self.model: Model = self.model_factory.build_model(
            "regression", aux_decoders=self.aux_heads, **self.hparams["model_args"]
        if self.hparams["freeze_backbone"]:
        if self.hparams["freeze_decoder"]:

    def configure_optimizers(
    ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
        optimizer = self.hparams["optimizer"]
        if optimizer is None:
            optimizer = "Adam"
        return optimizer_factory(

    def configure_losses(self) -> None:
    def configure_metrics(self) -> None:
        """Initialize the performance metrics."""

        def instantiate_metrics():
            return {
                "RMSE": MeanSquaredError(squared=False),
                "MSE": MeanSquaredError(squared=True),
                "MAE": MeanAbsoluteError(),

        def wrap_metrics_with_ignore_index(metrics):
            return {
                name: IgnoreIndexMetricWrapper(metric, ignore_index=self.hparams["ignore_index"])
                for name, metric in metrics.items()

        self.train_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="train/")
        self.val_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="val/")
        self.test_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/")

    def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    def _do_plot_samples(self, batch_index):
        if not self.plot_on_val:  # dont plot if self.plot_on_val is 0
            return False

        return (
            batch_index < BATCH_IDX_FOR_VALIDATION_PLOTTING
            and hasattr(self.trainer, "datamodule")
            and self.logger
            and not self.current_epoch % self.plot_on_val  # will be True every self.plot_on_val epochs
            and hasattr(self.logger, "experiment")
            and (hasattr(self.logger.experiment, "add_figure") or hasattr(self.logger.experiment, "log_figure"))

    def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
Bases: BaseTask

Semantic Segmentation Task that accepts models from a range of sources.

This class is analog in functionality to class:SemanticSegmentationTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor

class SemanticSegmentationTask(BaseTask):
    """Semantic Segmentation Task that accepts models from a range of sources.

    This class is analog in functionality to class:SemanticSegmentationTask defined by torchgeo.
    However, it has some important differences:
        - Accepts the specification of a model factory
        - Logs metrics per class
        - Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
        - Allows the setting of optimizers in the constructor

    def __init__(
        model_args: dict,
        model_factory: str,
        loss: str = "ce",
        aux_heads: list[AuxiliaryHead] | None = None,
        aux_loss: dict[str, float] | None = None,
        class_weights: list[float] | None = None,
        ignore_index: int | None = None,
        lr: float = 0.001,
        # the following are optional so CLI doesnt need to pass them
        optimizer: str | None = None,
        optimizer_hparams: dict | None = None,
        scheduler: str | None = None,
        scheduler_hparams: dict | None = None,
        freeze_backbone: bool = False,  # noqa: FBT001, FBT002
        freeze_decoder: bool = False,  # noqa: FBT002, FBT001
        plot_on_val: bool | int = 10,
        class_names: list[str] | None = None,
        tiled_inference_parameters: TiledInferenceParameters = None,
    # overwrite early stopping
    def configure_callbacks(self) -> list[Callback]:
        return []

    def configure_models(self) -> None:
        self.model: Model = self.model_factory.build_model(
            "segmentation", aux_decoders=self.aux_heads, **self.hparams["model_args"]
        if self.hparams["freeze_backbone"]:
        if self.hparams["freeze_decoder"]:

    def configure_optimizers(
    ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
        optimizer = self.hparams["optimizer"]
        if optimizer is None:
            optimizer = "Adam"
        return optimizer_factory(

    def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    def on_train_epoch_end(self) -> None:
        self.log_dict(self.train_metrics.compute(), sync_dist=True)
        return super().on_train_epoch_end()

    def _do_plot_samples(self, batch_index):
        if not self.plot_on_val:  # dont plot if self.plot_on_val is 0
            return False

        return (
            batch_index < BATCH_IDX_FOR_VALIDATION_PLOTTING
            and hasattr(self.trainer, "datamodule")
            and self.logger
            and not self.current_epoch % self.plot_on_val  # will be True every self.plot_on_val epochs
            and hasattr(self.logger, "experiment")
            and (hasattr(self.logger.experiment, "add_figure") or hasattr(self.logger.experiment, "log_figure"))

    def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    def on_validation_epoch_end(self) -> None:
        self.log_dict(self.val_metrics.compute(), sync_dist=True)
        return super().on_validation_epoch_end()

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    def on_test_epoch_end(self) -> None:
        self.log_dict(self.test_metrics.compute(), sync_dist=True)
        return super().on_test_epoch_end()

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
Bases: BaseTask

Classification Task that accepts models from a range of sources.

This class is analog in functionality to class:ClassificationTask defined by torchgeo. However, it has some important differences: - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - It provides mIoU with both Micro and Macro averaging

.. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect minority class accuracy. * 'Macro' averaging gives equal weight to each class, useful for balanced performance assessment across imbalanced classes.

class ClassificationTask(BaseTask):
    """Classification Task that accepts models from a range of sources.

    This class is analog in functionality to class:ClassificationTask defined by torchgeo.
    However, it has some important differences:
        - Accepts the specification of a model factory
        - Logs metrics per class
        - Does not have any callbacks by default (TorchGeo tasks do early stopping by default)
        - Allows the setting of optimizers in the constructor
        - It provides mIoU with both Micro and Macro averaging

    .. note::
           * 'Micro' averaging suits overall performance evaluation but may not reflect
             minority class accuracy.
           * 'Macro' averaging gives equal weight to each class, useful
             for balanced performance assessment across imbalanced classes.

    def __init__(
        model_args: dict,
        model_factory: str,
        loss: str = "ce",
        aux_heads: list[AuxiliaryHead] | None = None,
        aux_loss: dict[str, float] | None = None,
        class_weights: list[float] | None = None,
        ignore_index: int | None = None,
        lr: float = 0.001,
        # the following are optional so CLI doesnt need to pass them
        optimizer: str | None = None,
        optimizer_hparams: dict | None = None,
        scheduler: str | None = None,
        scheduler_hparams: dict | None = None,
        freeze_backbone: bool = False,  # noqa: FBT001, FBT002
        freeze_decoder: bool = False,  # noqa: FBT002, FBT001
        class_names: list[str] | None = None,
    # overwrite early stopping
    def configure_callbacks(self) -> list[Callback]:
        return []

    def configure_models(self) -> None:
        self.model: Model = self.model_factory.build_model(
            "classification", aux_decoders=self.aux_heads, **self.hparams["model_args"]
        if self.hparams["freeze_backbone"]:
        if self.hparams["freeze_decoder"]:

    def configure_optimizers(
    ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
        optimizer = self.hparams["optimizer"]
        if optimizer is None:
            optimizer = "Adam"
        return optimizer_factory(

    def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
    def on_train_epoch_end(self) -> None:
        self.log_dict(self.train_metrics.compute(), sync_dist=True)
        return super().on_train_epoch_end()

    def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    def on_validation_epoch_end(self) -> None:
        self.log_dict(self.val_metrics.compute(), sync_dist=True)
        return super().on_validation_epoch_end()

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
    def on_test_epoch_end(self) -> None:
        self.log_dict(self.test_metrics.compute(), sync_dist=True)
        return super().on_test_epoch_end()

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor:
