Skip to content

Backbones#

Built-in Backbones#

terratorch.models.backbones.terramind.model.terramind_vit.TerraMindViT #

Bases: Module

Modified TerraMind model, adapted to behave as a raw data-only ViT.

Parameters:

Name Type Description Default
img_size int

Input image size.

224
modalities (list, dict)

List of modality keys and dicts, or dict with modality keys and values being ints (num_channels of modality) or nn.Module (patch embedding layer).

None
merge_method str

Specify how the output is merged for further processing. One of 'mean', 'max', 'concat', 'dict', or None. 'mean', 'max', and 'concat' are dropping all sequence modality tokens, split all image modality tokens and reduce the by applying the appropriate method. 'dict' splits all tokens into a dictionary {'modality': torch.Tensor}. Defaults to 'mean'.

'mean'
patch_size int

Patch size.

16
in_chans int

Number of input image channels.

3
dim int

Patch embedding dimension.

768
encoder_depth int

Depth of ViT / number of encoder blocks.

12
num_heads int

Number of attention heads in each ViT block.

12
mlp_ratio float

Ratio of mlp hidden dim to embedding dim.

4.0
qkv_bias bool

If True, add a learnable bias to query, key, value.

True
proj_bias bool

If True, adds a bias to the attention out proj layer.

True
mlp_bias bool

If True, adds a learnable bias for the feedforward.

True
drop_path_rate float

Stochastic depth rate.

0.0
drop_rate float

Dropout rate.

0.0
attn_drop_rate float

Attention dropout rate.

0.0
modality_drop_rate float

Drop modality inputs during training.

0.0
act_layer Module

Activation layer.

GELU
norm_layer Module

Normalization layer.

partial(LayerNorm, eps=1e-06)
gated_mlp bool

If True, makes the feedforward gated (e.g., for SwiGLU)

False
qk_norm bool

If True, normalizes the query and keys (as in ViT-22B)

False
encoder_norm bool

If True, adds a norm layer after the last encoder block.

True
tokenizer_dict dict

Dictionary of tokenizers.

None
Source code in terratorch/models/backbones/terramind/model/terramind_vit.py
class TerraMindViT(nn.Module):
    """Modified TerraMind model, adapted to behave as a raw data-only ViT.

    Args:
        img_size (int): Input image size.
        modalities (list, dict, optional): List of modality keys and dicts, or dict with modality keys and values being
            ints (num_channels of modality) or nn.Module (patch embedding layer).
        merge_method (str, optional): Specify how the output is merged for further processing. One of 'mean', 'max',
            'concat', 'dict', or None. 'mean', 'max', and 'concat' are dropping all sequence modality tokens, split all
            image modality tokens and reduce the by applying the appropriate method. 'dict' splits all tokens into a
            dictionary {'modality': torch.Tensor}. Defaults to 'mean'.
        patch_size (int): Patch size.
        in_chans (int): Number of input image channels.
        dim (int): Patch embedding dimension.
        encoder_depth (int): Depth of ViT / number of encoder blocks.
        num_heads (int): Number of attention heads in each ViT block.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool): If True, add a learnable bias to query, key, value.
        proj_bias (bool): If True, adds a bias to the attention out proj layer.
        mlp_bias (bool): If True, adds a learnable bias for the feedforward.
        drop_path_rate (float): Stochastic depth rate.
        drop_rate (float): Dropout rate.
        attn_drop_rate (float): Attention dropout rate.
        modality_drop_rate (float): Drop modality inputs during training.
        act_layer (nn.Module): Activation layer.
        norm_layer (nn.Module): Normalization layer.
        gated_mlp (bool): If True, makes the feedforward gated (e.g., for SwiGLU)
        qk_norm (bool): If True, normalizes the query and keys (as in ViT-22B)
        encoder_norm (bool): If True, adds a norm layer after the last encoder block.
        tokenizer_dict (dict): Dictionary of tokenizers.
    """
    def __init__(
        self,
        img_size: int = 224,
        modalities: list | dict[str, int | nn.Module] | None = None,
        merge_method: str | None = 'mean',
        patch_size: int = 16,
        in_chans: int = 3,
        dim: int = 768,
        encoder_depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        proj_bias: bool = True,
        mlp_bias: bool = True,
        drop_path_rate: float = 0.0,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        modality_drop_rate: float = 0.0,
        act_layer: torch.Tensor = nn.GELU,
        norm_layer: partial | nn.Module = partial(LayerNorm, eps=1e-6),
        gated_mlp: bool = False,  # Make the feedforward gated for e.g. SwiGLU
        qk_norm: bool = False,
        encoder_norm: bool = True,
        tokenizer_dict: dict | None = None,
    ):
        super().__init__()

        if modalities is None or len(modalities) == 0:
            # Init new image modality
            modalities = [{'image': in_chans}]
        elif isinstance(modalities, dict):
            modalities = [modalities]
        elif not isinstance(modalities, list):
            raise ValueError(f'Modalities must be None, a list of modality keys or a dict with ints/embedding layers.')

        # Build embedding layers for all defined modalities
        mod_embeddings, mod_name_mapping = build_modality_embeddings(MODALITY_INFO, modalities, img_size=img_size,
                                                                     dim=dim, patch_size=patch_size)
        self.encoder_embeddings = nn.ModuleDict(mod_embeddings)
        self.mod_name_mapping = mod_name_mapping
        self.modalities = list(mod_name_mapping.keys())  # Further code expects list

        self.img_size = img_size
        self.merge_method = merge_method
        self.image_modalities = [key for key, value in self.encoder_embeddings.items()
             if isinstance(value, ImageEncoderEmbedding) or isinstance(value, ImageTokenEncoderEmbedding)]
        self.modality_drop_rate = modality_drop_rate
        assert 0 <= self.modality_drop_rate <= 1, "modality_drop_rate must be in [0, 1]"
        # New learned parameter for handling missing modalities
        if self.merge_method == 'concat':
            self.missing_mod_token = nn.Parameter(torch.Tensor(dim))

        # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_depth)]

        self.encoder = nn.ModuleList([
            Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias,
                  mlp_bias=mlp_bias, drop_path=dpr[i], drop=drop_rate, attn_drop=attn_drop_rate, act_layer=act_layer,
                  norm_layer=norm_layer, gated_mlp=gated_mlp, qk_norm=qk_norm)
            for i in range(encoder_depth)
        ])

        # Needed for terratorch decoders
        if merge_method == 'concat':
            self.out_channels = [dim * len(self.image_modalities) for i in range(encoder_depth)]
        else:
            self.out_channels = [dim for i in range(encoder_depth)]

        self.encoder_norm = norm_layer(dim) if encoder_norm else nn.Identity()

        if tokenizer_dict is not None:
            self.tokenizer = build_tokenizer(tokenizer_dict=tokenizer_dict,
                                             input_modalities=list(self.encoder_embeddings.keys()))

        # Weight init
        self.init_weights()

    def init_weights(self):
        """Weight initialization following MAE's initialization scheme"""

        for name, m in self.named_modules():
            # Skipping tokenizers to avoid reinitializing them
            if "tokenizer" in name:
                continue
            # Linear
            elif isinstance(m, nn.Linear):
                if 'qkv' in name:
                    # treat the weights of Q, K, V separately
                    val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
                    nn.init.uniform_(m.weight, -val, val)
                elif 'kv' in name:
                    # treat the weights of K, V separately
                    val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1]))
                    nn.init.uniform_(m.weight, -val, val)
                else:
                    nn.init.xavier_uniform_(m.weight)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            # LayerNorm
            elif isinstance(m, nn.LayerNorm) or isinstance(m, LayerNorm):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0)

            # Embedding
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, std=0.02)
            # Conv2d
            elif isinstance(m, nn.Conv2d):
                if '.proj' in name:
                    # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d)
                    w = m.weight.data
                    nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

    @torch.jit.ignore
    def no_weight_decay(self):
        no_wd_set = set()

        for mod, emb_module in self.encoder_embeddings.items():
            if hasattr(emb_module, 'no_weight_decay'):
                to_skip = emb_module.no_weight_decay()
                to_skip = set([f'encoder_embeddings.{mod}.{name}' for name in to_skip])
                no_wd_set = no_wd_set | to_skip

        return no_wd_set

    def forward(self, d: dict[str, torch.Tensor] | torch.Tensor | None = None, **kwargs) -> list[torch.Tensor]:
        """
        Forward pass of the model.

        Args:
            d (dict, torch.Tensor): Dict of inputs or input tensor with shape (B, C, H, W)

            Alternatively, keyword arguments with modality=tensor.

        Returns:
            list[torch.Tensor]: List of transformer layer outputs. Shape (B, L, D).
        """
        # Handle single image modality
        if not isinstance(d, dict):
            # Assuming first modality
            d = {self.modalities[0]: d}
        elif d is None or len(d) == 0:
            d = {}
            if not len(kwargs):
                raise ValueError("No input provided.")

        # Add additional keyword args to input dict
        for key, value in kwargs.items():
            d[key] = value

        # Check for unknown modalities in input
        for mod in list(d.keys()):
            if mod not in self.mod_name_mapping:
                warnings.warn(f"Unknown input modality: {mod}. Ignoring input.")
                del d[mod]
        if len(d) == 0:
            raise ValueError("No valid inputs provided.")

        if self.training and self.modality_drop_rate:
            # Drop random modalities during training
            for key in random.sample(list(d.keys()), k=len(d) - 1):
                if random.random() < self.modality_drop_rate:
                    _ = d.pop(key)

        x = []
        num_tokens = []
        image_mod = []
        for mod, tensor in d.items():
            if self.mod_name_mapping[mod] in self.tokenizer:
                # Tokenize input if required
                device = next(self.parameters()).device
                tensor = self.tokenizer[self.mod_name_mapping[mod]].encode(tensor, device)
                if self.mod_name_mapping[mod] in self.image_modalities:
                    tensor = tensor[-1]
                else:
                    tensor = tensor["tensor"]

            mod_dict = self.encoder_embeddings[self.mod_name_mapping[mod]](tensor)
            # Add embeddings to patchified data
            x.append(mod_dict['x'] + mod_dict['emb'])
            num_tokens.append(mod_dict['x'].shape[-2])
            image_mod.append(self.mod_name_mapping[mod] in self.image_modalities)

        # Concatenate along token dim
        x = torch.cat(x, dim=1)  # Shape: (B, N, D)

        # Forward encoder blocks
        out = []
        for block in self.encoder:
            x = block(x)
            out.append(x.clone())

        out[-1] = self.encoder_norm(x)  # Shape: (B, N, D)

        def _unstack_image_modalities(x):
            x = torch.split(x, num_tokens, dim=1)  # Split tokens by modality
            x = [m for m, keep in zip(x, image_mod) if keep]  # Drop sequence modalities
            x = torch.stack(x, dim=1)  # (B, M, N, D)
            return x

        # Merge tokens from different modalities
        if self.merge_method == 'mean':
            out = [_unstack_image_modalities(x) for x in out]
            out = [x.mean(dim=1) for x in out]

        elif self.merge_method == 'max':
            out = [_unstack_image_modalities(x) for x in out]
            out = [x.max(dim=1)[0] for x in out]

        elif self.merge_method == 'concat':
            out = [_unstack_image_modalities(x) for x in out]
            if len(d) < len(self.image_modalities):
                # Handle missing modalities with missing_mod_token
                num_missing = len(self.image_modalities) - len(d)
                missing_tokens = self.missing_mod_token.repeat(out[-1].shape[0], num_missing, out[-1].shape[2], 1)
                out = [torch.cat([x, missing_tokens], dim=1) for x in out]
            # Concat along embedding dim
            out = [torch.cat(x.unbind(dim=1), dim=-1) for x in out]

        elif self.merge_method == 'dict':
            out = [torch.split(x, num_tokens, dim=1) for x in out]
            out = [{mod: x[i] for i, mod in enumerate(d.keys())} for x in out]

        elif self.merge_method is None:
            pass  # Do nothing
        else:
            raise NotImplementedError(f'Merging method {self.merge_method} is not implemented. '
                                      f'Select one of mean, max, concat, dict, or None.')

        return out

forward(d=None, **kwargs) #

Forward pass of the model.

Parameters:

Name Type Description Default
d (dict, Tensor)

Dict of inputs or input tensor with shape (B, C, H, W)

None

Returns:

Type Description
list[Tensor]

list[torch.Tensor]: List of transformer layer outputs. Shape (B, L, D).

Source code in terratorch/models/backbones/terramind/model/terramind_vit.py
def forward(self, d: dict[str, torch.Tensor] | torch.Tensor | None = None, **kwargs) -> list[torch.Tensor]:
    """
    Forward pass of the model.

    Args:
        d (dict, torch.Tensor): Dict of inputs or input tensor with shape (B, C, H, W)

        Alternatively, keyword arguments with modality=tensor.

    Returns:
        list[torch.Tensor]: List of transformer layer outputs. Shape (B, L, D).
    """
    # Handle single image modality
    if not isinstance(d, dict):
        # Assuming first modality
        d = {self.modalities[0]: d}
    elif d is None or len(d) == 0:
        d = {}
        if not len(kwargs):
            raise ValueError("No input provided.")

    # Add additional keyword args to input dict
    for key, value in kwargs.items():
        d[key] = value

    # Check for unknown modalities in input
    for mod in list(d.keys()):
        if mod not in self.mod_name_mapping:
            warnings.warn(f"Unknown input modality: {mod}. Ignoring input.")
            del d[mod]
    if len(d) == 0:
        raise ValueError("No valid inputs provided.")

    if self.training and self.modality_drop_rate:
        # Drop random modalities during training
        for key in random.sample(list(d.keys()), k=len(d) - 1):
            if random.random() < self.modality_drop_rate:
                _ = d.pop(key)

    x = []
    num_tokens = []
    image_mod = []
    for mod, tensor in d.items():
        if self.mod_name_mapping[mod] in self.tokenizer:
            # Tokenize input if required
            device = next(self.parameters()).device
            tensor = self.tokenizer[self.mod_name_mapping[mod]].encode(tensor, device)
            if self.mod_name_mapping[mod] in self.image_modalities:
                tensor = tensor[-1]
            else:
                tensor = tensor["tensor"]

        mod_dict = self.encoder_embeddings[self.mod_name_mapping[mod]](tensor)
        # Add embeddings to patchified data
        x.append(mod_dict['x'] + mod_dict['emb'])
        num_tokens.append(mod_dict['x'].shape[-2])
        image_mod.append(self.mod_name_mapping[mod] in self.image_modalities)

    # Concatenate along token dim
    x = torch.cat(x, dim=1)  # Shape: (B, N, D)

    # Forward encoder blocks
    out = []
    for block in self.encoder:
        x = block(x)
        out.append(x.clone())

    out[-1] = self.encoder_norm(x)  # Shape: (B, N, D)

    def _unstack_image_modalities(x):
        x = torch.split(x, num_tokens, dim=1)  # Split tokens by modality
        x = [m for m, keep in zip(x, image_mod) if keep]  # Drop sequence modalities
        x = torch.stack(x, dim=1)  # (B, M, N, D)
        return x

    # Merge tokens from different modalities
    if self.merge_method == 'mean':
        out = [_unstack_image_modalities(x) for x in out]
        out = [x.mean(dim=1) for x in out]

    elif self.merge_method == 'max':
        out = [_unstack_image_modalities(x) for x in out]
        out = [x.max(dim=1)[0] for x in out]

    elif self.merge_method == 'concat':
        out = [_unstack_image_modalities(x) for x in out]
        if len(d) < len(self.image_modalities):
            # Handle missing modalities with missing_mod_token
            num_missing = len(self.image_modalities) - len(d)
            missing_tokens = self.missing_mod_token.repeat(out[-1].shape[0], num_missing, out[-1].shape[2], 1)
            out = [torch.cat([x, missing_tokens], dim=1) for x in out]
        # Concat along embedding dim
        out = [torch.cat(x.unbind(dim=1), dim=-1) for x in out]

    elif self.merge_method == 'dict':
        out = [torch.split(x, num_tokens, dim=1) for x in out]
        out = [{mod: x[i] for i, mod in enumerate(d.keys())} for x in out]

    elif self.merge_method is None:
        pass  # Do nothing
    else:
        raise NotImplementedError(f'Merging method {self.merge_method} is not implemented. '
                                  f'Select one of mean, max, concat, dict, or None.')

    return out

init_weights() #

Weight initialization following MAE's initialization scheme

Source code in terratorch/models/backbones/terramind/model/terramind_vit.py
def init_weights(self):
    """Weight initialization following MAE's initialization scheme"""

    for name, m in self.named_modules():
        # Skipping tokenizers to avoid reinitializing them
        if "tokenizer" in name:
            continue
        # Linear
        elif isinstance(m, nn.Linear):
            if 'qkv' in name:
                # treat the weights of Q, K, V separately
                val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
                nn.init.uniform_(m.weight, -val, val)
            elif 'kv' in name:
                # treat the weights of K, V separately
                val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1]))
                nn.init.uniform_(m.weight, -val, val)
            else:
                nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        # LayerNorm
        elif isinstance(m, nn.LayerNorm) or isinstance(m, LayerNorm):
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0)

        # Embedding
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, std=0.02)
        # Conv2d
        elif isinstance(m, nn.Conv2d):
            if '.proj' in name:
                # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d)
                w = m.weight.data
                nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

terratorch.models.backbones.prithvi_mae.PrithviViT #

Bases: Module

Prithvi ViT Encoder

Source code in terratorch/models/backbones/prithvi_mae.py
class PrithviViT(nn.Module):
    """Prithvi ViT Encoder"""

    def __init__(
        self,
        img_size: int | tuple[int, int] = 224,
        patch_size: int | tuple[int, int, int] = (1, 16, 16),
        num_frames: int = 1,
        in_chans: int = 3,
        embed_dim: int = 1024,
        depth: int = 24,
        num_heads: int = 16,
        mlp_ratio: float = 4.0,
        norm_layer: type[nn.Module] = nn.LayerNorm,
        coords_encoding: list[str] | None = None,
        coords_scale_learn: bool = False,
        drop_path: float = 0.0,
        vpt: bool = False,
        vpt_n_tokens: int | None = None,
        vpt_dropout: float = 0,
        **kwargs,
    ):
        super().__init__()

        self.in_chans = in_chans
        self.num_frames = num_frames
        self.embed_dim = embed_dim
        self.img_size = to_2tuple(img_size)
        if isinstance(patch_size, int):
            patch_size = (1, patch_size, patch_size)

        # 3D patch embedding
        self.patch_embed = PatchEmbed(
            input_size=(num_frames,) + self.img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )
        self.out_channels = [embed_dim * self.patch_embed.grid_size[0]] * depth

        # Optional temporal and location embedding
        coords_encoding = coords_encoding or []
        self.temporal_encoding = 'time' in coords_encoding
        self.location_encoding = 'location' in coords_encoding
        if self.temporal_encoding:
            assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
            self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
        if self.location_encoding:
            self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))

        # Transformer layers
        self.blocks = []
        for i in range(depth):
            self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
                                     drop_path=drop_path,))
        self.blocks = nn.ModuleList(self.blocks)

        self.norm = norm_layer(embed_dim)

        self.vpt = vpt
        self.vpt_n_tokens = vpt_n_tokens
        self.vpt_dropout = vpt_dropout
        if self.vpt:
            if self.vpt_n_tokens is None:
                msg = "vpt_n_tokens must be provided when using VPT"
                raise ValueError(msg)
            self.vpt_prompt_embeddings = nn.ParameterList(
                [nn.Parameter(torch.zeros(1, self.vpt_n_tokens, embed_dim)) for _ in range(depth)]
            )
            self.vpt_dropout_layers = nn.ModuleList([nn.Dropout(vpt_dropout) for _ in range(depth)])

        self.initialize_weights()

    def initialize_weights(self):
        # initialize (and freeze) position embeddings by sin-cos embedding
        pos_embed = get_3d_sincos_pos_embed(
            self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
        )
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=0.02)
        self.apply(_init_weights)

        # initialize VPT prompt embeddings
        if self.vpt:
            # extracted from https://github.com/KMnP/vpt/blob/4410440ec1b489f24f66b9fad3d9b10ff3443567/src/models/vit_prompt/vit.py#L57
            val = np.sqrt(6.0 / float(3 * reduce(mul, self.patch_embed.patch_size[1:], 1) + self.embed_dim))
            for emb in self.vpt_prompt_embeddings:
                nn.init.uniform_(emb, -val, val)

    def random_masking(self, sequence, mask_ratio, noise=None):
        """
        Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
        noise.

        Args:
            sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
            mask_ratio (float): mask ratio to use.
            noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
                mainly used for testing purposes to control randomness and maintain the reproducibility
        """
        batch_size, seq_length, dim = sequence.shape
        len_keep = int(seq_length * (1 - mask_ratio))

        if noise is None:
            noise = torch.rand(batch_size, seq_length, device=sequence.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([batch_size, seq_length], device=sequence.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return sequence_unmasked, mask, ids_restore

    def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]):

        pos_embed = _interpolate_pos_encoding(
            pos_embed=self.pos_embed,
            grid_size=self.patch_embed.grid_size,
            patch_size=self.patch_embed.patch_size,
            shape=sample_shape,
            embed_dim=self.embed_dim,
        )
        return pos_embed

    def forward(
        self, x: torch.Tensor,
        temporal_coords: None | torch.Tensor = None,
        location_coords: None | torch.Tensor = None,
        mask_ratio=0.75
    ):
        if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
            # add time dim
            x = x.unsqueeze(2)
        sample_shape = x.shape[-3:]

        # embed patches
        x = self.patch_embed(x)

        pos_embed = self.interpolate_pos_encoding(sample_shape)
        # add pos embed w/o cls token
        x = x + pos_embed[:, 1:, :]

        if self.temporal_encoding and temporal_coords is not None:
            num_tokens_per_frame = x.shape[1] // self.num_frames
            temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
            x = x + temporal_encoding
        if self.location_encoding and location_coords is not None:
            location_encoding = self.location_embed_enc(location_coords)
            x = x + location_encoding

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # append cls token
        cls_token = self.cls_token + pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        bs = x.shape[0]
        for idx, block in enumerate(self.blocks):
            if self.vpt:
                x = torch.cat(
                    (
                        x[:, :1, :],
                        self.vpt_dropout_layers[idx](self.vpt_prompt_embeddings[idx].expand(bs, -1, -1)),
                        x[:, 1:, :],
                    ),
                    dim=1,
                )  # (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
            x = block(x)
            if self.vpt:
                x = torch.cat(
                    (x[:, :1, :], x[:, (1 + self.vpt_n_tokens) :, :]),
                    dim=1,
                )
        x = self.norm(x)

        return x, mask, ids_restore

    def forward_features(
        self,
        x: torch.Tensor,
        temporal_coords: None | torch.Tensor = None,
        location_coords: None | torch.Tensor = None,
    ) -> list[torch.Tensor]:
        if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
            # add time dim
            x = x.unsqueeze(2)
        sample_shape = x.shape[-3:]

        # embed patches
        x = self.patch_embed(x)

        pos_embed = self.interpolate_pos_encoding(sample_shape)
        # add pos embed w/o cls token
        x = x + pos_embed[:, 1:, :]

        if self.temporal_encoding and temporal_coords is not None:
            num_tokens_per_frame = x.shape[1] // self.num_frames
            temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
            x = x + temporal_encoding
        if self.location_encoding and location_coords is not None:
            location_encoding = self.location_embed_enc(location_coords)
            x = x + location_encoding

        # append cls token
        cls_token = self.cls_token + pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        bs = x.shape[0]
        out = []
        for idx, block in enumerate(self.blocks):
            if self.vpt:
                x = torch.cat(
                    (
                        x[:, :1, :],
                        self.vpt_dropout_layers[idx](self.vpt_prompt_embeddings[idx].expand(bs, -1, -1)),
                        x[:, 1:, :],
                    ),
                    dim=1,
                )  # (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
            x = block(x)
            if self.vpt:
                x = torch.cat(
                    (x[:, :1, :], x[:, (1 + self.vpt_n_tokens) :, :]),
                    dim=1,
                )
            out.append(x.clone())

        x = self.norm(x)
        out[-1] = x
        return out

    def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        out = []
        effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0]
        for x in features:
            x_no_token = x[:, 1:, :]
            number_of_tokens = x_no_token.shape[1]
            tokens_per_timestep = number_of_tokens // effective_time_dim
            h = int(np.sqrt(tokens_per_timestep))
            encoded = rearrange(
                x_no_token,
                "batch (t h w) e -> batch (t e) h w",
                e=self.embed_dim,
                t=effective_time_dim,
                h=h,
            )
            out.append(encoded)
        return out

    def freeze(self):
        for n, param in self.named_parameters():
            if "vpt_prompt_embeddings" not in n:
                param.requires_grad_(False)

random_masking(sequence, mask_ratio, noise=None) #

Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise.

Parameters:

Name Type Description Default
mask_ratio float

mask ratio to use.

required
Source code in terratorch/models/backbones/prithvi_mae.py
def random_masking(self, sequence, mask_ratio, noise=None):
    """
    Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
    noise.

    Args:
        sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
        mask_ratio (float): mask ratio to use.
        noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
            mainly used for testing purposes to control randomness and maintain the reproducibility
    """
    batch_size, seq_length, dim = sequence.shape
    len_keep = int(seq_length * (1 - mask_ratio))

    if noise is None:
        noise = torch.rand(batch_size, seq_length, device=sequence.device)  # noise in [0, 1]

    # sort noise for each sample
    ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([batch_size, seq_length], device=sequence.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return sequence_unmasked, mask, ids_restore

terratorch.models.backbones.swin_encoder_decoder.MMSegSwinTransformer #

Bases: Module

Source code in terratorch/models/backbones/swin_encoder_decoder.py
class MMSegSwinTransformer(nn.Module):

    def __init__(
        self,
        pretrain_img_size=224,
        in_chans=3,
        embed_dim=96,
        patch_size=4,
        window_size=7,
        mlp_ratio=4,
        depths=(2, 2, 6, 2),
        num_heads=(3, 6, 12, 24),
        strides=(4, 2, 2, 2),
        num_classes: int = 1000,
        global_pool: str = "avg",
        out_indices=(0, 1, 2, 3),
        qkv_bias=True,  # noqa: FBT002
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.1,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        with_cp=False,  # noqa: FBT002
        frozen_stages=-1,
    ):
        """MMSeg Swin Transformer backbone.

        This backbone is the implementation of `Swin Transformer:
        Hierarchical Vision Transformer using Shifted
        Windows <https://arxiv.org/abs/2103.14030>`_.
        Inspiration from https://github.com/microsoft/Swin-Transformer.

        Args:
            pretrain_img_size (int | tuple[int]): The size of input image when
                pretrain. Defaults: 224.
            in_chans (int): The num of input channels.
                Defaults: 3.
            embed_dim (int): The feature dimension. Default: 96.
            patch_size (int | tuple[int]): Patch size. Default: 4.
            window_size (int): Window size. Default: 7.
            mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim.
                Default: 4.
            depths (tuple[int]): Depths of each Swin Transformer stage.
                Default: (2, 2, 6, 2).
            num_heads (tuple[int]): Parallel attention heads of each Swin
                Transformer stage. Default: (3, 6, 12, 24).
            strides (tuple[int]): The patch merging or patch embedding stride of
                each Swin Transformer stage. (In swin, we set kernel size equal to
                stride.) Default: (4, 2, 2, 2).
            out_indices (tuple[int]): Output from which stages.
                Default: (0, 1, 2, 3).
            qkv_bias (bool, optional): If True, add a learnable bias to query, key,
                value. Default: True
            qk_scale (float | None, optional): Override default qk scale of
                head_dim ** -0.5 if set. Default: None.
            patch_norm (bool): If add a norm layer for patch embed and patch
                merging. Default: True.
            drop_rate (float): Dropout rate. Defaults: 0.
            attn_drop_rate (float): Attention dropout rate. Default: 0.
            drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
            act_layer (dict): activation layer.
                Default: nn.GELU.
            norm_layer (dict): normalization layer at
                output of backone. Defaults: nn.LayerNorm.
            with_cp (bool, optional): Use checkpoint or not. Using checkpoint
                will save some memory while slowing down the training speed.
                Default: False.
            frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
                -1 means not freezing any parameters.
        """

        self.frozen_stages = frozen_stages
        self.output_fmt = "NHWC"
        if isinstance(pretrain_img_size, int):
            pretrain_img_size = to_2tuple(pretrain_img_size)
        elif isinstance(pretrain_img_size, tuple):
            if len(pretrain_img_size) == 1:
                pretrain_img_size = to_2tuple(pretrain_img_size[0])
            if not len(pretrain_img_size) == 2:  # noqa: PLR2004
                msg = f"The size of image should have length 1 or 2, but got {len(pretrain_img_size)}"
                raise Exception(msg)

        super().__init__()

        self.num_layers = len(depths)
        self.out_indices = out_indices
        self.feature_info = []

        if not strides[0] == patch_size:
            msg = "Use non-overlapping patch embed."
            raise Exception(msg)

        self.patch_embed = PatchEmbed(
            in_chans=in_chans,
            embed_dim=embed_dim,
            kernel_size=patch_size,
            stride=strides[0],
            padding="corner",
            norm_layer=norm_layer,
            padding_mode="replicate",
            drop_rate=drop_rate,
        )

        # self.drop_after_pos = nn.Dropout(p=drop_rate)

        # set stochastic depth decay rule
        total_depth = sum(depths)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)]

        stages = []
        in_chans = embed_dim
        scale = 1
        for i in range(self.num_layers):
            if i < self.num_layers - 1:
                downsample = PatchMerging(
                    in_chans=in_chans,
                    out_channels=2 * in_chans,
                    stride=strides[i + 1],
                    norm_layer=norm_layer,
                )
            else:
                downsample = None

            stage = SwinBlockSequence(
                embed_dim=in_chans,
                num_heads=num_heads[i],
                feedforward_channels=int(mlp_ratio * in_chans),
                depth=depths[i],
                window_size=window_size,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop_rate=drop_rate,
                attn_drop_rate=attn_drop_rate,
                drop_path_rate=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
                downsample=downsample,
                act_layer=act_layer,
                norm_layer=norm_layer,
                with_cp=with_cp,
            )
            stages.append(stage)
            if i > 0:
                scale *= 2
            self.feature_info += [{"num_chs": in_chans, "reduction": 4 * scale, "module": f"stages.{i}"}]
            if downsample:
                in_chans = downsample.out_channels
        self.stages = nn.Sequential(*stages)
        self.num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
        # Add a norm layer for each output

        self.head = ClassifierHead(
            self.num_features[-1],
            num_classes,
            pool_type=global_pool,
            drop_rate=drop_rate,
            input_fmt=self.output_fmt,
        )

    def train(self, mode=True):  # noqa: FBT002
        """Convert the model into training mode while keep layers freezed."""
        super().train(mode)
        self._freeze_stages()

    def _freeze_stages(self):
        if self.frozen_stages >= 0:
            self.patch_embed.eval()
            for param in self.patch_embed.parameters():
                param.requires_grad = False
            self.drop_after_pos.eval()

        for i in range(1, self.frozen_stages + 1):
            if (i - 1) in self.out_indices:
                norm_layer = getattr(self, f"norm{i-1}")
                norm_layer.eval()
                for param in norm_layer.parameters():
                    param.requires_grad = False

            m = self.stages[i - 1]
            m.eval()
            for param in m.parameters():
                param.requires_grad = False

    @torch.jit.ignore
    def init_weights(self, mode=""):
        modes = ("jax", "jax_nlhb", "moco", "")
        if mode not in modes:
            msg = f"mode must be one of {modes}"
            raise Exception(msg)
        head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
        named_apply(get_init_weights_vit(mode, head_bias=head_bias), self)

    @torch.jit.ignore
    def no_weight_decay(self):
        nwd = set()
        for n, _ in self.named_parameters():
            if "relative_position_bias_table" in n:
                nwd.add(n)
        return nwd

    @torch.jit.ignore
    def group_matcher(self, coarse=False):  # noqa: FBT002
        return {
            "stem": r"^patch_embed",  # stem and embed
            "blocks": r"^layers\.(\d+)"
            if coarse
            else [
                (r"^layers\.(\d+).downsample", (0,)),
                (r"^layers\.(\d+)\.\w+\.(\d+)", None),
                (r"^norm", (99999,)),
            ],
        }

    @torch.jit.ignore
    def get_classifier(self):
        return self.head.fc

    def reset_classifier(self, num_classes, global_pool=None):
        self.num_classes = num_classes
        self.head.reset(num_classes, pool_type=global_pool)

    def forward_features(self, x):
        x = self.patch_embed(x)
        x = self.stages(x)
        return x

    def forward_head(self, x, pre_logits: bool = False):  # noqa: FBT002, FBT001
        return self.head(x, pre_logits=True) if pre_logits else self.head(x)

    def forward(self, x):
        features = self.forward_features(x)
        x = self.forward_head(features[0])
        return x

__init__(pretrain_img_size=224, in_chans=3, embed_dim=96, patch_size=4, window_size=7, mlp_ratio=4, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), strides=(4, 2, 2, 2), num_classes=1000, global_pool='avg', out_indices=(0, 1, 2, 3), qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm, with_cp=False, frozen_stages=-1) #

MMSeg Swin Transformer backbone.

This backbone is the implementation of Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>_. Inspiration from https://github.com/microsoft/Swin-Transformer.

Parameters:

Name Type Description Default
pretrain_img_size int | tuple[int]

The size of input image when pretrain. Defaults: 224.

224
in_chans int

The num of input channels. Defaults: 3.

3
embed_dim int

The feature dimension. Default: 96.

96
patch_size int | tuple[int]

Patch size. Default: 4.

4
window_size int

Window size. Default: 7.

7
mlp_ratio int | float

Ratio of mlp hidden dim to embedding dim. Default: 4.

4
depths tuple[int]

Depths of each Swin Transformer stage. Default: (2, 2, 6, 2).

(2, 2, 6, 2)
num_heads tuple[int]

Parallel attention heads of each Swin Transformer stage. Default: (3, 6, 12, 24).

(3, 6, 12, 24)
strides tuple[int]

The patch merging or patch embedding stride of each Swin Transformer stage. (In swin, we set kernel size equal to stride.) Default: (4, 2, 2, 2).

(4, 2, 2, 2)
out_indices tuple[int]

Output from which stages. Default: (0, 1, 2, 3).

(0, 1, 2, 3)
qkv_bias bool

If True, add a learnable bias to query, key, value. Default: True

True
qk_scale float | None

Override default qk scale of head_dim ** -0.5 if set. Default: None.

None
patch_norm bool

If add a norm layer for patch embed and patch merging. Default: True.

required
drop_rate float

Dropout rate. Defaults: 0.

0.0
attn_drop_rate float

Attention dropout rate. Default: 0.

0.0
drop_path_rate float

Stochastic depth rate. Defaults: 0.1.

0.1
act_layer dict

activation layer. Default: nn.GELU.

GELU
norm_layer dict

normalization layer at output of backone. Defaults: nn.LayerNorm.

LayerNorm
with_cp bool

Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False.

False
frozen_stages int

Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters.

-1
Source code in terratorch/models/backbones/swin_encoder_decoder.py
def __init__(
    self,
    pretrain_img_size=224,
    in_chans=3,
    embed_dim=96,
    patch_size=4,
    window_size=7,
    mlp_ratio=4,
    depths=(2, 2, 6, 2),
    num_heads=(3, 6, 12, 24),
    strides=(4, 2, 2, 2),
    num_classes: int = 1000,
    global_pool: str = "avg",
    out_indices=(0, 1, 2, 3),
    qkv_bias=True,  # noqa: FBT002
    qk_scale=None,
    drop_rate=0.0,
    attn_drop_rate=0.0,
    drop_path_rate=0.1,
    act_layer=nn.GELU,
    norm_layer=nn.LayerNorm,
    with_cp=False,  # noqa: FBT002
    frozen_stages=-1,
):
    """MMSeg Swin Transformer backbone.

    This backbone is the implementation of `Swin Transformer:
    Hierarchical Vision Transformer using Shifted
    Windows <https://arxiv.org/abs/2103.14030>`_.
    Inspiration from https://github.com/microsoft/Swin-Transformer.

    Args:
        pretrain_img_size (int | tuple[int]): The size of input image when
            pretrain. Defaults: 224.
        in_chans (int): The num of input channels.
            Defaults: 3.
        embed_dim (int): The feature dimension. Default: 96.
        patch_size (int | tuple[int]): Patch size. Default: 4.
        window_size (int): Window size. Default: 7.
        mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim.
            Default: 4.
        depths (tuple[int]): Depths of each Swin Transformer stage.
            Default: (2, 2, 6, 2).
        num_heads (tuple[int]): Parallel attention heads of each Swin
            Transformer stage. Default: (3, 6, 12, 24).
        strides (tuple[int]): The patch merging or patch embedding stride of
            each Swin Transformer stage. (In swin, we set kernel size equal to
            stride.) Default: (4, 2, 2, 2).
        out_indices (tuple[int]): Output from which stages.
            Default: (0, 1, 2, 3).
        qkv_bias (bool, optional): If True, add a learnable bias to query, key,
            value. Default: True
        qk_scale (float | None, optional): Override default qk scale of
            head_dim ** -0.5 if set. Default: None.
        patch_norm (bool): If add a norm layer for patch embed and patch
            merging. Default: True.
        drop_rate (float): Dropout rate. Defaults: 0.
        attn_drop_rate (float): Attention dropout rate. Default: 0.
        drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
        act_layer (dict): activation layer.
            Default: nn.GELU.
        norm_layer (dict): normalization layer at
            output of backone. Defaults: nn.LayerNorm.
        with_cp (bool, optional): Use checkpoint or not. Using checkpoint
            will save some memory while slowing down the training speed.
            Default: False.
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters.
    """

    self.frozen_stages = frozen_stages
    self.output_fmt = "NHWC"
    if isinstance(pretrain_img_size, int):
        pretrain_img_size = to_2tuple(pretrain_img_size)
    elif isinstance(pretrain_img_size, tuple):
        if len(pretrain_img_size) == 1:
            pretrain_img_size = to_2tuple(pretrain_img_size[0])
        if not len(pretrain_img_size) == 2:  # noqa: PLR2004
            msg = f"The size of image should have length 1 or 2, but got {len(pretrain_img_size)}"
            raise Exception(msg)

    super().__init__()

    self.num_layers = len(depths)
    self.out_indices = out_indices
    self.feature_info = []

    if not strides[0] == patch_size:
        msg = "Use non-overlapping patch embed."
        raise Exception(msg)

    self.patch_embed = PatchEmbed(
        in_chans=in_chans,
        embed_dim=embed_dim,
        kernel_size=patch_size,
        stride=strides[0],
        padding="corner",
        norm_layer=norm_layer,
        padding_mode="replicate",
        drop_rate=drop_rate,
    )

    # self.drop_after_pos = nn.Dropout(p=drop_rate)

    # set stochastic depth decay rule
    total_depth = sum(depths)
    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)]

    stages = []
    in_chans = embed_dim
    scale = 1
    for i in range(self.num_layers):
        if i < self.num_layers - 1:
            downsample = PatchMerging(
                in_chans=in_chans,
                out_channels=2 * in_chans,
                stride=strides[i + 1],
                norm_layer=norm_layer,
            )
        else:
            downsample = None

        stage = SwinBlockSequence(
            embed_dim=in_chans,
            num_heads=num_heads[i],
            feedforward_channels=int(mlp_ratio * in_chans),
            depth=depths[i],
            window_size=window_size,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
            downsample=downsample,
            act_layer=act_layer,
            norm_layer=norm_layer,
            with_cp=with_cp,
        )
        stages.append(stage)
        if i > 0:
            scale *= 2
        self.feature_info += [{"num_chs": in_chans, "reduction": 4 * scale, "module": f"stages.{i}"}]
        if downsample:
            in_chans = downsample.out_channels
    self.stages = nn.Sequential(*stages)
    self.num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
    # Add a norm layer for each output

    self.head = ClassifierHead(
        self.num_features[-1],
        num_classes,
        pool_type=global_pool,
        drop_rate=drop_rate,
        input_fmt=self.output_fmt,
    )

train(mode=True) #

Convert the model into training mode while keep layers freezed.

Source code in terratorch/models/backbones/swin_encoder_decoder.py
def train(self, mode=True):  # noqa: FBT002
    """Convert the model into training mode while keep layers freezed."""
    super().train(mode)
    self._freeze_stages()

terratorch.models.backbones.unet.UNet #

Bases: Module

UNet backbone.

This backbone is the implementation of U-Net: Convolutional Networks for Biomedical Image Segmentation <https://arxiv.org/abs/1505.04597>_.

Parameters:

Name Type Description Default
in_channels int

Number of input image channels. Default" 3.

3
out_channels int

Number of base channels of each stage. The output channels of the first stage. Default: 64.

64
num_stages int

Number of stages in encoder, normally 5. Default: 5.

5
strides Sequence[int 1 | 2]

Strides of each stage in encoder. len(strides) is equal to num_stages. Normally the stride of the first stage in encoder is 1. If strides[i]=2, it uses stride convolution to downsample in the correspondence encoder stage. Default: (1, 1, 1, 1, 1).

(1, 1, 1, 1, 1)
enc_num_convs Sequence[int]

Number of convolutional layers in the convolution block of the correspondence encoder stage. Default: (2, 2, 2, 2, 2).

(2, 2, 2, 2, 2)
dec_num_convs Sequence[int]

Number of convolutional layers in the convolution block of the correspondence decoder stage. Default: (2, 2, 2, 2).

(2, 2, 2, 2)
downsamples Sequence[int]

Whether use MaxPool to downsample the feature map after the first stage of encoder (stages: [1, num_stages)). If the correspondence encoder stage use stride convolution (strides[i]=2), it will never use MaxPool to downsample, even downsamples[i-1]=True. Default: (True, True, True, True).

(True, True, True, True)
enc_dilations Sequence[int]

Dilation rate of each stage in encoder. Default: (1, 1, 1, 1, 1).

(1, 1, 1, 1, 1)
dec_dilations Sequence[int]

Dilation rate of each stage in decoder. Default: (1, 1, 1, 1).

(1, 1, 1, 1)
with_cp bool

Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False.

False
conv_cfg dict | None

Config dict for convolution layer. Default: None.

None
norm_cfg dict | None

Config dict for normalization layer. Default: dict(type='BN').

dict(type='BN')
act_cfg dict | None

Config dict for activation layer in ConvModule. Default: dict(type='ReLU').

dict(type='ReLU')
upsample_cfg dict

The upsample config of the upsample module in decoder. Default: dict(type='InterpConv').

None
norm_eval bool

Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False.

False
dcn bool

Use deformable convolution in convolutional layer or not. Default: None.

None
plugins dict

plugins for convolutional layers. Default: None.

None
pretrained str

model pretrained path. Default: None

None
init_cfg dict or list[dict]

Initialization config dict. Default: None

None
Notice

The input image size should be divisible by the whole downsample rate of the encoder. More detail of the whole downsample rate can be found in UNet._check_input_divisible.

Source code in terratorch/models/backbones/unet.py
@TERRATORCH_BACKBONE_REGISTRY.register
class UNet(nn.Module):
    """UNet backbone.

    This backbone is the implementation of `U-Net: Convolutional Networks
    for Biomedical Image Segmentation <https://arxiv.org/abs/1505.04597>`_.

    Args:
        in_channels (int): Number of input image channels. Default" 3.
        out_channels (int): Number of base channels of each stage.
            The output channels of the first stage. Default: 64.
        num_stages (int): Number of stages in encoder, normally 5. Default: 5.
        strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
            len(strides) is equal to num_stages. Normally the stride of the
            first stage in encoder is 1. If strides[i]=2, it uses stride
            convolution to downsample in the correspondence encoder stage.
            Default: (1, 1, 1, 1, 1).
        enc_num_convs (Sequence[int]): Number of convolutional layers in the
            convolution block of the correspondence encoder stage.
            Default: (2, 2, 2, 2, 2).
        dec_num_convs (Sequence[int]): Number of convolutional layers in the
            convolution block of the correspondence decoder stage.
            Default: (2, 2, 2, 2).
        downsamples (Sequence[int]): Whether use MaxPool to downsample the
            feature map after the first stage of encoder
            (stages: [1, num_stages)). If the correspondence encoder stage use
            stride convolution (strides[i]=2), it will never use MaxPool to
            downsample, even downsamples[i-1]=True.
            Default: (True, True, True, True).
        enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
            Default: (1, 1, 1, 1, 1).
        dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
            Default: (1, 1, 1, 1).
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed. Default: False.
        conv_cfg (dict | None): Config dict for convolution layer.
            Default: None.
        norm_cfg (dict | None): Config dict for normalization layer.
            Default: dict(type='BN').
        act_cfg (dict | None): Config dict for activation layer in ConvModule.
            Default: dict(type='ReLU').
        upsample_cfg (dict): The upsample config of the upsample module in
            decoder. Default: dict(type='InterpConv').
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only. Default: False.
        dcn (bool): Use deformable convolution in convolutional layer or not.
            Default: None.
        plugins (dict): plugins for convolutional layers. Default: None.
        pretrained (str, optional): model pretrained path. Default: None
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default: None

    Notice:
        The input image size should be divisible by the whole downsample rate
        of the encoder. More detail of the whole downsample rate can be found
        in `UNet._check_input_divisible`.
    """

    def __init__(self,
                 in_channels=3,
                 out_channels=64,
                 num_stages=5,
                 strides=(1, 1, 1, 1, 1),
                 enc_num_convs=(2, 2, 2, 2, 2),
                 dec_num_convs=(2, 2, 2, 2),
                 downsamples=(True, True, True, True),
                 enc_dilations=(1, 1, 1, 1, 1),
                 dec_dilations=(1, 1, 1, 1),
                 with_cp=False,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 act_cfg=dict(type='ReLU'),
                 upsample_cfg=None,
                 norm_eval=False,
                 dcn=None,
                 plugins=None,
                 pretrained=None,
                 init_cfg=None):
        super(UNet, self).__init__()

        self.pretrained = pretrained
        assert not (init_cfg and pretrained), \
            'init_cfg and pretrained cannot be setting at the same time'
        if isinstance(pretrained, str):
            warnings.warn('DeprecationWarning: pretrained is a deprecated, '
                          'please use "init_cfg" instead')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
        elif pretrained is None:
            if init_cfg is None:
                self.init_cfg = [
                    dict(type='Kaiming', layer='Conv2d'),
                    dict(
                        type='Constant',
                        val=1,
                        layer=['_BatchNorm', 'GroupNorm'])
                ]
        else:
            raise TypeError('pretrained must be a str or None')

        assert dcn is None, 'Not implemented yet.'
        assert plugins is None, 'Not implemented yet.'
        assert len(strides) == num_stages, \
            'The length of strides should be equal to num_stages, '\
            f'while the strides is {strides}, the length of '\
            f'strides is {len(strides)}, and the num_stages is '\
            f'{num_stages}.'
        assert len(enc_num_convs) == num_stages, \
            'The length of enc_num_convs should be equal to num_stages, '\
            f'while the enc_num_convs is {enc_num_convs}, the length of '\
            f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
            f'{num_stages}.'
        assert len(dec_num_convs) == (num_stages-1), \
            'The length of dec_num_convs should be equal to (num_stages-1), '\
            f'while the dec_num_convs is {dec_num_convs}, the length of '\
            f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
            f'{num_stages}.'
        assert len(downsamples) == (num_stages-1), \
            'The length of downsamples should be equal to (num_stages-1), '\
            f'while the downsamples is {downsamples}, the length of '\
            f'downsamples is {len(downsamples)}, and the num_stages is '\
            f'{num_stages}.'
        assert len(enc_dilations) == num_stages, \
            'The length of enc_dilations should be equal to num_stages, '\
            f'while the enc_dilations is {enc_dilations}, the length of '\
            f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
            f'{num_stages}.'
        assert len(dec_dilations) == (num_stages-1), \
            'The length of dec_dilations should be equal to (num_stages-1), '\
            f'while the dec_dilations is {dec_dilations}, the length of '\
            f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
            f'{num_stages}.'
        self.num_stages = num_stages
        self.strides = strides
        self.downsamples = downsamples
        self.norm_eval = norm_eval
        self.out_channels = [out_channels * 2**i for i in reversed(range(num_stages))]

        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()

        for i in range(num_stages):
            enc_conv_block = []
            if i != 0:
                if strides[i] == 1 and downsamples[i - 1]:
                    enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
                upsample = (strides[i] != 1 or downsamples[i - 1])
                self.decoder.append(
                    UpConvBlock(
                        conv_block=BasicConvBlock,
                        in_channels=out_channels * 2**i,
                        skip_channels=out_channels * 2**(i - 1),
                        out_channels=out_channels * 2**(i - 1),
                        num_convs=dec_num_convs[i - 1],
                        stride=1,
                        dilation=dec_dilations[i - 1],
                        with_cp=with_cp,
                        conv_cfg=conv_cfg,
                        norm_cfg=norm_cfg,
                        act_cfg=act_cfg,
                        upsample_cfg=upsample_cfg if upsample else None,
                        dcn=None,
                        plugins=None))

            enc_conv_block.append(
                BasicConvBlock(
                    in_channels=in_channels,
                    out_channels=out_channels * 2**i,
                    num_convs=enc_num_convs[i],
                    stride=strides[i],
                    dilation=enc_dilations[i],
                    with_cp=with_cp,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=act_cfg,
                    dcn=None,
                    plugins=None))
            self.encoder.append((nn.Sequential(*enc_conv_block)))
            in_channels = out_channels * 2**i

    def forward(self, x):

        # We can check just the first image, since the batch 
        # already was approved by the stackability test, which means
        # all images has the same dimensions. 
        self._check_input_divisible(x[0])

        enc_outs = []
        for enc in self.encoder:
            x = enc(x)
            enc_outs.append(x)
        dec_outs = [x]
        for i in reversed(range(len(self.decoder))):
            x = self.decoder[i](enc_outs[i], x)
            dec_outs.append(x)
        return dec_outs

    def train(self, mode=True):
        """Convert the model into training mode while keep normalization layer
        freezed."""
        super(UNet, self).train(mode)
        if mode and self.norm_eval:
            for m in self.modules():
                # trick: eval have effect on BatchNorm only
                if isinstance(m, _BatchNorm):
                    m.eval()

    def _check_input_divisible(self, x):
        h, w = x.shape[-2:]
        whole_downsample_rate = 1
        for i in range(1, self.num_stages):
            if self.strides[i] == 2 or self.downsamples[i - 1]:
                whole_downsample_rate *= 2
        assert (h % whole_downsample_rate == 0) \
            and (w % whole_downsample_rate == 0),\
            f'The input image size {(h, w)} should be divisible by the whole '\
            f'downsample rate {whole_downsample_rate}, when num_stages is '\
            f'{self.num_stages}, strides is {self.strides}, and downsamples '\
            f'is {self.downsamples}.'

train(mode=True) #

Convert the model into training mode while keep normalization layer freezed.

Source code in terratorch/models/backbones/unet.py
def train(self, mode=True):
    """Convert the model into training mode while keep normalization layer
    freezed."""
    super(UNet, self).train(mode)
    if mode and self.norm_eval:
        for m in self.modules():
            # trick: eval have effect on BatchNorm only
            if isinstance(m, _BatchNorm):
                m.eval()

terratorch.models.backbones.mmearth_convnextv2.ConvNeXtV2 #

Bases: Module

ConvNeXt V2

Parameters:

Name Type Description Default
in_chans int

Number of input image channels. Default: 3

3
num_classes int

Number of classes for classification head. Default: 1000

1000
depths tuple(int

Number of blocks at each stage. Default: [3, 3, 9, 3]

None
dims int

Feature dimension at each stage. Default: [96, 192, 384, 768]

None
drop_path_rate float

Stochastic depth rate. Default: 0.

0.0
head_init_scale float

Init scaling value for classifier weights and biases. Default: 1.

1.0
Source code in terratorch/models/backbones/mmearth_convnextv2.py
class ConvNeXtV2(nn.Module):
    """ConvNeXt V2

    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """

    def __init__(
        self,
        patch_size: int = 32,
        img_size: int = 128,
        in_chans: int = 3,
        num_classes: int = 1000,
        depths: list[int] = None,
        dims: list[int] = None,
        drop_path_rate: float = 0.0,
        head_init_scale: float = 1.0,
        use_orig_stem: bool = False,
        args: Namespace = None,
    ):
        super().__init__()
        self.depths = depths
        if self.depths is None:  # set default value
            self.depths = [3, 3, 9, 3]
        self.img_size = img_size
        self.use_orig_stem = use_orig_stem
        self.num_stage = len(depths)
        self.downsample_layers = (
            nn.ModuleList()
        )  # stem and 3 intermediate downsampling conv layer
        self.patch_size = patch_size
        if dims is None:
            dims = [96, 192, 384, 768]

        if self.use_orig_stem:
            self.stem_orig = nn.Sequential(
                nn.Conv2d(
                    in_chans,
                    dims[0],
                    kernel_size=patch_size // (2 ** (self.num_stage - 1)),
                    stride=patch_size // (2 ** (self.num_stage - 1)),
                ),
                LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
            )
        else:
            self.initial_conv = nn.Sequential(
                nn.Conv2d(in_chans, dims[0], kernel_size=3, stride=1),
                LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
                nn.GELU(),
            )
            # depthwise conv for stem
            self.stem = nn.Sequential(
                nn.Conv2d(
                    dims[0],
                    dims[0],
                    kernel_size=patch_size // (2 ** (self.num_stage - 1)),
                    stride=patch_size // (2 ** (self.num_stage - 1)),
                    padding=(patch_size // (2 ** (self.num_stage - 1))) // 2,
                    groups=dims[0],
                ),
                LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
            )

        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = (
            nn.ModuleList()
        )  # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(self.num_stage):
            stage = nn.Sequential(
                *[
                    Block(dim=dims[i], drop_path=dp_rates[cur + j])
                    for j in range(depths[i])
                ]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final norm layer
        self.head = nn.Linear(dims[-1], num_classes)

        self.apply(self._init_weights)
        self.head.weight.data.mul_(head_init_scale)
        self.head.bias.data.mul_(head_init_scale)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=0.02)
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        if self.use_orig_stem:
            x = self.stem_orig(x)
        else:
            x = self.initial_conv(x)
            x = self.stem(x)

        x = self.stages[0](x)
        for i in range(3):
            x = self.downsample_layers[i](x)
            x = self.stages[i + 1](x)

        return self.norm(
            x.mean([-2, -1])
        )  # global average pooling, (N, C, H, W) -> (N, C)

    def upsample_mask(self, mask, scale):
        assert len(mask.shape) == 2
        p = int(mask.shape[1] ** 0.5)
        return (
            mask.reshape(-1, p, p)
            .repeat_interleave(scale, axis=1)
            .repeat_interleave(scale, axis=2)
        )

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        if mask is not None:  # for the pretraining case
            num_patches = mask.shape[1]
            scale = int(self.img_size // (num_patches**0.5))
            mask = self.upsample_mask(mask, scale)

            mask = mask.unsqueeze(1).type_as(x)
            x *= 1.0 - mask
            if self.use_orig_stem:
                x = self.stem_orig(x)
            else:
                x = self.initial_conv(x)
                x = self.stem(x)

            x = self.stages[0](x)
            for i in range(3):
                x = self.downsample_layers[i](x)
                x = self.stages[i + 1](x)
            return x

        x = self.forward_features(x)
        x = self.head(x)
        return x

terratorch.models.backbones.dofa_vit.DOFAEncoderWrapper #

Bases: Module

A wrapper for DOFA models from torchgeo to return only the forward pass of the encoder Attributes: dofa_model (DOFA): The instantiated dofa model Methods: forward(x: List[torch.Tensor], wavelengths: list[float]) -> torch.Tensor: Forward pass for embeddings with specified indices.

Source code in terratorch/models/backbones/dofa_vit.py
class DOFAEncoderWrapper(nn.Module):
    """
    A wrapper for DOFA models from torchgeo to return only the forward pass of the encoder
    Attributes:
        dofa_model (DOFA): The instantiated dofa model
    Methods:
        forward(x: List[torch.Tensor], wavelengths: list[float]) -> torch.Tensor:
            Forward pass for embeddings with specified indices.
    """

    def __init__(self, dofa_model, wavelengths, weights=None, out_indices=None) -> None:
        """
        Args:
            dofa_model (DOFA): The decoder module to be wrapped.
            weights ()
        """
        super().__init__()
        self.dofa_model = dofa_model
        self.weights = weights
        self.wavelengths = wavelengths

        self.out_indices = out_indices if out_indices else [-1]
        self.out_channels = [self.dofa_model.patch_embed.embed_dim] * len(self.out_indices)

    def forward(self, x: list[torch.Tensor], **kwargs) -> torch.Tensor:
        wavelist = torch.tensor(self.wavelengths, device=x.device).float()

        x, _ = self.dofa_model.patch_embed(x, wavelist)
        x = x + self.dofa_model.pos_embed[:, 1:, :]
        # append cls token
        cls_token = self.dofa_model.cls_token + self.dofa_model.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        outs = []
        # apply Transformer blocks
        for i, block in enumerate(self.dofa_model.blocks):
            x = block(x)
            if i in self.out_indices:
                outs.append(x)
            elif (i == (len(self.dofa_model.blocks) - 1)) & (-1 in self.out_indices):
                outs.append(x)

        return tuple(outs)

__init__(dofa_model, wavelengths, weights=None, out_indices=None) #

Parameters:

Name Type Description Default
dofa_model DOFA

The decoder module to be wrapped.

required
Source code in terratorch/models/backbones/dofa_vit.py
def __init__(self, dofa_model, wavelengths, weights=None, out_indices=None) -> None:
    """
    Args:
        dofa_model (DOFA): The decoder module to be wrapped.
        weights ()
    """
    super().__init__()
    self.dofa_model = dofa_model
    self.weights = weights
    self.wavelengths = wavelengths

    self.out_indices = out_indices if out_indices else [-1]
    self.out_channels = [self.dofa_model.patch_embed.embed_dim] * len(self.out_indices)

terratorch.models.backbones.clay_v1.embedder #

Embedder #

Bases: Module

Source code in terratorch/models/backbones/clay_v1/embedder.py
class Embedder(nn.Module):
    default_out_indices = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)

    def __init__(
        self,
        img_size=256,
        num_frames=1,
        ckpt_path=None,
        bands=["blue", "green", "red", "nir", "swir16", "swir22"],
        out_indices: tuple[int] = default_out_indices,
        vpt: bool = False,
        vpt_n_tokens: int | None = None,
        vpt_dropout: float = 0.0,
        **kwargs,
    ):
        super().__init__()
        self.feature_info = []
        self.img_size = img_size
        self.num_frames = num_frames
        self.bands = bands
        self.out_indices = out_indices

        self.datacuber = Datacuber(bands=bands)

        # TODO: add support for various clay versions
        self.clay_encoder = (
            EmbeddingEncoder(  # Default parameters for the Clay base model
                img_size=img_size,
                patch_size=8,
                dim=768,
                depth=12,
                heads=12,
                dim_head=64,
                mlp_ratio=4.0,
                vpt=vpt,
                vpt_n_tokens=vpt_n_tokens,
                vpt_dropout=vpt_dropout,
            )
        )

        # for use in features list.
        for i in range(12):
            self.feature_info.append({"num_chs": 768, "reduction": 1, "module": f"blocks.{i}"})

        # assuming this is used to fine tune a network on top of the embeddings

        if ckpt_path:
            self.load_clay_weights(ckpt_path)

    def load_clay_weights(self, ckpt_path):
        "Load the weights from the Clay model encoder."
        ckpt = torch.load(ckpt_path, weights_only=True)
        state_dict = ckpt.get("state_dict")
        state_dict = {
            re.sub(r"^model\.encoder\.", "", name): param
            for name, param in state_dict.items()
            if name.startswith("model.encoder")
        }

        with torch.no_grad():
            for name, param in self.clay_encoder.named_parameters():
                if name in state_dict and param.size() == state_dict[name].size():
                    param.data.copy_(state_dict[name])  # Copy the weights
                else:
                    print(
                        f"No matching parameter for {name} with size {param.size()}")

        for param in self.clay_encoder.parameters():
            param.requires_grad = False

        self.clay_encoder.eval()

    @staticmethod
    def transform_state_dict(state_dict, model):
        state_dict = state_dict.get("state_dict")
        state_dict = {
            re.sub(r"^model\.encoder\.", "clay_encoder.", name): param
            for name, param in state_dict.items()
            if name.startswith("model.encoder")
        }
        for k, v in model.state_dict().items():
            if "vpt_prompt_embeddings" in k:
                state_dict[k] = v
        return state_dict

    def forward_features(
        self,
        x: torch.Tensor,
        time: torch.Tensor | None = None,
        latlon: torch.Tensor | None = None,
        waves: torch.Tensor | None = None,
        gsd: float | None = None,
    ):
        datacube = self.datacuber(x=x, time=time, latlon=latlon, waves=waves, gsd=gsd)
        embeddings = self.clay_encoder(datacube)

        return [embeddings[i] for i in self.out_indices]

    def fake_datacube(self):
        "Generate a fake datacube for model export."
        dummy_datacube = {
            "pixels": torch.randn(2, 3, self.img_size, self.img_size),
            "time": torch.randn(2, 4),
            "latlon": torch.randn(2, 4),
            "waves": torch.randn(3),
            "gsd": torch.randn(1),
        }
        dummy_datacube = {k: v
                          for k, v in dummy_datacube.items()}
        return dummy_datacube

    def prepare_features_for_image_model(self, features: list[Tensor]) -> list[Tensor]:
        x_no_token = features[-1][:, 1:, :]
        encoded = x_no_token.permute(0, 2, 1).reshape(
            x_no_token.shape[0],
            -1,
            int(np.sqrt(x_no_token.shape[1] // self.num_frames)),
            int(np.sqrt(x_no_token.shape[1] // self.num_frames)),
        )

        # return as list for features list compatibility
        return [encoded]

    def freeze(self):
        for n, param in self.named_parameters():
            if "vpt_prompt_embeddings" not in n:
                param.requires_grad_(False)
fake_datacube() #

Generate a fake datacube for model export.

Source code in terratorch/models/backbones/clay_v1/embedder.py
def fake_datacube(self):
    "Generate a fake datacube for model export."
    dummy_datacube = {
        "pixels": torch.randn(2, 3, self.img_size, self.img_size),
        "time": torch.randn(2, 4),
        "latlon": torch.randn(2, 4),
        "waves": torch.randn(3),
        "gsd": torch.randn(1),
    }
    dummy_datacube = {k: v
                      for k, v in dummy_datacube.items()}
    return dummy_datacube
load_clay_weights(ckpt_path) #

Load the weights from the Clay model encoder.

Source code in terratorch/models/backbones/clay_v1/embedder.py
def load_clay_weights(self, ckpt_path):
    "Load the weights from the Clay model encoder."
    ckpt = torch.load(ckpt_path, weights_only=True)
    state_dict = ckpt.get("state_dict")
    state_dict = {
        re.sub(r"^model\.encoder\.", "", name): param
        for name, param in state_dict.items()
        if name.startswith("model.encoder")
    }

    with torch.no_grad():
        for name, param in self.clay_encoder.named_parameters():
            if name in state_dict and param.size() == state_dict[name].size():
                param.data.copy_(state_dict[name])  # Copy the weights
            else:
                print(
                    f"No matching parameter for {name} with size {param.size()}")

    for param in self.clay_encoder.parameters():
        param.requires_grad = False

    self.clay_encoder.eval()

APIs for External Models#

Tip

You find a detailed overview of all models in the TorchGeo documentation.

terratorch.models.backbones.torchgeo_vit #

ViTEncoderWrapper #

Bases: Module

A wrapper for ViT models from torchgeo to return only the forward pass of the encoder Attributes: satlas_model (VisionTransformer): The instantiated dofa model weights Methods: forward(x: List[torch.Tensor], wavelengths: list[float]) -> torch.Tensor: Forward pass for embeddings with specified indices.

Source code in terratorch/models/backbones/torchgeo_vit.py
class ViTEncoderWrapper(nn.Module):

    """
    A wrapper for ViT models from torchgeo to return only the forward pass of the encoder 
    Attributes:
        satlas_model (VisionTransformer): The instantiated dofa model
        weights
    Methods:
        forward(x: List[torch.Tensor], wavelengths: list[float]) -> torch.Tensor:
            Forward pass for embeddings with specified indices.
    """

    def __init__(self, vit_model, vit_meta, weights=None, out_indices=None) -> None:
        """
        Args:
            dofa_model (DOFA): The decoder module to be wrapped.
            weights ()
        """
        super().__init__()
        self.vit_model = vit_model
        self.weights = weights
        self.out_channels = [x['num_chs'] for x in self.vit_model.feature_info]
        self.vit_meta = vit_meta


    def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
        return self.vit_model.forward_intermediates(x, intermediates_only=True)
__init__(vit_model, vit_meta, weights=None, out_indices=None) #

Parameters:

Name Type Description Default
dofa_model DOFA

The decoder module to be wrapped.

required
Source code in terratorch/models/backbones/torchgeo_vit.py
def __init__(self, vit_model, vit_meta, weights=None, out_indices=None) -> None:
    """
    Args:
        dofa_model (DOFA): The decoder module to be wrapped.
        weights ()
    """
    super().__init__()
    self.vit_model = vit_model
    self.weights = weights
    self.out_channels = [x['num_chs'] for x in self.vit_model.feature_info]
    self.vit_meta = vit_meta

ssl4eol_vit_small_patch16_224_landsat_etm_sr_moco(model_bands, pretrained=False, ckpt_data=None, weights=ViTSmall16_Weights.LANDSAT_ETM_SR_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_vit.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_vit_small_patch16_224_landsat_etm_sr_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ViTSmall16_Weights.LANDSAT_ETM_SR_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = vit_small_patch16_224(**kwargs)
    if pretrained:
        model = load_vit_weights(model, model_bands, ckpt_data, weights)
    return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices)

ssl4eol_vit_small_patch16_224_landsat_etm_sr_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ViTSmall16_Weights.LANDSAT_ETM_SR_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_vit.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_vit_small_patch16_224_landsat_etm_sr_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None =  ViTSmall16_Weights.LANDSAT_ETM_SR_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = vit_small_patch16_224(**kwargs)
    if pretrained:
        model = load_vit_weights(model, model_bands, ckpt_data, weights)
    return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices)

ssl4eol_vit_small_patch16_224_landsat_etm_toa_moco(model_bands, pretrained=False, ckpt_data=None, weights=ViTSmall16_Weights.LANDSAT_ETM_TOA_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_vit.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_vit_small_patch16_224_landsat_etm_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ViTSmall16_Weights.LANDSAT_ETM_TOA_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = vit_small_patch16_224(**kwargs)
    if pretrained:
        model = load_vit_weights(model, model_bands, ckpt_data, weights)
    return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices)

ssl4eol_vit_small_patch16_224_landsat_etm_toa_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ViTSmall16_Weights.LANDSAT_ETM_TOA_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_vit.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_vit_small_patch16_224_landsat_etm_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ViTSmall16_Weights.LANDSAT_ETM_TOA_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = vit_small_patch16_224(**kwargs)
    if pretrained:
        model = load_vit_weights(model, model_bands, ckpt_data, weights)
    return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices)

ssl4eol_vit_small_patch16_224_landsat_oli_sr_moco(model_bands, pretrained=False, ckpt_data=None, weights=ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_vit.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_vit_small_patch16_224_landsat_oli_sr_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None =  ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = vit_small_patch16_224(**kwargs)
    if pretrained:
        model = load_vit_weights(model, model_bands, ckpt_data, weights)
    return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices)

ssl4eol_vit_small_patch16_224_landsat_oli_sr_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_vit.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_vit_small_patch16_224_landsat_oli_sr_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None =   ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = vit_small_patch16_224(**kwargs)
    if pretrained:
        model = load_vit_weights(model, model_bands, ckpt_data, weights)
    return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices)

ssl4eol_vit_small_patch16_224_landsat_oli_tirs_toa_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ViTSmall16_Weights.LANDSAT_OLI_TIRS_TOA_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_vit.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_vit_small_patch16_224_landsat_oli_tirs_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None =  ViTSmall16_Weights.LANDSAT_OLI_TIRS_TOA_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = vit_small_patch16_224(**kwargs)
    if pretrained:
        model = load_vit_weights(model, model_bands, ckpt_data, weights)
    return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices)

ssl4eol_vit_small_patch16_224_landsat_tm_toa_moco(model_bands, pretrained=False, ckpt_data=None, weights=ViTSmall16_Weights.LANDSAT_TM_TOA_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_vit.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_vit_small_patch16_224_landsat_tm_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ViTSmall16_Weights.LANDSAT_TM_TOA_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = vit_small_patch16_224(**kwargs)
    if pretrained:
        model = load_vit_weights(model, model_bands, ckpt_data, weights)
    return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices)

ssl4eol_vit_small_patch16_224_landsat_tm_toa_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ViTSmall16_Weights.LANDSAT_TM_TOA_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_vit.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_vit_small_patch16_224_landsat_tm_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ViTSmall16_Weights.LANDSAT_TM_TOA_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = vit_small_patch16_224(**kwargs)
    if pretrained:
        model = load_vit_weights(model, model_bands, ckpt_data, weights)
    return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices)

ssl4eos12_vit_small_patch16_224_sentinel2_all_dino(model_bands, pretrained=False, ckpt_data=None, weights=ViTSmall16_Weights.SENTINEL2_ALL_DINO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_vit.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eos12_vit_small_patch16_224_sentinel2_all_dino(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None =   ViTSmall16_Weights.SENTINEL2_ALL_DINO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = vit_small_patch16_224(**kwargs)
    if pretrained:
        model = load_vit_weights(model, model_bands, ckpt_data, weights)
    return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices)

ssl4eos12_vit_small_patch16_224_sentinel2_all_moco(model_bands, pretrained=False, ckpt_data=None, weights=ViTSmall16_Weights.SENTINEL2_ALL_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_vit.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eos12_vit_small_patch16_224_sentinel2_all_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None =   ViTSmall16_Weights.SENTINEL2_ALL_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = vit_small_patch16_224(**kwargs)
    if pretrained:
        model = load_vit_weights(model, model_bands, ckpt_data, weights)
    return ViTEncoderWrapper(model, vit_s_meta, weights, out_indices)

terratorch.models.backbones.torchgeo_resnet #

ResNetEncoderWrapper #

Bases: Module

A wrapper for ViT models from torchgeo to return only the forward pass of the encoder Attributes: satlas_model (VisionTransformer): The instantiated dofa model weights Methods: forward(x: List[torch.Tensor], wavelengths: list[float]) -> torch.Tensor: Forward pass for embeddings with specified indices.

Source code in terratorch/models/backbones/torchgeo_resnet.py
class ResNetEncoderWrapper(nn.Module):

    """
    A wrapper for ViT models from torchgeo to return only the forward pass of the encoder 
    Attributes:
        satlas_model (VisionTransformer): The instantiated dofa model
        weights
    Methods:
        forward(x: List[torch.Tensor], wavelengths: list[float]) -> torch.Tensor:
            Forward pass for embeddings with specified indices.
    """

    def __init__(self, resnet_model, resnet_meta, weights=None, out_indices=None) -> None:
        """
        Args:
            dofa_model (DOFA): The decoder module to be wrapped.
            weights ()
        """
        super().__init__()
        self.resnet_model = resnet_model
        self.resnet_meta = resnet_meta
        self.weights = weights
        self.out_indices = out_indices if out_indices else [-1]
        self.out_channels = [x['num_chs'] for x in self.resnet_model.feature_info]
        self.resnet_meta['original_out_channels'] = self.out_channels
        self.out_channels = [x for i, x in enumerate(self.out_channels) if (i in self.out_indices) | (i == (len(self.out_channels)-1)) & (-1 in self.out_indices)]


    def forward(self, x: List[torch.Tensor], **kwargs) -> torch.Tensor:

        features = self.resnet_model.forward_intermediates(x, intermediates_only=True)

        outs = []
        for i, feature in enumerate(features):
            if i in self.out_indices:
                outs.append(feature)
            elif (i == (len(self.resnet_meta["original_out_channels"])-1)) & (-1 in self.out_indices):
                outs.append(feature)

        return outs
__init__(resnet_model, resnet_meta, weights=None, out_indices=None) #

Parameters:

Name Type Description Default
dofa_model DOFA

The decoder module to be wrapped.

required
Source code in terratorch/models/backbones/torchgeo_resnet.py
def __init__(self, resnet_model, resnet_meta, weights=None, out_indices=None) -> None:
    """
    Args:
        dofa_model (DOFA): The decoder module to be wrapped.
        weights ()
    """
    super().__init__()
    self.resnet_model = resnet_model
    self.resnet_meta = resnet_meta
    self.weights = weights
    self.out_indices = out_indices if out_indices else [-1]
    self.out_channels = [x['num_chs'] for x in self.resnet_model.feature_info]
    self.resnet_meta['original_out_channels'] = self.out_channels
    self.out_channels = [x for i, x in enumerate(self.out_channels) if (i in self.out_indices) | (i == (len(self.out_channels)-1)) & (-1 in self.out_indices)]

fmow_resnet50_fmow_rgb_gassl(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.FMOW_RGB_GASSL, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def fmow_resnet50_fmow_rgb_gassl(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.FMOW_RGB_GASSL, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

satlas_resnet152_sentinel2_mi_ms(model_bands, pretrained=False, ckpt_data=None, weights=ResNet152_Weights.SENTINEL2_MI_MS_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_resnet152_sentinel2_mi_ms(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None =  ResNet152_Weights.SENTINEL2_MI_MS_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet152(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet152_meta, weights, out_indices)

satlas_resnet152_sentinel2_mi_rgb(model_bands, pretrained=False, ckpt_data=None, weights=ResNet152_Weights.SENTINEL2_MI_RGB_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_resnet152_sentinel2_mi_rgb(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet152_Weights.SENTINEL2_MI_RGB_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet152(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet152_meta, weights, out_indices)

satlas_resnet152_sentinel2_si_ms_satlas(model_bands, pretrained=False, ckpt_data=None, weights=ResNet152_Weights.SENTINEL2_SI_MS_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_resnet152_sentinel2_si_ms_satlas(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet152_Weights.SENTINEL2_SI_MS_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet152(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet152_meta, weights, out_indices)

satlas_resnet152_sentinel2_si_rgb_satlas(model_bands, pretrained=False, ckpt_data=None, weights=ResNet152_Weights.SENTINEL2_SI_RGB_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_resnet152_sentinel2_si_rgb_satlas(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None =  ResNet152_Weights.SENTINEL2_SI_RGB_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet152(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet152_meta, weights, out_indices)

satlas_resnet50_sentinel2_mi_ms_satlas(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.SENTINEL2_MI_MS_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_resnet50_sentinel2_mi_ms_satlas(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.SENTINEL2_MI_MS_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

satlas_resnet50_sentinel2_mi_rgb_satlas(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.SENTINEL2_MI_RGB_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_resnet50_sentinel2_mi_rgb_satlas(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.SENTINEL2_MI_RGB_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

satlas_resnet50_sentinel2_si_ms_satlas(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.SENTINEL2_SI_MS_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_resnet50_sentinel2_si_ms_satlas(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.SENTINEL2_SI_MS_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

satlas_resnet50_sentinel2_si_rgb_satlas(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.SENTINEL2_SI_RGB_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_resnet50_sentinel2_si_rgb_satlas(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.SENTINEL2_SI_RGB_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

seco_resnet18_sentinel2_rgb_seco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet18_Weights.SENTINEL2_RGB_SECO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def seco_resnet18_sentinel2_rgb_seco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet18_Weights.SENTINEL2_RGB_SECO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet18(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices)

seco_resnet50_sentinel2_rgb_seco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.SENTINEL2_RGB_SECO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def seco_resnet50_sentinel2_rgb_seco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.SENTINEL2_RGB_SECO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eol_resnet18_landsat_etm_sr_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet18_Weights.LANDSAT_ETM_SR_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet18_landsat_etm_sr_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet18_Weights.LANDSAT_ETM_SR_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet18(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices)

ssl4eol_resnet18_landsat_etm_sr_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ResNet18_Weights.LANDSAT_ETM_SR_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet18_landsat_etm_sr_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet18_Weights.LANDSAT_ETM_SR_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet18(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices)

ssl4eol_resnet18_landsat_etm_toa_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet18_Weights.LANDSAT_ETM_TOA_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet18_landsat_etm_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet18_Weights.LANDSAT_ETM_TOA_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet18(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices)

ssl4eol_resnet18_landsat_etm_toa_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ResNet18_Weights.LANDSAT_ETM_TOA_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet18_landsat_etm_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet18_Weights.LANDSAT_ETM_TOA_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet18(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices)

ssl4eol_resnet18_landsat_oli_sr_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet18_Weights.LANDSAT_OLI_SR_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet18_landsat_oli_sr_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet18_Weights.LANDSAT_OLI_SR_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet18(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices)

ssl4eol_resnet18_landsat_oli_sr_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ResNet18_Weights.LANDSAT_OLI_SR_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet18_landsat_oli_sr_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet18_Weights.LANDSAT_OLI_SR_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet18(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices)

ssl4eol_resnet18_landsat_oli_tirs_toa_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet18_Weights.LANDSAT_OLI_TIRS_TOA_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet18_landsat_oli_tirs_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet18_Weights.LANDSAT_OLI_TIRS_TOA_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet18(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices)

ssl4eol_resnet18_landsat_oli_tirs_toa_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ResNet18_Weights.LANDSAT_OLI_TIRS_TOA_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet18_landsat_oli_tirs_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet18_Weights.LANDSAT_OLI_TIRS_TOA_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet18(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices)

ssl4eol_resnet18_landsat_tm_toa_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet18_Weights.LANDSAT_TM_TOA_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet18_landsat_tm_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet18_Weights.LANDSAT_TM_TOA_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet18(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices)

ssl4eol_resnet18_landsat_tm_toa_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ResNet18_Weights.LANDSAT_TM_TOA_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet18_landsat_tm_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet18_Weights.LANDSAT_TM_TOA_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet18(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices)

ssl4eol_resnet50_landsat_etm_sr_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.LANDSAT_ETM_SR_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet50_landsat_etm_sr_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.LANDSAT_ETM_SR_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eol_resnet50_landsat_etm_sr_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.LANDSAT_ETM_SR_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet50_landsat_etm_sr_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.LANDSAT_ETM_SR_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eol_resnet50_landsat_etm_toa_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.LANDSAT_ETM_TOA_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet50_landsat_etm_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.LANDSAT_ETM_TOA_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eol_resnet50_landsat_etm_toa_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.LANDSAT_ETM_TOA_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet50_landsat_etm_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.LANDSAT_ETM_TOA_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eol_resnet50_landsat_oli_sr_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.LANDSAT_OLI_SR_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet50_landsat_oli_sr_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.LANDSAT_OLI_SR_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eol_resnet50_landsat_oli_sr_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.LANDSAT_OLI_SR_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet50_landsat_oli_sr_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.LANDSAT_OLI_SR_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eol_resnet50_landsat_oli_tirs_toa_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.LANDSAT_OLI_TIRS_TOA_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet50_landsat_oli_tirs_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.LANDSAT_OLI_TIRS_TOA_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eol_resnet50_landsat_oli_tirs_toa_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.LANDSAT_OLI_TIRS_TOA_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet50_landsat_oli_tirs_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.LANDSAT_OLI_TIRS_TOA_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eol_resnet50_landsat_tm_toa_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.LANDSAT_TM_TOA_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet50_landsat_tm_toa_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.LANDSAT_TM_TOA_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eol_resnet50_landsat_tm_toa_simclr(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.LANDSAT_TM_TOA_SIMCLR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eol_resnet50_landsat_tm_toa_simclr(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.LANDSAT_TM_TOA_SIMCLR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eos12_resnet18_sentinel2_all_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet18_Weights.SENTINEL2_ALL_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eos12_resnet18_sentinel2_all_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None =  ResNet18_Weights.SENTINEL2_ALL_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet18(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices)

ssl4eos12_resnet18_sentinel2_rgb_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet18_Weights.SENTINEL2_RGB_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eos12_resnet18_sentinel2_rgb_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet18_Weights.SENTINEL2_RGB_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet18(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet18_meta, weights, out_indices)

ssl4eos12_resnet50_sentinel1_all_decur(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.SENTINEL1_ALL_DECUR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eos12_resnet50_sentinel1_all_decur(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.SENTINEL1_ALL_DECUR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        if weights is not None:
            weights.meta['bands'] = ['VV', 'VH']
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eos12_resnet50_sentinel1_all_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.SENTINEL1_ALL_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eos12_resnet50_sentinel1_all_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None =  ResNet50_Weights.SENTINEL1_ALL_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        if weights is not None:
            weights.meta['bands'] = ['VV', 'VH']
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eos12_resnet50_sentinel2_all_decur(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.SENTINEL2_ALL_DECUR, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eos12_resnet50_sentinel2_all_decur(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.SENTINEL2_ALL_DECUR, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        if weights is not None:
            weights.meta['bands'] = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12']
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eos12_resnet50_sentinel2_all_dino(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.SENTINEL2_ALL_DINO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eos12_resnet50_sentinel2_all_dino(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.SENTINEL2_ALL_DINO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        if weights is not None:
            weights.meta['bands'] = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12']
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eos12_resnet50_sentinel2_all_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.SENTINEL2_ALL_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eos12_resnet50_sentinel2_all_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.SENTINEL2_ALL_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        if weights is not None:
            weights.meta['bands'] = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12']
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

ssl4eos12_resnet50_sentinel2_rgb_moco(model_bands, pretrained=False, ckpt_data=None, weights=ResNet50_Weights.SENTINEL2_RGB_MOCO, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: ViTEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_resnet.py
@TERRATORCH_BACKBONE_REGISTRY.register
def ssl4eos12_resnet50_sentinel2_rgb_moco(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = ResNet50_Weights.SENTINEL2_RGB_MOCO, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        ViTEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = resnet50(**kwargs)
    if pretrained:
        model = load_resnet_weights(model, model_bands, ckpt_data, weights)
    return ResNetEncoderWrapper(model, resnet50_meta, weights, out_indices)

terratorch.models.backbones.torchgeo_swin_satlas #

SwinEncoderWrapper #

Bases: Module

A wrapper for Satlas models from torchgeo to return only the forward pass of the encoder Attributes: swin_model (SwinTransformer): The instantiated dofa model weights Methods: forward(x: List[torch.Tensor], wavelengths: list[float]) -> torch.Tensor: Forward pass for embeddings with specified indices.

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
class SwinEncoderWrapper(nn.Module):

    """
    A wrapper for Satlas models from torchgeo to return only the forward pass of the encoder 
    Attributes:
        swin_model (SwinTransformer): The instantiated dofa model
        weights
    Methods:
        forward(x: List[torch.Tensor], wavelengths: list[float]) -> torch.Tensor:
            Forward pass for embeddings with specified indices.
    """

    def __init__(self, swin_model, swin_meta, weights=None, out_indices=None) -> None:
        """
        Args:
            swin_model (SwinTransformer): The backbone module to be wrapped.
            swin_meta (dict): dict containing the metadata for swin.
            weights (Weights): Weights class for the swin model to be wrapped.
            out_indices (list): List containing the feature indices to be returned.
        """
        super().__init__()
        self.swin_model = swin_model
        self.weights = weights
        self.out_indices = out_indices if out_indices else [-1]

        self.out_channels = []
        for i in range(len(swin_meta["depths"])):
            self.out_channels.append(swin_meta["embed_dim"] * 2**i)
        self.out_channels = [elem for elem in self.out_channels for _ in range(2)]
        self.out_channels = [x for i, x in enumerate(self.out_channels) if (i in self.out_indices) | (i == (len(self.out_channels)-1)) & (-1 in self.out_indices)]

    def forward(self, x: List[torch.Tensor], **kwargs) -> torch.Tensor:

        outs = []
        for i, layer in enumerate(self.swin_model.features):
            x = layer(x)
            if i in self.out_indices:
                outs.append(x)
            elif (i == (len(self.swin_model.features)-1)) & (-1 in self.out_indices):
                outs.append(x)

        return tuple(outs)
__init__(swin_model, swin_meta, weights=None, out_indices=None) #

Parameters:

Name Type Description Default
swin_model SwinTransformer

The backbone module to be wrapped.

required
swin_meta dict

dict containing the metadata for swin.

required
weights Weights

Weights class for the swin model to be wrapped.

None
out_indices list

List containing the feature indices to be returned.

None
Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
def __init__(self, swin_model, swin_meta, weights=None, out_indices=None) -> None:
    """
    Args:
        swin_model (SwinTransformer): The backbone module to be wrapped.
        swin_meta (dict): dict containing the metadata for swin.
        weights (Weights): Weights class for the swin model to be wrapped.
        out_indices (list): List containing the feature indices to be returned.
    """
    super().__init__()
    self.swin_model = swin_model
    self.weights = weights
    self.out_indices = out_indices if out_indices else [-1]

    self.out_channels = []
    for i in range(len(swin_meta["depths"])):
        self.out_channels.append(swin_meta["embed_dim"] * 2**i)
    self.out_channels = [elem for elem in self.out_channels for _ in range(2)]
    self.out_channels = [x for i, x in enumerate(self.out_channels) if (i in self.out_indices) | (i == (len(self.out_channels)-1)) & (-1 in self.out_indices)]

satlas_swin_b_landsat_mi_ms(model_bands, pretrained=False, ckpt_data=None, weights=Swin_V2_B_Weights.LANDSAT_MI_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: SwinEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_swin_b_landsat_mi_ms(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = Swin_V2_B_Weights.LANDSAT_MI_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        SwinEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs)
    if pretrained:
        model = load_swin_weights(model, model_bands, ckpt_data, weights)
    return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices)

satlas_swin_b_landsat_mi_rgb(model_bands, pretrained=False, ckpt_data=None, weights=Swin_V2_B_Weights.LANDSAT_SI_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: SwinEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_swin_b_landsat_mi_rgb(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = Swin_V2_B_Weights.LANDSAT_SI_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        SwinEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs)
    if pretrained:
        model = load_swin_weights(model, model_bands, ckpt_data, weights)
    return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices)

satlas_swin_b_naip_mi_rgb(model_bands, pretrained=False, ckpt_data=None, weights=Swin_V2_B_Weights.NAIP_RGB_MI_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: SwinEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_swin_b_naip_mi_rgb(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = Swin_V2_B_Weights.NAIP_RGB_MI_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        SwinEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs)
    if pretrained:
        model = load_swin_weights(model, model_bands, ckpt_data, weights)
    return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices)

satlas_swin_b_naip_si_rgb(model_bands, pretrained=False, ckpt_data=None, weights=Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: SwinEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_swin_b_naip_si_rgb(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        SwinEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs)
    if pretrained:
        model = load_swin_weights(model, model_bands, ckpt_data, weights)
    return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices)

satlas_swin_b_sentinel1_mi(model_bands, pretrained=False, ckpt_data=None, weights=Swin_V2_B_Weights.SENTINEL1_MI_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: SwinEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_swin_b_sentinel1_mi(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = Swin_V2_B_Weights.SENTINEL1_MI_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        SwinEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs)
    if pretrained:
        model = load_swin_weights(model, model_bands, ckpt_data, weights)
    return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices)

satlas_swin_b_sentinel1_si(model_bands, pretrained=False, ckpt_data=None, weights=Swin_V2_B_Weights.SENTINEL1_SI_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: SwinEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_swin_b_sentinel1_si(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = Swin_V2_B_Weights.SENTINEL1_SI_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        SwinEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs)
    if pretrained:
        model = load_swin_weights(model, model_bands, ckpt_data, weights)
    return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices)

satlas_swin_b_sentinel2_mi_ms(model_bands, pretrained=False, ckpt_data=None, weights=Swin_V2_B_Weights.SENTINEL2_MI_MS_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: SwinEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_swin_b_sentinel2_mi_ms(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = Swin_V2_B_Weights.SENTINEL2_MI_MS_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        SwinEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs)
    if pretrained:
        model = load_swin_weights(model, model_bands, ckpt_data, weights)
    return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices)

satlas_swin_b_sentinel2_si_ms(model_bands, pretrained=False, ckpt_data=None, weights=Swin_V2_B_Weights.SENTINEL2_SI_MS_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: SwinEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_swin_b_sentinel2_si_ms(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = Swin_V2_B_Weights.SENTINEL2_SI_MS_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        SwinEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs)
    if pretrained:
        model = load_swin_weights(model, model_bands, ckpt_data, weights)
    return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices)

satlas_swin_b_sentinel2_si_rgb(model_bands, pretrained=False, ckpt_data=None, weights=Swin_V2_B_Weights.SENTINEL2_SI_RGB_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: SwinEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_swin_b_sentinel2_si_rgb(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = Swin_V2_B_Weights.SENTINEL2_SI_RGB_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        SwinEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = load_model(swin_v2_b, swin_v2_b_meta, **kwargs)
    if pretrained:
        model = load_swin_weights(model, model_bands, ckpt_data, weights)
    return SwinEncoderWrapper(model, swin_v2_b_meta, weights, out_indices)

satlas_swin_t_sentinel2_mi_ms(model_bands, pretrained=False, ckpt_data=None, weights=Swin_V2_T_Weights.SENTINEL2_MI_MS_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: SwinEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_swin_t_sentinel2_mi_ms(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = Swin_V2_T_Weights.SENTINEL2_MI_MS_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        SwinEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = load_model(swin_v2_t, swin_v2_t_meta, **kwargs)
    if pretrained:
        model = load_swin_weights(model, model_bands, ckpt_data, weights)
    return SwinEncoderWrapper(model, swin_v2_t_meta, weights, out_indices)

satlas_swin_t_sentinel2_mi_rgb(model_bands, pretrained=False, ckpt_data=None, weights=Swin_V2_T_Weights.SENTINEL2_MI_RGB_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: SwinEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_swin_t_sentinel2_mi_rgb(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = Swin_V2_T_Weights.SENTINEL2_MI_RGB_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        SwinEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = load_model(swin_v2_t, swin_v2_t_meta, **kwargs)
    if pretrained:
        model = load_swin_weights(model, model_bands, ckpt_data, weights)
    return SwinEncoderWrapper(model, swin_v2_t_meta, weights, out_indices)

satlas_swin_t_sentinel2_si_ms(model_bands, pretrained=False, ckpt_data=None, weights=Swin_V2_T_Weights.SENTINEL2_SI_MS_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: SwinEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_swin_t_sentinel2_si_ms(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = Swin_V2_T_Weights.SENTINEL2_SI_MS_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        SwinEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = load_model(swin_v2_t, swin_v2_t_meta, **kwargs)
    if pretrained:
        model = load_swin_weights(model, model_bands, ckpt_data, weights)
    return SwinEncoderWrapper(model, swin_v2_t_meta, weights, out_indices)

satlas_swin_t_sentinel2_si_rgb(model_bands, pretrained=False, ckpt_data=None, weights=Swin_V2_T_Weights.SENTINEL2_SI_RGB_SATLAS, out_indices=None, **kwargs) #

Parameters:

Name Type Description Default
model_bands list[str]

A list containing the names for the bands expected by the model.

required
pretrained bool

The model is already pretrained (weights are available and can be restored) or not.

False
ckpt_data str | None

Path for a checkpoint containing the model weights.

None

Returns: SwinEncoderWrapper

Source code in terratorch/models/backbones/torchgeo_swin_satlas.py
@TERRATORCH_BACKBONE_REGISTRY.register
def satlas_swin_t_sentinel2_si_rgb(model_bands, pretrained = False, ckpt_data: str | None = None,  weights: Weights | None = Swin_V2_T_Weights.SENTINEL2_SI_RGB_SATLAS, out_indices: list | None = None, **kwargs):
    """
    Args:
        model_bands (list[str]): A list containing the names for the bands expected by the model.
        pretrained (bool): The model is already pretrained (weights are available and can be restored) or not.
        ckpt_data (str | None): Path for a checkpoint containing the model weights.
    Returns:
        SwinEncoderWrapper
    """

    if "in_chans" not in kwargs: kwargs["in_chans"] = len(model_bands)
    model = load_model(swin_v2_t, swin_v2_t_meta, **kwargs)
    if pretrained:
        model = load_swin_weights(model, model_bands, ckpt_data, weights)
    return SwinEncoderWrapper(model, swin_v2_t_meta, weights, out_indices)