Skip to content

EncoderDecoderFactory

The Factory

A special factory provided by terratorch is the EncoderDecoderFactory.

This factory leverages the BACKBONE_REGISTRY, DECODER_REGISTRY and NECK_REGISTRY to compose models formed as encoder + decoder, with some optional glue in between provided by the necks. As most current models work this way, this is a particularly important factory, allowing for great flexibility in combining encoders and decoders from different sources.

The factory allows arguments to be passed to the encoder, decoder and head. Arguments with the prefix backbone_ will be routed to the backbone constructor, with decoder_ and head_ working the same way. These are accepted dynamically and not checked. Any unused arguments will raise a ValueError.

Both encoder and decoder may be passed as strings, in which case they will be looked in the respective registry, or as nn.Modules, in which case they will be used as is. In the second case, the factory assumes in good faith that the encoder or decoder which is passed conforms to the expected contract.

Not all decoders will readily accept the raw output of the given encoder. This is where necks come in. Necks are a sequence of operations which are applied to the output of the encoder before it is passed to the decoder. They must be instances of Neck, which is a subclass of nn.Module, meaning they can even define new trainable parameters.

The EncoderDecoderFactory returns a PixelWiseModel or a ScalarOutputModel depending on the task.

terratorch.models.encoder_decoder_factory.EncoderDecoderFactory

Bases: ModelFactory

Source code in terratorch/models/encoder_decoder_factory.py
@MODEL_FACTORY_REGISTRY.register
class EncoderDecoderFactory(ModelFactory):
    def build_model(
        self,
        task: str,
        backbone: str | nn.Module,
        decoder: str | nn.Module,
        num_classes: int | None = None,
        necks: list[dict] | None = None,
        aux_decoders: list[AuxiliaryHead] | None = None,
        rescale: bool = True,  # noqa: FBT002, FBT001
        **kwargs,
    ) -> Model:
        """Generic model factory that combines an encoder and decoder, together with a head, for a specific task.

        Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
        `backbone_`, `decoder_` and `head_` respectively.

        Args:
            task (str): Task to be performed. Currently supports "segmentation" and "regression".
            backbone (str, nn.Module): Backbone to be used. If a string, will look for such models in the different
                registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it
                directly. The backbone should have and `out_channels` attribute and its `forward` should return a list[Tensor].
            decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
                    If a string, will look for such decoders in the different
                    registries supported (internal terratorch registry, smp, ...).
                    If an nn.Module, we expect it to expose a property `decoder.out_channels`.
                    Pixel wise tasks will be concatenated with a Conv2d for the final convolution.
                    Defaults to "FCNDecoder".
            num_classes (int, optional): Number of classes. None for regression tasks.
            necks (list[dict]): nn.Modules to be called in succession on encoder features
                before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry.
                Expects each one to have a key "name" and subsequent keys for arguments, if any.
                Defaults to None, which applies the identity function.
            aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead decoders to be added to the model.
                These decoders take the input from the encoder as well.
            rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
                is different from the ground truth. Only applicable to pixel wise models
                (e.g. segmentation, pixel wise regression). Defaults to True.


        Returns:
            nn.Module: Full model with encoder, decoder and head.
        """
        task = task.lower()
        if task not in SUPPORTED_TASKS:
            msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
            raise NotImplementedError(msg)

        backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
        backbone = _get_backbone(backbone, **backbone_kwargs)

        try:
            out_channels = backbone.out_channels
        except AttributeError as e:
            msg = "backbone must have out_channels attribute"
            raise AttributeError(msg) from e

        if necks is None:
            necks = []
        neck_list, channel_list = build_neck_list(necks, out_channels)

        # some decoders already include a head
        # for these, we pass the num_classes to them
        # others dont include a head
        # for those, we dont pass num_classes
        decoder_kwargs, kwargs = extract_prefix_keys(kwargs, "decoder_")
        head_kwargs, kwargs = extract_prefix_keys(kwargs, "head_")

        decoder, head_kwargs, decoder_includes_head = _get_decoder_and_head_kwargs(
            decoder, channel_list, decoder_kwargs, head_kwargs, num_classes=num_classes
        )

        if aux_decoders is None:
            _check_all_args_used(kwargs)
            return _build_appropriate_model(task, backbone, decoder, head_kwargs, necks=neck_list, decoder_includes_head=decoder_includes_head, rescale=rescale)

        to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
        for aux_decoder in aux_decoders:
            args = aux_decoder.decoder_args if aux_decoder.decoder_args else {}
            aux_decoder_kwargs, args = extract_prefix_keys(args, "decoder_")
            aux_head_kwargs, args = extract_prefix_keys(args, "head_")
            aux_decoder_instance, aux_head_kwargs, aux_decoder_includes_head = _get_decoder_and_head_kwargs(
                aux_decoder.decoder, channel_list, aux_decoder_kwargs, aux_head_kwargs, num_classes=num_classes
            )
            to_be_aux_decoders.append(
                AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
            )
            _check_all_args_used(args)

        _check_all_args_used(kwargs)

        return _build_appropriate_model(
            task,
            backbone,
            decoder,
            head_kwargs,
            necks=neck_list,
            decoder_includes_head=decoder_includes_head,
            rescale=rescale,
            auxiliary_heads=to_be_aux_decoders,
        )

build_model(task, backbone, decoder, num_classes=None, necks=None, aux_decoders=None, rescale=True, **kwargs)

Generic model factory that combines an encoder and decoder, together with a head, for a specific task.

Further arguments to be passed to the backbone, decoder or head. They should be prefixed with backbone_, decoder_ and head_ respectively.

Parameters:

Name Type Description Default
task str

Task to be performed. Currently supports "segmentation" and "regression".

required
backbone (str, Module)

Backbone to be used. If a string, will look for such models in the different registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it directly. The backbone should have and out_channels attribute and its forward should return a list[Tensor].

required
decoder Union[str, Module]

Decoder to be used for the segmentation model. If a string, will look for such decoders in the different registries supported (internal terratorch registry, smp, ...). If an nn.Module, we expect it to expose a property decoder.out_channels. Pixel wise tasks will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".

required
num_classes int

Number of classes. None for regression tasks.

None
necks list[dict]

nn.Modules to be called in succession on encoder features before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry. Expects each one to have a key "name" and subsequent keys for arguments, if any. Defaults to None, which applies the identity function.

None
aux_decoders list[AuxiliaryHead] | None

List of AuxiliaryHead decoders to be added to the model. These decoders take the input from the encoder as well.

None
rescale bool

Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

True

Returns:

Type Description
Model

nn.Module: Full model with encoder, decoder and head.

Source code in terratorch/models/encoder_decoder_factory.py
def build_model(
    self,
    task: str,
    backbone: str | nn.Module,
    decoder: str | nn.Module,
    num_classes: int | None = None,
    necks: list[dict] | None = None,
    aux_decoders: list[AuxiliaryHead] | None = None,
    rescale: bool = True,  # noqa: FBT002, FBT001
    **kwargs,
) -> Model:
    """Generic model factory that combines an encoder and decoder, together with a head, for a specific task.

    Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
    `backbone_`, `decoder_` and `head_` respectively.

    Args:
        task (str): Task to be performed. Currently supports "segmentation" and "regression".
        backbone (str, nn.Module): Backbone to be used. If a string, will look for such models in the different
            registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it
            directly. The backbone should have and `out_channels` attribute and its `forward` should return a list[Tensor].
        decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
                If a string, will look for such decoders in the different
                registries supported (internal terratorch registry, smp, ...).
                If an nn.Module, we expect it to expose a property `decoder.out_channels`.
                Pixel wise tasks will be concatenated with a Conv2d for the final convolution.
                Defaults to "FCNDecoder".
        num_classes (int, optional): Number of classes. None for regression tasks.
        necks (list[dict]): nn.Modules to be called in succession on encoder features
            before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry.
            Expects each one to have a key "name" and subsequent keys for arguments, if any.
            Defaults to None, which applies the identity function.
        aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead decoders to be added to the model.
            These decoders take the input from the encoder as well.
        rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
            is different from the ground truth. Only applicable to pixel wise models
            (e.g. segmentation, pixel wise regression). Defaults to True.


    Returns:
        nn.Module: Full model with encoder, decoder and head.
    """
    task = task.lower()
    if task not in SUPPORTED_TASKS:
        msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
        raise NotImplementedError(msg)

    backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
    backbone = _get_backbone(backbone, **backbone_kwargs)

    try:
        out_channels = backbone.out_channels
    except AttributeError as e:
        msg = "backbone must have out_channels attribute"
        raise AttributeError(msg) from e

    if necks is None:
        necks = []
    neck_list, channel_list = build_neck_list(necks, out_channels)

    # some decoders already include a head
    # for these, we pass the num_classes to them
    # others dont include a head
    # for those, we dont pass num_classes
    decoder_kwargs, kwargs = extract_prefix_keys(kwargs, "decoder_")
    head_kwargs, kwargs = extract_prefix_keys(kwargs, "head_")

    decoder, head_kwargs, decoder_includes_head = _get_decoder_and_head_kwargs(
        decoder, channel_list, decoder_kwargs, head_kwargs, num_classes=num_classes
    )

    if aux_decoders is None:
        _check_all_args_used(kwargs)
        return _build_appropriate_model(task, backbone, decoder, head_kwargs, necks=neck_list, decoder_includes_head=decoder_includes_head, rescale=rescale)

    to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
    for aux_decoder in aux_decoders:
        args = aux_decoder.decoder_args if aux_decoder.decoder_args else {}
        aux_decoder_kwargs, args = extract_prefix_keys(args, "decoder_")
        aux_head_kwargs, args = extract_prefix_keys(args, "head_")
        aux_decoder_instance, aux_head_kwargs, aux_decoder_includes_head = _get_decoder_and_head_kwargs(
            aux_decoder.decoder, channel_list, aux_decoder_kwargs, aux_head_kwargs, num_classes=num_classes
        )
        to_be_aux_decoders.append(
            AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
        )
        _check_all_args_used(args)

    _check_all_args_used(kwargs)

    return _build_appropriate_model(
        task,
        backbone,
        decoder,
        head_kwargs,
        necks=neck_list,
        decoder_includes_head=decoder_includes_head,
        rescale=rescale,
        auxiliary_heads=to_be_aux_decoders,
    )

terratorch.models.pixel_wise_model.PixelWiseModel

terratorch.models.scalar_output_model.ScalarOutputModel

Encoders

To be a valid encoder, an object must be an nn.Module with an additional attribute out_channels which is a list of the channel dimension of the features it returns.

It's forward method should return a list of torch.Tensor.

Necks

Necks are the glue between encoder and decoder. They can perform operations such as selecting elements from the output of the encoder (SelectIndices), reshaping the outputs of ViTs so they are compatible with CNNs (ReshapeTokensToImage), amongst others.

Necks are nn.Modules, with an additional method process_channel_list which informs the EncoderDecoderFactory about how it will alter the channel list provided by encoder.out_channels.

terratorch.models.necks.Neck

Bases: ABC, Module

Base class for Neck

A neck must must implement self.process_channel_list which returns the new channel list.

Source code in terratorch/models/necks.py
class Neck(ABC, nn.Module):
    """Base class for Neck

    A neck must must implement `self.process_channel_list` which returns the new channel list.
    """

    def __init__(self, channel_list: list[int]) -> None:
        super().__init__()
        self.channel_list = channel_list

    @abstractmethod
    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        return channel_list

    @abstractmethod
    def forward(self, channel_list: list[torch.Tensor]) -> list[torch.Tensor]: ...

terratorch.models.necks.SelectIndices

Bases: Neck

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class SelectIndices(Neck):
    def __init__(self, channel_list: list[int], indices: list[int]):
        """Select indices from the embedding list

        Args:
            indices (list[int]): list of indices to select.
        """
        super().__init__(channel_list)
        self.indices = indices

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        features = [features[i] for i in self.indices]
        return features

    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        channel_list = [channel_list[i] for i in self.indices]
        return channel_list

__init__(channel_list, indices)

Select indices from the embedding list

Parameters:

Name Type Description Default
indices list[int]

list of indices to select.

required
Source code in terratorch/models/necks.py
def __init__(self, channel_list: list[int], indices: list[int]):
    """Select indices from the embedding list

    Args:
        indices (list[int]): list of indices to select.
    """
    super().__init__(channel_list)
    self.indices = indices

terratorch.models.necks.PermuteDims

Bases: Neck

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class PermuteDims(Neck):
    def __init__(self, channel_list: list[int], new_order: list[int]):
        """Permute dimensions of each element in the embedding list

        Args:
            new_order (list[int]): list of indices to be passed to tensor.permute()
        """
        super().__init__(channel_list)
        self.new_order = new_order

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        features = [feat.permute(*self.new_order).contiguous() for feat in features]
        return features

    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        return super().process_channel_list(channel_list)

__init__(channel_list, new_order)

Permute dimensions of each element in the embedding list

Parameters:

Name Type Description Default
new_order list[int]

list of indices to be passed to tensor.permute()

required
Source code in terratorch/models/necks.py
def __init__(self, channel_list: list[int], new_order: list[int]):
    """Permute dimensions of each element in the embedding list

    Args:
        new_order (list[int]): list of indices to be passed to tensor.permute()
    """
    super().__init__(channel_list)
    self.new_order = new_order

terratorch.models.necks.InterpolateToPyramidal

Bases: Neck

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class InterpolateToPyramidal(Neck):
    def __init__(self, channel_list: list[int], scale_factor: int = 2, mode: str = "nearest"):
        """Spatially interpolate embeddings so that embedding[i - 1] is scale_factor times larger than embedding[i]

        Useful to make non-pyramidal backbones compatible with hierarachical ones
        Args:
            scale_factor (int): Amount to scale embeddings by each layer. Defaults to 2.
            mode (str): Interpolation mode to be passed to torch.nn.functional.interpolate. Defaults to 'nearest'.
        """
        super().__init__(channel_list)
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        out = []
        scale_exponents = list(range(len(features), 0, -1))
        for x, exponent in zip(features, scale_exponents, strict=True):
            out.append(F.interpolate(x, scale_factor=self.scale_factor**exponent, mode=self.mode))

        return out

    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        return super().process_channel_list(channel_list)

__init__(channel_list, scale_factor=2, mode='nearest')

Spatially interpolate embeddings so that embedding[i - 1] is scale_factor times larger than embedding[i]

Useful to make non-pyramidal backbones compatible with hierarachical ones Args: scale_factor (int): Amount to scale embeddings by each layer. Defaults to 2. mode (str): Interpolation mode to be passed to torch.nn.functional.interpolate. Defaults to 'nearest'.

Source code in terratorch/models/necks.py
def __init__(self, channel_list: list[int], scale_factor: int = 2, mode: str = "nearest"):
    """Spatially interpolate embeddings so that embedding[i - 1] is scale_factor times larger than embedding[i]

    Useful to make non-pyramidal backbones compatible with hierarachical ones
    Args:
        scale_factor (int): Amount to scale embeddings by each layer. Defaults to 2.
        mode (str): Interpolation mode to be passed to torch.nn.functional.interpolate. Defaults to 'nearest'.
    """
    super().__init__(channel_list)
    self.scale_factor = scale_factor
    self.mode = mode

terratorch.models.necks.MaxpoolToPyramidal

Bases: Neck

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class MaxpoolToPyramidal(Neck):
    def __init__(self, channel_list: list[int], kernel_size: int = 2):
        """Spatially downsample embeddings so that embedding[i - 1] is scale_factor times smaller than embedding[i]

        Useful to make non-pyramidal backbones compatible with hierarachical ones
        Args:
            kernel_size (int). Base kernel size to use for maxpool. Defaults to 2.
        """
        super().__init__(channel_list)
        self.kernel_size = kernel_size

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        out = []
        scale_exponents = list(range(len(features)))
        for x, exponent in zip(features, scale_exponents, strict=True):
            if exponent == 0:
                out.append(x.clone())
            else:
                out.append(F.max_pool2d(x, kernel_size=self.kernel_size**exponent))

        return out

    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        return super().process_channel_list(channel_list)

__init__(channel_list, kernel_size=2)

Spatially downsample embeddings so that embedding[i - 1] is scale_factor times smaller than embedding[i]

Useful to make non-pyramidal backbones compatible with hierarachical ones Args: kernel_size (int). Base kernel size to use for maxpool. Defaults to 2.

Source code in terratorch/models/necks.py
def __init__(self, channel_list: list[int], kernel_size: int = 2):
    """Spatially downsample embeddings so that embedding[i - 1] is scale_factor times smaller than embedding[i]

    Useful to make non-pyramidal backbones compatible with hierarachical ones
    Args:
        kernel_size (int). Base kernel size to use for maxpool. Defaults to 2.
    """
    super().__init__(channel_list)
    self.kernel_size = kernel_size

terratorch.models.necks.ReshapeTokensToImage

Bases: Neck

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class ReshapeTokensToImage(Neck):
    def __init__(self, channel_list: list[int], remove_cls_token=True, effective_time_dim: int = 1):  # noqa: FBT002
        """Reshape output of transformer encoder so it can be passed to a conv net.

        Args:
            remove_cls_token (bool, optional): Whether to remove the cls token from the first position.
                Defaults to True.
            effective_time_dim (int, optional): The effective temporal dimension the transformer processes.
                For a ViT, his will be given by `num_frames // tubelet size`. This is used to determine
                the temporal dimension of the embedding, which is concatenated with the embedding dimension.
                For example:
                - A model which processes 1 frame with a tubelet size of 1 has an effective_time_dim of 1.
                    The embedding produced by this model has embedding size embed_dim * 1.
                - A model which processes 3 frames with a tubelet size of 1 has an effective_time_dim of 3.
                    The embedding produced by this model has embedding size embed_dim * 3.
                - A model which processes 12 frames with a tubelet size of 4 has an effective_time_dim of 3.
                    The embedding produced by this model has an embedding size embed_dim * 3.
                Defaults to 1.
        """
        super().__init__(channel_list)
        self.remove_cls_token = remove_cls_token
        self.effective_time_dim = effective_time_dim

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        out = []
        for x in features:
            if self.remove_cls_token:
                x_no_token = x[:, 1:, :]
            else:
                x_no_token = x
            number_of_tokens = x_no_token.shape[1]
            tokens_per_timestep = number_of_tokens // self.effective_time_dim
            h = int(np.sqrt(tokens_per_timestep))
            encoded = rearrange(
                x_no_token,
                "batch (t h w) e -> batch (t e) h w",
                batch=x_no_token.shape[0],
                t=self.effective_time_dim,
                h=h,
            )
            out.append(encoded)
        return out

    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        return super().process_channel_list(channel_list)

__init__(channel_list, remove_cls_token=True, effective_time_dim=1)

Reshape output of transformer encoder so it can be passed to a conv net.

Parameters:

Name Type Description Default
remove_cls_token bool

Whether to remove the cls token from the first position. Defaults to True.

True
effective_time_dim int

The effective temporal dimension the transformer processes. For a ViT, his will be given by num_frames // tubelet size. This is used to determine the temporal dimension of the embedding, which is concatenated with the embedding dimension. For example: - A model which processes 1 frame with a tubelet size of 1 has an effective_time_dim of 1. The embedding produced by this model has embedding size embed_dim * 1. - A model which processes 3 frames with a tubelet size of 1 has an effective_time_dim of 3. The embedding produced by this model has embedding size embed_dim * 3. - A model which processes 12 frames with a tubelet size of 4 has an effective_time_dim of 3. The embedding produced by this model has an embedding size embed_dim * 3. Defaults to 1.

1
Source code in terratorch/models/necks.py
def __init__(self, channel_list: list[int], remove_cls_token=True, effective_time_dim: int = 1):  # noqa: FBT002
    """Reshape output of transformer encoder so it can be passed to a conv net.

    Args:
        remove_cls_token (bool, optional): Whether to remove the cls token from the first position.
            Defaults to True.
        effective_time_dim (int, optional): The effective temporal dimension the transformer processes.
            For a ViT, his will be given by `num_frames // tubelet size`. This is used to determine
            the temporal dimension of the embedding, which is concatenated with the embedding dimension.
            For example:
            - A model which processes 1 frame with a tubelet size of 1 has an effective_time_dim of 1.
                The embedding produced by this model has embedding size embed_dim * 1.
            - A model which processes 3 frames with a tubelet size of 1 has an effective_time_dim of 3.
                The embedding produced by this model has embedding size embed_dim * 3.
            - A model which processes 12 frames with a tubelet size of 4 has an effective_time_dim of 3.
                The embedding produced by this model has an embedding size embed_dim * 3.
            Defaults to 1.
    """
    super().__init__(channel_list)
    self.remove_cls_token = remove_cls_token
    self.effective_time_dim = effective_time_dim

terratorch.models.necks.AddBottleneckLayer

Bases: Neck

Add a layer that reduces the channel dimension of the final embedding by half, and concatenates it

Useful for compatibility with some smp decoders.

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class AddBottleneckLayer(Neck):
    """Add a layer that reduces the channel dimension of the final embedding by half, and concatenates it

    Useful for compatibility with some smp decoders.
    """

    def __init__(self, channel_list: list[int]):
        super().__init__(channel_list)
        self.bottleneck = nn.Conv2d(channel_list[-1], channel_list[-1]//2, kernel_size=1)

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        new_embedding = self.bottleneck(features[-1])
        features.append(new_embedding)
        return features

    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        return [*channel_list, channel_list[-1] // 2]

terratorch.models.necks.LearnedInterpolateToPyramidal

Bases: Neck

Use learned convolutions to transform the output of a non-pyramidal encoder into pyramidal ones

Always requires exactly 4 embeddings

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class LearnedInterpolateToPyramidal(Neck):
    """Use learned convolutions to transform the output of a non-pyramidal encoder into pyramidal ones

    Always requires exactly 4 embeddings
    """

    def __init__(self, channel_list: list[int]):
        super().__init__(channel_list)
        if len(channel_list) != 4:
            msg = "This class can only handle exactly 4 input embeddings"
            raise Exception(msg)
        self.fpn1 = nn.Sequential(
            nn.ConvTranspose2d(channel_list[0], channel_list[0] // 2, 2, 2),
            nn.BatchNorm2d(channel_list[0] // 2),
            nn.GELU(),
            nn.ConvTranspose2d(channel_list[0] // 2, channel_list[0] // 4, 2, 2),
        )
        self.fpn2 = nn.Sequential(nn.ConvTranspose2d(channel_list[1], channel_list[1] // 2, 2, 2))
        self.fpn3 = nn.Sequential(nn.Identity())
        self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
        self.embedding_dim = [channel_list[0] // 4, channel_list[1] // 2, channel_list[2], channel_list[3]]

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        scaled_inputs = []
        scaled_inputs.append(self.fpn1(features[0]))
        scaled_inputs.append(self.fpn2(features[1]))
        scaled_inputs.append(self.fpn3(features[2]))
        scaled_inputs.append(self.fpn4(features[3]))
        return scaled_inputs

    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        return [channel_list[0] // 4, channel_list[1] // 2, channel_list[2], channel_list[3]]

Decoders

To be a valid decoder, an object must be an nn.Module with an additional attribute out_channels which is an int with the channel dimension of the output.

The first argument to its constructor will be a list of channel dimensions it should expect as input.

It's forward method should accept a list of embeddings.

Heads

Most decoders require a final head to be added for a specific task (e.g. semantic segmentation vs pixel wise regression).

Those registries producing decoders that dont require a head must expose the attribute includes_head=True so that a head is not added. Decoders passed as nn.Modules which dont require a head must expose the same attribute themselves.

terratorch.models.heads.classification_head.ClassificationHead

Bases: Module

Classification head

Source code in terratorch/models/heads/classification_head.py
class ClassificationHead(nn.Module):
    """Classification head"""

    # how to allow cls token?
    def __init__(
        self,
        in_dim: int,
        num_classes: int,
        dim_list: list[int] | None = None,
        dropout: float = 0,
        linear_after_pool: bool = False,
    ) -> None:
        """Constructor

        Args:
            in_dim (int): Input dimensionality
            num_classes (int): Number of output classes
            dim_list (list[int] | None, optional):  List with number of dimensions for each Linear
                layer to be created. Defaults to None.
            dropout (float, optional): Dropout value to apply. Defaults to 0.
            linear_after_pool (bool, optional): Apply pooling first, then apply the linear layer. Defaults to False
        """
        super().__init__()
        self.num_classes = num_classes
        self.linear_after_pool = linear_after_pool
        if dim_list is None:
            pre_head = nn.Identity()
        else:

            def block(in_dim, out_dim):
                return nn.Sequential(nn.Linear(in_features=in_dim, out_features=out_dim), nn.ReLU())

            dim_list = [in_dim, *dim_list]
            pre_head = nn.Sequential(*[block(dim_list[i], dim_list[i + 1]) for i in range(len(dim_list) - 1)])
            in_dim = dim_list[-1]
        dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
        self.head = nn.Sequential(
            pre_head,
            dropout,
            nn.Linear(in_features=in_dim, out_features=num_classes),
        )

    def forward(self, x: Tensor):
        x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)

        if self.linear_after_pool:
            x = x.mean(axis=1)
            out = self.head(x)
        else:
            x = self.head(x)
            out = x.mean(axis=1)
        return out

__init__(in_dim, num_classes, dim_list=None, dropout=0, linear_after_pool=False)

Constructor

Parameters:

Name Type Description Default
in_dim int

Input dimensionality

required
num_classes int

Number of output classes

required
dim_list list[int] | None

List with number of dimensions for each Linear layer to be created. Defaults to None.

None
dropout float

Dropout value to apply. Defaults to 0.

0
linear_after_pool bool

Apply pooling first, then apply the linear layer. Defaults to False

False
Source code in terratorch/models/heads/classification_head.py
def __init__(
    self,
    in_dim: int,
    num_classes: int,
    dim_list: list[int] | None = None,
    dropout: float = 0,
    linear_after_pool: bool = False,
) -> None:
    """Constructor

    Args:
        in_dim (int): Input dimensionality
        num_classes (int): Number of output classes
        dim_list (list[int] | None, optional):  List with number of dimensions for each Linear
            layer to be created. Defaults to None.
        dropout (float, optional): Dropout value to apply. Defaults to 0.
        linear_after_pool (bool, optional): Apply pooling first, then apply the linear layer. Defaults to False
    """
    super().__init__()
    self.num_classes = num_classes
    self.linear_after_pool = linear_after_pool
    if dim_list is None:
        pre_head = nn.Identity()
    else:

        def block(in_dim, out_dim):
            return nn.Sequential(nn.Linear(in_features=in_dim, out_features=out_dim), nn.ReLU())

        dim_list = [in_dim, *dim_list]
        pre_head = nn.Sequential(*[block(dim_list[i], dim_list[i + 1]) for i in range(len(dim_list) - 1)])
        in_dim = dim_list[-1]
    dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
    self.head = nn.Sequential(
        pre_head,
        dropout,
        nn.Linear(in_features=in_dim, out_features=num_classes),
    )

terratorch.models.heads.regression_head.RegressionHead

Bases: Module

Regression head

Source code in terratorch/models/heads/regression_head.py
class RegressionHead(nn.Module):
    """Regression head"""

    def __init__(
        self,
        in_channels: int,
        final_act: nn.Module | str | None = None,
        learned_upscale_layers: int = 0,
        channel_list: list[int] | None = None,
        batch_norm: bool = True,
        dropout: float = 0,
    ) -> None:
        """Constructor

        Args:
            in_channels (int): Number of input channels
            final_act (nn.Module | None, optional): Final activation to be applied. Defaults to None.
            learned_upscale_layers (int, optional): Number of Pixelshuffle layers to create. Each upscales 2x.
                Defaults to 0.
            channel_list (list[int] | None, optional): List with number of channels for each Conv
                layer to be created. Defaults to None.
            batch_norm (bool, optional): Whether to apply batch norm. Defaults to True.
            dropout (float, optional): Dropout value to apply. Defaults to 0.

        """
        super().__init__()
        self.learned_upscale_layers = learned_upscale_layers
        self.final_act = final_act if final_act else nn.Identity()
        if isinstance(final_act, str):
            module_name, class_name = final_act.rsplit(".", 1)
            target_class = getattr(importlib.import_module(module_name), class_name)
            self.final_act = target_class()
        pre_layers = []
        if learned_upscale_layers != 0:
            learned_upscale = nn.Sequential(
                *[PixelShuffleUpscale(in_channels) for _ in range(self.learned_upscale_layers)]
            )
            pre_layers.append(learned_upscale)

        if channel_list is None:
            pre_head = nn.Identity()
        else:

            def block(in_channels, out_channels):
                return nn.Sequential(
                    nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True),
                )

            channel_list = [in_channels, *channel_list]
            pre_head = nn.Sequential(
                *[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
            )
            in_channels = channel_list[-1]
            pre_layers.append(pre_head)
        dropout = nn.Dropout2d(dropout)
        final_layer = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1)
        self.head = nn.Sequential(*[*pre_layers, dropout, final_layer])

    def forward(self, x):
        output = self.head(x)
        return self.final_act(output)

__init__(in_channels, final_act=None, learned_upscale_layers=0, channel_list=None, batch_norm=True, dropout=0)

Constructor

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
final_act Module | None

Final activation to be applied. Defaults to None.

None
learned_upscale_layers int

Number of Pixelshuffle layers to create. Each upscales 2x. Defaults to 0.

0
channel_list list[int] | None

List with number of channels for each Conv layer to be created. Defaults to None.

None
batch_norm bool

Whether to apply batch norm. Defaults to True.

True
dropout float

Dropout value to apply. Defaults to 0.

0
Source code in terratorch/models/heads/regression_head.py
def __init__(
    self,
    in_channels: int,
    final_act: nn.Module | str | None = None,
    learned_upscale_layers: int = 0,
    channel_list: list[int] | None = None,
    batch_norm: bool = True,
    dropout: float = 0,
) -> None:
    """Constructor

    Args:
        in_channels (int): Number of input channels
        final_act (nn.Module | None, optional): Final activation to be applied. Defaults to None.
        learned_upscale_layers (int, optional): Number of Pixelshuffle layers to create. Each upscales 2x.
            Defaults to 0.
        channel_list (list[int] | None, optional): List with number of channels for each Conv
            layer to be created. Defaults to None.
        batch_norm (bool, optional): Whether to apply batch norm. Defaults to True.
        dropout (float, optional): Dropout value to apply. Defaults to 0.

    """
    super().__init__()
    self.learned_upscale_layers = learned_upscale_layers
    self.final_act = final_act if final_act else nn.Identity()
    if isinstance(final_act, str):
        module_name, class_name = final_act.rsplit(".", 1)
        target_class = getattr(importlib.import_module(module_name), class_name)
        self.final_act = target_class()
    pre_layers = []
    if learned_upscale_layers != 0:
        learned_upscale = nn.Sequential(
            *[PixelShuffleUpscale(in_channels) for _ in range(self.learned_upscale_layers)]
        )
        pre_layers.append(learned_upscale)

    if channel_list is None:
        pre_head = nn.Identity()
    else:

        def block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )

        channel_list = [in_channels, *channel_list]
        pre_head = nn.Sequential(
            *[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
        )
        in_channels = channel_list[-1]
        pre_layers.append(pre_head)
    dropout = nn.Dropout2d(dropout)
    final_layer = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1)
    self.head = nn.Sequential(*[*pre_layers, dropout, final_layer])

terratorch.models.heads.segmentation_head.SegmentationHead

Bases: Module

Segmentation head

Source code in terratorch/models/heads/segmentation_head.py
class SegmentationHead(nn.Module):
    """Segmentation head"""

    def __init__(
        self, in_channels: int, num_classes: int, channel_list: list[int] | None = None, dropout: float = 0
    ) -> None:
        """Constructor

        Args:
            in_channels (int): Number of input channels
            num_classes (int): Number of output classes
            channel_list (list[int] | None, optional):  List with number of channels for each Conv
                layer to be created. Defaults to None.
            dropout (float, optional): Dropout value to apply. Defaults to 0.
        """
        super().__init__()
        self.num_classes = num_classes
        if channel_list is None:
            pre_head = nn.Identity()
        else:

            def block(in_channels, out_channels):
                return nn.Sequential(
                    nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1), nn.ReLU()
                )

            channel_list = [in_channels, *channel_list]
            pre_head = nn.Sequential(
                *[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
            )
            in_channels = channel_list[-1]
        dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
        self.head = nn.Sequential(
            pre_head,
            dropout,
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=num_classes,
                kernel_size=1,
            ),
        )

    def forward(self, x):
        return self.head(x)

__init__(in_channels, num_classes, channel_list=None, dropout=0)

Constructor

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
num_classes int

Number of output classes

required
channel_list list[int] | None

List with number of channels for each Conv layer to be created. Defaults to None.

None
dropout float

Dropout value to apply. Defaults to 0.

0
Source code in terratorch/models/heads/segmentation_head.py
def __init__(
    self, in_channels: int, num_classes: int, channel_list: list[int] | None = None, dropout: float = 0
) -> None:
    """Constructor

    Args:
        in_channels (int): Number of input channels
        num_classes (int): Number of output classes
        channel_list (list[int] | None, optional):  List with number of channels for each Conv
            layer to be created. Defaults to None.
        dropout (float, optional): Dropout value to apply. Defaults to 0.
    """
    super().__init__()
    self.num_classes = num_classes
    if channel_list is None:
        pre_head = nn.Identity()
    else:

        def block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1), nn.ReLU()
            )

        channel_list = [in_channels, *channel_list]
        pre_head = nn.Sequential(
            *[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
        )
        in_channels = channel_list[-1]
    dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
    self.head = nn.Sequential(
        pre_head,
        dropout,
        nn.Conv2d(
            in_channels=in_channels,
            out_channels=num_classes,
            kernel_size=1,
        ),
    )

Decoder compatibilities

Not all encoders and decoders are compatible. Below we include some caveats.

Some decoders expect pyramidal outputs, but some encoders do not produce such outputs (e.g. vanilla ViT models). In this case, the InterpolateToPyramidal, MaxpoolToPyramidal and LearnedInterpolateToPyramidal necks may be particularly useful.

SMP decoders

Not all decoders are guaranteed to work with all encoders without additional necks. Please check smp documentation to understand the embedding spatial dimensions expected by each decoder.

In particular, smp seems to assume the first feature in the passed feature list has the same spatial resolution as the input, which may not always be true, and may break some decoders.

In addition, for some decoders, the final 2 features have the same spatial resolution. Adding the AddBottleneckLayer neck will make this compatible.

Some smp decoders require additional parameters, such as decoder_channels. These must be passed through the factory. In the case of decoder_channels, it would be passed as decoder_decoder_channels (the first decoder_ routes the parameter to the decoder, where it is passed as decoder_channels).

MMSegmentation decoders

MMSegmentation decoders are available through the BACKBONE_REGISTRY.

Warning

MMSegmentation currently requires mmcv==2.1.0. Pre-built wheels for this only exist for torch==2.1.0. In order to use mmseg without building from source, you must downgrade your torch to this version. Install mmseg with:

pip install -U openmim
mim install mmengine
mim install mmcv==2.1.0
pip install regex ftfy mmsegmentation

We provide access to mmseg decoders as an external source of decoders, but are not directly responsible for the maintainence of that library.

Some mmseg decoders require the parameter in_index, which performs the same function as the SelectIndices neck. For use for pixel wise regression, mmseg decoders should take num_classes=1.