Models
To interface with terratorch tasks correctly, models must conform to the Model ABC:
terratorch.models.model.Model
Bases: ABC
, Module
Source code in terratorch/models/model.py
and have a forward method which returns a ModelOutput:
terratorch.models.model.ModelOutput
dataclass
Model Factories
In order to be used by tasks, models must have a Model Factory which builds them.
Factories must conform to the ModelFactory ABC:
terratorch.models.model.ModelFactory
You most likely do not need to implement your own model factory, unless you are wrapping another library which generates full models.
For most cases, the encoder decoder factory can be used to combine a backbone with a decoder.
To add new backbones or decoders, to be used with the encoder decoder factory they should be registered.
To add a new model factory, it should be registered in the MODEL_FACTORY_REGISTRY
.
Adding a new model
To add a new backbone, simply create a class and annotate it (or a constructor function that instantiates it) with @TERRATORCH_BACKBONE_FACTORY.register
.
The model will be registered with the same name as the function. To create many model variants from the same class, the reccomended approach is to annotate a constructor function from each with a fully descriptive name.
from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY, BACKBONE_REGISTRY
from torch import nn
# make sure this is in the import path for terratorch
@TERRATORCH_BACKBONE_REGISTRY.register
class BasicBackbone(nn.Module):
def __init__(self, out_channels=64):
super().__init__()
self.flatten = nn.Flatten()
self.layer = nn.Linear(224*224, out_channels)
self.out_channels = [out_channels]
def forward(self, x):
return self.layer(self.flatten(x))
# you can build directly with the TERRATORCH_BACKBONE_REGISTRY
# but typically this will be accessed from the BACKBONE_REGISTRY
>>> BACKBONE_REGISTRY.build("BasicBackbone", out_channels=64)
BasicBackbone(
(flatten): Flatten(start_dim=1, end_dim=-1)
(layer): Linear(in_features=50176, out_features=64, bias=True)
)
@TERRATORCH_BACKBONE_REGISTRY.register
def basic_backbone_128():
return BasicBackbone(out_channels=128)
>>> BACKBONE_REGISTRY.build("basic_backbone_128")
BasicBackbone(
(flatten): Flatten(start_dim=1, end_dim=-1)
(layer): Linear(in_features=50176, out_features=128, bias=True)
)
Adding a new decoder can be done in the same way with the TERRATORCH_DECODER_REGISTRY
.
Info
All decoders will be passed the channel_list as the first argument for initialization.
To pass your own path from where to load the weights with the PrithviModelFactory, you can make use of timm's pretrained_cfg_overlay
.
E.g. to pass a local path, you can pass the parameter backbone_pretrained_cfg_overlay = {"file": "<local_path>"}
to the model factory.
Besides file
, you can also pass url
, hf_hub_id
, amongst others. Check timm's documentation for full details.
terratorch.models.backbones.select_patch_embed_weights
select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands)
Filter out the patch embedding weights according to the bands being used. If a band exists in the pretrained_bands, but not in model_bands, drop it. If a band exists in model_bands, but not pretrained_bands, randomly initialize those weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state_dict
|
dict
|
State Dict |
required |
model
|
Module
|
Model to load the weights onto. |
required |
pretrained_bands
|
list[HLSBands | int]
|
List of bands the model was pretrained on, in the correct order. |
required |
model_bands
|
list[HLSBands | int]
|
List of bands the model is going to be finetuned on, in the correct order |
required |
Returns:
Name | Type | Description |
---|---|---|
dict |
dict
|
New state dict |
Source code in terratorch/models/backbones/select_patch_embed_weights.py
Decoders
terratorch.models.decoders.fcn_decoder
FCNDecoder
Bases: Module
Fully Convolutional Decoder
Source code in terratorch/models/decoders/fcn_decoder.py
__init__(embed_dim, channels=256, num_convs=4, in_index=-1)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embed_dim
|
_type_
|
Input embedding dimension |
required |
channels
|
int
|
Number of channels for each conv. Defaults to 256. |
256
|
num_convs
|
int
|
Number of convs. Defaults to 4. |
4
|
in_index
|
int
|
Index of the input list to take. Defaults to -1. |
-1
|
Source code in terratorch/models/decoders/fcn_decoder.py
terratorch.models.decoders.identity_decoder
Pass the features straight through
IdentityDecoder
Bases: Module
Identity decoder. Useful to pass the feature straight to the head.
Source code in terratorch/models/decoders/identity_decoder.py
__init__(embed_dim, out_index=-1)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embed_dim
|
int
|
Input embedding dimension |
required |
out_index
|
int
|
Index of the input list to take.. Defaults to -1. |
-1
|
Source code in terratorch/models/decoders/identity_decoder.py
terratorch.models.decoders.upernet_decoder
PPM
Bases: ModuleList
Pooling Pyramid Module used in PSPNet.
Source code in terratorch/models/decoders/upernet_decoder.py
__init__(pool_scales, in_channels, channels, align_corners)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pool_scales
|
tuple[int]
|
Pooling scales used in Pooling Pyramid Module. |
required |
in_channels
|
int
|
Input channels. |
required |
channels
|
int
|
Channels after modules, before conv_seg. |
required |
align_corners
|
bool
|
align_corners argument of F.interpolate. |
required |
Source code in terratorch/models/decoders/upernet_decoder.py
forward(x)
Forward function.
Source code in terratorch/models/decoders/upernet_decoder.py
UperNetDecoder
Bases: Module
UperNetDecoder. Adapted from MMSegmentation.
Source code in terratorch/models/decoders/upernet_decoder.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
|
__init__(embed_dim, pool_scales=(1, 2, 3, 6), channels=256, align_corners=True, scale_modules=False)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embed_dim
|
list[int]
|
Input embedding dimension for each input. |
required |
pool_scales
|
tuple[int]
|
Pooling scales used in Pooling Pyramid Module applied on the last feature. Default: (1, 2, 3, 6). |
(1, 2, 3, 6)
|
channels
|
int
|
Channels used in the decoder. Defaults to 256. |
256
|
align_corners
|
bool
|
Wheter to align corners in rescaling. Defaults to True. |
True
|
scale_modules
|
bool
|
Whether to apply scale modules to the inputs. Needed for plain ViT. Defaults to False. |
False
|
Source code in terratorch/models/decoders/upernet_decoder.py
forward(inputs)
Forward function for feature maps before classifying each pixel with Args: inputs (list[Tensor]): List of multi-level img features.
Returns:
Name | Type | Description |
---|---|---|
feats |
Tensor
|
A tensor of shape (batch_size, self.channels, H, W) which is feature map for last layer of decoder head. |
Source code in terratorch/models/decoders/upernet_decoder.py
psp_forward(inputs)
Forward function of PSP module.
Source code in terratorch/models/decoders/upernet_decoder.py
Heads
terratorch.models.heads.regression_head
RegressionHead
Bases: Module
Regression head
Source code in terratorch/models/heads/regression_head.py
__init__(in_channels, final_act=None, learned_upscale_layers=0, channel_list=None, batch_norm=True, dropout=0)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
final_act
|
Module | None
|
Final activation to be applied. Defaults to None. |
None
|
learned_upscale_layers
|
int
|
Number of Pixelshuffle layers to create. Each upscales 2x. Defaults to 0. |
0
|
channel_list
|
list[int] | None
|
List with number of channels for each Conv layer to be created. Defaults to None. |
None
|
batch_norm
|
bool
|
Whether to apply batch norm. Defaults to True. |
True
|
dropout
|
float
|
Dropout value to apply. Defaults to 0. |
0
|
Source code in terratorch/models/heads/regression_head.py
terratorch.models.heads.segmentation_head
SegmentationHead
Bases: Module
Segmentation head
Source code in terratorch/models/heads/segmentation_head.py
__init__(in_channels, num_classes, channel_list=None, dropout=0)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
num_classes
|
int
|
Number of output classes |
required |
channel_list
|
list[int] | None
|
List with number of channels for each Conv layer to be created. Defaults to None. |
None
|
dropout
|
float
|
Dropout value to apply. Defaults to 0. |
0
|
Source code in terratorch/models/heads/segmentation_head.py
terratorch.models.heads.classification_head
ClassificationHead
Bases: Module
Classification head
Source code in terratorch/models/heads/classification_head.py
__init__(in_dim, num_classes, dim_list=None, dropout=0, linear_after_pool=False)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_dim
|
int
|
Input dimensionality |
required |
num_classes
|
int
|
Number of output classes |
required |
dim_list
|
list[int] | None
|
List with number of dimensions for each Linear layer to be created. Defaults to None. |
None
|
dropout
|
float
|
Dropout value to apply. Defaults to 0. |
0
|
linear_after_pool
|
bool
|
Apply pooling first, then apply the linear layer. Defaults to False |
False
|
Source code in terratorch/models/heads/classification_head.py
Auxiliary Heads
terratorch.models.model.AuxiliaryHead
dataclass
Class containing all information to create auxiliary heads.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name
|
str
|
Name of the head. Should match the name given to the auxiliary loss. |
required |
decoder
|
str
|
Name of the decoder class to be used. |
required |
decoder_args
|
dict | None
|
parameters to be passed to the decoder constructor.
Parameters for the decoder should be prefixed with |
required |
Source code in terratorch/models/model.py
Model Output
terratorch.models.model.ModelOutput
dataclass
Model Factory
terratorch.models.PrithviModelFactory
Bases: ModelFactory
Source code in terratorch/models/prithvi_model_factory.py
build_model(task, backbone, decoder, bands, in_channels=None, num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, **kwargs)
Model factory for prithvi models.
Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
backbone_
, decoder_
and head_
respectively.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task
|
str
|
Task to be performed. Currently supports "segmentation" and "regression". |
required |
backbone
|
(str, Module)
|
Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100". |
required |
decoder
|
Union[str, Module]
|
Decoder to be used for the segmentation model.
If a string, it will be created from a class exposed in decoder.init.py with the same name.
If an nn.Module, we expect it to expose a property |
required |
in_channels
|
int
|
Number of input channels. Defaults to 3. |
None
|
bands
|
list[HLSBands]
|
Bands the model will be trained on. Should be a list of terratorch.datasets.HLSBands. Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE]. |
required |
num_classes
|
int
|
Number of classes. None for regression tasks. |
None
|
pretrained
|
Union[bool, Path]
|
Whether to load pretrained weights for the backbone, if available. Defaults to True. |
True
|
num_frames
|
int
|
Number of timesteps for the model to handle. Defaults to 1. |
1
|
prepare_features_for_image_model
|
Callable | None
|
Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function. |
None
|
aux_decoders
|
list[AuxiliaryHead] | None
|
List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well. |
None
|
rescale
|
bool
|
Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True. |
True
|
Returns:
Type | Description |
---|---|
Model
|
nn.Module: Full model with encoder, decoder and head. |
Source code in terratorch/models/prithvi_model_factory.py
terratorch.models.SMPModelFactory
Bases: ModelFactory
Source code in terratorch/models/smp_model_factory.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
|
build_model(task, backbone, model, bands, in_channels=None, num_classes=1, pretrained=True, prepare_features_for_image_model=None, regression_relu=False, **kwargs)
Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.
This factory handles the instantiation of segmentation and regression models using specified encoders and decoders from the SMP library, along with custom modifications and extensions such as auxiliary decoders or modified encoders.
Attributes:
Name | Type | Description |
---|---|---|
task |
str
|
Specifies the task for which the model is being built. Supported tasks are "segmentation". |
backbone |
str
|
Specifies the backbone model to be used. |
decoder |
str
|
Specifies the decoder to be used for constructing the segmentation model. |
bands |
list[HLSBands | int]
|
A list specifying the bands that the model will operate on. These are expected to be from terratorch.datasets.HLSBands. |
in_channels |
int
|
Specifies the number of input channels. Defaults to None. |
num_classes |
int
|
The number of output classes for the model. |
pretrained |
bool | Path
|
Indicates whether to load pretrained weights for the backbone. Can also specify a path to weights. Defaults to True. |
num_frames |
int
|
Specifies the number of timesteps the model should handle. Useful for temporal models. |
regression_relu |
bool
|
Whether to apply ReLU activation in the case of regression tasks. |
**kwargs |
bool
|
Additional arguments that might be passed to further customize the backbone, decoder, or any auxiliary heads. These should be prefixed appropriately |
Raises:
Type | Description |
---|---|
ValueError
|
If the specified decoder is not supported by SMP. |
Exception
|
If the specified task is not "segmentation" |
Returns:
Type | Description |
---|---|
Model
|
nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified parameters and tasks. |
Source code in terratorch/models/smp_model_factory.py
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
|
Adding new model types
Adding new model types is as simple as creating a new factory that produces models. See for instance the example below for a potential SMPModelFactory
from terratorch.models.model import register_factory
@register_factory
class SMPModelFactory(ModelFactory):
def build_model(
self,
task: str,
backbone: str | nn.Module,
decoder: str | nn.Module,
in_channels: int,
**kwargs,
) -> Model:
model = smp.Unet(encoder_name="resnet34", encoder_weights=None, in_channels=in_channels, classes=1)
return SMPModelWrapper(model)
@register_factory
class SMPModelWrapper(Model, nn.Module):
def __init__(self, smp_model) -> None:
super().__init__()
self.smp_model = smp_model
def forward(self, *args, **kwargs):
return ModelOutput(self.smp_model(*args, **kwargs).squeeze(1))
def freeze_encoder(self):
pass
def freeze_decoder(self):
pass
Custom modules with CLI
Custom modules must be in the import path in order to be registered in the appropriate registries.
In order to do this without modifying the code when using the CLI, you may place your modules under a custom_modules
directory. This must be in the directory from which you execute terratorch.