Model Factories
Model factories build the model that is fine-tuned by TerraTorch. Specifically, a backbone is used an encoder and combined with a task-specific decoder and head. Necks are using to reshape the encoder output to be compatible with the decoder input.
Tip
The EncoderDecoderFactory
is the default factory for segmentation, pixel-wise regression, and classification tasks.
Other commonly used factories are the ObjectDetectionModelFactory
for object detection tasks and sometimes the FullModelFactory
if a model is registered in the FULL_MODEL_REGISTRY
and can be directly applied to a specific task.
terratorch.models.encoder_decoder_factory.EncoderDecoderFactory
#
Bases: ModelFactory
Source code in terratorch/models/encoder_decoder_factory.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
|
build_model(task, backbone, decoder, backbone_kwargs=None, decoder_kwargs=None, head_kwargs=None, num_classes=None, necks=None, aux_decoders=None, rescale=True, peft_config=None, **kwargs)
#
Generic model factory that combines an encoder and decoder, together with a head, for a specific task.
Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
backbone_
, decoder_
and head_
respectively.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task
|
str
|
Task to be performed. Currently supports "segmentation", "regression" and "classification". |
required |
backbone
|
(str, Module)
|
Backbone to be used. If a string, will look for such models in the different
registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it
directly. The backbone should have and |
required |
decoder
|
Union[str, Module]
|
Decoder to be used for the segmentation model.
If a string, will look for such decoders in the different
registries supported (internal terratorch registry, smp, ...).
If an nn.Module, we expect it to expose a property |
required |
backbone_kwargs
|
dict, optional)
|
Arguments to be passed to instantiate the backbone. |
None
|
decoder_kwargs
|
dict, optional)
|
Arguments to be passed to instantiate the decoder. |
None
|
head_kwargs
|
dict, optional)
|
Arguments to be passed to the head network. |
None
|
num_classes
|
int
|
Number of classes. None for regression tasks. |
None
|
necks
|
list[dict]
|
nn.Modules to be called in succession on encoder features before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry. Expects each one to have a key "name" and subsequent keys for arguments, if any. Defaults to None, which applies the identity function. |
None
|
aux_decoders
|
list[AuxiliaryHead] | None
|
List of AuxiliaryHead decoders to be added to the model. These decoders take the input from the encoder as well. |
None
|
rescale
|
bool
|
Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True. |
True
|
peft_config
|
dict
|
Configuration options for using PEFT. The dictionary should have the following keys:
|
None
|
Returns:
Type | Description |
---|---|
Model
|
nn.Module: Full model with encoder, decoder and head. |
Source code in terratorch/models/encoder_decoder_factory.py
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
|
terratorch.models.object_detection_model_factory.ObjectDetectionModelFactory
#
Bases: ModelFactory
Source code in terratorch/models/object_detection_model_factory.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
|
build_model(task, backbone, framework, num_classes=None, necks=None, **kwargs)
#
Generic model factory that combines an encoder and necks with the detection models, called framework, in torchvision.detection.
Further arguments to be passed to the backbone_ and framework_.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task
|
str
|
Task to be performed. Currently supports "object_detection". |
required |
backbone
|
(str, Module)
|
Backbone to be used. If a string, will look for such models in the different
registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it
directly. The backbone should have and |
required |
framework
|
str
|
object detection framework to be used between "faster-rcnn", "fcos", "retinanet" for object detection and "mask-rcnn" for instance segmentation. |
required |
num_classes
|
int
|
Number of classes. None for regression tasks. |
None
|
necks
|
list[dict]
|
nn.Modules to be called in succession on encoder features before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry. Expects each one to have a key "name" and subsequent keys for arguments, if any. Defaults to None, which applies the identity function. |
None
|
Returns:
Type | Description |
---|---|
Model
|
nn.Module: Full torchvision detection model. |
Source code in terratorch/models/object_detection_model_factory.py
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
|
terratorch.models.full_model_factory.FullModelFactory
#
Bases: ModelFactory
Source code in terratorch/models/full_model_factory.py
build_model(model, rescale=True, padding='reflect', peft_config=None, **kwargs)
#
Generic model factory that wraps any model.
All kwargs are passed to the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task
|
str
|
Task to be performed. Currently supports "segmentation" and "regression". |
required |
model
|
(str, Module)
|
Model to be used. If a string, will look for such models in the different registries supported (internal terratorch registry, ...). If a torch nn.Module, will use it directly. |
required |
rescale
|
bool
|
Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression, reconstruction). Defaults to True. |
True
|
padding
|
str
|
Padding method used if images are not divisible by the patch size. Defaults to "reflect". |
'reflect'
|
peft_config
|
dict
|
Configuration options for using PEFT. The dictionary should have the following keys: - "method": Which PEFT method to use. Should be one implemented in PEFT, a list is available here. - "replace_qkv": String containing a substring of the name of the submodules to replace with QKVSep. This should be used when the qkv matrices are merged together in a single linear layer and the PEFT method should be applied separately to query, key and value matrices (e.g. if LoRA is only desired in Q and V matrices). e.g. If using Prithvi this should be "qkv" - "peft_config_kwargs": Dictionary containing keyword arguments which will be passed to PeftConfig |
None
|
Returns:
Type | Description |
---|---|
Module
|
nn.Module: Full model. |
Source code in terratorch/models/full_model_factory.py
terratorch.models.smp_model_factory.SMPModelFactory
#
Bases: ModelFactory
Source code in terratorch/models/smp_model_factory.py
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
|
build_model(task, backbone, model, bands, in_channels=None, num_classes=1, pretrained=True, prepare_features_for_image_model=None, regression_relu=False, **kwargs)
#
Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.
This factory handles the instantiation of segmentation and regression models using specified encoders and decoders from the SMP library, along with custom modifications and extensions such as auxiliary decoders or modified encoders.
Attributes:
Name | Type | Description |
---|---|---|
task |
str
|
Specifies the task for which the model is being built. Supported tasks are "segmentation". |
backbone |
str
|
Specifies the backbone model to be used. |
decoder |
str
|
Specifies the decoder to be used for constructing the segmentation model. |
bands |
list[HLSBands | int]
|
A list specifying the bands that the model will operate on. These are expected to be from terratorch.datasets.HLSBands. |
in_channels |
int
|
Specifies the number of input channels. Defaults to None. |
num_classes |
int
|
The number of output classes for the model. |
pretrained |
bool | Path
|
Indicates whether to load pretrained weights for the backbone. Can also specify a path to weights. Defaults to True. |
num_frames |
int
|
Specifies the number of timesteps the model should handle. Useful for temporal models. |
regression_relu |
bool
|
Whether to apply ReLU activation in the case of regression tasks. |
**kwargs |
bool
|
Additional arguments that might be passed to further customize the backbone, decoder, or any auxiliary heads. These should be prefixed appropriately |
Raises:
Type | Description |
---|---|
ValueError
|
If the specified decoder is not supported by SMP. |
Exception
|
If the specified task is not "segmentation" |
Returns:
Type | Description |
---|---|
Model
|
nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified parameters and tasks. |
Source code in terratorch/models/smp_model_factory.py
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
|
terratorch.models.timm_model_factory.TimmModelFactory
#
Bases: ModelFactory
Source code in terratorch/models/timm_model_factory.py
build_model(task, backbone, in_channels, num_classes, pretrained=True, **kwargs)
#
Build a classifier from timm
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task
|
str
|
Must be "classification". |
required |
backbone
|
str
|
Name of the backbone in timm. |
required |
in_channels
|
int
|
Number of input channels. |
required |
num_classes
|
int
|
Number of classes. |
required |
Returns:
Name | Type | Description |
---|---|---|
Model |
Model
|
Timm model wrapped in TimmModelWrapper. |
Source code in terratorch/models/timm_model_factory.py
terratorch.models.generic_model_factory.GenericModelFactory
#
Bases: ModelFactory
Source code in terratorch/models/generic_model_factory.py
build_model(backbone=None, in_channels=6, pretrained=True, **kwargs)
#
Factory to create models from any custom module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
str
|
The name for the model class. |
required |
in_channels
|
int
|
Number of input channels. |
6
|
pretrained(str
|
| bool
|
Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True. |
required |
Returns:
Name | Type | Description |
---|---|---|
Model |
Model
|
A wrapped generic model. |
Source code in terratorch/models/generic_model_factory.py
terratorch.models.clay_model_factory.ClayModelFactory
#
Bases: ModelFactory
Source code in terratorch/models/clay_model_factory.py
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
|
build_model(task, backbone, decoder, in_channels, bands=[], num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, checkpoint_path=None, **kwargs)
#
Model factory for Clay models.
Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
backbone_
, decoder_
and head_
respectively.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task
|
str
|
Task to be performed. Currently supports "segmentation" and "regression". |
required |
backbone
|
(str, Module)
|
Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100". |
required |
decoder
|
Union[str, Module]
|
Decoder to be used for the segmentation model.
If a string, it will be created from a class exposed in decoder.init.py with the same name.
If an nn.Module, we expect it to expose a property |
required |
in_channels
|
int
|
Number of input channels. Defaults to 3. |
required |
num_classes
|
int
|
Number of classes. None for regression tasks. |
None
|
pretrained
|
Union[bool, Path]
|
Whether to load pretrained weights for the backbone, if available. Defaults to True. |
True
|
num_frames
|
int
|
Number of timesteps for the model to handle. Defaults to 1. |
1
|
prepare_features_for_image_model
|
Callable | None
|
Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function. |
None
|
aux_decoders
|
list[AuxiliaryHead] | None
|
List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well. |
None
|
rescale
|
bool
|
Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True. |
True
|
Raises:
Type | Description |
---|---|
NotImplementedError
|
description |
DecoderNotFoundException
|
description |
Returns:
Type | Description |
---|---|
Model
|
nn.Module: description |
Source code in terratorch/models/clay_model_factory.py
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
|
terratorch.models.generic_unet_model_factory.GenericUnetModelFactory
#
Bases: ModelFactory
Source code in terratorch/models/generic_unet_model_factory.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
|
build_model(task='segmentation', backbone=None, decoder=None, dilations=(1, 6, 12, 18), in_channels=6, pretrained=True, regression_relu=False, **kwargs)
#
Factory to create model based on mmseg.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task
|
str
|
Must be "segmentation". |
'segmentation'
|
model
|
str
|
Decoder architecture. Currently only supports "unet". |
required |
in_channels
|
int
|
Number of input channels. |
6
|
pretrained(str
|
| bool
|
Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True. |
required |
Returns:
Name | Type | Description |
---|---|---|
Model |
Model
|
UNet model. |
Source code in terratorch/models/generic_unet_model_factory.py
terratorch.models.satmae_model_factory.SatMAEModelFactory
#
Bases: ModelFactory
Source code in terratorch/models/satmae_model_factory.py
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
|
build_model(task, backbone, decoder, in_channels, bands, num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, checkpoint_path=None, **kwargs)
#
Model factory for SatMAE models.
Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
backbone_
, decoder_
and head_
respectively.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task
|
str
|
Task to be performed. Currently supports "segmentation" and "regression". |
required |
backbone
|
(str, Module)
|
Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100". |
required |
decoder
|
Union[str, Module]
|
Decoder to be used for the segmentation model.
If a string, it will be created from a class exposed in decoder.init.py with the same name.
If an nn.Module, we expect it to expose a property |
required |
in_channels
|
int
|
Number of input channels. Defaults to 3. |
required |
bands
|
list[HLSBands]
|
Bands the model will be trained on. Should be a list of terratorch.datasets.HLSBands. Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE]. |
required |
num_classes
|
int
|
Number of classes. None for regression tasks. |
None
|
pretrained
|
Union[bool, Path]
|
Whether to load pretrained weights for the backbone, if available. Defaults to True. |
True
|
num_frames
|
int
|
Number of timesteps for the model to handle. Defaults to 1. |
1
|
prepare_features_for_image_model
|
Callable | None
|
Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function. |
None
|
aux_decoders
|
list[AuxiliaryHead] | None
|
List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well. |
None
|
rescale
|
bool
|
Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True. |
True
|
Raises:
Type | Description |
---|---|
NotImplementedError
|
description |
DecoderNotFoundException
|
description |
Returns:
Type | Description |
---|---|
Model
|
nn.Module: description |
Source code in terratorch/models/satmae_model_factory.py
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
|
Base factory: