Model Factories
terratorch.models.model.ModelFactory
#
Bases: Protocol
terratorch.models.clay_model_factory.ClayModelFactory
#
Bases: ModelFactory
build_model(task, backbone, decoder, in_channels, bands=[], num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, checkpoint_path=None, **kwargs)
#
Model factory for Clay models.
Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
backbone_
, decoder_
and head_
respectively.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task
|
str
|
Task to be performed. Currently supports "segmentation" and "regression". |
required |
backbone
|
(str, Module)
|
Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100". |
required |
decoder
|
Union[str, Module]
|
Decoder to be used for the segmentation model.
If a string, it will be created from a class exposed in decoder.init.py with the same name.
If an nn.Module, we expect it to expose a property |
required |
in_channels
|
int
|
Number of input channels. Defaults to 3. |
required |
num_classes
|
int
|
Number of classes. None for regression tasks. |
None
|
pretrained
|
Union[bool, Path]
|
Whether to load pretrained weights for the backbone, if available. Defaults to True. |
True
|
num_frames
|
int
|
Number of timesteps for the model to handle. Defaults to 1. |
1
|
prepare_features_for_image_model
|
Callable | None
|
Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function. |
None
|
aux_decoders
|
list[AuxiliaryHead] | None
|
List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well. |
None
|
rescale
|
bool
|
Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True. |
True
|
Raises:
Type | Description |
---|---|
NotImplementedError
|
description |
DecoderNotFoundException
|
description |
Returns:
Type | Description |
---|---|
Model
|
nn.Module: description |
terratorch.models.generic_unet_model_factory.GenericUnetModelFactory
#
Bases: ModelFactory
build_model(task='segmentation', backbone=None, decoder=None, dilations=(1, 6, 12, 18), in_channels=6, pretrained=True, num_classes=1, regression_relu=False, **kwargs)
#
Factory to create model based on SMP.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task
|
str
|
Must be "segmentation". |
'segmentation'
|
model
|
str
|
Decoder architecture. Currently only supports "unet". |
required |
in_channels
|
int
|
Number of input channels. |
6
|
pretrained(str
|
| bool
|
Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True. |
required |
num_classes
|
int
|
Number of classes. |
1
|
Returns:
Name | Type | Description |
---|---|---|
Model |
Model
|
SMP model wrapped in SMPModelWrapper. |
terratorch.models.prithvi_model_factory.PrithviModelFactory
#
Bases: ModelFactory
build_model(task, backbone, decoder, bands, in_channels=None, num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, **kwargs)
#
Model factory for prithvi models.
Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
backbone_
, decoder_
and head_
respectively.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task
|
str
|
Task to be performed. Currently supports "segmentation" and "regression". |
required |
backbone
|
(str, Module)
|
Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100". |
required |
decoder
|
Union[str, Module]
|
Decoder to be used for the segmentation model.
If a string, it will be created from a class exposed in decoder.init.py with the same name.
If an nn.Module, we expect it to expose a property |
required |
in_channels
|
int
|
Number of input channels. Defaults to 3. |
None
|
bands
|
list[HLSBands]
|
Bands the model will be trained on. Should be a list of terratorch.datasets.HLSBands. Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE]. |
required |
num_classes
|
int
|
Number of classes. None for regression tasks. |
None
|
pretrained
|
Union[bool, Path]
|
Whether to load pretrained weights for the backbone, if available. Defaults to True. |
True
|
num_frames
|
int
|
Number of timesteps for the model to handle. Defaults to 1. |
1
|
prepare_features_for_image_model
|
Callable | None
|
Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function. |
None
|
aux_decoders
|
list[AuxiliaryHead] | None
|
List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well. |
None
|
rescale
|
bool
|
Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True. |
True
|
Returns:
Type | Description |
---|---|
Model
|
nn.Module: Full model with encoder, decoder and head. |
terratorch.models.satmae_model_factory.SatMAEModelFactory
#
Bases: ModelFactory
build_model(task, backbone, decoder, in_channels, bands, num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, checkpoint_path=None, **kwargs)
#
Model factory for SatMAE models.
Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
backbone_
, decoder_
and head_
respectively.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task
|
str
|
Task to be performed. Currently supports "segmentation" and "regression". |
required |
backbone
|
(str, Module)
|
Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100". |
required |
decoder
|
Union[str, Module]
|
Decoder to be used for the segmentation model.
If a string, it will be created from a class exposed in decoder.init.py with the same name.
If an nn.Module, we expect it to expose a property |
required |
in_channels
|
int
|
Number of input channels. Defaults to 3. |
required |
bands
|
list[HLSBands]
|
Bands the model will be trained on. Should be a list of terratorch.datasets.HLSBands. Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE]. |
required |
num_classes
|
int
|
Number of classes. None for regression tasks. |
None
|
pretrained
|
Union[bool, Path]
|
Whether to load pretrained weights for the backbone, if available. Defaults to True. |
True
|
num_frames
|
int
|
Number of timesteps for the model to handle. Defaults to 1. |
1
|
prepare_features_for_image_model
|
Callable | None
|
Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function. |
None
|
aux_decoders
|
list[AuxiliaryHead] | None
|
List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well. |
None
|
rescale
|
bool
|
Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True. |
True
|
Raises:
Type | Description |
---|---|
NotImplementedError
|
description |
DecoderNotFoundException
|
description |
Returns:
Type | Description |
---|---|
Model
|
nn.Module: description |
terratorch.models.smp_model_factory.SMPModelFactory
#
Bases: ModelFactory
build_model(task, backbone, model, bands, in_channels=None, num_classes=1, pretrained=True, prepare_features_for_image_model=None, regression_relu=False, **kwargs)
#
Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.
This factory handles the instantiation of segmentation and regression models using specified encoders and decoders from the SMP library, along with custom modifications and extensions such as auxiliary decoders or modified encoders.
Attributes:
Name | Type | Description |
---|---|---|
task |
str
|
Specifies the task for which the model is being built. Supported tasks are "segmentation". |
backbone |
str
|
Specifies the backbone model to be used. |
decoder |
str
|
Specifies the decoder to be used for constructing the segmentation model. |
bands |
list[HLSBands | int]
|
A list specifying the bands that the model will operate on. These are expected to be from terratorch.datasets.HLSBands. |
in_channels |
int
|
Specifies the number of input channels. Defaults to None. |
num_classes |
int
|
The number of output classes for the model. |
pretrained |
bool | Path
|
Indicates whether to load pretrained weights for the backbone. Can also specify a path to weights. Defaults to True. |
num_frames |
int
|
Specifies the number of timesteps the model should handle. Useful for temporal models. |
regression_relu |
bool
|
Whether to apply ReLU activation in the case of regression tasks. |
**kwargs |
bool
|
Additional arguments that might be passed to further customize the backbone, decoder, or any auxiliary heads. These should be prefixed appropriately |
Raises:
Type | Description |
---|---|
ValueError
|
If the specified decoder is not supported by SMP. |
Exception
|
If the specified task is not "segmentation" |
Returns:
Type | Description |
---|---|
Model
|
nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified parameters and tasks. |
terratorch.models.encoder_decoder_factory.EncoderDecoderFactory
#
Bases: ModelFactory
build_model(task, backbone, decoder, backbone_kwargs=None, decoder_kwargs=None, head_kwargs=None, num_classes=None, necks=None, aux_decoders=None, rescale=True, peft_config=None, **kwargs)
#
Generic model factory that combines an encoder and decoder, together with a head, for a specific task.
Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
backbone_
, decoder_
and head_
respectively.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task
|
str
|
Task to be performed. Currently supports "segmentation", "regression" and "classification". |
required |
backbone
|
(str, Module)
|
Backbone to be used. If a string, will look for such models in the different
registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it
directly. The backbone should have and |
required |
decoder
|
Union[str, Module]
|
Decoder to be used for the segmentation model.
If a string, will look for such decoders in the different
registries supported (internal terratorch registry, smp, ...).
If an nn.Module, we expect it to expose a property |
required |
backbone_kwargs
|
dict, optional)
|
Arguments to be passed to instantiate the backbone. |
None
|
decoder_kwargs
|
dict, optional)
|
Arguments to be passed to instantiate the decoder. |
None
|
head_kwargs
|
dict, optional)
|
Arguments to be passed to the head network. |
None
|
num_classes
|
int
|
Number of classes. None for regression tasks. |
None
|
necks
|
list[dict]
|
nn.Modules to be called in succession on encoder features before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry. Expects each one to have a key "name" and subsequent keys for arguments, if any. Defaults to None, which applies the identity function. |
None
|
aux_decoders
|
list[AuxiliaryHead] | None
|
List of AuxiliaryHead decoders to be added to the model. These decoders take the input from the encoder as well. |
None
|
rescale
|
bool
|
Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True. |
True
|
peft_config
|
dict
|
Configuration options for using PEFT. The dictionary should have the following keys:
|
None
|
Returns:
Type | Description |
---|---|
Model
|
nn.Module: Full model with encoder, decoder and head. |