Skip to content

Model Factories

terratorch.models.model.ModelFactory #

Bases: Protocol

terratorch.models.clay_model_factory.ClayModelFactory #

Bases: ModelFactory

build_model(task, backbone, decoder, in_channels, bands=[], num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, checkpoint_path=None, **kwargs) #

Model factory for Clay models.

Further arguments to be passed to the backbone, decoder or head. They should be prefixed with backbone_, decoder_ and head_ respectively.

Parameters:

Name Type Description Default
task str

Task to be performed. Currently supports "segmentation" and "regression".

required
backbone (str, Module)

Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100".

required
decoder Union[str, Module]

Decoder to be used for the segmentation model. If a string, it will be created from a class exposed in decoder.init.py with the same name. If an nn.Module, we expect it to expose a property decoder.out_channels. Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".

required
in_channels int

Number of input channels. Defaults to 3.

required
num_classes int

Number of classes. None for regression tasks.

None
pretrained Union[bool, Path]

Whether to load pretrained weights for the backbone, if available. Defaults to True.

True
num_frames int

Number of timesteps for the model to handle. Defaults to 1.

1
prepare_features_for_image_model Callable | None

Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function.

None
aux_decoders list[AuxiliaryHead] | None

List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well.

None
rescale bool

Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

True

Raises:

Type Description
NotImplementedError

description

DecoderNotFoundException

description

Returns:

Type Description
Model

nn.Module: description

terratorch.models.generic_unet_model_factory.GenericUnetModelFactory #

Bases: ModelFactory

build_model(task='segmentation', backbone=None, decoder=None, dilations=(1, 6, 12, 18), in_channels=6, pretrained=True, num_classes=1, regression_relu=False, **kwargs) #

Factory to create model based on SMP.

Parameters:

Name Type Description Default
task str

Must be "segmentation".

'segmentation'
model str

Decoder architecture. Currently only supports "unet".

required
in_channels int

Number of input channels.

6
pretrained(str | bool

Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True.

required
num_classes int

Number of classes.

1

Returns:

Name Type Description
Model Model

SMP model wrapped in SMPModelWrapper.

terratorch.models.prithvi_model_factory.PrithviModelFactory #

Bases: ModelFactory

build_model(task, backbone, decoder, bands, in_channels=None, num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, **kwargs) #

Model factory for prithvi models.

Further arguments to be passed to the backbone, decoder or head. They should be prefixed with backbone_, decoder_ and head_ respectively.

Parameters:

Name Type Description Default
task str

Task to be performed. Currently supports "segmentation" and "regression".

required
backbone (str, Module)

Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100".

required
decoder Union[str, Module]

Decoder to be used for the segmentation model. If a string, it will be created from a class exposed in decoder.init.py with the same name. If an nn.Module, we expect it to expose a property decoder.out_channels. Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".

required
in_channels int

Number of input channels. Defaults to 3.

None
bands list[HLSBands]

Bands the model will be trained on. Should be a list of terratorch.datasets.HLSBands. Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE].

required
num_classes int

Number of classes. None for regression tasks.

None
pretrained Union[bool, Path]

Whether to load pretrained weights for the backbone, if available. Defaults to True.

True
num_frames int

Number of timesteps for the model to handle. Defaults to 1.

1
prepare_features_for_image_model Callable | None

Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function.

None
aux_decoders list[AuxiliaryHead] | None

List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well.

None
rescale bool

Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

True

Returns:

Type Description
Model

nn.Module: Full model with encoder, decoder and head.

terratorch.models.satmae_model_factory.SatMAEModelFactory #

Bases: ModelFactory

build_model(task, backbone, decoder, in_channels, bands, num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, checkpoint_path=None, **kwargs) #

Model factory for SatMAE models.

Further arguments to be passed to the backbone, decoder or head. They should be prefixed with backbone_, decoder_ and head_ respectively.

Parameters:

Name Type Description Default
task str

Task to be performed. Currently supports "segmentation" and "regression".

required
backbone (str, Module)

Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100".

required
decoder Union[str, Module]

Decoder to be used for the segmentation model. If a string, it will be created from a class exposed in decoder.init.py with the same name. If an nn.Module, we expect it to expose a property decoder.out_channels. Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".

required
in_channels int

Number of input channels. Defaults to 3.

required
bands list[HLSBands]

Bands the model will be trained on. Should be a list of terratorch.datasets.HLSBands. Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE].

required
num_classes int

Number of classes. None for regression tasks.

None
pretrained Union[bool, Path]

Whether to load pretrained weights for the backbone, if available. Defaults to True.

True
num_frames int

Number of timesteps for the model to handle. Defaults to 1.

1
prepare_features_for_image_model Callable | None

Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function.

None
aux_decoders list[AuxiliaryHead] | None

List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well.

None
rescale bool

Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

True

Raises:

Type Description
NotImplementedError

description

DecoderNotFoundException

description

Returns:

Type Description
Model

nn.Module: description

terratorch.models.smp_model_factory.SMPModelFactory #

Bases: ModelFactory

build_model(task, backbone, model, bands, in_channels=None, num_classes=1, pretrained=True, prepare_features_for_image_model=None, regression_relu=False, **kwargs) #

Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.

This factory handles the instantiation of segmentation and regression models using specified encoders and decoders from the SMP library, along with custom modifications and extensions such as auxiliary decoders or modified encoders.

Attributes:

Name Type Description
task str

Specifies the task for which the model is being built. Supported tasks are "segmentation".

backbone str

Specifies the backbone model to be used.

decoder str

Specifies the decoder to be used for constructing the segmentation model.

bands list[HLSBands | int]

A list specifying the bands that the model will operate on. These are expected to be from terratorch.datasets.HLSBands.

in_channels int

Specifies the number of input channels. Defaults to None.

num_classes int

The number of output classes for the model.

pretrained bool | Path

Indicates whether to load pretrained weights for the backbone. Can also specify a path to weights. Defaults to True.

num_frames int

Specifies the number of timesteps the model should handle. Useful for temporal models.

regression_relu bool

Whether to apply ReLU activation in the case of regression tasks.

**kwargs bool

Additional arguments that might be passed to further customize the backbone, decoder, or any auxiliary heads. These should be prefixed appropriately

Raises:

Type Description
ValueError

If the specified decoder is not supported by SMP.

Exception

If the specified task is not "segmentation"

Returns:

Type Description
Model

nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified parameters and tasks.

terratorch.models.encoder_decoder_factory.EncoderDecoderFactory #

Bases: ModelFactory

build_model(task, backbone, decoder, backbone_kwargs=None, decoder_kwargs=None, head_kwargs=None, num_classes=None, necks=None, aux_decoders=None, rescale=True, peft_config=None, **kwargs) #

Generic model factory that combines an encoder and decoder, together with a head, for a specific task.

Further arguments to be passed to the backbone, decoder or head. They should be prefixed with backbone_, decoder_ and head_ respectively.

Parameters:

Name Type Description Default
task str

Task to be performed. Currently supports "segmentation", "regression" and "classification".

required
backbone (str, Module)

Backbone to be used. If a string, will look for such models in the different registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it directly. The backbone should have and out_channels attribute and its forward should return a list[Tensor].

required
decoder Union[str, Module]

Decoder to be used for the segmentation model. If a string, will look for such decoders in the different registries supported (internal terratorch registry, smp, ...). If an nn.Module, we expect it to expose a property decoder.out_channels. Pixel wise tasks will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".

required
backbone_kwargs dict, optional)

Arguments to be passed to instantiate the backbone.

None
decoder_kwargs dict, optional)

Arguments to be passed to instantiate the decoder.

None
head_kwargs dict, optional)

Arguments to be passed to the head network.

None
num_classes int

Number of classes. None for regression tasks.

None
necks list[dict]

nn.Modules to be called in succession on encoder features before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry. Expects each one to have a key "name" and subsequent keys for arguments, if any. Defaults to None, which applies the identity function.

None
aux_decoders list[AuxiliaryHead] | None

List of AuxiliaryHead decoders to be added to the model. These decoders take the input from the encoder as well.

None
rescale bool

Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

True
peft_config dict

Configuration options for using PEFT. The dictionary should have the following keys:

  • "method": Which PEFT method to use. Should be one implemented in PEFT, a list is available here.
  • "replace_qkv": String containing a substring of the name of the submodules to replace with QKVSep. This should be used when the qkv matrices are merged together in a single linear layer and the PEFT method should be applied separately to query, key and value matrices (e.g. if LoRA is only desired in Q and V matrices). e.g. If using Prithvi this should be "qkv"
  • "peft_config_kwargs": Dictionary containing keyword arguments which will be passed to PeftConfig
None

Returns:

Type Description
Model

nn.Module: Full model with encoder, decoder and head.


Last update: March 23, 2025