Models
To interface with terratorch tasks correctly, models must conform to the Model ABC:
terratorch.models.model.Model
Bases: ABC
, Module
Source code in terratorch/models/model.py
and have a forward method which returns a ModelOutput:
terratorch.models.model.ModelOutput
dataclass
Model Factories
In order to be used by tasks, models must have a Model Factory which builds them.
Factories must conform to the ModelFactory ABC:
terratorch.models.model.ModelFactory
You most likely do not need to implement your own model factory, unless you are wrapping another library which generates full models.
For most cases, the encoder decoder factory can be used to combine a backbone with a decoder.
To add new backbones or decoders, to be used with the encoder decoder factory they should be registered.
To add a new model factory, it should be registered in the MODEL_FACTORY_REGISTRY
.
Adding a new model
To add a new backbone, simply create a class and annotate it (or a constructor function that instantiates it) with @TERRATORCH_BACKBONE_FACTORY.register
.
The model will be registered with the same name as the function. To create many model variants from the same class, the reccomended approach is to annotate a constructor function from each with a fully descriptive name.
from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY, BACKBONE_REGISTRY
from torch import nn
# make sure this is in the import path for terratorch
@TERRATORCH_BACKBONE_REGISTRY.register
class BasicBackbone(nn.Module):
def __init__(self, out_channels=64):
super().__init__()
self.flatten = nn.Flatten()
self.layer = nn.Linear(224*224, out_channels)
self.out_channels = [out_channels]
def forward(self, x):
return self.layer(self.flatten(x))
# you can build directly with the TERRATORCH_BACKBONE_REGISTRY
# but typically this will be accessed from the BACKBONE_REGISTRY
>>> BACKBONE_REGISTRY.build("BasicBackbone", out_channels=64)
BasicBackbone(
(flatten): Flatten(start_dim=1, end_dim=-1)
(layer): Linear(in_features=50176, out_features=64, bias=True)
)
@TERRATORCH_BACKBONE_REGISTRY.register
def basic_backbone_128():
return BasicBackbone(out_channels=128)
>>> BACKBONE_REGISTRY.build("basic_backbone_128")
BasicBackbone(
(flatten): Flatten(start_dim=1, end_dim=-1)
(layer): Linear(in_features=50176, out_features=128, bias=True)
)
Adding a new decoder can be done in the same way with the TERRATORCH_DECODER_REGISTRY
.
Info
All decoders will be passed the channel_list as the first argument for initialization.
To pass your own path from where to load the weights with the PrithviModelFactory, you can make use of timm's pretrained_cfg_overlay
.
E.g. to pass a local path, you can pass the parameter backbone_pretrained_cfg_overlay = {"file": "<local_path>"}
to the model factory.
Besides file
, you can also pass url
, hf_hub_id
, amongst others. Check timm's documentation for full details.
terratorch.models.backbones.select_patch_embed_weights
select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands)
Filter out the patch embedding weights according to the bands being used. If a band exists in the pretrained_bands, but not in model_bands, drop it. If a band exists in model_bands, but not pretrained_bands, randomly initialize those weights.
Parameters: |
|
---|
Returns: |
|
---|
Source code in terratorch/models/backbones/select_patch_embed_weights.py
Decoders
terratorch.models.decoders.fcn_decoder
FCNDecoder
Bases: Module
Fully Convolutional Decoder
Source code in terratorch/models/decoders/fcn_decoder.py
__init__(embed_dim, channels=256, num_convs=4, in_index=-1)
Constructor
Parameters: |
|
---|
Source code in terratorch/models/decoders/fcn_decoder.py
terratorch.models.decoders.identity_decoder
Pass the features straight through
IdentityDecoder
Bases: Module
Identity decoder. Useful to pass the feature straight to the head.
Source code in terratorch/models/decoders/identity_decoder.py
__init__(embed_dim, out_index=-1)
Constructor
Parameters: |
|
---|
Source code in terratorch/models/decoders/identity_decoder.py
terratorch.models.decoders.upernet_decoder
PPM
Bases: ModuleList
Pooling Pyramid Module used in PSPNet.
Source code in terratorch/models/decoders/upernet_decoder.py
__init__(pool_scales, in_channels, channels, align_corners)
Constructor
Parameters: |
|
---|
Source code in terratorch/models/decoders/upernet_decoder.py
forward(x)
Forward function.
Source code in terratorch/models/decoders/upernet_decoder.py
UperNetDecoder
Bases: Module
UperNetDecoder. Adapted from MMSegmentation.
Source code in terratorch/models/decoders/upernet_decoder.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
|
__init__(embed_dim, pool_scales=(1, 2, 3, 6), channels=256, align_corners=True, scale_modules=False)
Constructor
Parameters: |
|
---|
Source code in terratorch/models/decoders/upernet_decoder.py
forward(inputs)
Forward function for feature maps before classifying each pixel with Args: inputs (list[Tensor]): List of multi-level img features.
Returns: |
|
---|
Source code in terratorch/models/decoders/upernet_decoder.py
psp_forward(inputs)
Forward function of PSP module.
Source code in terratorch/models/decoders/upernet_decoder.py
Heads
terratorch.models.heads.regression_head
RegressionHead
Bases: Module
Regression head
Source code in terratorch/models/heads/regression_head.py
__init__(in_channels, final_act=None, learned_upscale_layers=0, channel_list=None, batch_norm=True, dropout=0)
Constructor
Parameters: |
|
---|
Source code in terratorch/models/heads/regression_head.py
terratorch.models.heads.segmentation_head
SegmentationHead
Bases: Module
Segmentation head
Source code in terratorch/models/heads/segmentation_head.py
__init__(in_channels, num_classes, channel_list=None, dropout=0)
Constructor
Parameters: |
|
---|
Source code in terratorch/models/heads/segmentation_head.py
terratorch.models.heads.classification_head
ClassificationHead
Bases: Module
Classification head
Source code in terratorch/models/heads/classification_head.py
__init__(in_dim, num_classes, dim_list=None, dropout=0, linear_after_pool=False)
Constructor
Parameters: |
|
---|
Source code in terratorch/models/heads/classification_head.py
Auxiliary Heads
terratorch.models.model.AuxiliaryHead
dataclass
Class containing all information to create auxiliary heads.
Parameters: |
|
---|
Source code in terratorch/models/model.py
Model Output
terratorch.models.model.ModelOutput
dataclass
Model Factory
terratorch.models.PrithviModelFactory
Bases: ModelFactory
Source code in terratorch/models/prithvi_model_factory.py
build_model(task, backbone, decoder, bands, in_channels=None, num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, **kwargs)
Model factory for prithvi models.
Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
backbone_
, decoder_
and head_
respectively.
Parameters: |
|
---|
Returns: |
|
---|
Source code in terratorch/models/prithvi_model_factory.py
terratorch.models.SMPModelFactory
Bases: ModelFactory
Source code in terratorch/models/smp_model_factory.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
|
build_model(task, backbone, model, bands, in_channels=None, num_classes=1, pretrained=True, prepare_features_for_image_model=None, regression_relu=False, **kwargs)
Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.
This factory handles the instantiation of segmentation and regression models using specified encoders and decoders from the SMP library, along with custom modifications and extensions such as auxiliary decoders or modified encoders.
Attributes: |
|
---|
Raises: |
|
---|
Returns: |
|
---|
Source code in terratorch/models/smp_model_factory.py
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
|
Adding new model types
Adding new model types is as simple as creating a new factory that produces models. See for instance the example below for a potential SMPModelFactory
from terratorch.models.model import register_factory
@register_factory
class SMPModelFactory(ModelFactory):
def build_model(
self,
task: str,
backbone: str | nn.Module,
decoder: str | nn.Module,
in_channels: int,
**kwargs,
) -> Model:
model = smp.Unet(encoder_name="resnet34", encoder_weights=None, in_channels=in_channels, classes=1)
return SMPModelWrapper(model)
@register_factory
class SMPModelWrapper(Model, nn.Module):
def __init__(self, smp_model) -> None:
super().__init__()
self.smp_model = smp_model
def forward(self, *args, **kwargs):
return ModelOutput(self.smp_model(*args, **kwargs).squeeze(1))
def freeze_encoder(self):
pass
def freeze_decoder(self):
pass
Custom modules with CLI
Custom modules must be in the import path in order to be registered in the appropriate registries.
In order to do this without modifying the code when using the CLI, you may place your modules under a custom_modules
directory. This must be in the directory from which you execute terratorch.