Skip to content

Backbones

terratorch.models.backbones.swin_encoder_decoder #

Swin transformer implementation. Mix of MMSegmentation implementation and timm implementation.

We use this implementation instead of the original implementation or timm's. This is because it offers a few advantages, namely being able to handle a dynamic input size through padding.

Please note the original timm implementation can still be used as a backbone via timm.create_model("swin_..."). You can see the available models with timm.list_models("swin*").

AdaptivePadding #

Bases: Module

Applies padding to input (if needed) so that input can get fully covered by filter you specified. It support two modes "same" and "corner". The "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around input. The "corner" mode would pad zero to bottom right.

Parameters:

Name Type Description Default
kernel_size int | tuple

Size of the kernel:

1
stride int | tuple

Stride of the filter. Default: 1:

1
dilation int | tuple

Spacing between kernel elements. Default: 1.

1
padding str

Support "same" and "corner", "corner" mode would pad zero to bottom right, and "same" mode would pad zero around input. Default: "corner".

'corner'

Example:

    >>> import torch
    >>> kernel_size = 16
    >>> stride = 16
    >>> dilation = 1
    >>> input = torch.rand(1, 1, 15, 17)
    >>> adap_pad = AdaptivePadding(
    >>>     kernel_size=kernel_size,
    >>>     stride=stride,
    >>>     dilation=dilation,
    >>>     padding="corner")
    >>> out = adap_pad(input)
    >>> assert (out.shape[2], out.shape[3]) == (16, 32)
    >>> input = torch.rand(1, 1, 16, 17)
    >>> out = adap_pad(input)
    >>> assert (out.shape[2], out.shape[3]) == (16, 32)

FFN #

Bases: Module

Implements feed-forward networks (FFNs) with identity connection.

Parameters:

Name Type Description Default
embed_dims int

The feature dimension. Same as MultiheadAttention. Defaults: 256.

256
feedforward_channels int

The hidden dimension of FFNs. Defaults: 1024.

1024
num_fcs int

The number of fully-connected layers in FFNs. Default: 2.

2
act_cfg dict

The activation config for FFNs. Default: dict(type='ReLU')

required
ffn_drop float

Probability of an element to be zeroed in FFN. Default 0.0.

0.0
add_identity bool

Whether to add the identity connection. Default: True.

True
dropout_layer obj

ConfigDict): The dropout_layer used when adding the shortcut.

None
init_cfg obj

mmcv.ConfigDict): The Config for initialization. Default: None.

required

forward(x, identity=None) #

Forward function for FFN.

The function would add x to the output tensor if residue is None.

MMSegSwinTransformer #

Bases: Module

__init__(pretrain_img_size=224, in_chans=3, embed_dim=96, patch_size=4, window_size=7, mlp_ratio=4, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), strides=(4, 2, 2, 2), num_classes=1000, global_pool='avg', out_indices=(0, 1, 2, 3), qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm, with_cp=False, frozen_stages=-1) #

MMSeg Swin Transformer backbone.

This backbone is the implementation of Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>_. Inspiration from https://github.com/microsoft/Swin-Transformer.

Parameters:

Name Type Description Default
pretrain_img_size int | tuple[int]

The size of input image when pretrain. Defaults: 224.

224
in_chans int

The num of input channels. Defaults: 3.

3
embed_dim int

The feature dimension. Default: 96.

96
patch_size int | tuple[int]

Patch size. Default: 4.

4
window_size int

Window size. Default: 7.

7
mlp_ratio int | float

Ratio of mlp hidden dim to embedding dim. Default: 4.

4
depths tuple[int]

Depths of each Swin Transformer stage. Default: (2, 2, 6, 2).

(2, 2, 6, 2)
num_heads tuple[int]

Parallel attention heads of each Swin Transformer stage. Default: (3, 6, 12, 24).

(3, 6, 12, 24)
strides tuple[int]

The patch merging or patch embedding stride of each Swin Transformer stage. (In swin, we set kernel size equal to stride.) Default: (4, 2, 2, 2).

(4, 2, 2, 2)
out_indices tuple[int]

Output from which stages. Default: (0, 1, 2, 3).

(0, 1, 2, 3)
qkv_bias bool

If True, add a learnable bias to query, key, value. Default: True

True
qk_scale float | None

Override default qk scale of head_dim ** -0.5 if set. Default: None.

None
patch_norm bool

If add a norm layer for patch embed and patch merging. Default: True.

required
drop_rate float

Dropout rate. Defaults: 0.

0.0
attn_drop_rate float

Attention dropout rate. Default: 0.

0.0
drop_path_rate float

Stochastic depth rate. Defaults: 0.1.

0.1
act_layer dict

activation layer. Default: nn.GELU.

GELU
norm_layer dict

normalization layer at output of backone. Defaults: nn.LayerNorm.

LayerNorm
with_cp bool

Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False.

False
frozen_stages int

Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters.

-1

train(mode=True) #

Convert the model into training mode while keep layers freezed.

PatchEmbed #

Bases: Module

Image to Patch Embedding.

We use a conv layer to implement PatchEmbed.

Parameters:

Name Type Description Default
in_chans int

The num of input channels. Default: 3

3
embed_dim int

The dimensions of embedding. Default: 768

768
kernel_size int

The kernel_size of embedding conv. Default: 16.

16
stride int

The slide stride of embedding conv. Default: None (Would be set as kernel_size).

None
padding int | tuple | string

The padding length of embedding conv. When it is a string, it means the mode of adaptive padding, support "same" and "corner" now. Default: "corner".

'corner'
padding_mode string

The padding mode to use. Default "constant".

'constant'
dilation int

The dilation rate of embedding conv. Default: 1.

1
bias bool

Bias of embed conv. Default: True.

True
norm_cfg dict

Config dict for normalization layer. Default: None.

required
input_size int | tuple | None

The size of input, which will be used to calculate the out size. Only work when dynamic_size is False. Default: None.

None

forward(x) #

Parameters:

Name Type Description Default
x Tensor

Has shape (B, C, H, W). In most case, C is 3.

required

Returns:

Name Type Description
tuple

Contains merged results and its spatial shape.

  • x (Tensor): Has shape (B, out_h * out_w, embed_dim)
  • out_size (tuple[int]): Spatial shape of x, arrange as (out_h, out_w).

PatchMerging #

Bases: Module

Merge patch feature map.

This layer groups feature map by kernel_size, and applies norm and linear layers to the grouped feature map. Our implementation uses nn.Unfold to merge patch, which is about 25% faster than original implementation. Instead, we need to modify pretrained models for compatibility.

Parameters:

Name Type Description Default
in_chans int

The num of input channels.

required
out_channels int

The num of output channels.

required
kernel_size int | tuple

the kernel size in the unfold layer. Defaults to 2.

2
stride int | tuple

the stride of the sliding blocks in the unfold layer. Default: None. (Would be set as kernel_size)

None
padding int | tuple | string

The padding length of embedding conv. When it is a string, it means the mode of adaptive padding, support "same" and "corner" now. Default: "corner".

'corner'
dilation int | tuple

dilation parameter in the unfold layer. Default: 1.

1
bias bool

Whether to add bias in linear layer or not. Defaults: False.

False
norm_cfg dict

Config dict for normalization layer. Default: dict(type='LN').

required

forward(x, input_size) #

Parameters:

Name Type Description Default
x Tensor

Has shape (B, H*W, C_in).

required
input_size tuple[int]

The spatial shape of x, arrange as (H, W). Default: None.

required

Returns:

Name Type Description
tuple

Contains merged results and its spatial shape.

  • x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
  • out_size (tuple[int]): Spatial shape of x, arrange as (Merged_H, Merged_W).

ShiftWindowMSA #

Bases: Module

Shifted Window Multihead Self-Attention Module.

Parameters:

Name Type Description Default
embed_dim int

Number of input channels.

required
num_heads int

Number of attention heads.

required
window_size int

The height and width of the window.

required
shift_size int

The shift step of each window towards right-bottom. If zero, act as regular window-msa. Defaults to 0.

0
qkv_bias bool

If True, add a learnable bias to q, k, v. Default: True

True
qk_scale float | None

Override default qk scale of head_dim ** -0.5 if set. Defaults: None.

None
attn_drop_rate float

Dropout ratio of attention weight. Defaults: 0.

0
proj_drop_rate float

Dropout ratio of output. Defaults: 0.

0
drop_path_rate float

Dropout ratio of layer used before output. Defaults: 0.

0

window_partition(x) #

Parameters:

Name Type Description Default
x

(B, H, W, C)

required

Returns:

Name Type Description
tuple

(num_windows*B, window_size, window_size, C)

window_reverse(windows, H, W) #

Parameters:

Name Type Description Default
windows

(num_windows*B, window_size, window_size, C)

required
H int

Height of image

required
W int

Width of image

required

Returns:

Name Type Description
tuple

(B, H, W, C)

SwinBlock #

Bases: Module

Parameters:

Name Type Description Default
embed_dim int

The feature dimension.

required
num_heads int

Parallel attention heads.

required
feedforward_channels int

The hidden dimension for Mlps.

required
window_size int

The local window scale. Default: 7.

7
shift bool

whether to shift window or not. Default False.

False
qkv_bias bool

enable bias for qkv if True. Default: True.

True
qk_scale float | None

Override default qk scale of head_dim ** -0.5 if set. Default: None.

None
drop_rate float

Dropout rate. Default: 0.

0.0
attn_drop_rate float

Attention dropout rate. Default: 0.

0.0
drop_path_rate float

Stochastic depth rate. Default: 0.

0.0
act_cfg dict

The config dict of activation function. Default: dict(type='GELU').

required
norm_cfg dict

The config dict of normalization. Default: dict(type='LN').

required
with_cp bool

Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False.

False

SwinBlockSequence #

Bases: Module

Implements one stage in Swin Transformer.

Parameters:

Name Type Description Default
embed_dim int

The feature dimension.

required
num_heads int

Parallel attention heads.

required
feedforward_channels int

The hidden dimension for Mlps.

required
depth int

The number of blocks in this stage.

required
window_size int

The local window scale. Default: 7.

7
qkv_bias bool

enable bias for qkv if True. Default: True.

True
qk_scale float | None

Override default qk scale of head_dim ** -0.5 if set. Default: None.

None
drop_rate float

Dropout rate. Default: 0.

0.0
attn_drop_rate float

Attention dropout rate. Default: 0.

0.0
drop_path_rate float | list[float]

Stochastic depth rate. Default: 0.

0.0
downsample BaseModule | None

The downsample operation module. Default: None.

None
act_cfg dict

The config dict of activation function. Default: dict(type='GELU').

required
norm_cfg dict

The config dict of normalization. Default: dict(type='LN').

required
with_cp bool

Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False.

False

WindowMSA #

Bases: Module

Window based multi-head self-attention (W-MSA) module with relative position bias.

Parameters:

Name Type Description Default
embed_dim int

Number of input channels.

required
num_heads int

Number of attention heads.

required
window_size tuple[int]

The height and width of the window.

required
qkv_bias bool

If True, add a learnable bias to q, k, v. Default: True.

True
qk_scale float | None

Override default qk scale of head_dim ** -0.5 if set. Default: None.

None
attn_drop_rate float

Dropout ratio of attention weight. Default: 0.0

0.0
proj_drop_rate float

Dropout ratio of output. Default: 0.

0.0

forward(x, mask=None) #

Parameters:

Name Type Description Default
x tensor

input features with shape of (num_windows*B, N, C)

required
mask (tensor | None, Optional)

mask with shape of (num_windows, WhWw, WhWw), value should be between (-inf, 0].

None

terratorch.models.backbones.prithvi_mae #

LocationEncoder #

Bases: Module

forward(location_coords) #

location_coords: lat and lon info with shape (B, 2).

MAEDecoder #

Bases: Module

Transformer Decoder used in the Prithvi MAE

PatchEmbed #

Bases: Module

3D version of timm.models.vision_transformer.PatchEmbed

PrithviMAE #

Bases: Module

Prithvi Masked Autoencoder

forward_loss(pixel_values, pred, mask) #

Parameters:

Name Type Description Default
pixel_values `torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`

Pixel values.

required
mask `torch.FloatTensor` of shape `(batch_size, sequence_length)`

Tensor indicating which patches are masked (1) and which are not (0).

required

Returns:

Type Description

torch.FloatTensor: Pixel reconstruction loss.

patchify(pixel_values) #

Parameters:

Name Type Description Default
pixel_values torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`

Pixel values.

required

Returns:

Type Description

torch.FloatTensor of shape (batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels): Patchified pixel values.

unpatchify(patchified_pixel_values, image_size=None) #

Parameters:

Name Type Description Default
image_size `tuple[int, int]`, *optional*

Original image size.

None

Returns:

Type Description

torch.FloatTensor of shape (batch_size, num_channels, height, width): Pixel values.

PrithviViT #

Bases: Module

Prithvi ViT Encoder

random_masking(sequence, mask_ratio, noise=None) #

Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise.

Parameters:

Name Type Description Default
mask_ratio float

mask ratio to use.

required

TemporalEncoder #

Bases: Module

forward(temporal_coords, tokens_per_frame=None) #

Parameters:

Name Type Description Default
temporal_coords Tensor

year and day-of-year info with shape (B, T, 2).

required
tokens_per_frame int | None

number of tokens for each frame in the sample. If provided, embeddings will be repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).

None

get_1d_sincos_pos_embed_from_grid(embed_dim, pos) #

embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)

get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False) #

Create 3D sin/cos positional embeddings.

Parameters:

Name Type Description Default
embed_dim int

Embedding dimension.

required
grid_size tuple[int, int, int] | list[int]

The grid depth, height and width.

required
add_cls_token bool, *optional*, defaults to False

Whether or not to add a classification (CLS) token.

False

Returns:

Type Description

(torch.FloatTensor of shape (grid_size[0]grid_size[1]grid_size[2], embed_dim) or

(1 + grid_size[0] * grid_size[1] * grid_size[2], embed_dim)

the position embeddings (with or without cls token)

terratorch.models.backbones.unet #

UNet #

Bases: Module

UNet backbone.

This backbone is the implementation of U-Net: Convolutional Networks for Biomedical Image Segmentation <https://arxiv.org/abs/1505.04597>_.

Parameters:

Name Type Description Default
in_channels int

Number of input image channels. Default" 3.

3
out_channels int

Number of base channels of each stage. The output channels of the first stage. Default: 64.

64
num_stages int

Number of stages in encoder, normally 5. Default: 5.

5
strides Sequence[int 1 | 2]

Strides of each stage in encoder. len(strides) is equal to num_stages. Normally the stride of the first stage in encoder is 1. If strides[i]=2, it uses stride convolution to downsample in the correspondence encoder stage. Default: (1, 1, 1, 1, 1).

(1, 1, 1, 1, 1)
enc_num_convs Sequence[int]

Number of convolutional layers in the convolution block of the correspondence encoder stage. Default: (2, 2, 2, 2, 2).

(2, 2, 2, 2, 2)
dec_num_convs Sequence[int]

Number of convolutional layers in the convolution block of the correspondence decoder stage. Default: (2, 2, 2, 2).

(2, 2, 2, 2)
downsamples Sequence[int]

Whether use MaxPool to downsample the feature map after the first stage of encoder (stages: [1, num_stages)). If the correspondence encoder stage use stride convolution (strides[i]=2), it will never use MaxPool to downsample, even downsamples[i-1]=True. Default: (True, True, True, True).

(True, True, True, True)
enc_dilations Sequence[int]

Dilation rate of each stage in encoder. Default: (1, 1, 1, 1, 1).

(1, 1, 1, 1, 1)
dec_dilations Sequence[int]

Dilation rate of each stage in decoder. Default: (1, 1, 1, 1).

(1, 1, 1, 1)
with_cp bool

Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False.

False
conv_cfg dict | None

Config dict for convolution layer. Default: None.

None
norm_cfg dict | None

Config dict for normalization layer. Default: dict(type='BN').

dict(type='BN')
act_cfg dict | None

Config dict for activation layer in ConvModule. Default: dict(type='ReLU').

dict(type='ReLU')
upsample_cfg dict

The upsample config of the upsample module in decoder. Default: dict(type='InterpConv').

None
norm_eval bool

Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False.

False
dcn bool

Use deformable convolution in convolutional layer or not. Default: None.

None
plugins dict

plugins for convolutional layers. Default: None.

None
pretrained str

model pretrained path. Default: None

None
init_cfg dict or list[dict]

Initialization config dict. Default: None

None
Notice

The input image size should be divisible by the whole downsample rate of the encoder. More detail of the whole downsample rate can be found in UNet._check_input_divisible.

train(mode=True) #

Convert the model into training mode while keep normalization layer freezed.


Last update: March 23, 2025