Models
Prithvi backbones
We provide access to the Prithvi backbones through integration with timm
.
By passing features_only=True
, you can conveniently get access to a model that outputs the features produced at each layer of the model.
Passing features_only=False
will let you access the full original model.
import timm
import terratorch # even though we don't use the import directly, we need it so that the models are available in the timm registry
# find available prithvi models by name
print(timm.list_models("prithvi*"))
# and those with pretrained weights
print(timm.list_pretrained("prithvi*"))
# instantiate your desired model with features_only=True to obtain a backbone
model = timm.create_model(
"prithvi_vit_100", num_frames=1, pretrained=True, features_only=True
)
# instantiate your model with weights of your own
model = timm.create_model(
"prithvi_vit_100", num_frames=1, pretrained=True, pretrained_cfg_overlay={"file": "<path to weights>"}, features_only=True
)
# Rest of your PyTorch / PyTorchLightning code
We also provide a model factory that can build a task specific model for a downstream task based on a Prithvi backbone.
By passing a list of bands being used to the constructor, we automatically filter out unused bands, and randomly initialize weights for new bands that were not pretrained on.
Info
To pass your own path from where to load the weights with the PrithviModelFactory, you can make use of timm's pretrained_cfg_overlay
.
E.g. to pass a local path, you can pass the parameter backbone_pretrained_cfg_overlay = {"file": "<local_path>"}
to the model factory.
Besides file
, you can also pass url
, hf_hub_id
, amongst others. Check timm's documentation for full details.
terratorch.models.backbones.prithvi_select_patch_embed_weights
prithvi_select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands)
Filter out the patch embedding weights according to the bands being used. If a band exists in the pretrained_bands, but not in model_bands, drop it. If a band exists in model_bands, but not pretrained_bands, randomly initialize those weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state_dict |
dict
|
State Dict |
required |
model |
Module
|
Model to load the weights onto. |
required |
pretrained_bands |
list[HLSBands]
|
List of bands the model was pretrained on, in the correct order. |
required |
model_bands |
list[HLSBands]
|
List of bands the model is going to be finetuned on, in the correct order |
required |
Returns:
Name | Type | Description |
---|---|---|
dict |
dict
|
New state dict |
Source code in terratorch/models/backbones/prithvi_select_patch_embed_weights.py
Decoders
terratorch.models.decoders.fcn_decoder
FCNDecoder
Bases: Module
Fully Convolutional Decoder
Source code in terratorch/models/decoders/fcn_decoder.py
__init__(embed_dim, channels=256, num_convs=4, in_index=-1)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embed_dim |
_type_
|
Input embedding dimension |
required |
channels |
int
|
Number of channels for each conv. Defaults to 256. |
256
|
num_convs |
int
|
Number of convs. Defaults to 4. |
4
|
in_index |
int
|
Index of the input list to take. Defaults to -1. |
-1
|
Source code in terratorch/models/decoders/fcn_decoder.py
terratorch.models.decoders.identity_decoder
Pass the features straight through
IdentityDecoder
Bases: Module
Identity decoder. Useful to pass the feature straight to the head.
Source code in terratorch/models/decoders/identity_decoder.py
__init__(embed_dim, out_index=-1)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embed_dim |
int
|
Input embedding dimension |
required |
out_index |
int
|
Index of the input list to take.. Defaults to -1. |
-1
|
Source code in terratorch/models/decoders/identity_decoder.py
terratorch.models.decoders.upernet_decoder
PPM
Bases: ModuleList
Pooling Pyramid Module used in PSPNet.
Source code in terratorch/models/decoders/upernet_decoder.py
__init__(pool_scales, in_channels, channels, align_corners)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pool_scales |
tuple[int]
|
Pooling scales used in Pooling Pyramid Module. |
required |
in_channels |
int
|
Input channels. |
required |
channels |
int
|
Channels after modules, before conv_seg. |
required |
align_corners |
bool
|
align_corners argument of F.interpolate. |
required |
Source code in terratorch/models/decoders/upernet_decoder.py
forward(x)
Forward function.
Source code in terratorch/models/decoders/upernet_decoder.py
UperNetDecoder
Bases: Module
UperNetDecoder. Adapted from MMSegmentation.
Source code in terratorch/models/decoders/upernet_decoder.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
|
__init__(embed_dim, pool_scales=(1, 2, 3, 6), channels=256, align_corners=True, scale_modules=False)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embed_dim |
list[int]
|
Input embedding dimension for each input. |
required |
pool_scales |
tuple[int]
|
Pooling scales used in Pooling Pyramid Module applied on the last feature. Default: (1, 2, 3, 6). |
(1, 2, 3, 6)
|
channels |
int
|
Channels used in the decoder. Defaults to 256. |
256
|
align_corners |
bool
|
Wheter to align corners in rescaling. Defaults to True. |
True
|
scale_modules |
bool
|
Whether to apply scale modules to the inputs. Needed for plain ViT. Defaults to False. |
False
|
Source code in terratorch/models/decoders/upernet_decoder.py
forward(inputs)
Forward function for feature maps before classifying each pixel with Args: inputs (list[Tensor]): List of multi-level img features.
Returns:
Name | Type | Description |
---|---|---|
feats |
Tensor
|
A tensor of shape (batch_size, self.channels, H, W) which is feature map for last layer of decoder head. |
Source code in terratorch/models/decoders/upernet_decoder.py
psp_forward(inputs)
Forward function of PSP module.
Source code in terratorch/models/decoders/upernet_decoder.py
Heads
terratorch.models.heads.regression_head
RegressionHead
Bases: Module
Regression head
Source code in terratorch/models/heads/regression_head.py
__init__(in_channels, final_act=None, learned_upscale_layers=0, channel_list=None, batch_norm=True, dropout=0)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_channels |
int
|
Number of input channels |
required |
final_act |
Module | None
|
Final activation to be applied. Defaults to None. |
None
|
learned_upscale_layers |
int
|
Number of Pixelshuffle layers to create. Each upscales 2x. Defaults to 0. |
0
|
channel_list |
list[int] | None
|
List with number of channels for each Conv layer to be created. Defaults to None. |
None
|
batch_norm |
bool
|
Whether to apply batch norm. Defaults to True. |
True
|
dropout |
float
|
Dropout value to apply. Defaults to 0. |
0
|
Source code in terratorch/models/heads/regression_head.py
terratorch.models.heads.segmentation_head
SegmentationHead
Bases: Module
Segmentation head
Source code in terratorch/models/heads/segmentation_head.py
__init__(in_channels, num_classes, channel_list=None, dropout=0)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_channels |
int
|
Number of input channels |
required |
num_classes |
int
|
Number of output classes |
required |
channel_list |
list[int] | None
|
List with number of channels for each Conv layer to be created. Defaults to None. |
None
|
dropout |
float
|
Dropout value to apply. Defaults to 0. |
0
|
Source code in terratorch/models/heads/segmentation_head.py
terratorch.models.heads.classification_head
ClassificationHead
Bases: Module
Classification head
Source code in terratorch/models/heads/classification_head.py
__init__(in_dim, num_classes, dim_list=None, dropout=0, linear_after_pool=False)
Constructor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_dim |
int
|
Input dimensionality |
required |
num_classes |
int
|
Number of output classes |
required |
dim_list |
list[int] | None
|
List with number of dimensions for each Linear layer to be created. Defaults to None. |
None
|
dropout |
float
|
Dropout value to apply. Defaults to 0. |
0
|
linear_after_pool |
bool
|
Apply pooling first, then apply the linear layer. Defaults to False |
False
|
Source code in terratorch/models/heads/classification_head.py
Auxiliary Heads
terratorch.models.model.AuxiliaryHead
dataclass
Class containing all information to create auxiliary heads.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str
|
Name of the head. Should match the name given to the auxiliary loss. |
required |
decoder |
str
|
Name of the decoder class to be used. |
required |
decoder_args |
dict | None
|
parameters to be passed to the decoder constructor.
Parameters for the decoder should be prefixed with |
required |
Source code in terratorch/models/model.py
Model Output
terratorch.models.model.ModelOutput
dataclass
Model Factory
terratorch.models.PrithviModelFactory
Bases: ModelFactory
Source code in terratorch/models/prithvi_model_factory.py
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
|
build_model(task, backbone, decoder, bands, in_channels=None, num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, **kwargs)
Model factory for prithvi models.
Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
backbone_
, decoder_
and head_
respectively.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task |
str
|
Task to be performed. Currently supports "segmentation" and "regression". |
required |
backbone |
(str, Module)
|
Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100". |
required |
decoder |
Union[str, Module]
|
Decoder to be used for the segmentation model.
If a string, it will be created from a class exposed in decoder.init.py with the same name.
If an nn.Module, we expect it to expose a property |
required |
in_channels |
int
|
Number of input channels. Defaults to 3. |
None
|
bands |
list[HLSBands]
|
Bands the model will be trained on. Should be a list of terratorch.datasets.HLSBands. Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE]. |
required |
num_classes |
int
|
Number of classes. None for regression tasks. |
None
|
pretrained |
Union[bool, Path]
|
Whether to load pretrained weights for the backbone, if available. Defaults to True. |
True
|
num_frames |
int
|
Number of timesteps for the model to handle. Defaults to 1. |
1
|
prepare_features_for_image_model |
Callable | None
|
Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function. |
None
|
aux_decoders |
list[AuxiliaryHead] | None
|
List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well. |
None
|
rescale |
bool
|
Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True. |
True
|
Returns:
Type | Description |
---|---|
Model
|
nn.Module: Full model with encoder, decoder and head. |
Source code in terratorch/models/prithvi_model_factory.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
|
terratorch.models.SMPModelFactory
Bases: ModelFactory
Source code in terratorch/models/smp_model_factory.py
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
|
build_model(task, backbone, model, bands, in_channels=None, num_classes=1, pretrained=True, prepare_features_for_image_model=None, regression_relu=False, **kwargs)
Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.
This factory handles the instantiation of segmentation and regression models using specified encoders and decoders from the SMP library, along with custom modifications and extensions such as auxiliary decoders or modified encoders.
Attributes:
Name | Type | Description |
---|---|---|
task |
str
|
Specifies the task for which the model is being built. Supported tasks are "segmentation". |
backbone |
str
|
Specifies the backbone model to be used. |
decoder |
str
|
Specifies the decoder to be used for constructing the segmentation model. |
bands |
list[HLSBands | int]
|
A list specifying the bands that the model will operate on. These are expected to be from terratorch.datasets.HLSBands. |
in_channels |
int
|
Specifies the number of input channels. Defaults to None. |
num_classes |
int
|
The number of output classes for the model. |
pretrained |
bool | Path
|
Indicates whether to load pretrained weights for the backbone. Can also specify a path to weights. Defaults to True. |
num_frames |
int
|
Specifies the number of timesteps the model should handle. Useful for temporal models. |
regression_relu |
bool
|
Whether to apply ReLU activation in the case of regression tasks. |
**kwargs |
bool
|
Additional arguments that might be passed to further customize the backbone, decoder, or any auxiliary heads. These should be prefixed appropriately |
Raises:
Type | Description |
---|---|
ValueError
|
If the specified decoder is not supported by SMP. |
Exception
|
If the specified task is not "segmentation" |
Returns:
Type | Description |
---|---|
Model
|
nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified parameters and tasks. |
Source code in terratorch/models/smp_model_factory.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
|
Adding new model types
Adding new model types is as simple as creating a new factory that produces models. See for instance the example below for a potential SMPModelFactory
from terratorch.models.model import register_factory
@register_factory
class SMPModelFactory(ModelFactory):
def build_model(
self,
task: str,
backbone: str | nn.Module,
decoder: str | nn.Module,
in_channels: int,
**kwargs,
) -> Model:
model = smp.Unet(encoder_name="resnet34", encoder_weights=None, in_channels=in_channels, classes=1)
return SMPModelWrapper(model)
@register_factory
class SMPModelWrapper(Model, nn.Module):
def __init__(self, smp_model) -> None:
super().__init__()
self.smp_model = smp_model
def forward(self, *args, **kwargs):
return ModelOutput(self.smp_model(*args, **kwargs).squeeze(1))
def freeze_encoder(self):
pass
def freeze_decoder(self):
pass