Skip to content

EncoderDecoderFactory#

Check the Glossary for more information about the terms used in this page.

The EncoderDecoderFactory is the main class used to instantiate and compose models for general tasks.

This factory leverages the BACKBONE_REGISTRY, DECODER_REGISTRY and NECK_REGISTRY to compose models formed as encoder + decoder, with some optional glue in between provided by the necks. As most current models work this way, this is a particularly important factory, allowing for great flexibility in combining encoders and decoders from different sources.

The factory allows arguments to be passed to the encoder, decoder and head. Arguments with the prefix backbone_ will be routed to the backbone constructor, with decoder_ and head_ working the same way. These are accepted dynamically and not checked. Any unused arguments will raise a ValueError.

Both encoder and decoder may be passed as strings, in which case they will be looked in the respective registry, or as nn.Modules, in which case they will be used as is. In the second case, the factory assumes in good faith that the encoder or decoder which is passed conforms to the expected contract.

Not all decoders will readily accept the raw output of the given encoder. This is where necks come in. Necks are a sequence of operations which are applied to the output of the encoder before it is passed to the decoder. They must be instances of Neck, which is a subclass of nn.Module, meaning they can even define new trainable parameters.

The EncoderDecoderFactory returns a PixelWiseModel or a ScalarOutputModel depending on the task.

terratorch.models.encoder_decoder_factory.EncoderDecoderFactory #

Bases: ModelFactory

build_model(task, backbone, decoder, backbone_kwargs=None, decoder_kwargs=None, head_kwargs=None, num_classes=None, necks=None, aux_decoders=None, rescale=True, peft_config=None, **kwargs) #

Generic model factory that combines an encoder and decoder, together with a head, for a specific task.

Further arguments to be passed to the backbone, decoder or head. They should be prefixed with backbone_, decoder_ and head_ respectively.

Parameters:

Name Type Description Default
task str

Task to be performed. Currently supports "segmentation", "regression" and "classification".

required
backbone (str, Module)

Backbone to be used. If a string, will look for such models in the different registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it directly. The backbone should have and out_channels attribute and its forward should return a list[Tensor].

required
decoder Union[str, Module]

Decoder to be used for the segmentation model. If a string, will look for such decoders in the different registries supported (internal terratorch registry, smp, ...). If an nn.Module, we expect it to expose a property decoder.out_channels. Pixel wise tasks will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".

required
backbone_kwargs dict, optional)

Arguments to be passed to instantiate the backbone.

None
decoder_kwargs dict, optional)

Arguments to be passed to instantiate the decoder.

None
head_kwargs dict, optional)

Arguments to be passed to the head network.

None
num_classes int

Number of classes. None for regression tasks.

None
necks list[dict]

nn.Modules to be called in succession on encoder features before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry. Expects each one to have a key "name" and subsequent keys for arguments, if any. Defaults to None, which applies the identity function.

None
aux_decoders list[AuxiliaryHead] | None

List of AuxiliaryHead decoders to be added to the model. These decoders take the input from the encoder as well.

None
rescale bool

Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

True
peft_config dict

Configuration options for using PEFT. The dictionary should have the following keys:

  • "method": Which PEFT method to use. Should be one implemented in PEFT, a list is available here.
  • "replace_qkv": String containing a substring of the name of the submodules to replace with QKVSep. This should be used when the qkv matrices are merged together in a single linear layer and the PEFT method should be applied separately to query, key and value matrices (e.g. if LoRA is only desired in Q and V matrices). e.g. If using Prithvi this should be "qkv"
  • "peft_config_kwargs": Dictionary containing keyword arguments which will be passed to PeftConfig
None

Returns:

Type Description
Model

nn.Module: Full model with encoder, decoder and head.

Encoders#

To be a valid encoder, an object must be an nn.Module and contain an attribute out_channels, basically a list of the channel dimensions corresponding to the features it returns. The forward method of any encoder should return a list of torch.Tensor.

In [19]: backbone = BACKBONE_REGISTRY.build("prithvi_eo_v2_300", pretrained=True)

In [20]: import numpy as np

In [21]: import torch

In [22]: input_image = torch.tensor(np.random.rand(1,6,224,224).astype("float32"))

In [23]: output = backbone.forward(input_image)

In [24]: [item.shape for item in output]

Out[24]: 

[torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024]),
 torch.Size([1, 197, 1024])]

Necks#

Necks are the connectors between encoder and decoder. They can perform operations such as selecting elements from the output of the encoder (SelectIndices), reshaping the outputs of ViTs so they are compatible with CNNs (ReshapeTokensToImage), amongst others. Necks are nn.Modules, with an additional method process_channel_list which informs the EncoderDecoderFactory about how it will alter the channel list provided by encoder.out_channels. See a better description about necks here.

Decoders#

To be a valid decoder, an object must be an nn.Module with an attribute out_channels, an int representing the channel dimension of the output. The first argument to its constructor will be a list of channel dimensions it should expect as input. It's forward method should accept a list of embeddings. To see a list of built-in decoders check the related documentation.

Heads#

Most decoders require a final head to be added for a specific task (e.g. semantic segmentation vs pixel wise regression). Those registries producing decoders that dont require a head must expose the attribute includes_head=True so that a head is not added. Decoders passed as nn.Modules which do not require a head must expose the same attribute themselves. More about heads can be seen in its documentation.

Decoder compatibilities#

Not all encoders and decoders are compatible. Below we include some caveats. Some decoders expect pyramidal outputs, but some encoders do not produce such outputs (e.g. vanilla ViT models). In this case, the InterpolateToPyramidal, MaxpoolToPyramidal and LearnedInterpolateToPyramidal necks may be particularly useful.

SMP decoders#

Not all decoders are guaranteed to work with all encoders without additional necks. Please check smp documentation to understand the embedding spatial dimensions expected by each decoder.

In particular, smp seems to assume the first feature in the passed feature list has the same spatial resolution as the input, which may not always be true, and may break some decoders.

In addition, for some decoders, the final 2 features have the same spatial resolution. Adding the AddBottleneckLayer neck will make this compatible.

Some smp decoders require additional parameters, such as decoder_channels. These must be passed through the factory. In the case of decoder_channels, it would be passed as decoder_decoder_channels (the first decoder_ routes the parameter to the decoder, where it is passed as decoder_channels).

MMSegmentation decoders#

MMSegmentation decoders are available through the BACKBONE_REGISTRY.

Warning

MMSegmentation currently requires mmcv==2.1.0. Pre-built wheels for this only exist for torch==2.1.0. In order to use mmseg without building from source, you must downgrade your torch to this version. Install mmseg with:

pip install -U openmim
mim install mmengine
mim install mmcv==2.1.0
pip install regex ftfy mmsegmentation

We provide access to mmseg decoders as an external source of decoders, but are not directly responsible for the maintainence of that library.

Some mmseg decoders require the parameter in_index, which performs the same function as the SelectIndices neck. For use for pixel wise regression, mmseg decoders should take num_classes=1.


Last update: March 24, 2025