Skip to content

Necks#

Necks reshape the output of an encoder into a format suitable for the decoder. By combining different necks, you can combine any backbone with any decoder.

terratorch.models.necks.Neck #

Bases: ABC, Module

Base class for Neck

A neck must must implement self.process_channel_list which returns the new channel list.

Source code in terratorch/models/necks.py
class Neck(ABC, nn.Module):
    """Base class for Neck

    A neck must must implement `self.process_channel_list` which returns the new channel list.
    """

    def __init__(self, channel_list: list[int]) -> None:
        super().__init__()
        self.channel_list = channel_list

    @abstractmethod
    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        return channel_list

    @abstractmethod
    def forward(self, channel_list: list[torch.Tensor]) -> list[torch.Tensor]: ...

terratorch.models.necks.SelectIndices #

Bases: Neck

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class SelectIndices(Neck):
    def __init__(self, channel_list: list[int], indices: list[int]):
        """Select indices from the embedding list

        Args:
            indices (list[int]): list of indices to select.
        """
        super().__init__(channel_list)
        self.indices = indices

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        features = [features[i] for i in self.indices]
        return features

    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        channel_list = [channel_list[i] for i in self.indices]
        return channel_list

__init__(channel_list, indices) #

Select indices from the embedding list

Parameters:

Name Type Description Default
indices list[int]

list of indices to select.

required
Source code in terratorch/models/necks.py
def __init__(self, channel_list: list[int], indices: list[int]):
    """Select indices from the embedding list

    Args:
        indices (list[int]): list of indices to select.
    """
    super().__init__(channel_list)
    self.indices = indices

terratorch.models.necks.ReshapeTokensToImage #

Bases: Neck

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class ReshapeTokensToImage(Neck):
    def __init__(
        self, channel_list: list[int], remove_cls_token=True, effective_time_dim: int = 1, h: int | None = None
    ):
        """Reshape output of transformer encoder so it can be passed to a conv net.

        Args:
            remove_cls_token (bool, optional): Whether to remove the cls token from the first position.
                Defaults to True.
            effective_time_dim (int, optional): The effective temporal dimension the transformer processes.
                For a ViT, his will be given by `num_frames // tubelet size`. This is used to determine
                the temporal dimension of the embedding, which is concatenated with the embedding dimension.
                For example:
                - A model which processes 1 frame with a tubelet size of 1 has an effective_time_dim of 1.
                    The embedding produced by this model has embedding size embed_dim * 1.
                - A model which processes 3 frames with a tubelet size of 1 has an effective_time_dim of 3.
                    The embedding produced by this model has embedding size embed_dim * 3.
                - A model which processes 12 frames with a tubelet size of 4 has an effective_time_dim of 3.
                    The embedding produced by this model has an embedding size embed_dim * 3.
                Defaults to 1.
            h (int | None):
                You can choose a value for the height of the reshaped image.
                The embedding size will be implicitly discovered from it.
        """
        super().__init__(channel_list)
        self.remove_cls_token = remove_cls_token
        self.effective_time_dim = effective_time_dim
        self.h = h

    def collapse_dims(self, x):
        """
        When the encoder output has more than 3 dimensions, is necessary to
        reshape it.
        """
        shape = x.shape
        batch = x.shape[0]
        e = x.shape[-1]
        collapsed_dim = np.prod(x.shape[1:-1])

        return x.reshape(batch, collapsed_dim, e)

    @staticmethod
    def is_prime(n):
        if n <= 1:
            return False
        for i in range(2, int(math.sqrt(n)) + 1):
            if n % i == 0:
                return False
        return True

    def factorize_to_get_h(self, tokens_per_timestep):
        primes = [2, 3, 5, 7, 11]
        j = primes[0]
        i = 0
        value = tokens_per_timestep
        dividers = []
        status = 0

        while not status:
            if self.is_prime(value):
                status = 1
            else:
                if value % j == 0:
                    value //= j
                    dividers.append(j)
                else:
                    i += 1
                    j = primes[i]

                status = 0

        return int(np.prod(dividers) / 2)

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        out = []
        for x in features:
            if self.remove_cls_token:
                x_no_token = x[:, 1:, :]
            else:
                x_no_token = x
            x_no_token = self.collapse_dims(x_no_token)
            number_of_tokens = x_no_token.shape[1]
            tokens_per_timestep = number_of_tokens // self.effective_time_dim

            # Adaptation to use non-square images
            h = self.h or math.sqrt(tokens_per_timestep)
            if h - int(h) == 0:
                h = int(h)
            else:
                h = self.factorize_to_get_h(tokens_per_timestep)

            encoded = rearrange(
                x_no_token,
                "batch (t h w) e -> batch (t e) h w",
                batch=x_no_token.shape[0],
                t=self.effective_time_dim,
                h=h,
            )

            out.append(encoded)
        return out

    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        return super().process_channel_list(channel_list)

__init__(channel_list, remove_cls_token=True, effective_time_dim=1, h=None) #

Reshape output of transformer encoder so it can be passed to a conv net.

Parameters:

Name Type Description Default
remove_cls_token bool

Whether to remove the cls token from the first position. Defaults to True.

True
effective_time_dim int

The effective temporal dimension the transformer processes. For a ViT, his will be given by num_frames // tubelet size. This is used to determine the temporal dimension of the embedding, which is concatenated with the embedding dimension. For example: - A model which processes 1 frame with a tubelet size of 1 has an effective_time_dim of 1. The embedding produced by this model has embedding size embed_dim * 1. - A model which processes 3 frames with a tubelet size of 1 has an effective_time_dim of 3. The embedding produced by this model has embedding size embed_dim * 3. - A model which processes 12 frames with a tubelet size of 4 has an effective_time_dim of 3. The embedding produced by this model has an embedding size embed_dim * 3. Defaults to 1.

1
h int | None

You can choose a value for the height of the reshaped image. The embedding size will be implicitly discovered from it.

None
Source code in terratorch/models/necks.py
def __init__(
    self, channel_list: list[int], remove_cls_token=True, effective_time_dim: int = 1, h: int | None = None
):
    """Reshape output of transformer encoder so it can be passed to a conv net.

    Args:
        remove_cls_token (bool, optional): Whether to remove the cls token from the first position.
            Defaults to True.
        effective_time_dim (int, optional): The effective temporal dimension the transformer processes.
            For a ViT, his will be given by `num_frames // tubelet size`. This is used to determine
            the temporal dimension of the embedding, which is concatenated with the embedding dimension.
            For example:
            - A model which processes 1 frame with a tubelet size of 1 has an effective_time_dim of 1.
                The embedding produced by this model has embedding size embed_dim * 1.
            - A model which processes 3 frames with a tubelet size of 1 has an effective_time_dim of 3.
                The embedding produced by this model has embedding size embed_dim * 3.
            - A model which processes 12 frames with a tubelet size of 4 has an effective_time_dim of 3.
                The embedding produced by this model has an embedding size embed_dim * 3.
            Defaults to 1.
        h (int | None):
            You can choose a value for the height of the reshaped image.
            The embedding size will be implicitly discovered from it.
    """
    super().__init__(channel_list)
    self.remove_cls_token = remove_cls_token
    self.effective_time_dim = effective_time_dim
    self.h = h

collapse_dims(x) #

When the encoder output has more than 3 dimensions, is necessary to reshape it.

Source code in terratorch/models/necks.py
def collapse_dims(self, x):
    """
    When the encoder output has more than 3 dimensions, is necessary to
    reshape it.
    """
    shape = x.shape
    batch = x.shape[0]
    e = x.shape[-1]
    collapsed_dim = np.prod(x.shape[1:-1])

    return x.reshape(batch, collapsed_dim, e)

terratorch.models.necks.InterpolateToPyramidal #

Bases: Neck

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class InterpolateToPyramidal(Neck):
    def __init__(self, channel_list: list[int], scale_factor: int = 2, mode: str = "nearest"):
        """Spatially interpolate embeddings so that embedding[i - 1] is scale_factor times larger than embedding[i]

        Useful to make non-pyramidal backbones compatible with hierarachical ones
        Args:
            scale_factor (int): Amount to scale embeddings by each layer. Defaults to 2.
            mode (str): Interpolation mode to be passed to torch.nn.functional.interpolate. Defaults to 'nearest'.
        """
        super().__init__(channel_list)
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        out = []
        scale_exponents = list(range(len(features), 0, -1))
        for x, exponent in zip(features, scale_exponents, strict=True):
            out.append(F.interpolate(x, scale_factor=self.scale_factor**exponent, mode=self.mode))

        return out

    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        return super().process_channel_list(channel_list)

__init__(channel_list, scale_factor=2, mode='nearest') #

Spatially interpolate embeddings so that embedding[i - 1] is scale_factor times larger than embedding[i]

Useful to make non-pyramidal backbones compatible with hierarachical ones Args: scale_factor (int): Amount to scale embeddings by each layer. Defaults to 2. mode (str): Interpolation mode to be passed to torch.nn.functional.interpolate. Defaults to 'nearest'.

Source code in terratorch/models/necks.py
def __init__(self, channel_list: list[int], scale_factor: int = 2, mode: str = "nearest"):
    """Spatially interpolate embeddings so that embedding[i - 1] is scale_factor times larger than embedding[i]

    Useful to make non-pyramidal backbones compatible with hierarachical ones
    Args:
        scale_factor (int): Amount to scale embeddings by each layer. Defaults to 2.
        mode (str): Interpolation mode to be passed to torch.nn.functional.interpolate. Defaults to 'nearest'.
    """
    super().__init__(channel_list)
    self.scale_factor = scale_factor
    self.mode = mode

terratorch.models.necks.LearnedInterpolateToPyramidal #

Bases: Neck

Use learned convolutions to transform the output of a non-pyramidal encoder into pyramidal ones

Always requires exactly 4 embeddings

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class LearnedInterpolateToPyramidal(Neck):
    """Use learned convolutions to transform the output of a non-pyramidal encoder into pyramidal ones

    Always requires exactly 4 embeddings
    """

    def __init__(self, channel_list: list[int]):
        super().__init__(channel_list)
        if len(channel_list) != 4:
            msg = "This class can only handle exactly 4 input embeddings"
            raise Exception(msg)
        self.fpn1 = nn.Sequential(
            nn.ConvTranspose2d(channel_list[0], channel_list[0] // 2, 2, 2),
            nn.BatchNorm2d(channel_list[0] // 2),
            nn.GELU(),
            nn.ConvTranspose2d(channel_list[0] // 2, channel_list[0] // 4, 2, 2),
        )
        self.fpn2 = nn.Sequential(nn.ConvTranspose2d(channel_list[1], channel_list[1] // 2, 2, 2))
        self.fpn3 = nn.Sequential(nn.Identity())
        self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
        self.embedding_dim = [channel_list[0] // 4, channel_list[1] // 2, channel_list[2], channel_list[3]]

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        scaled_inputs = []
        scaled_inputs.append(self.fpn1(features[0]))
        scaled_inputs.append(self.fpn2(features[1]))
        scaled_inputs.append(self.fpn3(features[2]))
        scaled_inputs.append(self.fpn4(features[3]))
        return scaled_inputs

    def process_channel_list(self, channel_list: list[int] = None) -> list[int]:
        return [channel_list[0] // 4, channel_list[1] // 2, channel_list[2], channel_list[3]]

terratorch.models.necks.PermuteDims #

Bases: Neck

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class PermuteDims(Neck):
    def __init__(self, channel_list: list[int], new_order: list[int]):
        """Permute dimensions of each element in the embedding list

        Args:
            new_order (list[int]): list of indices to be passed to tensor.permute()
        """
        super().__init__(channel_list)
        self.new_order = new_order

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        features = [feat.permute(*self.new_order).contiguous() for feat in features]
        return features

    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        return super().process_channel_list(channel_list)

__init__(channel_list, new_order) #

Permute dimensions of each element in the embedding list

Parameters:

Name Type Description Default
new_order list[int]

list of indices to be passed to tensor.permute()

required
Source code in terratorch/models/necks.py
def __init__(self, channel_list: list[int], new_order: list[int]):
    """Permute dimensions of each element in the embedding list

    Args:
        new_order (list[int]): list of indices to be passed to tensor.permute()
    """
    super().__init__(channel_list)
    self.new_order = new_order

terratorch.models.necks.MaxpoolToPyramidal #

Bases: Neck

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class MaxpoolToPyramidal(Neck):
    def __init__(self, channel_list: list[int], kernel_size: int = 2):
        """Spatially downsample embeddings so that embedding[i - 1] is scale_factor times smaller than embedding[i]

        Useful to make non-pyramidal backbones compatible with hierarachical ones
        Args:
            kernel_size (int). Base kernel size to use for maxpool. Defaults to 2.
        """
        super().__init__(channel_list)
        self.kernel_size = kernel_size

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        out = []
        scale_exponents = list(range(len(features)))
        for x, exponent in zip(features, scale_exponents, strict=True):
            if exponent == 0:
                out.append(x.clone())
            else:
                out.append(F.max_pool2d(x, kernel_size=self.kernel_size**exponent))

        return out

    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        return super().process_channel_list(channel_list)

__init__(channel_list, kernel_size=2) #

Spatially downsample embeddings so that embedding[i - 1] is scale_factor times smaller than embedding[i]

Useful to make non-pyramidal backbones compatible with hierarachical ones Args: kernel_size (int). Base kernel size to use for maxpool. Defaults to 2.

Source code in terratorch/models/necks.py
def __init__(self, channel_list: list[int], kernel_size: int = 2):
    """Spatially downsample embeddings so that embedding[i - 1] is scale_factor times smaller than embedding[i]

    Useful to make non-pyramidal backbones compatible with hierarachical ones
    Args:
        kernel_size (int). Base kernel size to use for maxpool. Defaults to 2.
    """
    super().__init__(channel_list)
    self.kernel_size = kernel_size

terratorch.models.necks.AddBottleneckLayer #

Bases: Neck

Add a layer that reduces the channel dimension of the final embedding by half, and concatenates it

Useful for compatibility with some smp decoders.

Source code in terratorch/models/necks.py
@TERRATORCH_NECK_REGISTRY.register
class AddBottleneckLayer(Neck):
    """Add a layer that reduces the channel dimension of the final embedding by half, and concatenates it

    Useful for compatibility with some smp decoders.
    """

    def __init__(self, channel_list: list[int]):
        super().__init__(channel_list)
        self.bottleneck = nn.Conv2d(channel_list[-1], channel_list[-1] // 2, kernel_size=1)

    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        new_embedding = self.bottleneck(features[-1])
        features.append(new_embedding)
        return features

    def process_channel_list(self, channel_list: list[int]) -> list[int]:
        return [*channel_list, channel_list[-1] // 2]