Skip to content

Decoder#

Tip

The IdentityDecoder is typically used for classification tasks while the UNetDecoder is suited for pixel-wise segmentation or regression tasks.

terratorch.models.decoders.identity_decoder.IdentityDecoder #

Bases: Module

Identity decoder. Useful to pass the feature straight to the head.

Source code in terratorch/models/decoders/identity_decoder.py
@TERRATORCH_DECODER_REGISTRY.register
class IdentityDecoder(nn.Module):
    """Identity decoder. Useful to pass the feature straight to the head."""

    def __init__(self, embed_dim: int, out_index=-1) -> None:
        """Constructor

        Args:
            embed_dim (int): Input embedding dimension
            out_index (int, optional): Index of the input list to take.. Defaults to -1.
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.dim = out_index

    @property
    def out_channels(self):
        return self.embed_dim[self.dim]

    def forward(self, x: list[Tensor]):
        return x[self.dim]

__init__(embed_dim, out_index=-1) #

Constructor

Parameters:

Name Type Description Default
embed_dim int

Input embedding dimension

required
out_index int

Index of the input list to take.. Defaults to -1.

-1
Source code in terratorch/models/decoders/identity_decoder.py
def __init__(self, embed_dim: int, out_index=-1) -> None:
    """Constructor

    Args:
        embed_dim (int): Input embedding dimension
        out_index (int, optional): Index of the input list to take.. Defaults to -1.
    """
    super().__init__()
    self.embed_dim = embed_dim
    self.dim = out_index

terratorch.models.decoders.unet_decoder.UNetDecoder #

Bases: Module

UNetDecoder. Wrapper around UNetDecoder from segmentation_models_pytorch to avoid ignoring the first layer.

Source code in terratorch/models/decoders/unet_decoder.py
@TERRATORCH_DECODER_REGISTRY.register
class UNetDecoder(nn.Module):
    """UNetDecoder. Wrapper around UNetDecoder from segmentation_models_pytorch to avoid ignoring the first layer."""

    def __init__(
        self, embed_dim: list[int], channels: list[int], use_batchnorm: bool = True, attention_type: str | None = None
    ):
        """Constructor

        Args:
            embed_dim (list[int]): Input embedding dimension for each input.
            channels (list[int]): Channels used in the decoder.
            use_batchnorm (bool, optional): Whether to use batchnorm. Defaults to True.
            attention_type (str | None, optional): Attention type to use. Defaults to None
        """
        if len(embed_dim) != len(channels):
            msg = "channels should have the same length as embed_dim"
            raise ValueError(msg)
        super().__init__()
        self.decoder = UnetDecoder(
            encoder_channels=[embed_dim[0], *embed_dim],
            decoder_channels=channels,
            n_blocks=len(channels),
            use_batchnorm=use_batchnorm,
            center=False,
            attention_type=attention_type,
        )
        initialize_decoder(self.decoder)
        self.out_channels = channels[-1]

    def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
        # The first layer is ignored in the original UnetDecoder, so we need to duplicate the first layer
        x = [x[0].clone(), *x]
        return self.decoder(*x)

__init__(embed_dim, channels, use_batchnorm=True, attention_type=None) #

Constructor

Parameters:

Name Type Description Default
embed_dim list[int]

Input embedding dimension for each input.

required
channels list[int]

Channels used in the decoder.

required
use_batchnorm bool

Whether to use batchnorm. Defaults to True.

True
attention_type str | None

Attention type to use. Defaults to None

None
Source code in terratorch/models/decoders/unet_decoder.py
def __init__(
    self, embed_dim: list[int], channels: list[int], use_batchnorm: bool = True, attention_type: str | None = None
):
    """Constructor

    Args:
        embed_dim (list[int]): Input embedding dimension for each input.
        channels (list[int]): Channels used in the decoder.
        use_batchnorm (bool, optional): Whether to use batchnorm. Defaults to True.
        attention_type (str | None, optional): Attention type to use. Defaults to None
    """
    if len(embed_dim) != len(channels):
        msg = "channels should have the same length as embed_dim"
        raise ValueError(msg)
    super().__init__()
    self.decoder = UnetDecoder(
        encoder_channels=[embed_dim[0], *embed_dim],
        decoder_channels=channels,
        n_blocks=len(channels),
        use_batchnorm=use_batchnorm,
        center=False,
        attention_type=attention_type,
    )
    initialize_decoder(self.decoder)
    self.out_channels = channels[-1]

terratorch.models.decoders.upernet_decoder.UperNetDecoder #

Bases: Module

UperNetDecoder. Adapted from MMSegmentation.

Source code in terratorch/models/decoders/upernet_decoder.py
@TERRATORCH_DECODER_REGISTRY.register
class UperNetDecoder(nn.Module):
    """UperNetDecoder. Adapted from MMSegmentation."""

    def __init__(
        self,
        embed_dim: list[int],
        pool_scales: tuple[int] = (1, 2, 3, 6),
        channels: int = 256,
        align_corners: bool = True,  # noqa: FBT001, FBT002
        scale_modules: bool = False,
    ):
        """Constructor

        Args:
            embed_dim (list[int]): Input embedding dimension for each input.
            pool_scales (tuple[int], optional): Pooling scales used in Pooling Pyramid
                Module applied on the last feature. Default: (1, 2, 3, 6).
            channels (int, optional): Channels used in the decoder. Defaults to 256.
            align_corners (bool, optional): Wheter to align corners in rescaling. Defaults to True.
            scale_modules (bool, optional): Whether to apply scale modules to the inputs. Needed for plain ViT.
                Defaults to False.
        """
        super().__init__()
        if scale_modules:
            # TODO: remove scale_modules before v1?
            warnings.warn(
                "DeprecationWarning: scale_modules is deprecated and will be removed in future versions. "
                "Use LearnedInterpolateToPyramidal neck instead.",
                stacklevel=1,
            )

        self.scale_modules = scale_modules
        if scale_modules:
            self.fpn1 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim[0],
                                embed_dim[0] // 2, 2, 2),
                nn.BatchNorm2d(embed_dim[0] // 2),
                nn.GELU(),
                nn.ConvTranspose2d(embed_dim[0] // 2,
                                embed_dim[0] // 4, 2, 2))
            self.fpn2 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim[1],
                                embed_dim[1] // 2, 2, 2))
            self.fpn3 = nn.Sequential(nn.Identity())
            self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
            self.embed_dim = [embed_dim[0] // 4, embed_dim[1] // 2, embed_dim[2], embed_dim[3]]
        else:
            self.embed_dim = embed_dim

        self.out_channels = channels
        self.channels = channels
        self.align_corners = align_corners
        # PSP Module
        self.psp_modules = PPM(
            pool_scales,
            self.embed_dim[-1],
            self.channels,
            align_corners=self.align_corners,
        )
        self.bottleneck = ConvModule(
            self.embed_dim[-1] + len(pool_scales) * self.channels, self.channels, 3, padding=1, inplace=True
        )
        # FPN Module
        self.lateral_convs = nn.ModuleList()
        self.fpn_convs = nn.ModuleList()
        for embed_dim in self.embed_dim[:-1]:  # skip the top layer
            l_conv = ConvModule(
                embed_dim,
                self.channels,
                1,
                inplace=False,
            )
            fpn_conv = ConvModule(
                self.channels,
                self.channels,
                3,
                padding=1,
                inplace=False,
            )
            self.lateral_convs.append(l_conv)
            self.fpn_convs.append(fpn_conv)

        self.fpn_bottleneck = ConvModule(len(self.embed_dim) * self.channels, self.channels, 3, padding=1, inplace=True)

    def psp_forward(self, inputs):
        """Forward function of PSP module."""
        x = inputs[-1]
        psp_outs = [x]
        psp_outs.extend(self.psp_modules(x))
        psp_outs = torch.cat(psp_outs, dim=1)
        output = self.bottleneck(psp_outs)

        return output

    def forward(self, inputs):
        """Forward function for feature maps before classifying each pixel with
        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            feats (Tensor): A tensor of shape (batch_size, self.channels,
                H, W) which is feature map for last layer of decoder head.
        """

        if self.scale_modules:
            scaled_inputs = []
            scaled_inputs.append(self.fpn1(inputs[0]))
            scaled_inputs.append(self.fpn2(inputs[1]))
            scaled_inputs.append(self.fpn3(inputs[2]))
            scaled_inputs.append(self.fpn4(inputs[3]))
            inputs = scaled_inputs
        # build laterals
        laterals = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
        laterals.append(self.psp_forward(inputs))

        # build top-down path
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
            prev_shape = laterals[i - 1].shape[2:]
            laterals[i - 1] = laterals[i - 1] + torch.nn.functional.interpolate(
                laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
            )

        # build outputs
        fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
        # append psp feature
        fpn_outs.append(laterals[-1])

        for i in range(used_backbone_levels - 1, 0, -1):
            fpn_outs[i] = torch.nn.functional.interpolate(
                fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
            )
        fpn_outs = torch.cat(fpn_outs, dim=1)
        feats = self.fpn_bottleneck(fpn_outs)
        return feats

__init__(embed_dim, pool_scales=(1, 2, 3, 6), channels=256, align_corners=True, scale_modules=False) #

Constructor

Parameters:

Name Type Description Default
embed_dim list[int]

Input embedding dimension for each input.

required
pool_scales tuple[int]

Pooling scales used in Pooling Pyramid Module applied on the last feature. Default: (1, 2, 3, 6).

(1, 2, 3, 6)
channels int

Channels used in the decoder. Defaults to 256.

256
align_corners bool

Wheter to align corners in rescaling. Defaults to True.

True
scale_modules bool

Whether to apply scale modules to the inputs. Needed for plain ViT. Defaults to False.

False
Source code in terratorch/models/decoders/upernet_decoder.py
def __init__(
    self,
    embed_dim: list[int],
    pool_scales: tuple[int] = (1, 2, 3, 6),
    channels: int = 256,
    align_corners: bool = True,  # noqa: FBT001, FBT002
    scale_modules: bool = False,
):
    """Constructor

    Args:
        embed_dim (list[int]): Input embedding dimension for each input.
        pool_scales (tuple[int], optional): Pooling scales used in Pooling Pyramid
            Module applied on the last feature. Default: (1, 2, 3, 6).
        channels (int, optional): Channels used in the decoder. Defaults to 256.
        align_corners (bool, optional): Wheter to align corners in rescaling. Defaults to True.
        scale_modules (bool, optional): Whether to apply scale modules to the inputs. Needed for plain ViT.
            Defaults to False.
    """
    super().__init__()
    if scale_modules:
        # TODO: remove scale_modules before v1?
        warnings.warn(
            "DeprecationWarning: scale_modules is deprecated and will be removed in future versions. "
            "Use LearnedInterpolateToPyramidal neck instead.",
            stacklevel=1,
        )

    self.scale_modules = scale_modules
    if scale_modules:
        self.fpn1 = nn.Sequential(
            nn.ConvTranspose2d(embed_dim[0],
                            embed_dim[0] // 2, 2, 2),
            nn.BatchNorm2d(embed_dim[0] // 2),
            nn.GELU(),
            nn.ConvTranspose2d(embed_dim[0] // 2,
                            embed_dim[0] // 4, 2, 2))
        self.fpn2 = nn.Sequential(
            nn.ConvTranspose2d(embed_dim[1],
                            embed_dim[1] // 2, 2, 2))
        self.fpn3 = nn.Sequential(nn.Identity())
        self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
        self.embed_dim = [embed_dim[0] // 4, embed_dim[1] // 2, embed_dim[2], embed_dim[3]]
    else:
        self.embed_dim = embed_dim

    self.out_channels = channels
    self.channels = channels
    self.align_corners = align_corners
    # PSP Module
    self.psp_modules = PPM(
        pool_scales,
        self.embed_dim[-1],
        self.channels,
        align_corners=self.align_corners,
    )
    self.bottleneck = ConvModule(
        self.embed_dim[-1] + len(pool_scales) * self.channels, self.channels, 3, padding=1, inplace=True
    )
    # FPN Module
    self.lateral_convs = nn.ModuleList()
    self.fpn_convs = nn.ModuleList()
    for embed_dim in self.embed_dim[:-1]:  # skip the top layer
        l_conv = ConvModule(
            embed_dim,
            self.channels,
            1,
            inplace=False,
        )
        fpn_conv = ConvModule(
            self.channels,
            self.channels,
            3,
            padding=1,
            inplace=False,
        )
        self.lateral_convs.append(l_conv)
        self.fpn_convs.append(fpn_conv)

    self.fpn_bottleneck = ConvModule(len(self.embed_dim) * self.channels, self.channels, 3, padding=1, inplace=True)

forward(inputs) #

Forward function for feature maps before classifying each pixel with Args: inputs (list[Tensor]): List of multi-level img features.

Returns:

Name Type Description
feats Tensor

A tensor of shape (batch_size, self.channels, H, W) which is feature map for last layer of decoder head.

Source code in terratorch/models/decoders/upernet_decoder.py
def forward(self, inputs):
    """Forward function for feature maps before classifying each pixel with
    Args:
        inputs (list[Tensor]): List of multi-level img features.

    Returns:
        feats (Tensor): A tensor of shape (batch_size, self.channels,
            H, W) which is feature map for last layer of decoder head.
    """

    if self.scale_modules:
        scaled_inputs = []
        scaled_inputs.append(self.fpn1(inputs[0]))
        scaled_inputs.append(self.fpn2(inputs[1]))
        scaled_inputs.append(self.fpn3(inputs[2]))
        scaled_inputs.append(self.fpn4(inputs[3]))
        inputs = scaled_inputs
    # build laterals
    laterals = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
    laterals.append(self.psp_forward(inputs))

    # build top-down path
    used_backbone_levels = len(laterals)
    for i in range(used_backbone_levels - 1, 0, -1):
        prev_shape = laterals[i - 1].shape[2:]
        laterals[i - 1] = laterals[i - 1] + torch.nn.functional.interpolate(
            laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
        )

    # build outputs
    fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
    # append psp feature
    fpn_outs.append(laterals[-1])

    for i in range(used_backbone_levels - 1, 0, -1):
        fpn_outs[i] = torch.nn.functional.interpolate(
            fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
        )
    fpn_outs = torch.cat(fpn_outs, dim=1)
    feats = self.fpn_bottleneck(fpn_outs)
    return feats

psp_forward(inputs) #

Forward function of PSP module.

Source code in terratorch/models/decoders/upernet_decoder.py
def psp_forward(self, inputs):
    """Forward function of PSP module."""
    x = inputs[-1]
    psp_outs = [x]
    psp_outs.extend(self.psp_modules(x))
    psp_outs = torch.cat(psp_outs, dim=1)
    output = self.bottleneck(psp_outs)

    return output

terratorch.models.decoders.fcn_decoder.FCNDecoder #

Bases: Module

Fully Convolutional Decoder

Source code in terratorch/models/decoders/fcn_decoder.py
@TERRATORCH_DECODER_REGISTRY.register
class FCNDecoder(nn.Module):
    """Fully Convolutional Decoder"""

    def __init__(self, embed_dim: int, channels: int = 256, num_convs: int = 4, in_index: int = -1) -> None:
        """Constructor

        Args:
            embed_dim (_type_): Input embedding dimension
            channels (int, optional): Number of channels for each conv. Defaults to 256.
            num_convs (int, optional): Number of convs. Defaults to 4.
            in_index (int, optional): Index of the input list to take. Defaults to -1.
        """
        super().__init__()
        kernel_size = 2
        stride = 2
        dilation = 1
        padding = 0
        output_padding = 0
        self.channels = channels
        self.num_convs = num_convs
        self.in_index = in_index
        self.embed_dim = embed_dim[in_index]
        if num_convs < 1:
            msg = "num_convs must be >= 1"
            raise Exception(msg)

        convs = []

        for i in range(num_convs):
            in_channels = self.embed_dim if i == 0 else self.channels
            convs.append(
                _conv_upscale_block(in_channels, self.channels, kernel_size, stride, dilation, padding, output_padding)
            )

        self.convs = nn.Sequential(*convs)

    @property
    def out_channels(self):
        return self.channels

    def forward(self, x: list[Tensor]):
        x = x[self.in_index]
        decoded = self.convs(x)
        return decoded

__init__(embed_dim, channels=256, num_convs=4, in_index=-1) #

Constructor

Parameters:

Name Type Description Default
embed_dim _type_

Input embedding dimension

required
channels int

Number of channels for each conv. Defaults to 256.

256
num_convs int

Number of convs. Defaults to 4.

4
in_index int

Index of the input list to take. Defaults to -1.

-1
Source code in terratorch/models/decoders/fcn_decoder.py
def __init__(self, embed_dim: int, channels: int = 256, num_convs: int = 4, in_index: int = -1) -> None:
    """Constructor

    Args:
        embed_dim (_type_): Input embedding dimension
        channels (int, optional): Number of channels for each conv. Defaults to 256.
        num_convs (int, optional): Number of convs. Defaults to 4.
        in_index (int, optional): Index of the input list to take. Defaults to -1.
    """
    super().__init__()
    kernel_size = 2
    stride = 2
    dilation = 1
    padding = 0
    output_padding = 0
    self.channels = channels
    self.num_convs = num_convs
    self.in_index = in_index
    self.embed_dim = embed_dim[in_index]
    if num_convs < 1:
        msg = "num_convs must be >= 1"
        raise Exception(msg)

    convs = []

    for i in range(num_convs):
        in_channels = self.embed_dim if i == 0 else self.channels
        convs.append(
            _conv_upscale_block(in_channels, self.channels, kernel_size, stride, dilation, padding, output_padding)
        )

    self.convs = nn.Sequential(*convs)

terratorch.models.decoders.linear_decoder.LinearDecoder #

Bases: Module

A linear decoder using a transposed convolution layer for upsampling.

Source code in terratorch/models/decoders/linear_decoder.py
@TERRATORCH_DECODER_REGISTRY.register
class LinearDecoder(nn.Module):
    """
    A linear decoder using a transposed convolution layer for upsampling.
    """
    includes_head: bool = True

    def __init__(self, embed_dim: list[int], num_classes: int, upsampling_size: int, in_index: int = -1) -> None:
        """Constructor

        Args:
            embed_dim (list[int]): A list of embedding dimensions for different feature maps.
            num_classes (int): Number of output classes.
            upsampling_size (int): Kernel and stride size for transposed convolution.
            in_index (int, optional): Index of the input feature map to use. Defaults to -1."
        """
        super().__init__()
        self.num_classes = num_classes
        self.in_index = in_index
        self.embed_dim = embed_dim[in_index]

        self.conv = nn.ConvTranspose2d(
            in_channels=self.embed_dim,
            out_channels=self.num_classes,
            kernel_size=upsampling_size,
            stride=upsampling_size,
            padding=0,
            output_padding=0,
        )

    @property
    def out_channels(self) -> int:
        return self.num_classes

    def forward(self, x: list[Tensor]) -> Tensor:
        return self.conv(x[self.in_index])

__init__(embed_dim, num_classes, upsampling_size, in_index=-1) #

Constructor

Parameters:

Name Type Description Default
embed_dim list[int]

A list of embedding dimensions for different feature maps.

required
num_classes int

Number of output classes.

required
upsampling_size int

Kernel and stride size for transposed convolution.

required
in_index int

Index of the input feature map to use. Defaults to -1."

-1
Source code in terratorch/models/decoders/linear_decoder.py
def __init__(self, embed_dim: list[int], num_classes: int, upsampling_size: int, in_index: int = -1) -> None:
    """Constructor

    Args:
        embed_dim (list[int]): A list of embedding dimensions for different feature maps.
        num_classes (int): Number of output classes.
        upsampling_size (int): Kernel and stride size for transposed convolution.
        in_index (int, optional): Index of the input feature map to use. Defaults to -1."
    """
    super().__init__()
    self.num_classes = num_classes
    self.in_index = in_index
    self.embed_dim = embed_dim[in_index]

    self.conv = nn.ConvTranspose2d(
        in_channels=self.embed_dim,
        out_channels=self.num_classes,
        kernel_size=upsampling_size,
        stride=upsampling_size,
        padding=0,
        output_padding=0,
    )

terratorch.models.decoders.mlp_decoder.MLPDecoder #

Bases: Module

Identity decoder. Useful to pass the feature straight to the head.

Source code in terratorch/models/decoders/mlp_decoder.py
@TERRATORCH_DECODER_REGISTRY.register
class MLPDecoder(nn.Module):
    """Identity decoder. Useful to pass the feature straight to the head."""

    def __init__(self, embed_dim: int, channels: int = 100, out_dim:int = 100, activation: str = "ReLU", out_index=-1) -> None:
        """Constructor

        Args:
            embed_dim (int): Input embedding dimension
            out_index (int, optional): Index of the input list to take.. Defaults to -1.
        """

        super().__init__()
        self.embed_dim = embed_dim
        self.channels = channels
        self.dim = out_index
        self.n_inputs = len(self.embed_dim)
        self.out_channels = self.embed_dim[self.dim]
        self.hidden_layer = torch.nn.Linear(self.out_channels*self.n_inputs, self.out_channels)
        self.activation = getattr(nn, activation)()

    def forward(self, x: list[Tensor]):

        data_ = torch.cat(x, axis=1)
        data_ = data_.permute(0, 2, 3, 1)
        data_ = self.activation(self.hidden_layer(data_))
        data_ = data_.permute(0, 3, 1, 2)

        return data_ 

__init__(embed_dim, channels=100, out_dim=100, activation='ReLU', out_index=-1) #

Constructor

Parameters:

Name Type Description Default
embed_dim int

Input embedding dimension

required
out_index int

Index of the input list to take.. Defaults to -1.

-1
Source code in terratorch/models/decoders/mlp_decoder.py
def __init__(self, embed_dim: int, channels: int = 100, out_dim:int = 100, activation: str = "ReLU", out_index=-1) -> None:
    """Constructor

    Args:
        embed_dim (int): Input embedding dimension
        out_index (int, optional): Index of the input list to take.. Defaults to -1.
    """

    super().__init__()
    self.embed_dim = embed_dim
    self.channels = channels
    self.dim = out_index
    self.n_inputs = len(self.embed_dim)
    self.out_channels = self.embed_dim[self.dim]
    self.hidden_layer = torch.nn.Linear(self.out_channels*self.n_inputs, self.out_channels)
    self.activation = getattr(nn, activation)()

terratorch.models.decoders.aspp_head #

ASPPHead #

Bases: Module

Rethinking Atrous Convolution for Semantic Image Segmentation.

This head is the implementation of DeepLabV3 <https://arxiv.org/abs/1706.05587>_.

Parameters:

Name Type Description Default
dilations tuple[int]

Dilation rates for ASPP module. Default: (1, 6, 12, 18).

(1, 6, 12, 18)
Source code in terratorch/models/decoders/aspp_head.py
@TERRATORCH_DECODER_REGISTRY.register
class ASPPHead(nn.Module):
    """Rethinking Atrous Convolution for Semantic Image Segmentation.

    This head is the implementation of `DeepLabV3
    <https://arxiv.org/abs/1706.05587>`_.

    Args:
        dilations (tuple[int]): Dilation rates for ASPP module.
            Default: (1, 6, 12, 18).
    """

    def __init__(self, dilations:list | tuple =(1, 6, 12, 18), 
                 in_channels:int=None, 
                 channels:int=None,
                 out_dim:int=3,
                 align_corners=False,
                 head_dropout_ratio:float=0.3,
                 input_transform: str = None,
                 in_index: int = -1,
                 **kwargs):

        super(ASPPHead, self).__init__(**kwargs)

        self.dilations = dilations
        self.in_channels = in_channels
        self.channels = channels
        self.out_dim = out_dim

        self.align_corners = align_corners
        self.input_transform = input_transform
        self.in_index = in_index 

        if 'conv_cfg' not in kwargs:
            self.conv_cfg = self._default_conv_cfg

        if 'norm_cfg' not in kwargs:
            self.norm_cfg = self._default_norm_cfg

        if 'act_cfg' not in kwargs:
            self.act_cfg = self._default_act_cfg

        self.image_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            ConvModule(
                self.in_channels,
                self.channels,
                1,

                # TODO Extend it to support more possible configurations
                # for convolution, normalization and activation.
                #self.conv_cfg,
                #norm_cfg=self.norm_cfg,
                #act_cfg=self.act_cfg))
                ))

        self.aspp_modules = ASPPModule(
            dilations,
            self.in_channels,
            self.channels,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)

        self.bottleneck = ConvModule(
            (len(dilations) + 1) * self.channels,
            self.channels,
            padding=1,)
            # TODO Extend it to support more possible configurations
            # for convolution, normalization and activation.
            #conv_cfg=self.conv_cfg,
            #norm_cfg=self.norm_cfg,
            #act_cfg=self.act_cfg)

        if head_dropout_ratio > 0:
            self.dropout = nn.Dropout2d(head_dropout_ratio)

    @property
    def _default_conv_cfg(self):
        return {"kernel_size": 3, "padding": 0, "bias": False}

    @property
    def _default_norm_cfg(self):
        return {}

    @property
    def _default_act_cfg(self):
        return {}

    def _transform_inputs(self, inputs):
        """Transform inputs for decoder.

        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            Tensor: The transformed inputs
        """

        if self.input_transform == 'resize_concat':
            inputs = [inputs[i] for i in self.in_index]
            upsampled_inputs = [
                resize(
                    input=x,
                    size=inputs[0].shape[2:],
                    mode='bilinear',
                    align_corners=self.align_corners) for x in inputs
            ]
            inputs = torch.cat(upsampled_inputs, dim=1)
        elif self.input_transform == 'multiple_select':
            inputs = [inputs[i] for i in self.in_index]
        else:
            inputs = inputs[self.in_index]

        return inputs

    def _forward_feature(self, inputs):
        """Forward function.

        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            feats (Tensor): A tensor of shape (batch_size, self.channels,
                H, W) which is feature map for last layer of decoder head.
        """
        inputs = self._transform_inputs(inputs)

        aspp_outs = [
            resize(
                self.image_pool(inputs),
                size=inputs.size()[2:],
                mode='bilinear',
                align_corners=self.align_corners)
        ]

        aspp_outs.extend(self.aspp_modules(inputs))
        aspp_outs = torch.cat(aspp_outs, dim=1)
        feats = self.bottleneck(aspp_outs)

        return feats

    def forward(self, inputs):

        output = self._forward_feature(inputs)

        return output

ASPPSegmentationHead #

Bases: ASPPHead

Rethinking Atrous Convolution for Semantic Image Segmentation.

This head is the implementation of DeepLabV3 <https://arxiv.org/abs/1706.05587>_.

Parameters:

Name Type Description Default
dilations tuple[int]

Dilation rates for ASPP module. Default: (1, 6, 12, 18).

(1, 6, 12, 18)
Source code in terratorch/models/decoders/aspp_head.py
@TERRATORCH_DECODER_REGISTRY.register
class ASPPSegmentationHead(ASPPHead):
    """Rethinking Atrous Convolution for Semantic Image Segmentation.

    This head is the implementation of `DeepLabV3
    <https://arxiv.org/abs/1706.05587>`_.

    Args:
        dilations (tuple[int]): Dilation rates for ASPP module.
            Default: (1, 6, 12, 18).
    """

    def __init__(self, channel_list,
                 dilations:list | tuple =(1, 6, 12, 18), 
                 in_channels:int=None, 
                 channels:int=None,
                 num_classes:int=2,
                 align_corners=False,
                 head_dropout_ratio:float=0.3,
                 input_transform: str = None,
                 in_index: int = -1,
                 **kwargs):

        super(ASPPSegmentationHead, self).__init__(
                 dilations=dilations, 
                 in_channels=in_channels, 
                 channels=channels,
                 align_corners=align_corners,
                 head_dropout_ratio=head_dropout_ratio,
                 input_transform=input_transform,
                 in_index=in_index,
                **kwargs)

        self.num_classes = num_classes
        self.conv_seg = nn.Conv2d(self.channels, self.num_classes, kernel_size=1)

        if head_dropout_ratio > 0:
            self.dropout = nn.Dropout2d(head_dropout_ratio)

    def segmentation_head(self, features):

        """PixelWise classification"""

        if self.dropout is not None:
            features = self.dropout(features)
        output = self.conv_seg(features)
        return output

    def forward(self, inputs):

        output = self._forward_feature(inputs)
        output = self.segmentation_head(output)

        return output
segmentation_head(features) #

PixelWise classification

Source code in terratorch/models/decoders/aspp_head.py
def segmentation_head(self, features):

    """PixelWise classification"""

    if self.dropout is not None:
        features = self.dropout(features)
    output = self.conv_seg(features)
    return output

ASPPRegressionHead #

Bases: ASPPHead

Rethinking Atrous Convolution for regression.

This head is the implementation of DeepLabV3 <https://arxiv.org/abs/1706.05587>_.

Parameters:

Name Type Description Default
dilations tuple[int]

Dilation rates for ASPP module. Default: (1, 6, 12, 18).

(1, 6, 12, 18)
Source code in terratorch/models/decoders/aspp_head.py
@TERRATORCH_DECODER_REGISTRY.register
class ASPPRegressionHead(ASPPHead):
    """Rethinking Atrous Convolution for regression.

    This head is the implementation of `DeepLabV3
    <https://arxiv.org/abs/1706.05587>`_.

    Args:
        dilations (tuple[int]): Dilation rates for ASPP module.
            Default: (1, 6, 12, 18).
    """

    def __init__(self, channel_list,
                 dilations:list | tuple =(1, 6, 12, 18), 
                 in_channels:int=None, 
                 channels:int=None,
                 out_channels:int=1,
                 align_corners=False,
                 head_dropout_ratio:float=0.3,
                 input_transform: str = None,
                 in_index: int = -1,
                 **kwargs):

        super(ASPPRegressionHead, self).__init__(
                 dilations=dilations, 
                 in_channels=in_channels, 
                 channels=channels,
                 out_dim=out_channels,
                 align_corners=align_corners,
                 head_dropout_ratio=head_dropout_ratio,
                 input_transform=input_transform,
                 in_index=in_index,
                **kwargs)

        self.out_channels = out_channels
        self.conv_reg = nn.Conv2d(self.channels, self.out_channels, kernel_size=1)

    def regression_head(self, features):

        """PixelWise regression"""
        if self.dropout is not None:
            features = self.dropout(features)
        output = self.conv_reg(features)
        return output


    def forward(self, inputs):

        output = self._forward_feature(inputs)
        output = self.regression_head(output)

        return output
regression_head(features) #

PixelWise regression

Source code in terratorch/models/decoders/aspp_head.py
def regression_head(self, features):

    """PixelWise regression"""
    if self.dropout is not None:
        features = self.dropout(features)
    output = self.conv_reg(features)
    return output