Models

To interface with terratorch tasks correctly, models must conform to the Model ABC:

terratorch.models.model.Model

Bases: ABC, Module

Source code in terratorch/models/model.py
class Model(ABC, nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
    @abstractmethod
    def freeze_encoder(self):
        pass

    @abstractmethod
    def freeze_decoder(self):
        pass

    @abstractmethod
    def forward(self, *args, **kwargs) -> ModelOutput:
        pass

and have a forward method which returns a ModelOutput:

terratorch.models.model.ModelOutput dataclass

Source code in terratorch/models/model.py
@dataclass
class ModelOutput:
    output: Tensor
    auxiliary_heads: dict[str, Tensor] = None

Model Factories

In order to be used by tasks, models must have a Model Factory which builds them.

Factories must conform to the ModelFactory ABC:

terratorch.models.model.ModelFactory

Bases: Protocol

Source code in terratorch/models/model.py
class ModelFactory(typing.Protocol):
    def build_model(self, *args, **kwargs) -> Model:...

You most likely do not need to implement your own model factory, unless you are wrapping another library which generates full models.

For most cases, the encoder decoder factory can be used to combine a backbone with a decoder.

To add new backbones or decoders, to be used with the encoder decoder factory they should be registered.

To add a new model factory, it should be registered in the MODEL_FACTORY_REGISTRY.

Adding a new model

To add a new backbone, simply create a class and annotate it (or a constructor function that instantiates it) with @TERRATORCH_BACKBONE_FACTORY.register.

The model will be registered with the same name as the function. To create many model variants from the same class, the reccomended approach is to annotate a constructor function from each with a fully descriptive name.

from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY, BACKBONE_REGISTRY

from torch import nn

# make sure this is in the import path for terratorch
@TERRATORCH_BACKBONE_REGISTRY.register
class BasicBackbone(nn.Module):
    def __init__(self, out_channels=64):
        super().__init__()
        self.flatten = nn.Flatten()
        self.layer = nn.Linear(224*224, out_channels)
        self.out_channels = [out_channels]

    def forward(self, x):
        return self.layer(self.flatten(x))

# you can build directly with the TERRATORCH_BACKBONE_REGISTRY
# but typically this will be accessed from the BACKBONE_REGISTRY
>>> BACKBONE_REGISTRY.build("BasicBackbone", out_channels=64)
BasicBackbone(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layer): Linear(in_features=50176, out_features=64, bias=True)
)

@TERRATORCH_BACKBONE_REGISTRY.register
def basic_backbone_128():
    return BasicBackbone(out_channels=128)

>>> BACKBONE_REGISTRY.build("basic_backbone_128")
BasicBackbone(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layer): Linear(in_features=50176, out_features=128, bias=True)
)

Adding a new decoder can be done in the same way with the TERRATORCH_DECODER_REGISTRY.

Info

All decoders will be passed the channel_list as the first argument for initialization.

To pass your own path from where to load the weights with the PrithviModelFactory, you can make use of timm's pretrained_cfg_overlay. E.g. to pass a local path, you can pass the parameter backbone_pretrained_cfg_overlay = {"file": "<local_path>"} to the model factory.

Besides file, you can also pass url, hf_hub_id, amongst others. Check timm's documentation for full details.

terratorch.models.backbones.select_patch_embed_weights

select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands)

Filter out the patch embedding weights according to the bands being used. If a band exists in the pretrained_bands, but not in model_bands, drop it. If a band exists in model_bands, but not pretrained_bands, randomly initialize those weights.

Parameters:
  • state_dict (dict) –

    State Dict

  • model (Module) –

    Model to load the weights onto.

  • pretrained_bands (list[HLSBands | int]) –

    List of bands the model was pretrained on, in the correct order.

  • model_bands (list[HLSBands | int]) –

    List of bands the model is going to be finetuned on, in the correct order

Returns:
  • dict( dict ) –

    New state dict

Source code in terratorch/models/backbones/select_patch_embed_weights.py
def select_patch_embed_weights(
    state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int], model_bands: list[HLSBands | int]
) -> dict:
    """Filter out the patch embedding weights according to the bands being used.
    If a band exists in the pretrained_bands, but not in model_bands, drop it.
    If a band exists in model_bands, but not pretrained_bands, randomly initialize those weights.


    Args:
        state_dict (dict): State Dict
        model (nn.Module): Model to load the weights onto.
        pretrained_bands (list[HLSBands | int]): List of bands the model was pretrained on, in the correct order.
        model_bands (list[HLSBands | int]): List of bands the model is going to be finetuned on, in the correct order

    Returns:
        dict: New state dict
    """
    _possible_keys_for_proj_weight = {
        "encoder.patch_embed.proj.weight",
        "patch_embed.proj.weight",
        "module.patch_embed.proj.weight",
        "patch_embed.projection.weight",
        "module.patch_embed.projection.weight",
    }
    patch_embed_proj_weight_key = state_dict.keys() & _possible_keys_for_proj_weight
    if len(patch_embed_proj_weight_key) == 0:
        msg = "Could not find key for patch embed weight"
        raise Exception(msg)
    if len(patch_embed_proj_weight_key) > 1:
        msg = "Too many matches for key for patch embed weight"
        raise Exception(msg)

    # extract the single element from the set
    (patch_embed_proj_weight_key,) = patch_embed_proj_weight_key
    patch_embed_weight = state_dict[patch_embed_proj_weight_key]

    temp_weight = model.state_dict()[patch_embed_proj_weight_key].clone()

    # only do this if the patch size and tubelet size match. If not, start with random weights
    if patch_embed_weights_are_compatible(temp_weight, patch_embed_weight):
        torch.nn.init.xavier_uniform_(temp_weight.view([temp_weight.shape[0], -1]))
        for index, band in enumerate(model_bands):
            if band in pretrained_bands:
                logging.info(f"Loaded weights for {band} in position {index} of patch embed")
                temp_weight[:, index] = patch_embed_weight[:, pretrained_bands.index(band)]
    else:
        warnings.warn(
            f"Incompatible shapes between patch embedding of model {temp_weight.shape} and\
            of checkpoint {patch_embed_weight.shape}",
            category=UserWarning,
            stacklevel=1,
        )

    state_dict[patch_embed_proj_weight_key] = temp_weight
    return state_dict

Decoders

terratorch.models.decoders.fcn_decoder

FCNDecoder

Bases: Module

Fully Convolutional Decoder

Source code in terratorch/models/decoders/fcn_decoder.py
@TERRATORCH_DECODER_REGISTRY.register
class FCNDecoder(nn.Module):
    """Fully Convolutional Decoder"""

    def __init__(self, embed_dim: int, channels: int = 256, num_convs: int = 4, in_index: int = -1) -> None:
        """Constructor

        Args:
            embed_dim (_type_): Input embedding dimension
            channels (int, optional): Number of channels for each conv. Defaults to 256.
            num_convs (int, optional): Number of convs. Defaults to 4.
            in_index (int, optional): Index of the input list to take. Defaults to -1.
        """
        super().__init__()
        kernel_size = 2
        stride = 2
        dilation = 1
        padding = 0
        output_padding = 0
        self.channels = channels
        self.num_convs = num_convs
        self.in_index = in_index
        self.embed_dim = embed_dim[in_index]
        if num_convs < 1:
            msg = "num_convs must be >= 1"
            raise Exception(msg)

        convs = []

        for i in range(num_convs):
            in_channels = self.embed_dim if i == 0 else self.channels
            convs.append(
                _conv_upscale_block(in_channels, self.channels, kernel_size, stride, dilation, padding, output_padding)
            )

        self.convs = nn.Sequential(*convs)

    @property
    def out_channels(self):
        return self.channels

    def forward(self, x: list[Tensor]):
        x = x[self.in_index]
        decoded = self.convs(x)
        return decoded
__init__(embed_dim, channels=256, num_convs=4, in_index=-1)

Constructor

Parameters:
  • embed_dim (_type_) –

    Input embedding dimension

  • channels (int, default: 256 ) –

    Number of channels for each conv. Defaults to 256.

  • num_convs (int, default: 4 ) –

    Number of convs. Defaults to 4.

  • in_index (int, default: -1 ) –

    Index of the input list to take. Defaults to -1.

Source code in terratorch/models/decoders/fcn_decoder.py
def __init__(self, embed_dim: int, channels: int = 256, num_convs: int = 4, in_index: int = -1) -> None:
    """Constructor

    Args:
        embed_dim (_type_): Input embedding dimension
        channels (int, optional): Number of channels for each conv. Defaults to 256.
        num_convs (int, optional): Number of convs. Defaults to 4.
        in_index (int, optional): Index of the input list to take. Defaults to -1.
    """
    super().__init__()
    kernel_size = 2
    stride = 2
    dilation = 1
    padding = 0
    output_padding = 0
    self.channels = channels
    self.num_convs = num_convs
    self.in_index = in_index
    self.embed_dim = embed_dim[in_index]
    if num_convs < 1:
        msg = "num_convs must be >= 1"
        raise Exception(msg)

    convs = []

    for i in range(num_convs):
        in_channels = self.embed_dim if i == 0 else self.channels
        convs.append(
            _conv_upscale_block(in_channels, self.channels, kernel_size, stride, dilation, padding, output_padding)
        )

    self.convs = nn.Sequential(*convs)

terratorch.models.decoders.identity_decoder

Pass the features straight through

IdentityDecoder

Bases: Module

Identity decoder. Useful to pass the feature straight to the head.

Source code in terratorch/models/decoders/identity_decoder.py
@TERRATORCH_DECODER_REGISTRY.register
class IdentityDecoder(nn.Module):
    """Identity decoder. Useful to pass the feature straight to the head."""

    def __init__(self, embed_dim: int, out_index=-1) -> None:
        """Constructor

        Args:
            embed_dim (int): Input embedding dimension
            out_index (int, optional): Index of the input list to take.. Defaults to -1.
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.dim = out_index

    @property
    def out_channels(self):
        return self.embed_dim[self.dim]

    def forward(self, x: list[Tensor]):
        return x[self.dim]
__init__(embed_dim, out_index=-1)

Constructor

Parameters:
  • embed_dim (int) –

    Input embedding dimension

  • out_index (int, default: -1 ) –

    Index of the input list to take.. Defaults to -1.

Source code in terratorch/models/decoders/identity_decoder.py
def __init__(self, embed_dim: int, out_index=-1) -> None:
    """Constructor

    Args:
        embed_dim (int): Input embedding dimension
        out_index (int, optional): Index of the input list to take.. Defaults to -1.
    """
    super().__init__()
    self.embed_dim = embed_dim
    self.dim = out_index

terratorch.models.decoders.upernet_decoder

PPM

Bases: ModuleList

Pooling Pyramid Module used in PSPNet.

Source code in terratorch/models/decoders/upernet_decoder.py
class PPM(nn.ModuleList):
    """Pooling Pyramid Module used in PSPNet."""

    def __init__(self, pool_scales, in_channels, channels, align_corners):
        """Constructor

        Args:
            pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
                Module.
            in_channels (int): Input channels.
            channels (int): Channels after modules, before conv_seg.
            align_corners (bool): align_corners argument of F.interpolate.
        """
        super().__init__()
        self.pool_scales = pool_scales
        self.align_corners = align_corners
        self.in_channels = in_channels
        self.channels = channels

        for pool_scale in pool_scales:
            self.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(pool_scale),
                    ConvModule(self.in_channels, self.channels, 1, inplace=True),
                )
            )

    def forward(self, x):
        """Forward function."""
        ppm_outs = []
        for ppm in self:
            ppm_out = ppm(x)
            upsampled_ppm_out = torch.nn.functional.interpolate(
                ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
            )
            ppm_outs.append(upsampled_ppm_out)
        return ppm_outs
__init__(pool_scales, in_channels, channels, align_corners)

Constructor

Parameters:
  • pool_scales (tuple[int]) –

    Pooling scales used in Pooling Pyramid Module.

  • in_channels (int) –

    Input channels.

  • channels (int) –

    Channels after modules, before conv_seg.

  • align_corners (bool) –

    align_corners argument of F.interpolate.

Source code in terratorch/models/decoders/upernet_decoder.py
def __init__(self, pool_scales, in_channels, channels, align_corners):
    """Constructor

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module.
        in_channels (int): Input channels.
        channels (int): Channels after modules, before conv_seg.
        align_corners (bool): align_corners argument of F.interpolate.
    """
    super().__init__()
    self.pool_scales = pool_scales
    self.align_corners = align_corners
    self.in_channels = in_channels
    self.channels = channels

    for pool_scale in pool_scales:
        self.append(
            nn.Sequential(
                nn.AdaptiveAvgPool2d(pool_scale),
                ConvModule(self.in_channels, self.channels, 1, inplace=True),
            )
        )
forward(x)

Forward function.

Source code in terratorch/models/decoders/upernet_decoder.py
def forward(self, x):
    """Forward function."""
    ppm_outs = []
    for ppm in self:
        ppm_out = ppm(x)
        upsampled_ppm_out = torch.nn.functional.interpolate(
            ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
        )
        ppm_outs.append(upsampled_ppm_out)
    return ppm_outs

UperNetDecoder

Bases: Module

UperNetDecoder. Adapted from MMSegmentation.

Source code in terratorch/models/decoders/upernet_decoder.py
@TERRATORCH_DECODER_REGISTRY.register
class UperNetDecoder(nn.Module):
    """UperNetDecoder. Adapted from MMSegmentation."""

    def __init__(
        self,
        embed_dim: list[int],
        pool_scales: tuple[int] = (1, 2, 3, 6),
        channels: int = 256,
        align_corners: bool = True,  # noqa: FBT001, FBT002
        scale_modules: bool = False
    ):
        """Constructor

        Args:
            embed_dim (list[int]): Input embedding dimension for each input.
            pool_scales (tuple[int], optional): Pooling scales used in Pooling Pyramid
                Module applied on the last feature. Default: (1, 2, 3, 6).
            channels (int, optional): Channels used in the decoder. Defaults to 256.
            align_corners (bool, optional): Wheter to align corners in rescaling. Defaults to True.
            scale_modules (bool, optional): Whether to apply scale modules to the inputs. Needed for plain ViT.
                Defaults to False.
        """
        super().__init__()
        self.scale_modules = scale_modules
        if scale_modules:
            self.fpn1 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim[0],
                                embed_dim[0] // 2, 2, 2),
                nn.BatchNorm2d(embed_dim[0] // 2),
                nn.GELU(),
                nn.ConvTranspose2d(embed_dim[0] // 2,
                                embed_dim[0] // 4, 2, 2))
            self.fpn2 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim[1],
                                embed_dim[1] // 2, 2, 2))
            self.fpn3 = nn.Sequential(nn.Identity())
            self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
            self.embed_dim = [embed_dim[0] // 4, embed_dim[1] // 2, embed_dim[2], embed_dim[3]]
        else:
            self.embed_dim = embed_dim

        self.out_channels = channels
        self.channels = channels
        self.align_corners = align_corners
        # PSP Module
        self.psp_modules = PPM(
            pool_scales,
            self.embed_dim[-1],
            self.channels,
            align_corners=self.align_corners,
        )
        self.bottleneck = ConvModule(
            self.embed_dim[-1] + len(pool_scales) * self.channels, self.channels, 3, padding=1, inplace=True
        )
        # FPN Module
        self.lateral_convs = nn.ModuleList()
        self.fpn_convs = nn.ModuleList()
        for embed_dim in self.embed_dim[:-1]:  # skip the top layer
            l_conv = ConvModule(
                embed_dim,
                self.channels,
                1,
                inplace=False,
            )
            fpn_conv = ConvModule(
                self.channels,
                self.channels,
                3,
                padding=1,
                inplace=False,
            )
            self.lateral_convs.append(l_conv)
            self.fpn_convs.append(fpn_conv)

        self.fpn_bottleneck = ConvModule(len(self.embed_dim) * self.channels, self.channels, 3, padding=1, inplace=True)

    def psp_forward(self, inputs):
        """Forward function of PSP module."""
        x = inputs[-1]
        psp_outs = [x]
        psp_outs.extend(self.psp_modules(x))
        psp_outs = torch.cat(psp_outs, dim=1)
        output = self.bottleneck(psp_outs)

        return output

    def forward(self, inputs):
        """Forward function for feature maps before classifying each pixel with
        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            feats (Tensor): A tensor of shape (batch_size, self.channels,
                H, W) which is feature map for last layer of decoder head.
        """

        if self.scale_modules:
            scaled_inputs = []
            scaled_inputs.append(self.fpn1(inputs[0]))
            scaled_inputs.append(self.fpn2(inputs[1]))
            scaled_inputs.append(self.fpn3(inputs[2]))
            scaled_inputs.append(self.fpn4(inputs[3]))
            inputs = scaled_inputs
        # build laterals
        laterals = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
        laterals.append(self.psp_forward(inputs))

        # build top-down path
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
            prev_shape = laterals[i - 1].shape[2:]
            laterals[i - 1] = laterals[i - 1] + torch.nn.functional.interpolate(
                laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
            )

        # build outputs
        fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
        # append psp feature
        fpn_outs.append(laterals[-1])

        for i in range(used_backbone_levels - 1, 0, -1):
            fpn_outs[i] = torch.nn.functional.interpolate(
                fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
            )
        fpn_outs = torch.cat(fpn_outs, dim=1)
        feats = self.fpn_bottleneck(fpn_outs)
        return feats
__init__(embed_dim, pool_scales=(1, 2, 3, 6), channels=256, align_corners=True, scale_modules=False)

Constructor

Parameters:
  • embed_dim (list[int]) –

    Input embedding dimension for each input.

  • pool_scales (tuple[int], default: (1, 2, 3, 6) ) –

    Pooling scales used in Pooling Pyramid Module applied on the last feature. Default: (1, 2, 3, 6).

  • channels (int, default: 256 ) –

    Channels used in the decoder. Defaults to 256.

  • align_corners (bool, default: True ) –

    Wheter to align corners in rescaling. Defaults to True.

  • scale_modules (bool, default: False ) –

    Whether to apply scale modules to the inputs. Needed for plain ViT. Defaults to False.

Source code in terratorch/models/decoders/upernet_decoder.py
def __init__(
    self,
    embed_dim: list[int],
    pool_scales: tuple[int] = (1, 2, 3, 6),
    channels: int = 256,
    align_corners: bool = True,  # noqa: FBT001, FBT002
    scale_modules: bool = False
):
    """Constructor

    Args:
        embed_dim (list[int]): Input embedding dimension for each input.
        pool_scales (tuple[int], optional): Pooling scales used in Pooling Pyramid
            Module applied on the last feature. Default: (1, 2, 3, 6).
        channels (int, optional): Channels used in the decoder. Defaults to 256.
        align_corners (bool, optional): Wheter to align corners in rescaling. Defaults to True.
        scale_modules (bool, optional): Whether to apply scale modules to the inputs. Needed for plain ViT.
            Defaults to False.
    """
    super().__init__()
    self.scale_modules = scale_modules
    if scale_modules:
        self.fpn1 = nn.Sequential(
            nn.ConvTranspose2d(embed_dim[0],
                            embed_dim[0] // 2, 2, 2),
            nn.BatchNorm2d(embed_dim[0] // 2),
            nn.GELU(),
            nn.ConvTranspose2d(embed_dim[0] // 2,
                            embed_dim[0] // 4, 2, 2))
        self.fpn2 = nn.Sequential(
            nn.ConvTranspose2d(embed_dim[1],
                            embed_dim[1] // 2, 2, 2))
        self.fpn3 = nn.Sequential(nn.Identity())
        self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
        self.embed_dim = [embed_dim[0] // 4, embed_dim[1] // 2, embed_dim[2], embed_dim[3]]
    else:
        self.embed_dim = embed_dim

    self.out_channels = channels
    self.channels = channels
    self.align_corners = align_corners
    # PSP Module
    self.psp_modules = PPM(
        pool_scales,
        self.embed_dim[-1],
        self.channels,
        align_corners=self.align_corners,
    )
    self.bottleneck = ConvModule(
        self.embed_dim[-1] + len(pool_scales) * self.channels, self.channels, 3, padding=1, inplace=True
    )
    # FPN Module
    self.lateral_convs = nn.ModuleList()
    self.fpn_convs = nn.ModuleList()
    for embed_dim in self.embed_dim[:-1]:  # skip the top layer
        l_conv = ConvModule(
            embed_dim,
            self.channels,
            1,
            inplace=False,
        )
        fpn_conv = ConvModule(
            self.channels,
            self.channels,
            3,
            padding=1,
            inplace=False,
        )
        self.lateral_convs.append(l_conv)
        self.fpn_convs.append(fpn_conv)

    self.fpn_bottleneck = ConvModule(len(self.embed_dim) * self.channels, self.channels, 3, padding=1, inplace=True)
forward(inputs)

Forward function for feature maps before classifying each pixel with Args: inputs (list[Tensor]): List of multi-level img features.

Returns:
  • feats( Tensor ) –

    A tensor of shape (batch_size, self.channels, H, W) which is feature map for last layer of decoder head.

Source code in terratorch/models/decoders/upernet_decoder.py
def forward(self, inputs):
    """Forward function for feature maps before classifying each pixel with
    Args:
        inputs (list[Tensor]): List of multi-level img features.

    Returns:
        feats (Tensor): A tensor of shape (batch_size, self.channels,
            H, W) which is feature map for last layer of decoder head.
    """

    if self.scale_modules:
        scaled_inputs = []
        scaled_inputs.append(self.fpn1(inputs[0]))
        scaled_inputs.append(self.fpn2(inputs[1]))
        scaled_inputs.append(self.fpn3(inputs[2]))
        scaled_inputs.append(self.fpn4(inputs[3]))
        inputs = scaled_inputs
    # build laterals
    laterals = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
    laterals.append(self.psp_forward(inputs))

    # build top-down path
    used_backbone_levels = len(laterals)
    for i in range(used_backbone_levels - 1, 0, -1):
        prev_shape = laterals[i - 1].shape[2:]
        laterals[i - 1] = laterals[i - 1] + torch.nn.functional.interpolate(
            laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
        )

    # build outputs
    fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
    # append psp feature
    fpn_outs.append(laterals[-1])

    for i in range(used_backbone_levels - 1, 0, -1):
        fpn_outs[i] = torch.nn.functional.interpolate(
            fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
        )
    fpn_outs = torch.cat(fpn_outs, dim=1)
    feats = self.fpn_bottleneck(fpn_outs)
    return feats
psp_forward(inputs)

Forward function of PSP module.

Source code in terratorch/models/decoders/upernet_decoder.py
def psp_forward(self, inputs):
    """Forward function of PSP module."""
    x = inputs[-1]
    psp_outs = [x]
    psp_outs.extend(self.psp_modules(x))
    psp_outs = torch.cat(psp_outs, dim=1)
    output = self.bottleneck(psp_outs)

    return output

Heads

terratorch.models.heads.regression_head

RegressionHead

Bases: Module

Regression head

Source code in terratorch/models/heads/regression_head.py
class RegressionHead(nn.Module):
    """Regression head"""

    def __init__(
        self,
        in_channels: int,
        final_act: nn.Module | str | None = None,
        learned_upscale_layers: int = 0,
        channel_list: list[int] | None = None,
        batch_norm: bool = True,
        dropout: float = 0,
    ) -> None:
        """Constructor

        Args:
            in_channels (int): Number of input channels
            final_act (nn.Module | None, optional): Final activation to be applied. Defaults to None.
            learned_upscale_layers (int, optional): Number of Pixelshuffle layers to create. Each upscales 2x.
                Defaults to 0.
            channel_list (list[int] | None, optional): List with number of channels for each Conv
                layer to be created. Defaults to None.
            batch_norm (bool, optional): Whether to apply batch norm. Defaults to True.
            dropout (float, optional): Dropout value to apply. Defaults to 0.

        """
        super().__init__()
        self.learned_upscale_layers = learned_upscale_layers
        self.final_act = final_act if final_act else nn.Identity()
        if isinstance(final_act, str):
            module_name, class_name = final_act.rsplit(".", 1)
            target_class = getattr(importlib.import_module(module_name), class_name)
            self.final_act = target_class()
        pre_layers = []
        if learned_upscale_layers != 0:
            learned_upscale = nn.Sequential(
                *[PixelShuffleUpscale(in_channels) for _ in range(self.learned_upscale_layers)]
            )
            pre_layers.append(learned_upscale)

        if channel_list is None:
            pre_head = nn.Identity()
        else:

            def block(in_channels, out_channels):
                return nn.Sequential(
                    nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True),
                )

            channel_list = [in_channels, *channel_list]
            pre_head = nn.Sequential(
                *[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
            )
            in_channels = channel_list[-1]
            pre_layers.append(pre_head)
        dropout = nn.Dropout2d(dropout)
        final_layer = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1)
        self.head = nn.Sequential(*[*pre_layers, dropout, final_layer])

    def forward(self, x):
        output = self.head(x)
        return self.final_act(output)
__init__(in_channels, final_act=None, learned_upscale_layers=0, channel_list=None, batch_norm=True, dropout=0)

Constructor

Parameters:
  • in_channels (int) –

    Number of input channels

  • final_act (Module | None, default: None ) –

    Final activation to be applied. Defaults to None.

  • learned_upscale_layers (int, default: 0 ) –

    Number of Pixelshuffle layers to create. Each upscales 2x. Defaults to 0.

  • channel_list (list[int] | None, default: None ) –

    List with number of channels for each Conv layer to be created. Defaults to None.

  • batch_norm (bool, default: True ) –

    Whether to apply batch norm. Defaults to True.

  • dropout (float, default: 0 ) –

    Dropout value to apply. Defaults to 0.

Source code in terratorch/models/heads/regression_head.py
def __init__(
    self,
    in_channels: int,
    final_act: nn.Module | str | None = None,
    learned_upscale_layers: int = 0,
    channel_list: list[int] | None = None,
    batch_norm: bool = True,
    dropout: float = 0,
) -> None:
    """Constructor

    Args:
        in_channels (int): Number of input channels
        final_act (nn.Module | None, optional): Final activation to be applied. Defaults to None.
        learned_upscale_layers (int, optional): Number of Pixelshuffle layers to create. Each upscales 2x.
            Defaults to 0.
        channel_list (list[int] | None, optional): List with number of channels for each Conv
            layer to be created. Defaults to None.
        batch_norm (bool, optional): Whether to apply batch norm. Defaults to True.
        dropout (float, optional): Dropout value to apply. Defaults to 0.

    """
    super().__init__()
    self.learned_upscale_layers = learned_upscale_layers
    self.final_act = final_act if final_act else nn.Identity()
    if isinstance(final_act, str):
        module_name, class_name = final_act.rsplit(".", 1)
        target_class = getattr(importlib.import_module(module_name), class_name)
        self.final_act = target_class()
    pre_layers = []
    if learned_upscale_layers != 0:
        learned_upscale = nn.Sequential(
            *[PixelShuffleUpscale(in_channels) for _ in range(self.learned_upscale_layers)]
        )
        pre_layers.append(learned_upscale)

    if channel_list is None:
        pre_head = nn.Identity()
    else:

        def block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )

        channel_list = [in_channels, *channel_list]
        pre_head = nn.Sequential(
            *[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
        )
        in_channels = channel_list[-1]
        pre_layers.append(pre_head)
    dropout = nn.Dropout2d(dropout)
    final_layer = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1)
    self.head = nn.Sequential(*[*pre_layers, dropout, final_layer])

terratorch.models.heads.segmentation_head

SegmentationHead

Bases: Module

Segmentation head

Source code in terratorch/models/heads/segmentation_head.py
class SegmentationHead(nn.Module):
    """Segmentation head"""

    def __init__(
        self, in_channels: int, num_classes: int, channel_list: list[int] | None = None, dropout: float = 0
    ) -> None:
        """Constructor

        Args:
            in_channels (int): Number of input channels
            num_classes (int): Number of output classes
            channel_list (list[int] | None, optional):  List with number of channels for each Conv
                layer to be created. Defaults to None.
            dropout (float, optional): Dropout value to apply. Defaults to 0.
        """
        super().__init__()
        self.num_classes = num_classes
        if channel_list is None:
            pre_head = nn.Identity()
        else:

            def block(in_channels, out_channels):
                return nn.Sequential(
                    nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1), nn.ReLU()
                )

            channel_list = [in_channels, *channel_list]
            pre_head = nn.Sequential(
                *[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
            )
            in_channels = channel_list[-1]
        dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
        self.head = nn.Sequential(
            pre_head,
            dropout,
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=num_classes,
                kernel_size=1,
            ),
        )

    def forward(self, x):
        return self.head(x)
__init__(in_channels, num_classes, channel_list=None, dropout=0)

Constructor

Parameters:
  • in_channels (int) –

    Number of input channels

  • num_classes (int) –

    Number of output classes

  • channel_list (list[int] | None, default: None ) –

    List with number of channels for each Conv layer to be created. Defaults to None.

  • dropout (float, default: 0 ) –

    Dropout value to apply. Defaults to 0.

Source code in terratorch/models/heads/segmentation_head.py
def __init__(
    self, in_channels: int, num_classes: int, channel_list: list[int] | None = None, dropout: float = 0
) -> None:
    """Constructor

    Args:
        in_channels (int): Number of input channels
        num_classes (int): Number of output classes
        channel_list (list[int] | None, optional):  List with number of channels for each Conv
            layer to be created. Defaults to None.
        dropout (float, optional): Dropout value to apply. Defaults to 0.
    """
    super().__init__()
    self.num_classes = num_classes
    if channel_list is None:
        pre_head = nn.Identity()
    else:

        def block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1), nn.ReLU()
            )

        channel_list = [in_channels, *channel_list]
        pre_head = nn.Sequential(
            *[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
        )
        in_channels = channel_list[-1]
    dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
    self.head = nn.Sequential(
        pre_head,
        dropout,
        nn.Conv2d(
            in_channels=in_channels,
            out_channels=num_classes,
            kernel_size=1,
        ),
    )

terratorch.models.heads.classification_head

ClassificationHead

Bases: Module

Classification head

Source code in terratorch/models/heads/classification_head.py
class ClassificationHead(nn.Module):
    """Classification head"""

    # how to allow cls token?
    def __init__(
        self,
        in_dim: int,
        num_classes: int,
        dim_list: list[int] | None = None,
        dropout: float = 0,
        linear_after_pool: bool = False,
    ) -> None:
        """Constructor

        Args:
            in_dim (int): Input dimensionality
            num_classes (int): Number of output classes
            dim_list (list[int] | None, optional):  List with number of dimensions for each Linear
                layer to be created. Defaults to None.
            dropout (float, optional): Dropout value to apply. Defaults to 0.
            linear_after_pool (bool, optional): Apply pooling first, then apply the linear layer. Defaults to False
        """
        super().__init__()
        self.num_classes = num_classes
        self.linear_after_pool = linear_after_pool
        if dim_list is None:
            pre_head = nn.Identity()
        else:

            def block(in_dim, out_dim):
                return nn.Sequential(nn.Linear(in_features=in_dim, out_features=out_dim), nn.ReLU())

            dim_list = [in_dim, *dim_list]
            pre_head = nn.Sequential(*[block(dim_list[i], dim_list[i + 1]) for i in range(len(dim_list) - 1)])
            in_dim = dim_list[-1]
        dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
        self.head = nn.Sequential(
            pre_head,
            dropout,
            nn.Linear(in_features=in_dim, out_features=num_classes),
        )

    def forward(self, x: Tensor):
        x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)

        if self.linear_after_pool:
            x = x.mean(axis=1)
            out = self.head(x)
        else:
            x = self.head(x)
            out = x.mean(axis=1)
        return out
__init__(in_dim, num_classes, dim_list=None, dropout=0, linear_after_pool=False)

Constructor

Parameters:
  • in_dim (int) –

    Input dimensionality

  • num_classes (int) –

    Number of output classes

  • dim_list (list[int] | None, default: None ) –

    List with number of dimensions for each Linear layer to be created. Defaults to None.

  • dropout (float, default: 0 ) –

    Dropout value to apply. Defaults to 0.

  • linear_after_pool (bool, default: False ) –

    Apply pooling first, then apply the linear layer. Defaults to False

Source code in terratorch/models/heads/classification_head.py
def __init__(
    self,
    in_dim: int,
    num_classes: int,
    dim_list: list[int] | None = None,
    dropout: float = 0,
    linear_after_pool: bool = False,
) -> None:
    """Constructor

    Args:
        in_dim (int): Input dimensionality
        num_classes (int): Number of output classes
        dim_list (list[int] | None, optional):  List with number of dimensions for each Linear
            layer to be created. Defaults to None.
        dropout (float, optional): Dropout value to apply. Defaults to 0.
        linear_after_pool (bool, optional): Apply pooling first, then apply the linear layer. Defaults to False
    """
    super().__init__()
    self.num_classes = num_classes
    self.linear_after_pool = linear_after_pool
    if dim_list is None:
        pre_head = nn.Identity()
    else:

        def block(in_dim, out_dim):
            return nn.Sequential(nn.Linear(in_features=in_dim, out_features=out_dim), nn.ReLU())

        dim_list = [in_dim, *dim_list]
        pre_head = nn.Sequential(*[block(dim_list[i], dim_list[i + 1]) for i in range(len(dim_list) - 1)])
        in_dim = dim_list[-1]
    dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
    self.head = nn.Sequential(
        pre_head,
        dropout,
        nn.Linear(in_features=in_dim, out_features=num_classes),
    )

Auxiliary Heads

terratorch.models.model.AuxiliaryHead dataclass

Class containing all information to create auxiliary heads.

Parameters:
  • name (str) –

    Name of the head. Should match the name given to the auxiliary loss.

  • decoder (str) –

    Name of the decoder class to be used.

  • decoder_args (dict | None) –

    parameters to be passed to the decoder constructor. Parameters for the decoder should be prefixed with decoder_. Parameters for the head should be prefixed with head_.

Source code in terratorch/models/model.py
@dataclass
class AuxiliaryHead:
    """Class containing all information to create auxiliary heads.

    Args:
        name (str): Name of the head. Should match the name given to the auxiliary loss.
        decoder (str): Name of the decoder class to be used.
        decoder_args (dict | None): parameters to be passed to the decoder constructor.
            Parameters for the decoder should be prefixed with `decoder_`.
            Parameters for the head should be prefixed with `head_`.
    """

    name: str
    decoder: str
    decoder_args: dict | None

Model Output

terratorch.models.model.ModelOutput dataclass

Source code in terratorch/models/model.py
@dataclass
class ModelOutput:
    output: Tensor
    auxiliary_heads: dict[str, Tensor] = None

Model Factory

terratorch.models.PrithviModelFactory

Bases: ModelFactory

Source code in terratorch/models/prithvi_model_factory.py
@MODEL_FACTORY_REGISTRY.register
class PrithviModelFactory(ModelFactory):
    def __init__(self) -> None:
        self._factory: EncoderDecoderFactory = EncoderDecoderFactory()
    def build_model(
        self,
        task: str,
        backbone: str | nn.Module,
        decoder: str | nn.Module,
        bands: list[HLSBands | int],
        in_channels: int
        | None = None,  # this should be removed, can be derived from bands. But it is a breaking change
        num_classes: int | None = None,
        pretrained: bool = True,  # noqa: FBT001, FBT002
        num_frames: int = 1,
        prepare_features_for_image_model: Callable | None = None,
        aux_decoders: list[AuxiliaryHead] | None = None,
        rescale: bool = True,  # noqa: FBT002, FBT001
        **kwargs,
    ) -> Model:
        """Model factory for prithvi models.

        Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
        `backbone_`, `decoder_` and `head_` respectively.

        Args:
            task (str): Task to be performed. Currently supports "segmentation" and "regression".
            backbone (str, nn.Module): Backbone to be used. If string, should be able to be parsed
                by the specified factory. Defaults to "prithvi_100".
            decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
                    If a string, it will be created from a class exposed in decoder.__init__.py with the same name.
                    If an nn.Module, we expect it to expose a property `decoder.out_channels`.
                    Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".
            in_channels (int, optional): Number of input channels. Defaults to 3.
            bands (list[terratorch.datasets.HLSBands], optional): Bands the model will be trained on.
                    Should be a list of terratorch.datasets.HLSBands.
                    Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE].
            num_classes (int, optional): Number of classes. None for regression tasks.
            pretrained (Union[bool, Path], optional): Whether to load pretrained weights for the backbone, if available.
                Defaults to True.
            num_frames (int, optional): Number of timesteps for the model to handle. Defaults to 1.
            prepare_features_for_image_model (Callable | None): Function to be called on encoder features
                before passing them to the decoder. Defaults to None, which applies the identity function.
            aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model.
                These decoders take the input from the encoder as well.
            rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
                is different from the ground truth. Only applicable to pixel wise models
                (e.g. segmentation, pixel wise regression). Defaults to True.


        Returns:
            nn.Module: Full model with encoder, decoder and head.
        """
        warnings.warn("PrithviModelFactory is deprecated. Please switch to EncoderDecoderFactory.", stacklevel=1)
        if in_channels is None:
            in_channels = len(bands)
        # TODO: support auxiliary heads
        kwargs["backbone_bands"] = bands
        kwargs["backbone_in_chans"] = in_channels
        kwargs["backbone_pretrained"] = pretrained
        kwargs["backbone_num_frames"] = num_frames
        if prepare_features_for_image_model:
            msg = (
                "This functionality is no longer supported. Please migrate to EncoderDecoderFactory\
                         and use necks."
            )
            raise RuntimeError(msg)


        if not isinstance(backbone, nn.Module):
            if not backbone.startswith("prithvi_"):
                msg = "This class only handles models for `prithvi` encoders"
                raise NotImplementedError(msg)

        return self._factory.build_model(task,
                                         backbone,
                                         decoder,
                                         num_classes=num_classes,
                                         necks=None,
                                         aux_decoders=aux_decoders,
                                         rescale=rescale,
                                         **kwargs)

build_model(task, backbone, decoder, bands, in_channels=None, num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, **kwargs)

Model factory for prithvi models.

Further arguments to be passed to the backbone, decoder or head. They should be prefixed with backbone_, decoder_ and head_ respectively.

Parameters:
  • task (str) –

    Task to be performed. Currently supports "segmentation" and "regression".

  • backbone ((str, Module)) –

    Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100".

  • decoder (Union[str, Module]) –

    Decoder to be used for the segmentation model. If a string, it will be created from a class exposed in decoder.init.py with the same name. If an nn.Module, we expect it to expose a property decoder.out_channels. Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".

  • in_channels (int, default: None ) –

    Number of input channels. Defaults to 3.

  • bands (list[HLSBands]) –

    Bands the model will be trained on. Should be a list of terratorch.datasets.HLSBands. Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE].

  • num_classes (int, default: None ) –

    Number of classes. None for regression tasks.

  • pretrained (Union[bool, Path], default: True ) –

    Whether to load pretrained weights for the backbone, if available. Defaults to True.

  • num_frames (int, default: 1 ) –

    Number of timesteps for the model to handle. Defaults to 1.

  • prepare_features_for_image_model (Callable | None, default: None ) –

    Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function.

  • aux_decoders (list[AuxiliaryHead] | None, default: None ) –

    List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well.

  • rescale (bool, default: True ) –

    Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

Returns:
  • Model

    nn.Module: Full model with encoder, decoder and head.

Source code in terratorch/models/prithvi_model_factory.py
def build_model(
    self,
    task: str,
    backbone: str | nn.Module,
    decoder: str | nn.Module,
    bands: list[HLSBands | int],
    in_channels: int
    | None = None,  # this should be removed, can be derived from bands. But it is a breaking change
    num_classes: int | None = None,
    pretrained: bool = True,  # noqa: FBT001, FBT002
    num_frames: int = 1,
    prepare_features_for_image_model: Callable | None = None,
    aux_decoders: list[AuxiliaryHead] | None = None,
    rescale: bool = True,  # noqa: FBT002, FBT001
    **kwargs,
) -> Model:
    """Model factory for prithvi models.

    Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
    `backbone_`, `decoder_` and `head_` respectively.

    Args:
        task (str): Task to be performed. Currently supports "segmentation" and "regression".
        backbone (str, nn.Module): Backbone to be used. If string, should be able to be parsed
            by the specified factory. Defaults to "prithvi_100".
        decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
                If a string, it will be created from a class exposed in decoder.__init__.py with the same name.
                If an nn.Module, we expect it to expose a property `decoder.out_channels`.
                Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".
        in_channels (int, optional): Number of input channels. Defaults to 3.
        bands (list[terratorch.datasets.HLSBands], optional): Bands the model will be trained on.
                Should be a list of terratorch.datasets.HLSBands.
                Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE].
        num_classes (int, optional): Number of classes. None for regression tasks.
        pretrained (Union[bool, Path], optional): Whether to load pretrained weights for the backbone, if available.
            Defaults to True.
        num_frames (int, optional): Number of timesteps for the model to handle. Defaults to 1.
        prepare_features_for_image_model (Callable | None): Function to be called on encoder features
            before passing them to the decoder. Defaults to None, which applies the identity function.
        aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model.
            These decoders take the input from the encoder as well.
        rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
            is different from the ground truth. Only applicable to pixel wise models
            (e.g. segmentation, pixel wise regression). Defaults to True.


    Returns:
        nn.Module: Full model with encoder, decoder and head.
    """
    warnings.warn("PrithviModelFactory is deprecated. Please switch to EncoderDecoderFactory.", stacklevel=1)
    if in_channels is None:
        in_channels = len(bands)
    # TODO: support auxiliary heads
    kwargs["backbone_bands"] = bands
    kwargs["backbone_in_chans"] = in_channels
    kwargs["backbone_pretrained"] = pretrained
    kwargs["backbone_num_frames"] = num_frames
    if prepare_features_for_image_model:
        msg = (
            "This functionality is no longer supported. Please migrate to EncoderDecoderFactory\
                     and use necks."
        )
        raise RuntimeError(msg)


    if not isinstance(backbone, nn.Module):
        if not backbone.startswith("prithvi_"):
            msg = "This class only handles models for `prithvi` encoders"
            raise NotImplementedError(msg)

    return self._factory.build_model(task,
                                     backbone,
                                     decoder,
                                     num_classes=num_classes,
                                     necks=None,
                                     aux_decoders=aux_decoders,
                                     rescale=rescale,
                                     **kwargs)

terratorch.models.SMPModelFactory

Bases: ModelFactory

Source code in terratorch/models/smp_model_factory.py
@MODEL_FACTORY_REGISTRY.register
class SMPModelFactory(ModelFactory):
    def build_model(
        self,
        task: str,
        backbone: str,
        model: str,
        bands: list[HLSBands | int],
        in_channels: int | None = None,
        num_classes: int = 1,
        pretrained: str | bool | None = True,  # noqa: FBT002
        prepare_features_for_image_model: Callable | None = None,
        regression_relu: bool = False,  # noqa: FBT001, FBT002
        **kwargs,
    ) -> Model:
        """
        Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.

        This factory handles the instantiation of segmentation and regression models using specified
        encoders and decoders from the SMP library, along with custom modifications and extensions such
        as auxiliary decoders or modified encoders.

        Attributes:
            task (str): Specifies the task for which the model is being built. Supported tasks are
                        "segmentation".
            backbone (str): Specifies the backbone model to be used.
            decoder (str): Specifies the decoder to be used for constructing the
                        segmentation model.
            bands (list[terratorch.datasets.HLSBands | int]): A list specifying the bands that the model
                        will operate on. These are expected to be from terratorch.datasets.HLSBands.
            in_channels (int, optional): Specifies the number of input channels. Defaults to None.
            num_classes (int, optional): The number of output classes for the model.
            pretrained (bool | Path, optional): Indicates whether to load pretrained weights for the
                        backbone. Can also specify a path to weights. Defaults to True.
            num_frames (int, optional): Specifies the number of timesteps the model should handle. Useful
                        for temporal models.
            regression_relu (bool): Whether to apply ReLU activation in the case of regression tasks.
            **kwargs: Additional arguments that might be passed to further customize the backbone, decoder,
                        or any auxiliary heads. These should be prefixed appropriately

        Raises:
            ValueError: If the specified decoder is not supported by SMP.
            Exception: If the specified task is not "segmentation"

        Returns:
            nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified
                    parameters and tasks.
        """
        if task != "segmentation":
            msg = f"SMP models can only perform segmentation, but got task {task}"
            raise Exception(msg)

        bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands]
        if in_channels is None:
            in_channels = len(bands)

        # Gets decoder module.
        model_module = getattr(smp, model, None)
        if model_module is None:
            msg = f"Decoder {model} is not supported in SMP."
            raise ValueError(msg)

        backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")  # Encoder params should be prefixed backbone_
        smp_kwargs, kwargs = extract_prefix_keys(backbone_kwargs, "smp_")  # Smp model params should be prefixed smp_
        aux_params, kwargs = extract_prefix_keys(backbone_kwargs, "aux_")  # Auxiliary head params should be prefixed aux_
        aux_params = None if aux_params == {} else aux_params

        if isinstance(pretrained, bool):
            if pretrained:
                pretrained = "imagenet"
            else:
                pretrained = None

        # If encoder not currently supported by SMP (custom encoder).
        if backbone not in smp_encoders:
            # These params must be included in the config file with appropriate prefix.
            required_params = {
                "encoder_depth": smp_kwargs,
                "out_channels": backbone_kwargs,
                "output_stride": backbone_kwargs,
            }

            for param, config_dict in required_params.items():
                if param not in config_dict:
                    msg = f"Config must include the '{param}' parameter"
                    raise ValueError(msg)

            # Using new encoder.
            backbone_class = make_smp_encoder(backbone)
            backbone_kwargs["prepare_features_for_image_model"] = prepare_features_for_image_model
            # Registering custom encoder into SMP.
            register_custom_encoder(backbone_class, backbone_kwargs, pretrained)

            model_args = {
                "encoder_name": "SMPEncoderWrapperWithPFFIM",
                "encoder_weights": pretrained,
                "in_channels": in_channels,
                "classes": num_classes,
                **smp_kwargs,
            }
        # Using SMP encoder.
        else:
            model_args = {
                "encoder_name": backbone,
                "encoder_weights": pretrained,
                "in_channels": in_channels,
                "classes": num_classes,
                **smp_kwargs,
            }

        model = model_module(**model_args, aux_params=aux_params)

        return SMPModelWrapper(
            model, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression"
        )

build_model(task, backbone, model, bands, in_channels=None, num_classes=1, pretrained=True, prepare_features_for_image_model=None, regression_relu=False, **kwargs)

Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.

This factory handles the instantiation of segmentation and regression models using specified encoders and decoders from the SMP library, along with custom modifications and extensions such as auxiliary decoders or modified encoders.

Attributes:
  • task (str) –

    Specifies the task for which the model is being built. Supported tasks are "segmentation".

  • backbone (str) –

    Specifies the backbone model to be used.

  • decoder (str) –

    Specifies the decoder to be used for constructing the segmentation model.

  • bands (list[HLSBands | int]) –

    A list specifying the bands that the model will operate on. These are expected to be from terratorch.datasets.HLSBands.

  • in_channels (int) –

    Specifies the number of input channels. Defaults to None.

  • num_classes (int) –

    The number of output classes for the model.

  • pretrained (bool | Path) –

    Indicates whether to load pretrained weights for the backbone. Can also specify a path to weights. Defaults to True.

  • num_frames (int) –

    Specifies the number of timesteps the model should handle. Useful for temporal models.

  • regression_relu (bool) –

    Whether to apply ReLU activation in the case of regression tasks.

  • **kwargs (bool) –

    Additional arguments that might be passed to further customize the backbone, decoder, or any auxiliary heads. These should be prefixed appropriately

Raises:
  • ValueError

    If the specified decoder is not supported by SMP.

  • Exception

    If the specified task is not "segmentation"

Returns:
  • Model

    nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified parameters and tasks.

Source code in terratorch/models/smp_model_factory.py
def build_model(
    self,
    task: str,
    backbone: str,
    model: str,
    bands: list[HLSBands | int],
    in_channels: int | None = None,
    num_classes: int = 1,
    pretrained: str | bool | None = True,  # noqa: FBT002
    prepare_features_for_image_model: Callable | None = None,
    regression_relu: bool = False,  # noqa: FBT001, FBT002
    **kwargs,
) -> Model:
    """
    Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.

    This factory handles the instantiation of segmentation and regression models using specified
    encoders and decoders from the SMP library, along with custom modifications and extensions such
    as auxiliary decoders or modified encoders.

    Attributes:
        task (str): Specifies the task for which the model is being built. Supported tasks are
                    "segmentation".
        backbone (str): Specifies the backbone model to be used.
        decoder (str): Specifies the decoder to be used for constructing the
                    segmentation model.
        bands (list[terratorch.datasets.HLSBands | int]): A list specifying the bands that the model
                    will operate on. These are expected to be from terratorch.datasets.HLSBands.
        in_channels (int, optional): Specifies the number of input channels. Defaults to None.
        num_classes (int, optional): The number of output classes for the model.
        pretrained (bool | Path, optional): Indicates whether to load pretrained weights for the
                    backbone. Can also specify a path to weights. Defaults to True.
        num_frames (int, optional): Specifies the number of timesteps the model should handle. Useful
                    for temporal models.
        regression_relu (bool): Whether to apply ReLU activation in the case of regression tasks.
        **kwargs: Additional arguments that might be passed to further customize the backbone, decoder,
                    or any auxiliary heads. These should be prefixed appropriately

    Raises:
        ValueError: If the specified decoder is not supported by SMP.
        Exception: If the specified task is not "segmentation"

    Returns:
        nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified
                parameters and tasks.
    """
    if task != "segmentation":
        msg = f"SMP models can only perform segmentation, but got task {task}"
        raise Exception(msg)

    bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands]
    if in_channels is None:
        in_channels = len(bands)

    # Gets decoder module.
    model_module = getattr(smp, model, None)
    if model_module is None:
        msg = f"Decoder {model} is not supported in SMP."
        raise ValueError(msg)

    backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")  # Encoder params should be prefixed backbone_
    smp_kwargs, kwargs = extract_prefix_keys(backbone_kwargs, "smp_")  # Smp model params should be prefixed smp_
    aux_params, kwargs = extract_prefix_keys(backbone_kwargs, "aux_")  # Auxiliary head params should be prefixed aux_
    aux_params = None if aux_params == {} else aux_params

    if isinstance(pretrained, bool):
        if pretrained:
            pretrained = "imagenet"
        else:
            pretrained = None

    # If encoder not currently supported by SMP (custom encoder).
    if backbone not in smp_encoders:
        # These params must be included in the config file with appropriate prefix.
        required_params = {
            "encoder_depth": smp_kwargs,
            "out_channels": backbone_kwargs,
            "output_stride": backbone_kwargs,
        }

        for param, config_dict in required_params.items():
            if param not in config_dict:
                msg = f"Config must include the '{param}' parameter"
                raise ValueError(msg)

        # Using new encoder.
        backbone_class = make_smp_encoder(backbone)
        backbone_kwargs["prepare_features_for_image_model"] = prepare_features_for_image_model
        # Registering custom encoder into SMP.
        register_custom_encoder(backbone_class, backbone_kwargs, pretrained)

        model_args = {
            "encoder_name": "SMPEncoderWrapperWithPFFIM",
            "encoder_weights": pretrained,
            "in_channels": in_channels,
            "classes": num_classes,
            **smp_kwargs,
        }
    # Using SMP encoder.
    else:
        model_args = {
            "encoder_name": backbone,
            "encoder_weights": pretrained,
            "in_channels": in_channels,
            "classes": num_classes,
            **smp_kwargs,
        }

    model = model_module(**model_args, aux_params=aux_params)

    return SMPModelWrapper(
        model, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression"
    )

Adding new model types

Adding new model types is as simple as creating a new factory that produces models. See for instance the example below for a potential SMPModelFactory

from terratorch.models.model import register_factory

@register_factory
class SMPModelFactory(ModelFactory):
    def build_model(
        self,
        task: str,
        backbone: str | nn.Module,
        decoder: str | nn.Module,
        in_channels: int,
        **kwargs,
    ) -> Model:

        model = smp.Unet(encoder_name="resnet34", encoder_weights=None, in_channels=in_channels, classes=1)
        return SMPModelWrapper(model)

@register_factory
class SMPModelWrapper(Model, nn.Module):
    def __init__(self, smp_model) -> None:
        super().__init__()
        self.smp_model = smp_model

    def forward(self, *args, **kwargs):
        return ModelOutput(self.smp_model(*args, **kwargs).squeeze(1))

    def freeze_encoder(self):
        pass

    def freeze_decoder(self):
        pass

Custom modules with CLI

Custom modules must be in the import path in order to be registered in the appropriate registries.

In order to do this without modifying the code when using the CLI, you may place your modules under a custom_modules directory. This must be in the directory from which you execute terratorch.