Skip to content

Models

Prithvi backbones

We provide access to the Prithvi backbones through integration with timm.

By passing features_only=True, you can conveniently get access to a model that outputs the features produced at each layer of the model.

Passing features_only=False will let you access the full original model.

Instantiating a prithvi backbone from timm
import timm
import terratorch # even though we don't use the import directly, we need it so that the models are available in the timm registry

# find available prithvi models by name
print(timm.list_models("prithvi*"))
# and those with pretrained weights
print(timm.list_pretrained("prithvi*"))

# instantiate your desired model with features_only=True to obtain a backbone
model = timm.create_model(
    "prithvi_vit_100", num_frames=1, pretrained=True, features_only=True
)

# instantiate your model with weights of your own
model = timm.create_model(
    "prithvi_vit_100", num_frames=1, pretrained=True, pretrained_cfg_overlay={"file": "<path to weights>"}, features_only=True
)
# Rest of your PyTorch / PyTorchLightning code

We also provide a model factory that can build a task specific model for a downstream task based on a Prithvi backbone.

By passing a list of bands being used to the constructor, we automatically filter out unused bands, and randomly initialize weights for new bands that were not pretrained on.

Info

To pass your own path from where to load the weights with the PrithviModelFactory, you can make use of timm's pretrained_cfg_overlay. E.g. to pass a local path, you can pass the parameter backbone_pretrained_cfg_overlay = {"file": "<local_path>"} to the model factory.

Besides file, you can also pass url, hf_hub_id, amongst others. Check timm's documentation for full details.

terratorch.models.backbones.prithvi_select_patch_embed_weights

prithvi_select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands)

Filter out the patch embedding weights according to the bands being used. If a band exists in the pretrained_bands, but not in model_bands, drop it. If a band exists in model_bands, but not pretrained_bands, randomly initialize those weights.

Parameters:

Name Type Description Default
state_dict dict

State Dict

required
model Module

Model to load the weights onto.

required
pretrained_bands list[HLSBands]

List of bands the model was pretrained on, in the correct order.

required
model_bands list[HLSBands]

List of bands the model is going to be finetuned on, in the correct order

required

Returns:

Name Type Description
dict dict

New state dict

Source code in terratorch/models/backbones/prithvi_select_patch_embed_weights.py
def prithvi_select_patch_embed_weights(
    state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int], model_bands: list[HLSBands | int]
) -> dict:
    """Filter out the patch embedding weights according to the bands being used.
    If a band exists in the pretrained_bands, but not in model_bands, drop it.
    If a band exists in model_bands, but not pretrained_bands, randomly initialize those weights.


    Args:
        state_dict (dict): State Dict
        model (nn.Module): Model to load the weights onto.
        pretrained_bands (list[HLSBands]): List of bands the model was pretrained on, in the correct order.
        model_bands (list[HLSBands]): List of bands the model is going to be finetuned on, in the correct order

    Returns:
        dict: New state dict
    """
    _possible_keys_for_proj_weight = {
        "patch_embed.proj.weight",
        "module.patch_embed.proj.weight",
        "patch_embed.projection.weight",
        "module.patch_embed.projection.weight",
    }
    patch_embed_proj_weight_key = state_dict.keys() & _possible_keys_for_proj_weight
    if len(patch_embed_proj_weight_key) == 0:
        msg = "Could not find key for patch embed weight"
        raise Exception(msg)
    if len(patch_embed_proj_weight_key) > 1:
        msg = "Too many matches for key for patch embed weight"
        raise Exception(msg)

    # extract the single element from the set
    (patch_embed_proj_weight_key,) = patch_embed_proj_weight_key
    patch_embed_weight = state_dict[patch_embed_proj_weight_key]

    temp_weight = model.state_dict()[patch_embed_proj_weight_key].clone()
    torch.nn.init.xavier_uniform_(temp_weight.view([temp_weight.shape[0], -1]))
    for index, band in enumerate(model_bands):
        if band in pretrained_bands:
            temp_weight[:, index] = patch_embed_weight[:, pretrained_bands.index(band)]

    state_dict[patch_embed_proj_weight_key] = temp_weight
    return state_dict

Decoders

terratorch.models.decoders.fcn_decoder

FCNDecoder

Bases: Module

Fully Convolutional Decoder

Source code in terratorch/models/decoders/fcn_decoder.py
class FCNDecoder(nn.Module):
    """Fully Convolutional Decoder"""

    def __init__(self, embed_dim: int, channels: int = 256, num_convs: int = 4, in_index: int = -1) -> None:
        """Constructor

        Args:
            embed_dim (_type_): Input embedding dimension
            channels (int, optional): Number of channels for each conv. Defaults to 256.
            num_convs (int, optional): Number of convs. Defaults to 4.
            in_index (int, optional): Index of the input list to take. Defaults to -1.
        """
        super().__init__()
        kernel_size = 2
        stride = 2
        dilation = 1
        padding = 0
        output_padding = 0
        self.channels = channels
        self.num_convs = num_convs
        self.in_index = in_index
        self.embed_dim = embed_dim[in_index]
        if num_convs < 1:
            msg = "num_convs must be >= 1"
            raise Exception(msg)

        convs = []

        for i in range(num_convs):
            in_channels = self.embed_dim if i == 0 else self.channels
            convs.append(
                _conv_upscale_block(in_channels, self.channels, kernel_size, stride, dilation, padding, output_padding)
            )

        self.convs = nn.Sequential(*convs)

    @property
    def output_embed_dim(self):
        return self.channels

    def forward(self, x: list[Tensor]):
        x = x[self.in_index]
        decoded = self.convs(x)
        return decoded
__init__(embed_dim, channels=256, num_convs=4, in_index=-1)

Constructor

Parameters:

Name Type Description Default
embed_dim _type_

Input embedding dimension

required
channels int

Number of channels for each conv. Defaults to 256.

256
num_convs int

Number of convs. Defaults to 4.

4
in_index int

Index of the input list to take. Defaults to -1.

-1
Source code in terratorch/models/decoders/fcn_decoder.py
def __init__(self, embed_dim: int, channels: int = 256, num_convs: int = 4, in_index: int = -1) -> None:
    """Constructor

    Args:
        embed_dim (_type_): Input embedding dimension
        channels (int, optional): Number of channels for each conv. Defaults to 256.
        num_convs (int, optional): Number of convs. Defaults to 4.
        in_index (int, optional): Index of the input list to take. Defaults to -1.
    """
    super().__init__()
    kernel_size = 2
    stride = 2
    dilation = 1
    padding = 0
    output_padding = 0
    self.channels = channels
    self.num_convs = num_convs
    self.in_index = in_index
    self.embed_dim = embed_dim[in_index]
    if num_convs < 1:
        msg = "num_convs must be >= 1"
        raise Exception(msg)

    convs = []

    for i in range(num_convs):
        in_channels = self.embed_dim if i == 0 else self.channels
        convs.append(
            _conv_upscale_block(in_channels, self.channels, kernel_size, stride, dilation, padding, output_padding)
        )

    self.convs = nn.Sequential(*convs)

terratorch.models.decoders.identity_decoder

Pass the features straight through

IdentityDecoder

Bases: Module

Identity decoder. Useful to pass the feature straight to the head.

Source code in terratorch/models/decoders/identity_decoder.py
class IdentityDecoder(nn.Module):
    """Identity decoder. Useful to pass the feature straight to the head."""

    def __init__(self, embed_dim: int, out_index=-1) -> None:
        """Constructor

        Args:
            embed_dim (int): Input embedding dimension
            out_index (int, optional): Index of the input list to take.. Defaults to -1.
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.dim = out_index

    @property
    def output_embed_dim(self):
        return self.embed_dim[self.dim]

    def forward(self, x: list[Tensor]):
        return x[self.dim]
__init__(embed_dim, out_index=-1)

Constructor

Parameters:

Name Type Description Default
embed_dim int

Input embedding dimension

required
out_index int

Index of the input list to take.. Defaults to -1.

-1
Source code in terratorch/models/decoders/identity_decoder.py
def __init__(self, embed_dim: int, out_index=-1) -> None:
    """Constructor

    Args:
        embed_dim (int): Input embedding dimension
        out_index (int, optional): Index of the input list to take.. Defaults to -1.
    """
    super().__init__()
    self.embed_dim = embed_dim
    self.dim = out_index

terratorch.models.decoders.upernet_decoder

PPM

Bases: ModuleList

Pooling Pyramid Module used in PSPNet.

Source code in terratorch/models/decoders/upernet_decoder.py
class PPM(nn.ModuleList):
    """Pooling Pyramid Module used in PSPNet."""

    def __init__(self, pool_scales, in_channels, channels, align_corners):
        """Constructor

        Args:
            pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
                Module.
            in_channels (int): Input channels.
            channels (int): Channels after modules, before conv_seg.
            align_corners (bool): align_corners argument of F.interpolate.
        """
        super().__init__()
        self.pool_scales = pool_scales
        self.align_corners = align_corners
        self.in_channels = in_channels
        self.channels = channels

        for pool_scale in pool_scales:
            self.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(pool_scale),
                    ConvModule(self.in_channels, self.channels, 1, inplace=True),
                )
            )

    def forward(self, x):
        """Forward function."""
        ppm_outs = []
        for ppm in self:
            ppm_out = ppm(x)
            upsampled_ppm_out = torch.nn.functional.interpolate(
                ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
            )
            ppm_outs.append(upsampled_ppm_out)
        return ppm_outs
__init__(pool_scales, in_channels, channels, align_corners)

Constructor

Parameters:

Name Type Description Default
pool_scales tuple[int]

Pooling scales used in Pooling Pyramid Module.

required
in_channels int

Input channels.

required
channels int

Channels after modules, before conv_seg.

required
align_corners bool

align_corners argument of F.interpolate.

required
Source code in terratorch/models/decoders/upernet_decoder.py
def __init__(self, pool_scales, in_channels, channels, align_corners):
    """Constructor

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module.
        in_channels (int): Input channels.
        channels (int): Channels after modules, before conv_seg.
        align_corners (bool): align_corners argument of F.interpolate.
    """
    super().__init__()
    self.pool_scales = pool_scales
    self.align_corners = align_corners
    self.in_channels = in_channels
    self.channels = channels

    for pool_scale in pool_scales:
        self.append(
            nn.Sequential(
                nn.AdaptiveAvgPool2d(pool_scale),
                ConvModule(self.in_channels, self.channels, 1, inplace=True),
            )
        )
forward(x)

Forward function.

Source code in terratorch/models/decoders/upernet_decoder.py
def forward(self, x):
    """Forward function."""
    ppm_outs = []
    for ppm in self:
        ppm_out = ppm(x)
        upsampled_ppm_out = torch.nn.functional.interpolate(
            ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
        )
        ppm_outs.append(upsampled_ppm_out)
    return ppm_outs

UperNetDecoder

Bases: Module

UperNetDecoder. Adapted from MMSegmentation.

Source code in terratorch/models/decoders/upernet_decoder.py
class UperNetDecoder(nn.Module):
    """UperNetDecoder. Adapted from MMSegmentation."""

    def __init__(
        self,
        embed_dim: list[int],
        pool_scales: tuple[int] = (1, 2, 3, 6),
        channels: int = 256,
        align_corners: bool = True,  # noqa: FBT001, FBT002
        scale_modules: bool = False
    ):
        """Constructor

        Args:
            embed_dim (list[int]): Input embedding dimension for each input.
            pool_scales (tuple[int], optional): Pooling scales used in Pooling Pyramid
                Module applied on the last feature. Default: (1, 2, 3, 6).
            channels (int, optional): Channels used in the decoder. Defaults to 256.
            align_corners (bool, optional): Wheter to align corners in rescaling. Defaults to True.
            scale_modules (bool, optional): Whether to apply scale modules to the inputs. Needed for plain ViT.
                Defaults to False.
        """
        super().__init__()
        self.scale_modules = scale_modules
        if scale_modules:
            self.fpn1 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim[0],
                                embed_dim[0] // 2, 2, 2),
                nn.BatchNorm2d(embed_dim[0] // 2),
                nn.GELU(),
                nn.ConvTranspose2d(embed_dim[0] // 2,
                                embed_dim[0] // 4, 2, 2))
            self.fpn2 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim[1],
                                embed_dim[1] // 2, 2, 2))
            self.fpn3 = nn.Sequential(nn.Identity())
            self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
            self.embed_dim = [embed_dim[0] // 4, embed_dim[1] // 2, embed_dim[2], embed_dim[3]]
        else:
            self.embed_dim = embed_dim

        self.output_embed_dim = channels
        self.channels = channels
        self.align_corners = align_corners
        # PSP Module
        self.psp_modules = PPM(
            pool_scales,
            self.embed_dim[-1],
            self.channels,
            align_corners=self.align_corners,
        )
        self.bottleneck = ConvModule(
            self.embed_dim[-1] + len(pool_scales) * self.channels, self.channels, 3, padding=1, inplace=True
        )
        # FPN Module
        self.lateral_convs = nn.ModuleList()
        self.fpn_convs = nn.ModuleList()
        for embed_dim in self.embed_dim[:-1]:  # skip the top layer
            l_conv = ConvModule(
                embed_dim,
                self.channels,
                1,
                inplace=False,
            )
            fpn_conv = ConvModule(
                self.channels,
                self.channels,
                3,
                padding=1,
                inplace=False,
            )
            self.lateral_convs.append(l_conv)
            self.fpn_convs.append(fpn_conv)

        self.fpn_bottleneck = ConvModule(len(self.embed_dim) * self.channels, self.channels, 3, padding=1, inplace=True)

    def psp_forward(self, inputs):
        """Forward function of PSP module."""
        x = inputs[-1]
        psp_outs = [x]
        psp_outs.extend(self.psp_modules(x))
        psp_outs = torch.cat(psp_outs, dim=1)
        output = self.bottleneck(psp_outs)

        return output

    def forward(self, inputs):
        """Forward function for feature maps before classifying each pixel with
        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            feats (Tensor): A tensor of shape (batch_size, self.channels,
                H, W) which is feature map for last layer of decoder head.
        """

        if self.scale_modules:
            scaled_inputs = []
            scaled_inputs.append(self.fpn1(inputs[0]))
            scaled_inputs.append(self.fpn2(inputs[1]))
            scaled_inputs.append(self.fpn3(inputs[2]))
            scaled_inputs.append(self.fpn4(inputs[3]))
            inputs = scaled_inputs
        # build laterals
        laterals = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
        laterals.append(self.psp_forward(inputs))

        # build top-down path
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
            prev_shape = laterals[i - 1].shape[2:]
            laterals[i - 1] = laterals[i - 1] + torch.nn.functional.interpolate(
                laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
            )

        # build outputs
        fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
        # append psp feature
        fpn_outs.append(laterals[-1])

        for i in range(used_backbone_levels - 1, 0, -1):
            fpn_outs[i] = torch.nn.functional.interpolate(
                fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
            )
        fpn_outs = torch.cat(fpn_outs, dim=1)
        feats = self.fpn_bottleneck(fpn_outs)
        return feats
__init__(embed_dim, pool_scales=(1, 2, 3, 6), channels=256, align_corners=True, scale_modules=False)

Constructor

Parameters:

Name Type Description Default
embed_dim list[int]

Input embedding dimension for each input.

required
pool_scales tuple[int]

Pooling scales used in Pooling Pyramid Module applied on the last feature. Default: (1, 2, 3, 6).

(1, 2, 3, 6)
channels int

Channels used in the decoder. Defaults to 256.

256
align_corners bool

Wheter to align corners in rescaling. Defaults to True.

True
scale_modules bool

Whether to apply scale modules to the inputs. Needed for plain ViT. Defaults to False.

False
Source code in terratorch/models/decoders/upernet_decoder.py
def __init__(
    self,
    embed_dim: list[int],
    pool_scales: tuple[int] = (1, 2, 3, 6),
    channels: int = 256,
    align_corners: bool = True,  # noqa: FBT001, FBT002
    scale_modules: bool = False
):
    """Constructor

    Args:
        embed_dim (list[int]): Input embedding dimension for each input.
        pool_scales (tuple[int], optional): Pooling scales used in Pooling Pyramid
            Module applied on the last feature. Default: (1, 2, 3, 6).
        channels (int, optional): Channels used in the decoder. Defaults to 256.
        align_corners (bool, optional): Wheter to align corners in rescaling. Defaults to True.
        scale_modules (bool, optional): Whether to apply scale modules to the inputs. Needed for plain ViT.
            Defaults to False.
    """
    super().__init__()
    self.scale_modules = scale_modules
    if scale_modules:
        self.fpn1 = nn.Sequential(
            nn.ConvTranspose2d(embed_dim[0],
                            embed_dim[0] // 2, 2, 2),
            nn.BatchNorm2d(embed_dim[0] // 2),
            nn.GELU(),
            nn.ConvTranspose2d(embed_dim[0] // 2,
                            embed_dim[0] // 4, 2, 2))
        self.fpn2 = nn.Sequential(
            nn.ConvTranspose2d(embed_dim[1],
                            embed_dim[1] // 2, 2, 2))
        self.fpn3 = nn.Sequential(nn.Identity())
        self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
        self.embed_dim = [embed_dim[0] // 4, embed_dim[1] // 2, embed_dim[2], embed_dim[3]]
    else:
        self.embed_dim = embed_dim

    self.output_embed_dim = channels
    self.channels = channels
    self.align_corners = align_corners
    # PSP Module
    self.psp_modules = PPM(
        pool_scales,
        self.embed_dim[-1],
        self.channels,
        align_corners=self.align_corners,
    )
    self.bottleneck = ConvModule(
        self.embed_dim[-1] + len(pool_scales) * self.channels, self.channels, 3, padding=1, inplace=True
    )
    # FPN Module
    self.lateral_convs = nn.ModuleList()
    self.fpn_convs = nn.ModuleList()
    for embed_dim in self.embed_dim[:-1]:  # skip the top layer
        l_conv = ConvModule(
            embed_dim,
            self.channels,
            1,
            inplace=False,
        )
        fpn_conv = ConvModule(
            self.channels,
            self.channels,
            3,
            padding=1,
            inplace=False,
        )
        self.lateral_convs.append(l_conv)
        self.fpn_convs.append(fpn_conv)

    self.fpn_bottleneck = ConvModule(len(self.embed_dim) * self.channels, self.channels, 3, padding=1, inplace=True)
forward(inputs)

Forward function for feature maps before classifying each pixel with Args: inputs (list[Tensor]): List of multi-level img features.

Returns:

Name Type Description
feats Tensor

A tensor of shape (batch_size, self.channels, H, W) which is feature map for last layer of decoder head.

Source code in terratorch/models/decoders/upernet_decoder.py
def forward(self, inputs):
    """Forward function for feature maps before classifying each pixel with
    Args:
        inputs (list[Tensor]): List of multi-level img features.

    Returns:
        feats (Tensor): A tensor of shape (batch_size, self.channels,
            H, W) which is feature map for last layer of decoder head.
    """

    if self.scale_modules:
        scaled_inputs = []
        scaled_inputs.append(self.fpn1(inputs[0]))
        scaled_inputs.append(self.fpn2(inputs[1]))
        scaled_inputs.append(self.fpn3(inputs[2]))
        scaled_inputs.append(self.fpn4(inputs[3]))
        inputs = scaled_inputs
    # build laterals
    laterals = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
    laterals.append(self.psp_forward(inputs))

    # build top-down path
    used_backbone_levels = len(laterals)
    for i in range(used_backbone_levels - 1, 0, -1):
        prev_shape = laterals[i - 1].shape[2:]
        laterals[i - 1] = laterals[i - 1] + torch.nn.functional.interpolate(
            laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
        )

    # build outputs
    fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
    # append psp feature
    fpn_outs.append(laterals[-1])

    for i in range(used_backbone_levels - 1, 0, -1):
        fpn_outs[i] = torch.nn.functional.interpolate(
            fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
        )
    fpn_outs = torch.cat(fpn_outs, dim=1)
    feats = self.fpn_bottleneck(fpn_outs)
    return feats
psp_forward(inputs)

Forward function of PSP module.

Source code in terratorch/models/decoders/upernet_decoder.py
def psp_forward(self, inputs):
    """Forward function of PSP module."""
    x = inputs[-1]
    psp_outs = [x]
    psp_outs.extend(self.psp_modules(x))
    psp_outs = torch.cat(psp_outs, dim=1)
    output = self.bottleneck(psp_outs)

    return output

Heads

terratorch.models.heads.regression_head

RegressionHead

Bases: Module

Regression head

Source code in terratorch/models/heads/regression_head.py
class RegressionHead(nn.Module):
    """Regression head"""

    def __init__(
        self,
        in_channels: int,
        final_act: nn.Module | str | None = None,
        learned_upscale_layers: int = 0,
        channel_list: list[int] | None = None,
        batch_norm: bool = True,
        dropout: float = 0,
    ) -> None:
        """Constructor

        Args:
            in_channels (int): Number of input channels
            final_act (nn.Module | None, optional): Final activation to be applied. Defaults to None.
            learned_upscale_layers (int, optional): Number of Pixelshuffle layers to create. Each upscales 2x.
                Defaults to 0.
            channel_list (list[int] | None, optional): List with number of channels for each Conv
                layer to be created. Defaults to None.
            batch_norm (bool, optional): Whether to apply batch norm. Defaults to True.
            dropout (float, optional): Dropout value to apply. Defaults to 0.

        """
        super().__init__()
        self.learned_upscale_layers = learned_upscale_layers
        self.final_act = final_act if final_act else nn.Identity()
        if isinstance(final_act, str):
            module_name, class_name = final_act.rsplit(".", 1)
            target_class = getattr(importlib.import_module(module_name), class_name)
            self.final_act = target_class()
        pre_layers = []
        if learned_upscale_layers != 0:
            learned_upscale = nn.Sequential(
                *[PixelShuffleUpscale(in_channels) for _ in range(self.learned_upscale_layers)]
            )
            pre_layers.append(learned_upscale)

        if channel_list is None:
            pre_head = nn.Identity()
        else:

            def block(in_channels, out_channels):
                return nn.Sequential(
                    nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True),
                )

            channel_list = [in_channels, *channel_list]
            pre_head = nn.Sequential(
                *[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
            )
            in_channels = channel_list[-1]
            pre_layers.append(pre_head)
        dropout = nn.Dropout2d(dropout)
        final_layer = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1)
        self.head = nn.Sequential(*[*pre_layers, dropout, final_layer])

    def forward(self, x):
        output = self.head(x)
        return self.final_act(output)
__init__(in_channels, final_act=None, learned_upscale_layers=0, channel_list=None, batch_norm=True, dropout=0)

Constructor

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
final_act Module | None

Final activation to be applied. Defaults to None.

None
learned_upscale_layers int

Number of Pixelshuffle layers to create. Each upscales 2x. Defaults to 0.

0
channel_list list[int] | None

List with number of channels for each Conv layer to be created. Defaults to None.

None
batch_norm bool

Whether to apply batch norm. Defaults to True.

True
dropout float

Dropout value to apply. Defaults to 0.

0
Source code in terratorch/models/heads/regression_head.py
def __init__(
    self,
    in_channels: int,
    final_act: nn.Module | str | None = None,
    learned_upscale_layers: int = 0,
    channel_list: list[int] | None = None,
    batch_norm: bool = True,
    dropout: float = 0,
) -> None:
    """Constructor

    Args:
        in_channels (int): Number of input channels
        final_act (nn.Module | None, optional): Final activation to be applied. Defaults to None.
        learned_upscale_layers (int, optional): Number of Pixelshuffle layers to create. Each upscales 2x.
            Defaults to 0.
        channel_list (list[int] | None, optional): List with number of channels for each Conv
            layer to be created. Defaults to None.
        batch_norm (bool, optional): Whether to apply batch norm. Defaults to True.
        dropout (float, optional): Dropout value to apply. Defaults to 0.

    """
    super().__init__()
    self.learned_upscale_layers = learned_upscale_layers
    self.final_act = final_act if final_act else nn.Identity()
    if isinstance(final_act, str):
        module_name, class_name = final_act.rsplit(".", 1)
        target_class = getattr(importlib.import_module(module_name), class_name)
        self.final_act = target_class()
    pre_layers = []
    if learned_upscale_layers != 0:
        learned_upscale = nn.Sequential(
            *[PixelShuffleUpscale(in_channels) for _ in range(self.learned_upscale_layers)]
        )
        pre_layers.append(learned_upscale)

    if channel_list is None:
        pre_head = nn.Identity()
    else:

        def block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )

        channel_list = [in_channels, *channel_list]
        pre_head = nn.Sequential(
            *[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
        )
        in_channels = channel_list[-1]
        pre_layers.append(pre_head)
    dropout = nn.Dropout2d(dropout)
    final_layer = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1)
    self.head = nn.Sequential(*[*pre_layers, dropout, final_layer])

terratorch.models.heads.segmentation_head

SegmentationHead

Bases: Module

Segmentation head

Source code in terratorch/models/heads/segmentation_head.py
class SegmentationHead(nn.Module):
    """Segmentation head"""

    def __init__(
        self, in_channels: int, num_classes: int, channel_list: list[int] | None = None, dropout: float = 0
    ) -> None:
        """Constructor

        Args:
            in_channels (int): Number of input channels
            num_classes (int): Number of output classes
            channel_list (list[int] | None, optional):  List with number of channels for each Conv
                layer to be created. Defaults to None.
            dropout (float, optional): Dropout value to apply. Defaults to 0.
        """
        super().__init__()
        self.num_classes = num_classes
        if channel_list is None:
            pre_head = nn.Identity()
        else:

            def block(in_channels, out_channels):
                return nn.Sequential(
                    nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1), nn.ReLU()
                )

            channel_list = [in_channels, *channel_list]
            pre_head = nn.Sequential(
                *[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
            )
            in_channels = channel_list[-1]
        dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
        self.head = nn.Sequential(
            pre_head,
            dropout,
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=num_classes,
                kernel_size=1,
            ),
        )

    def forward(self, x):
        return self.head(x)
__init__(in_channels, num_classes, channel_list=None, dropout=0)

Constructor

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
num_classes int

Number of output classes

required
channel_list list[int] | None

List with number of channels for each Conv layer to be created. Defaults to None.

None
dropout float

Dropout value to apply. Defaults to 0.

0
Source code in terratorch/models/heads/segmentation_head.py
def __init__(
    self, in_channels: int, num_classes: int, channel_list: list[int] | None = None, dropout: float = 0
) -> None:
    """Constructor

    Args:
        in_channels (int): Number of input channels
        num_classes (int): Number of output classes
        channel_list (list[int] | None, optional):  List with number of channels for each Conv
            layer to be created. Defaults to None.
        dropout (float, optional): Dropout value to apply. Defaults to 0.
    """
    super().__init__()
    self.num_classes = num_classes
    if channel_list is None:
        pre_head = nn.Identity()
    else:

        def block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1), nn.ReLU()
            )

        channel_list = [in_channels, *channel_list]
        pre_head = nn.Sequential(
            *[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
        )
        in_channels = channel_list[-1]
    dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
    self.head = nn.Sequential(
        pre_head,
        dropout,
        nn.Conv2d(
            in_channels=in_channels,
            out_channels=num_classes,
            kernel_size=1,
        ),
    )

terratorch.models.heads.classification_head

ClassificationHead

Bases: Module

Classification head

Source code in terratorch/models/heads/classification_head.py
class ClassificationHead(nn.Module):
    """Classification head"""

    # how to allow cls token?
    def __init__(
        self,
        in_dim: int,
        num_classes: int,
        dim_list: list[int] | None = None,
        dropout: float = 0,
        linear_after_pool: bool = False,
    ) -> None:
        """Constructor

        Args:
            in_dim (int): Input dimensionality
            num_classes (int): Number of output classes
            dim_list (list[int] | None, optional):  List with number of dimensions for each Linear
                layer to be created. Defaults to None.
            dropout (float, optional): Dropout value to apply. Defaults to 0.
            linear_after_pool (bool, optional): Apply pooling first, then apply the linear layer. Defaults to False
        """
        super().__init__()
        self.num_classes = num_classes
        self.linear_after_pool = linear_after_pool
        if dim_list is None:
            pre_head = nn.Identity()
        else:

            def block(in_dim, out_dim):
                return nn.Sequential(nn.Linear(in_features=in_dim, out_features=out_dim), nn.ReLU())

            dim_list = [in_dim, *dim_list]
            pre_head = nn.Sequential(*[block(dim_list[i], dim_list[i + 1]) for i in range(len(dim_list) - 1)])
            in_dim = dim_list[-1]
        dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
        self.head = nn.Sequential(
            pre_head,
            dropout,
            nn.Linear(in_features=in_dim, out_features=num_classes),
        )

    def forward(self, x: Tensor):
        x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)

        if self.linear_after_pool:
            x = x.mean(axis=1)
            out = self.head(x)
        else:
            x = self.head(x)
            out = x.mean(axis=1)
        return out
__init__(in_dim, num_classes, dim_list=None, dropout=0, linear_after_pool=False)

Constructor

Parameters:

Name Type Description Default
in_dim int

Input dimensionality

required
num_classes int

Number of output classes

required
dim_list list[int] | None

List with number of dimensions for each Linear layer to be created. Defaults to None.

None
dropout float

Dropout value to apply. Defaults to 0.

0
linear_after_pool bool

Apply pooling first, then apply the linear layer. Defaults to False

False
Source code in terratorch/models/heads/classification_head.py
def __init__(
    self,
    in_dim: int,
    num_classes: int,
    dim_list: list[int] | None = None,
    dropout: float = 0,
    linear_after_pool: bool = False,
) -> None:
    """Constructor

    Args:
        in_dim (int): Input dimensionality
        num_classes (int): Number of output classes
        dim_list (list[int] | None, optional):  List with number of dimensions for each Linear
            layer to be created. Defaults to None.
        dropout (float, optional): Dropout value to apply. Defaults to 0.
        linear_after_pool (bool, optional): Apply pooling first, then apply the linear layer. Defaults to False
    """
    super().__init__()
    self.num_classes = num_classes
    self.linear_after_pool = linear_after_pool
    if dim_list is None:
        pre_head = nn.Identity()
    else:

        def block(in_dim, out_dim):
            return nn.Sequential(nn.Linear(in_features=in_dim, out_features=out_dim), nn.ReLU())

        dim_list = [in_dim, *dim_list]
        pre_head = nn.Sequential(*[block(dim_list[i], dim_list[i + 1]) for i in range(len(dim_list) - 1)])
        in_dim = dim_list[-1]
    dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
    self.head = nn.Sequential(
        pre_head,
        dropout,
        nn.Linear(in_features=in_dim, out_features=num_classes),
    )

Auxiliary Heads

terratorch.models.model.AuxiliaryHead dataclass

Class containing all information to create auxiliary heads.

Parameters:

Name Type Description Default
name str

Name of the head. Should match the name given to the auxiliary loss.

required
decoder str

Name of the decoder class to be used.

required
decoder_args dict | None

parameters to be passed to the decoder constructor. Parameters for the decoder should be prefixed with decoder_. Parameters for the head should be prefixed with head_.

required
Source code in terratorch/models/model.py
@dataclass
class AuxiliaryHead:
    """Class containing all information to create auxiliary heads.

    Args:
        name (str): Name of the head. Should match the name given to the auxiliary loss.
        decoder (str): Name of the decoder class to be used.
        decoder_args (dict | None): parameters to be passed to the decoder constructor.
            Parameters for the decoder should be prefixed with `decoder_`.
            Parameters for the head should be prefixed with `head_`.
    """

    name: str
    decoder: str
    decoder_args: dict | None

Model Output

terratorch.models.model.ModelOutput dataclass

Source code in terratorch/models/model.py
@dataclass
class ModelOutput:
    output: Tensor
    auxiliary_heads: dict[str, Tensor] = None

Model Factory

terratorch.models.PrithviModelFactory

Bases: ModelFactory

Source code in terratorch/models/prithvi_model_factory.py
@register_factory
class PrithviModelFactory(ModelFactory):
    def build_model(
        self,
        task: str,
        backbone: str | nn.Module,
        decoder: str | nn.Module,
        bands: list[HLSBands | int],
        in_channels: int
        | None = None,  # this should be removed, can be derived from bands. But it is a breaking change
        num_classes: int | None = None,
        pretrained: bool = True,  # noqa: FBT001, FBT002
        num_frames: int = 1,
        prepare_features_for_image_model: Callable | None = None,
        aux_decoders: list[AuxiliaryHead] | None = None,
        rescale: bool = True,  # noqa: FBT002, FBT001
        **kwargs,
    ) -> Model:
        """Model factory for prithvi models.

        Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
        `backbone_`, `decoder_` and `head_` respectively.

        Args:
            task (str): Task to be performed. Currently supports "segmentation" and "regression".
            backbone (str, nn.Module): Backbone to be used. If string, should be able to be parsed
                by the specified factory. Defaults to "prithvi_100".
            decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
                    If a string, it will be created from a class exposed in decoder.__init__.py with the same name.
                    If an nn.Module, we expect it to expose a property `decoder.output_embed_dim`.
                    Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".
            in_channels (int, optional): Number of input channels. Defaults to 3.
            bands (list[terratorch.datasets.HLSBands], optional): Bands the model will be trained on.
                    Should be a list of terratorch.datasets.HLSBands.
                    Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE].
            num_classes (int, optional): Number of classes. None for regression tasks.
            pretrained (Union[bool, Path], optional): Whether to load pretrained weights for the backbone, if available.
                Defaults to True.
            num_frames (int, optional): Number of timesteps for the model to handle. Defaults to 1.
            prepare_features_for_image_model (Callable | None): Function to be called on encoder features
                before passing them to the decoder. Defaults to None, which applies the identity function.
            aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model.
                These decoders take the input from the encoder as well.
            rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
                is different from the ground truth. Only applicable to pixel wise models
                (e.g. segmentation, pixel wise regression). Defaults to True.


        Returns:
            nn.Module: Full model with encoder, decoder and head.
        """
        bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands]
        if in_channels is None:
            in_channels = len(bands)
        # TODO: support auxiliary heads
        if not isinstance(backbone, nn.Module):
            if not backbone.startswith("prithvi_"):
                msg = "This class only handles models for `prithvi` encoders"
                raise NotImplementedError(msg)

            task = task.lower()
            if task not in SUPPORTED_TASKS:
                msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
                raise NotImplementedError(msg)

            backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_")
            # These params are used in case we need a SMP decoder
            # but should not be used for timm encoder
            output_stride = backbone_kwargs.pop("output_stride", None)
            out_channels = backbone_kwargs.pop("out_channels", None)

            backbone: nn.Module = timm.create_model(
                backbone,
                pretrained=pretrained,
                in_chans=in_channels,  # this can be removed, can be derived from bands. But is a breaking change.
                num_frames=num_frames,
                bands=bands,
                features_only=True,
                **backbone_kwargs,
            )

        decoder_kwargs, kwargs = _extract_prefix_keys(kwargs, "decoder_")
        # TODO: remove this
        if decoder.startswith("smp_"):
            decoder: nn.Module = _get_smp_decoder(
                decoder,
                backbone_kwargs,
                decoder_kwargs,
                out_channels,
                in_channels,
                num_classes,
                output_stride,
            )
        else:
            # allow decoder to be a module passed directly
            decoder_cls = _get_decoder(decoder)
            decoder: nn.Module = decoder_cls(backbone.feature_info.channels(), **decoder_kwargs)
            # decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs)

        head_kwargs, kwargs = _extract_prefix_keys(kwargs, "head_")
        if num_classes:
            head_kwargs["num_classes"] = num_classes
        if aux_decoders is None:
            return _build_appropriate_model(
                task, backbone, decoder, head_kwargs, prepare_features_for_image_model, rescale=rescale
            )

        to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
        for aux_decoder in aux_decoders:
            args = aux_decoder.decoder_args if aux_decoder.decoder_args else {}
            aux_decoder_cls: nn.Module = _get_decoder(aux_decoder.decoder)
            aux_decoder_kwargs, kwargs = _extract_prefix_keys(args, "decoder_")
            aux_decoder_instance = aux_decoder_cls(backbone.feature_info.channels(), **aux_decoder_kwargs)

            aux_head_kwargs, kwargs = _extract_prefix_keys(args, "head_")
            if num_classes:
                aux_head_kwargs["num_classes"] = num_classes
            to_be_aux_decoders.append(
                AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
            )

        return _build_appropriate_model(
            task,
            backbone,
            decoder,
            head_kwargs,
            prepare_features_for_image_model,
            rescale=rescale,
            auxiliary_heads=to_be_aux_decoders,
        )

build_model(task, backbone, decoder, bands, in_channels=None, num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, **kwargs)

Model factory for prithvi models.

Further arguments to be passed to the backbone, decoder or head. They should be prefixed with backbone_, decoder_ and head_ respectively.

Parameters:

Name Type Description Default
task str

Task to be performed. Currently supports "segmentation" and "regression".

required
backbone (str, Module)

Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100".

required
decoder Union[str, Module]

Decoder to be used for the segmentation model. If a string, it will be created from a class exposed in decoder.init.py with the same name. If an nn.Module, we expect it to expose a property decoder.output_embed_dim. Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".

required
in_channels int

Number of input channels. Defaults to 3.

None
bands list[HLSBands]

Bands the model will be trained on. Should be a list of terratorch.datasets.HLSBands. Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE].

required
num_classes int

Number of classes. None for regression tasks.

None
pretrained Union[bool, Path]

Whether to load pretrained weights for the backbone, if available. Defaults to True.

True
num_frames int

Number of timesteps for the model to handle. Defaults to 1.

1
prepare_features_for_image_model Callable | None

Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function.

None
aux_decoders list[AuxiliaryHead] | None

List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well.

None
rescale bool

Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

True

Returns:

Type Description
Model

nn.Module: Full model with encoder, decoder and head.

Source code in terratorch/models/prithvi_model_factory.py
def build_model(
    self,
    task: str,
    backbone: str | nn.Module,
    decoder: str | nn.Module,
    bands: list[HLSBands | int],
    in_channels: int
    | None = None,  # this should be removed, can be derived from bands. But it is a breaking change
    num_classes: int | None = None,
    pretrained: bool = True,  # noqa: FBT001, FBT002
    num_frames: int = 1,
    prepare_features_for_image_model: Callable | None = None,
    aux_decoders: list[AuxiliaryHead] | None = None,
    rescale: bool = True,  # noqa: FBT002, FBT001
    **kwargs,
) -> Model:
    """Model factory for prithvi models.

    Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
    `backbone_`, `decoder_` and `head_` respectively.

    Args:
        task (str): Task to be performed. Currently supports "segmentation" and "regression".
        backbone (str, nn.Module): Backbone to be used. If string, should be able to be parsed
            by the specified factory. Defaults to "prithvi_100".
        decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
                If a string, it will be created from a class exposed in decoder.__init__.py with the same name.
                If an nn.Module, we expect it to expose a property `decoder.output_embed_dim`.
                Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".
        in_channels (int, optional): Number of input channels. Defaults to 3.
        bands (list[terratorch.datasets.HLSBands], optional): Bands the model will be trained on.
                Should be a list of terratorch.datasets.HLSBands.
                Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE].
        num_classes (int, optional): Number of classes. None for regression tasks.
        pretrained (Union[bool, Path], optional): Whether to load pretrained weights for the backbone, if available.
            Defaults to True.
        num_frames (int, optional): Number of timesteps for the model to handle. Defaults to 1.
        prepare_features_for_image_model (Callable | None): Function to be called on encoder features
            before passing them to the decoder. Defaults to None, which applies the identity function.
        aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model.
            These decoders take the input from the encoder as well.
        rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
            is different from the ground truth. Only applicable to pixel wise models
            (e.g. segmentation, pixel wise regression). Defaults to True.


    Returns:
        nn.Module: Full model with encoder, decoder and head.
    """
    bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands]
    if in_channels is None:
        in_channels = len(bands)
    # TODO: support auxiliary heads
    if not isinstance(backbone, nn.Module):
        if not backbone.startswith("prithvi_"):
            msg = "This class only handles models for `prithvi` encoders"
            raise NotImplementedError(msg)

        task = task.lower()
        if task not in SUPPORTED_TASKS:
            msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
            raise NotImplementedError(msg)

        backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_")
        # These params are used in case we need a SMP decoder
        # but should not be used for timm encoder
        output_stride = backbone_kwargs.pop("output_stride", None)
        out_channels = backbone_kwargs.pop("out_channels", None)

        backbone: nn.Module = timm.create_model(
            backbone,
            pretrained=pretrained,
            in_chans=in_channels,  # this can be removed, can be derived from bands. But is a breaking change.
            num_frames=num_frames,
            bands=bands,
            features_only=True,
            **backbone_kwargs,
        )

    decoder_kwargs, kwargs = _extract_prefix_keys(kwargs, "decoder_")
    # TODO: remove this
    if decoder.startswith("smp_"):
        decoder: nn.Module = _get_smp_decoder(
            decoder,
            backbone_kwargs,
            decoder_kwargs,
            out_channels,
            in_channels,
            num_classes,
            output_stride,
        )
    else:
        # allow decoder to be a module passed directly
        decoder_cls = _get_decoder(decoder)
        decoder: nn.Module = decoder_cls(backbone.feature_info.channels(), **decoder_kwargs)
        # decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs)

    head_kwargs, kwargs = _extract_prefix_keys(kwargs, "head_")
    if num_classes:
        head_kwargs["num_classes"] = num_classes
    if aux_decoders is None:
        return _build_appropriate_model(
            task, backbone, decoder, head_kwargs, prepare_features_for_image_model, rescale=rescale
        )

    to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
    for aux_decoder in aux_decoders:
        args = aux_decoder.decoder_args if aux_decoder.decoder_args else {}
        aux_decoder_cls: nn.Module = _get_decoder(aux_decoder.decoder)
        aux_decoder_kwargs, kwargs = _extract_prefix_keys(args, "decoder_")
        aux_decoder_instance = aux_decoder_cls(backbone.feature_info.channels(), **aux_decoder_kwargs)

        aux_head_kwargs, kwargs = _extract_prefix_keys(args, "head_")
        if num_classes:
            aux_head_kwargs["num_classes"] = num_classes
        to_be_aux_decoders.append(
            AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
        )

    return _build_appropriate_model(
        task,
        backbone,
        decoder,
        head_kwargs,
        prepare_features_for_image_model,
        rescale=rescale,
        auxiliary_heads=to_be_aux_decoders,
    )

terratorch.models.SMPModelFactory

Bases: ModelFactory

Source code in terratorch/models/smp_model_factory.py
@register_factory
class SMPModelFactory(ModelFactory):
    def build_model(
        self,
        task: str,
        backbone: str,
        model: str,
        bands: list[HLSBands | int],
        in_channels: int | None = None,
        num_classes: int = 1,
        pretrained: str | bool | None = True,  # noqa: FBT002
        prepare_features_for_image_model: Callable | None = None,
        regression_relu: bool = False,  # noqa: FBT001, FBT002
        **kwargs,
    ) -> Model:
        """
        Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.

        This factory handles the instantiation of segmentation and regression models using specified
        encoders and decoders from the SMP library, along with custom modifications and extensions such
        as auxiliary decoders or modified encoders.

        Attributes:
            task (str): Specifies the task for which the model is being built. Supported tasks are
                        "segmentation".
            backbone (str): Specifies the backbone model to be used.
            decoder (str): Specifies the decoder to be used for constructing the
                        segmentation model.
            bands (list[terratorch.datasets.HLSBands | int]): A list specifying the bands that the model
                        will operate on. These are expected to be from terratorch.datasets.HLSBands.
            in_channels (int, optional): Specifies the number of input channels. Defaults to None.
            num_classes (int, optional): The number of output classes for the model.
            pretrained (bool | Path, optional): Indicates whether to load pretrained weights for the
                        backbone. Can also specify a path to weights. Defaults to True.
            num_frames (int, optional): Specifies the number of timesteps the model should handle. Useful
                        for temporal models.
            regression_relu (bool): Whether to apply ReLU activation in the case of regression tasks.
            **kwargs: Additional arguments that might be passed to further customize the backbone, decoder,
                        or any auxiliary heads. These should be prefixed appropriately

        Raises:
            ValueError: If the specified decoder is not supported by SMP.
            Exception: If the specified task is not "segmentation"

        Returns:
            nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified
                    parameters and tasks.
        """
        if task != "segmentation":
            msg = f"SMP models can only perform segmentation, but got task {task}"
            raise Exception(msg)

        bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands]
        if in_channels is None:
            in_channels = len(bands)

        # Gets decoder module.
        model_module = getattr(smp, model, None)
        if model_module is None:
            msg = f"Decoder {model} is not supported in SMP."
            raise ValueError(msg)

        backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_")  # Encoder params should be prefixed backbone_
        smp_kwargs = _extract_prefix_keys(backbone_kwargs, "smp_")  # Smp model params should be prefixed smp_
        aux_params = _extract_prefix_keys(backbone_kwargs, "aux_")  # Auxiliary head params should be prefixed aux_
        aux_params = None if aux_params == {} else aux_params

        if isinstance(pretrained, bool):
            if pretrained:
                pretrained = "imagenet"
            else:
                pretrained = None

        # If encoder not currently supported by SMP (custom encoder).
        if backbone not in smp_encoders:
            # These params must be included in the config file with appropriate prefix.
            required_params = {
                "encoder_depth": smp_kwargs,
                "out_channels": backbone_kwargs,
                "output_stride": backbone_kwargs,
            }

            for param, config_dict in required_params.items():
                if param not in config_dict:
                    msg = f"Config must include the '{param}' parameter"
                    raise ValueError(msg)

            # Using new encoder.
            backbone_class = make_smp_encoder(backbone)
            backbone_kwargs["prepare_features_for_image_model"] = prepare_features_for_image_model
            # Registering custom encoder into SMP.
            register_custom_encoder(backbone_class, backbone_kwargs, pretrained)

            model_args = {
                "encoder_name": "SMPEncoderWrapperWithPFFIM",
                "encoder_weights": pretrained,
                "in_channels": in_channels,
                "classes": num_classes,
                **smp_kwargs,
            }
        # Using SMP encoder.
        else:
            model_args = {
                "encoder_name": backbone,
                "encoder_weights": pretrained,
                "in_channels": in_channels,
                "classes": num_classes,
                **smp_kwargs,
            }

        model = model_module(**model_args, aux_params=aux_params)

        return SMPModelWrapper(
            model, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression"
        )

build_model(task, backbone, model, bands, in_channels=None, num_classes=1, pretrained=True, prepare_features_for_image_model=None, regression_relu=False, **kwargs)

Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.

This factory handles the instantiation of segmentation and regression models using specified encoders and decoders from the SMP library, along with custom modifications and extensions such as auxiliary decoders or modified encoders.

Attributes:

Name Type Description
task str

Specifies the task for which the model is being built. Supported tasks are "segmentation".

backbone str

Specifies the backbone model to be used.

decoder str

Specifies the decoder to be used for constructing the segmentation model.

bands list[HLSBands | int]

A list specifying the bands that the model will operate on. These are expected to be from terratorch.datasets.HLSBands.

in_channels int

Specifies the number of input channels. Defaults to None.

num_classes int

The number of output classes for the model.

pretrained bool | Path

Indicates whether to load pretrained weights for the backbone. Can also specify a path to weights. Defaults to True.

num_frames int

Specifies the number of timesteps the model should handle. Useful for temporal models.

regression_relu bool

Whether to apply ReLU activation in the case of regression tasks.

**kwargs bool

Additional arguments that might be passed to further customize the backbone, decoder, or any auxiliary heads. These should be prefixed appropriately

Raises:

Type Description
ValueError

If the specified decoder is not supported by SMP.

Exception

If the specified task is not "segmentation"

Returns:

Type Description
Model

nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified parameters and tasks.

Source code in terratorch/models/smp_model_factory.py
def build_model(
    self,
    task: str,
    backbone: str,
    model: str,
    bands: list[HLSBands | int],
    in_channels: int | None = None,
    num_classes: int = 1,
    pretrained: str | bool | None = True,  # noqa: FBT002
    prepare_features_for_image_model: Callable | None = None,
    regression_relu: bool = False,  # noqa: FBT001, FBT002
    **kwargs,
) -> Model:
    """
    Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.

    This factory handles the instantiation of segmentation and regression models using specified
    encoders and decoders from the SMP library, along with custom modifications and extensions such
    as auxiliary decoders or modified encoders.

    Attributes:
        task (str): Specifies the task for which the model is being built. Supported tasks are
                    "segmentation".
        backbone (str): Specifies the backbone model to be used.
        decoder (str): Specifies the decoder to be used for constructing the
                    segmentation model.
        bands (list[terratorch.datasets.HLSBands | int]): A list specifying the bands that the model
                    will operate on. These are expected to be from terratorch.datasets.HLSBands.
        in_channels (int, optional): Specifies the number of input channels. Defaults to None.
        num_classes (int, optional): The number of output classes for the model.
        pretrained (bool | Path, optional): Indicates whether to load pretrained weights for the
                    backbone. Can also specify a path to weights. Defaults to True.
        num_frames (int, optional): Specifies the number of timesteps the model should handle. Useful
                    for temporal models.
        regression_relu (bool): Whether to apply ReLU activation in the case of regression tasks.
        **kwargs: Additional arguments that might be passed to further customize the backbone, decoder,
                    or any auxiliary heads. These should be prefixed appropriately

    Raises:
        ValueError: If the specified decoder is not supported by SMP.
        Exception: If the specified task is not "segmentation"

    Returns:
        nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified
                parameters and tasks.
    """
    if task != "segmentation":
        msg = f"SMP models can only perform segmentation, but got task {task}"
        raise Exception(msg)

    bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands]
    if in_channels is None:
        in_channels = len(bands)

    # Gets decoder module.
    model_module = getattr(smp, model, None)
    if model_module is None:
        msg = f"Decoder {model} is not supported in SMP."
        raise ValueError(msg)

    backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_")  # Encoder params should be prefixed backbone_
    smp_kwargs = _extract_prefix_keys(backbone_kwargs, "smp_")  # Smp model params should be prefixed smp_
    aux_params = _extract_prefix_keys(backbone_kwargs, "aux_")  # Auxiliary head params should be prefixed aux_
    aux_params = None if aux_params == {} else aux_params

    if isinstance(pretrained, bool):
        if pretrained:
            pretrained = "imagenet"
        else:
            pretrained = None

    # If encoder not currently supported by SMP (custom encoder).
    if backbone not in smp_encoders:
        # These params must be included in the config file with appropriate prefix.
        required_params = {
            "encoder_depth": smp_kwargs,
            "out_channels": backbone_kwargs,
            "output_stride": backbone_kwargs,
        }

        for param, config_dict in required_params.items():
            if param not in config_dict:
                msg = f"Config must include the '{param}' parameter"
                raise ValueError(msg)

        # Using new encoder.
        backbone_class = make_smp_encoder(backbone)
        backbone_kwargs["prepare_features_for_image_model"] = prepare_features_for_image_model
        # Registering custom encoder into SMP.
        register_custom_encoder(backbone_class, backbone_kwargs, pretrained)

        model_args = {
            "encoder_name": "SMPEncoderWrapperWithPFFIM",
            "encoder_weights": pretrained,
            "in_channels": in_channels,
            "classes": num_classes,
            **smp_kwargs,
        }
    # Using SMP encoder.
    else:
        model_args = {
            "encoder_name": backbone,
            "encoder_weights": pretrained,
            "in_channels": in_channels,
            "classes": num_classes,
            **smp_kwargs,
        }

    model = model_module(**model_args, aux_params=aux_params)

    return SMPModelWrapper(
        model, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression"
    )

Adding new model types

Adding new model types is as simple as creating a new factory that produces models. See for instance the example below for a potential SMPModelFactory

from terratorch.models.model import register_factory

@register_factory
class SMPModelFactory(ModelFactory):
    def build_model(
        self,
        task: str,
        backbone: str | nn.Module,
        decoder: str | nn.Module,
        in_channels: int,
        **kwargs,
    ) -> Model:

        model = smp.Unet(encoder_name="resnet34", encoder_weights=None, in_channels=in_channels, classes=1)
        return SMPModelWrapper(model)

@register_factory
class SMPModelWrapper(Model, nn.Module):
    def __init__(self, smp_model) -> None:
        super().__init__()
        self.smp_model = smp_model

    def forward(self, *args, **kwargs):
        return ModelOutput(self.smp_model(*args, **kwargs).squeeze(1))

    def freeze_encoder(self):
        pass

    def freeze_decoder(self):
        pass