EncoderDecoderFactory
The Factory
A special factory provided by terratorch is the EncoderDecoderFactory.
This factory leverages the BACKBONE_REGISTRY
, DECODER_REGISTRY
and NECK_REGISTRY
to compose models formed as encoder + decoder, with some optional glue in between provided by the necks.
As most current models work this way, this is a particularly important factory, allowing for great flexibility in combining encoders and decoders from different sources.
The factory allows arguments to be passed to the encoder, decoder and head. Arguments with the prefix backbone_
will be routed to the backbone constructor, with decoder_
and head_
working the same way. These are accepted dynamically and not checked.
Any unused arguments will raise a ValueError
.
Both encoder and decoder may be passed as strings, in which case they will be looked in the respective registry, or as nn.Modules
, in which case they will be used as is. In the second case, the factory assumes in good faith that the encoder or decoder which is passed conforms to the expected contract.
Not all decoders will readily accept the raw output of the given encoder. This is where necks come in.
Necks are a sequence of operations which are applied to the output of the encoder before it is passed to the decoder.
They must be instances of Neck, which is a subclass of nn.Module
, meaning they can even define new trainable parameters.
The EncoderDecoderFactory returns a PixelWiseModel or a ScalarOutputModel depending on the task.
terratorch.models.encoder_decoder_factory.EncoderDecoderFactory
Bases: ModelFactory
Source code in terratorch/models/encoder_decoder_factory.py
| @MODEL_FACTORY_REGISTRY.register
class EncoderDecoderFactory(ModelFactory):
def build_model(
self,
task: str,
backbone: str | nn.Module,
decoder: str | nn.Module,
num_classes: int | None = None,
necks: list[dict] | None = None,
aux_decoders: list[AuxiliaryHead] | None = None,
rescale: bool = True, # noqa: FBT002, FBT001
**kwargs,
) -> Model:
"""Generic model factory that combines an encoder and decoder, together with a head, for a specific task.
Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
`backbone_`, `decoder_` and `head_` respectively.
Args:
task (str): Task to be performed. Currently supports "segmentation" and "regression".
backbone (str, nn.Module): Backbone to be used. If a string, will look for such models in the different
registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it
directly. The backbone should have and `out_channels` attribute and its `forward` should return a list[Tensor].
decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
If a string, will look for such decoders in the different
registries supported (internal terratorch registry, smp, ...).
If an nn.Module, we expect it to expose a property `decoder.out_channels`.
Pixel wise tasks will be concatenated with a Conv2d for the final convolution.
Defaults to "FCNDecoder".
num_classes (int, optional): Number of classes. None for regression tasks.
necks (list[dict]): nn.Modules to be called in succession on encoder features
before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry.
Expects each one to have a key "name" and subsequent keys for arguments, if any.
Defaults to None, which applies the identity function.
aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead decoders to be added to the model.
These decoders take the input from the encoder as well.
rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
is different from the ground truth. Only applicable to pixel wise models
(e.g. segmentation, pixel wise regression). Defaults to True.
Returns:
nn.Module: Full model with encoder, decoder and head.
"""
task = task.lower()
if task not in SUPPORTED_TASKS:
msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
raise NotImplementedError(msg)
backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
backbone = _get_backbone(backbone, **backbone_kwargs)
try:
out_channels = backbone.out_channels
except AttributeError as e:
msg = "backbone must have out_channels attribute"
raise AttributeError(msg) from e
if necks is None:
necks = []
neck_list, channel_list = build_neck_list(necks, out_channels)
# some decoders already include a head
# for these, we pass the num_classes to them
# others dont include a head
# for those, we dont pass num_classes
decoder_kwargs, kwargs = extract_prefix_keys(kwargs, "decoder_")
head_kwargs, kwargs = extract_prefix_keys(kwargs, "head_")
decoder, head_kwargs, decoder_includes_head = _get_decoder_and_head_kwargs(
decoder, channel_list, decoder_kwargs, head_kwargs, num_classes=num_classes
)
if aux_decoders is None:
_check_all_args_used(kwargs)
return _build_appropriate_model(task, backbone, decoder, head_kwargs, necks=neck_list, decoder_includes_head=decoder_includes_head, rescale=rescale)
to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
for aux_decoder in aux_decoders:
args = aux_decoder.decoder_args if aux_decoder.decoder_args else {}
aux_decoder_kwargs, args = extract_prefix_keys(args, "decoder_")
aux_head_kwargs, args = extract_prefix_keys(args, "head_")
aux_decoder_instance, aux_head_kwargs, aux_decoder_includes_head = _get_decoder_and_head_kwargs(
aux_decoder.decoder, channel_list, aux_decoder_kwargs, aux_head_kwargs, num_classes=num_classes
)
to_be_aux_decoders.append(
AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
)
_check_all_args_used(args)
_check_all_args_used(kwargs)
return _build_appropriate_model(
task,
backbone,
decoder,
head_kwargs,
necks=neck_list,
decoder_includes_head=decoder_includes_head,
rescale=rescale,
auxiliary_heads=to_be_aux_decoders,
)
|
build_model(task, backbone, decoder, num_classes=None, necks=None, aux_decoders=None, rescale=True, **kwargs)
Generic model factory that combines an encoder and decoder, together with a head, for a specific task.
Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
backbone_
, decoder_
and head_
respectively.
Parameters:
Name |
Type |
Description |
Default |
task
|
str
|
Task to be performed. Currently supports "segmentation" and "regression".
|
required
|
backbone
|
(str, Module)
|
Backbone to be used. If a string, will look for such models in the different
registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it
directly. The backbone should have and out_channels attribute and its forward should return a list[Tensor].
|
required
|
decoder
|
Union[str, Module]
|
Decoder to be used for the segmentation model.
If a string, will look for such decoders in the different
registries supported (internal terratorch registry, smp, ...).
If an nn.Module, we expect it to expose a property decoder.out_channels .
Pixel wise tasks will be concatenated with a Conv2d for the final convolution.
Defaults to "FCNDecoder".
|
required
|
num_classes
|
int
|
Number of classes. None for regression tasks.
|
None
|
necks
|
list[dict]
|
nn.Modules to be called in succession on encoder features
before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry.
Expects each one to have a key "name" and subsequent keys for arguments, if any.
Defaults to None, which applies the identity function.
|
None
|
aux_decoders
|
list[AuxiliaryHead] | None
|
List of AuxiliaryHead decoders to be added to the model.
These decoders take the input from the encoder as well.
|
None
|
rescale
|
bool
|
Whether to apply bilinear interpolation to rescale the model output if its size
is different from the ground truth. Only applicable to pixel wise models
(e.g. segmentation, pixel wise regression). Defaults to True.
|
True
|
Returns:
Type |
Description |
Model
|
nn.Module: Full model with encoder, decoder and head.
|
Source code in terratorch/models/encoder_decoder_factory.py
| def build_model(
self,
task: str,
backbone: str | nn.Module,
decoder: str | nn.Module,
num_classes: int | None = None,
necks: list[dict] | None = None,
aux_decoders: list[AuxiliaryHead] | None = None,
rescale: bool = True, # noqa: FBT002, FBT001
**kwargs,
) -> Model:
"""Generic model factory that combines an encoder and decoder, together with a head, for a specific task.
Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
`backbone_`, `decoder_` and `head_` respectively.
Args:
task (str): Task to be performed. Currently supports "segmentation" and "regression".
backbone (str, nn.Module): Backbone to be used. If a string, will look for such models in the different
registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it
directly. The backbone should have and `out_channels` attribute and its `forward` should return a list[Tensor].
decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
If a string, will look for such decoders in the different
registries supported (internal terratorch registry, smp, ...).
If an nn.Module, we expect it to expose a property `decoder.out_channels`.
Pixel wise tasks will be concatenated with a Conv2d for the final convolution.
Defaults to "FCNDecoder".
num_classes (int, optional): Number of classes. None for regression tasks.
necks (list[dict]): nn.Modules to be called in succession on encoder features
before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry.
Expects each one to have a key "name" and subsequent keys for arguments, if any.
Defaults to None, which applies the identity function.
aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead decoders to be added to the model.
These decoders take the input from the encoder as well.
rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
is different from the ground truth. Only applicable to pixel wise models
(e.g. segmentation, pixel wise regression). Defaults to True.
Returns:
nn.Module: Full model with encoder, decoder and head.
"""
task = task.lower()
if task not in SUPPORTED_TASKS:
msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
raise NotImplementedError(msg)
backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
backbone = _get_backbone(backbone, **backbone_kwargs)
try:
out_channels = backbone.out_channels
except AttributeError as e:
msg = "backbone must have out_channels attribute"
raise AttributeError(msg) from e
if necks is None:
necks = []
neck_list, channel_list = build_neck_list(necks, out_channels)
# some decoders already include a head
# for these, we pass the num_classes to them
# others dont include a head
# for those, we dont pass num_classes
decoder_kwargs, kwargs = extract_prefix_keys(kwargs, "decoder_")
head_kwargs, kwargs = extract_prefix_keys(kwargs, "head_")
decoder, head_kwargs, decoder_includes_head = _get_decoder_and_head_kwargs(
decoder, channel_list, decoder_kwargs, head_kwargs, num_classes=num_classes
)
if aux_decoders is None:
_check_all_args_used(kwargs)
return _build_appropriate_model(task, backbone, decoder, head_kwargs, necks=neck_list, decoder_includes_head=decoder_includes_head, rescale=rescale)
to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
for aux_decoder in aux_decoders:
args = aux_decoder.decoder_args if aux_decoder.decoder_args else {}
aux_decoder_kwargs, args = extract_prefix_keys(args, "decoder_")
aux_head_kwargs, args = extract_prefix_keys(args, "head_")
aux_decoder_instance, aux_head_kwargs, aux_decoder_includes_head = _get_decoder_and_head_kwargs(
aux_decoder.decoder, channel_list, aux_decoder_kwargs, aux_head_kwargs, num_classes=num_classes
)
to_be_aux_decoders.append(
AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
)
_check_all_args_used(args)
_check_all_args_used(kwargs)
return _build_appropriate_model(
task,
backbone,
decoder,
head_kwargs,
necks=neck_list,
decoder_includes_head=decoder_includes_head,
rescale=rescale,
auxiliary_heads=to_be_aux_decoders,
)
|
terratorch.models.pixel_wise_model.PixelWiseModel
terratorch.models.scalar_output_model.ScalarOutputModel
Encoders
To be a valid encoder, an object must be an nn.Module
with an additional attribute out_channels
which is a list of the channel dimension of the features it returns.
It's forward method should return a list of torch.Tensor
.
Necks
Necks are the glue between encoder and decoder. They can perform operations such as selecting elements from the output of the encoder (SelectIndices), reshaping the outputs of ViTs so they are compatible with CNNs (ReshapeTokensToImage), amongst others.
Necks are nn.Modules
, with an additional method process_channel_list
which informs the EncoderDecoderFactory about how it will alter the channel list provided by encoder.out_channels
.
terratorch.models.necks.Neck
Bases: ABC
, Module
Base class for Neck
A neck must must implement self.process_channel_list
which returns the new channel list.
Source code in terratorch/models/necks.py
| class Neck(ABC, nn.Module):
"""Base class for Neck
A neck must must implement `self.process_channel_list` which returns the new channel list.
"""
def __init__(self, channel_list: list[int]) -> None:
super().__init__()
self.channel_list = channel_list
@abstractmethod
def process_channel_list(self, channel_list: list[int]) -> list[int]:
return channel_list
@abstractmethod
def forward(self, channel_list: list[torch.Tensor]) -> list[torch.Tensor]: ...
|
terratorch.models.necks.SelectIndices
Bases: Neck
Source code in terratorch/models/necks.py
| @TERRATORCH_NECK_REGISTRY.register
class SelectIndices(Neck):
def __init__(self, channel_list: list[int], indices: list[int]):
"""Select indices from the embedding list
Args:
indices (list[int]): list of indices to select.
"""
super().__init__(channel_list)
self.indices = indices
def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
features = [features[i] for i in self.indices]
return features
def process_channel_list(self, channel_list: list[int]) -> list[int]:
channel_list = [channel_list[i] for i in self.indices]
return channel_list
|
__init__(channel_list, indices)
Select indices from the embedding list
Parameters:
Name |
Type |
Description |
Default |
indices
|
list[int]
|
list of indices to select.
|
required
|
Source code in terratorch/models/necks.py
| def __init__(self, channel_list: list[int], indices: list[int]):
"""Select indices from the embedding list
Args:
indices (list[int]): list of indices to select.
"""
super().__init__(channel_list)
self.indices = indices
|
terratorch.models.necks.PermuteDims
Bases: Neck
Source code in terratorch/models/necks.py
| @TERRATORCH_NECK_REGISTRY.register
class PermuteDims(Neck):
def __init__(self, channel_list: list[int], new_order: list[int]):
"""Permute dimensions of each element in the embedding list
Args:
new_order (list[int]): list of indices to be passed to tensor.permute()
"""
super().__init__(channel_list)
self.new_order = new_order
def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
features = [feat.permute(*self.new_order).contiguous() for feat in features]
return features
def process_channel_list(self, channel_list: list[int]) -> list[int]:
return super().process_channel_list(channel_list)
|
__init__(channel_list, new_order)
Permute dimensions of each element in the embedding list
Parameters:
Name |
Type |
Description |
Default |
new_order
|
list[int]
|
list of indices to be passed to tensor.permute()
|
required
|
Source code in terratorch/models/necks.py
| def __init__(self, channel_list: list[int], new_order: list[int]):
"""Permute dimensions of each element in the embedding list
Args:
new_order (list[int]): list of indices to be passed to tensor.permute()
"""
super().__init__(channel_list)
self.new_order = new_order
|
terratorch.models.necks.InterpolateToPyramidal
Bases: Neck
Source code in terratorch/models/necks.py
| @TERRATORCH_NECK_REGISTRY.register
class InterpolateToPyramidal(Neck):
def __init__(self, channel_list: list[int], scale_factor: int = 2, mode: str = "nearest"):
"""Spatially interpolate embeddings so that embedding[i - 1] is scale_factor times larger than embedding[i]
Useful to make non-pyramidal backbones compatible with hierarachical ones
Args:
scale_factor (int): Amount to scale embeddings by each layer. Defaults to 2.
mode (str): Interpolation mode to be passed to torch.nn.functional.interpolate. Defaults to 'nearest'.
"""
super().__init__(channel_list)
self.scale_factor = scale_factor
self.mode = mode
def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
out = []
scale_exponents = list(range(len(features), 0, -1))
for x, exponent in zip(features, scale_exponents, strict=True):
out.append(F.interpolate(x, scale_factor=self.scale_factor**exponent, mode=self.mode))
return out
def process_channel_list(self, channel_list: list[int]) -> list[int]:
return super().process_channel_list(channel_list)
|
__init__(channel_list, scale_factor=2, mode='nearest')
Spatially interpolate embeddings so that embedding[i - 1] is scale_factor times larger than embedding[i]
Useful to make non-pyramidal backbones compatible with hierarachical ones
Args:
scale_factor (int): Amount to scale embeddings by each layer. Defaults to 2.
mode (str): Interpolation mode to be passed to torch.nn.functional.interpolate. Defaults to 'nearest'.
Source code in terratorch/models/necks.py
| def __init__(self, channel_list: list[int], scale_factor: int = 2, mode: str = "nearest"):
"""Spatially interpolate embeddings so that embedding[i - 1] is scale_factor times larger than embedding[i]
Useful to make non-pyramidal backbones compatible with hierarachical ones
Args:
scale_factor (int): Amount to scale embeddings by each layer. Defaults to 2.
mode (str): Interpolation mode to be passed to torch.nn.functional.interpolate. Defaults to 'nearest'.
"""
super().__init__(channel_list)
self.scale_factor = scale_factor
self.mode = mode
|
terratorch.models.necks.MaxpoolToPyramidal
Bases: Neck
Source code in terratorch/models/necks.py
| @TERRATORCH_NECK_REGISTRY.register
class MaxpoolToPyramidal(Neck):
def __init__(self, channel_list: list[int], kernel_size: int = 2):
"""Spatially downsample embeddings so that embedding[i - 1] is scale_factor times smaller than embedding[i]
Useful to make non-pyramidal backbones compatible with hierarachical ones
Args:
kernel_size (int). Base kernel size to use for maxpool. Defaults to 2.
"""
super().__init__(channel_list)
self.kernel_size = kernel_size
def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
out = []
scale_exponents = list(range(len(features)))
for x, exponent in zip(features, scale_exponents, strict=True):
if exponent == 0:
out.append(x.clone())
else:
out.append(F.max_pool2d(x, kernel_size=self.kernel_size**exponent))
return out
def process_channel_list(self, channel_list: list[int]) -> list[int]:
return super().process_channel_list(channel_list)
|
__init__(channel_list, kernel_size=2)
Spatially downsample embeddings so that embedding[i - 1] is scale_factor times smaller than embedding[i]
Useful to make non-pyramidal backbones compatible with hierarachical ones
Args:
kernel_size (int). Base kernel size to use for maxpool. Defaults to 2.
Source code in terratorch/models/necks.py
| def __init__(self, channel_list: list[int], kernel_size: int = 2):
"""Spatially downsample embeddings so that embedding[i - 1] is scale_factor times smaller than embedding[i]
Useful to make non-pyramidal backbones compatible with hierarachical ones
Args:
kernel_size (int). Base kernel size to use for maxpool. Defaults to 2.
"""
super().__init__(channel_list)
self.kernel_size = kernel_size
|
terratorch.models.necks.ReshapeTokensToImage
Bases: Neck
Source code in terratorch/models/necks.py
| @TERRATORCH_NECK_REGISTRY.register
class ReshapeTokensToImage(Neck):
def __init__(self, channel_list: list[int], remove_cls_token=True, effective_time_dim: int = 1): # noqa: FBT002
"""Reshape output of transformer encoder so it can be passed to a conv net.
Args:
remove_cls_token (bool, optional): Whether to remove the cls token from the first position.
Defaults to True.
effective_time_dim (int, optional): The effective temporal dimension the transformer processes.
For a ViT, his will be given by `num_frames // tubelet size`. This is used to determine
the temporal dimension of the embedding, which is concatenated with the embedding dimension.
For example:
- A model which processes 1 frame with a tubelet size of 1 has an effective_time_dim of 1.
The embedding produced by this model has embedding size embed_dim * 1.
- A model which processes 3 frames with a tubelet size of 1 has an effective_time_dim of 3.
The embedding produced by this model has embedding size embed_dim * 3.
- A model which processes 12 frames with a tubelet size of 4 has an effective_time_dim of 3.
The embedding produced by this model has an embedding size embed_dim * 3.
Defaults to 1.
"""
super().__init__(channel_list)
self.remove_cls_token = remove_cls_token
self.effective_time_dim = effective_time_dim
def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
out = []
for x in features:
if self.remove_cls_token:
x_no_token = x[:, 1:, :]
else:
x_no_token = x
number_of_tokens = x_no_token.shape[1]
tokens_per_timestep = number_of_tokens // self.effective_time_dim
h = int(np.sqrt(tokens_per_timestep))
encoded = rearrange(
x_no_token,
"batch (t h w) e -> batch (t e) h w",
batch=x_no_token.shape[0],
t=self.effective_time_dim,
h=h,
)
out.append(encoded)
return out
def process_channel_list(self, channel_list: list[int]) -> list[int]:
return super().process_channel_list(channel_list)
|
__init__(channel_list, remove_cls_token=True, effective_time_dim=1)
Reshape output of transformer encoder so it can be passed to a conv net.
Parameters:
Name |
Type |
Description |
Default |
remove_cls_token
|
bool
|
Whether to remove the cls token from the first position.
Defaults to True.
|
True
|
effective_time_dim
|
int
|
The effective temporal dimension the transformer processes.
For a ViT, his will be given by num_frames // tubelet size . This is used to determine
the temporal dimension of the embedding, which is concatenated with the embedding dimension.
For example:
- A model which processes 1 frame with a tubelet size of 1 has an effective_time_dim of 1.
The embedding produced by this model has embedding size embed_dim * 1.
- A model which processes 3 frames with a tubelet size of 1 has an effective_time_dim of 3.
The embedding produced by this model has embedding size embed_dim * 3.
- A model which processes 12 frames with a tubelet size of 4 has an effective_time_dim of 3.
The embedding produced by this model has an embedding size embed_dim * 3.
Defaults to 1.
|
1
|
Source code in terratorch/models/necks.py
| def __init__(self, channel_list: list[int], remove_cls_token=True, effective_time_dim: int = 1): # noqa: FBT002
"""Reshape output of transformer encoder so it can be passed to a conv net.
Args:
remove_cls_token (bool, optional): Whether to remove the cls token from the first position.
Defaults to True.
effective_time_dim (int, optional): The effective temporal dimension the transformer processes.
For a ViT, his will be given by `num_frames // tubelet size`. This is used to determine
the temporal dimension of the embedding, which is concatenated with the embedding dimension.
For example:
- A model which processes 1 frame with a tubelet size of 1 has an effective_time_dim of 1.
The embedding produced by this model has embedding size embed_dim * 1.
- A model which processes 3 frames with a tubelet size of 1 has an effective_time_dim of 3.
The embedding produced by this model has embedding size embed_dim * 3.
- A model which processes 12 frames with a tubelet size of 4 has an effective_time_dim of 3.
The embedding produced by this model has an embedding size embed_dim * 3.
Defaults to 1.
"""
super().__init__(channel_list)
self.remove_cls_token = remove_cls_token
self.effective_time_dim = effective_time_dim
|
terratorch.models.necks.AddBottleneckLayer
Bases: Neck
Add a layer that reduces the channel dimension of the final embedding by half, and concatenates it
Useful for compatibility with some smp decoders.
Source code in terratorch/models/necks.py
| @TERRATORCH_NECK_REGISTRY.register
class AddBottleneckLayer(Neck):
"""Add a layer that reduces the channel dimension of the final embedding by half, and concatenates it
Useful for compatibility with some smp decoders.
"""
def __init__(self, channel_list: list[int]):
super().__init__(channel_list)
self.bottleneck = nn.Conv2d(channel_list[-1], channel_list[-1]//2, kernel_size=1)
def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
new_embedding = self.bottleneck(features[-1])
features.append(new_embedding)
return features
def process_channel_list(self, channel_list: list[int]) -> list[int]:
return [*channel_list, channel_list[-1] // 2]
|
terratorch.models.necks.LearnedInterpolateToPyramidal
Bases: Neck
Use learned convolutions to transform the output of a non-pyramidal encoder into pyramidal ones
Always requires exactly 4 embeddings
Source code in terratorch/models/necks.py
| @TERRATORCH_NECK_REGISTRY.register
class LearnedInterpolateToPyramidal(Neck):
"""Use learned convolutions to transform the output of a non-pyramidal encoder into pyramidal ones
Always requires exactly 4 embeddings
"""
def __init__(self, channel_list: list[int]):
super().__init__(channel_list)
if len(channel_list) != 4:
msg = "This class can only handle exactly 4 input embeddings"
raise Exception(msg)
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(channel_list[0], channel_list[0] // 2, 2, 2),
nn.BatchNorm2d(channel_list[0] // 2),
nn.GELU(),
nn.ConvTranspose2d(channel_list[0] // 2, channel_list[0] // 4, 2, 2),
)
self.fpn2 = nn.Sequential(nn.ConvTranspose2d(channel_list[1], channel_list[1] // 2, 2, 2))
self.fpn3 = nn.Sequential(nn.Identity())
self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
self.embedding_dim = [channel_list[0] // 4, channel_list[1] // 2, channel_list[2], channel_list[3]]
def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
scaled_inputs = []
scaled_inputs.append(self.fpn1(features[0]))
scaled_inputs.append(self.fpn2(features[1]))
scaled_inputs.append(self.fpn3(features[2]))
scaled_inputs.append(self.fpn4(features[3]))
return scaled_inputs
def process_channel_list(self, channel_list: list[int]) -> list[int]:
return [channel_list[0] // 4, channel_list[1] // 2, channel_list[2], channel_list[3]]
|
Decoders
To be a valid decoder, an object must be an nn.Module
with an additional attribute out_channels
which is an int
with the channel dimension of the output.
The first argument to its constructor will be a list of channel dimensions it should expect as input.
It's forward method should accept a list of embeddings.
Heads
Most decoders require a final head to be added for a specific task (e.g. semantic segmentation vs pixel wise regression).
Those registries producing decoders that dont require a head must expose the attribute includes_head=True
so that a head is not added.
Decoders passed as nn.Modules
which dont require a head must expose the same attribute themselves.
terratorch.models.heads.classification_head.ClassificationHead
Bases: Module
Classification head
Source code in terratorch/models/heads/classification_head.py
| class ClassificationHead(nn.Module):
"""Classification head"""
# how to allow cls token?
def __init__(
self,
in_dim: int,
num_classes: int,
dim_list: list[int] | None = None,
dropout: float = 0,
linear_after_pool: bool = False,
) -> None:
"""Constructor
Args:
in_dim (int): Input dimensionality
num_classes (int): Number of output classes
dim_list (list[int] | None, optional): List with number of dimensions for each Linear
layer to be created. Defaults to None.
dropout (float, optional): Dropout value to apply. Defaults to 0.
linear_after_pool (bool, optional): Apply pooling first, then apply the linear layer. Defaults to False
"""
super().__init__()
self.num_classes = num_classes
self.linear_after_pool = linear_after_pool
if dim_list is None:
pre_head = nn.Identity()
else:
def block(in_dim, out_dim):
return nn.Sequential(nn.Linear(in_features=in_dim, out_features=out_dim), nn.ReLU())
dim_list = [in_dim, *dim_list]
pre_head = nn.Sequential(*[block(dim_list[i], dim_list[i + 1]) for i in range(len(dim_list) - 1)])
in_dim = dim_list[-1]
dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
self.head = nn.Sequential(
pre_head,
dropout,
nn.Linear(in_features=in_dim, out_features=num_classes),
)
def forward(self, x: Tensor):
x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)
if self.linear_after_pool:
x = x.mean(axis=1)
out = self.head(x)
else:
x = self.head(x)
out = x.mean(axis=1)
return out
|
__init__(in_dim, num_classes, dim_list=None, dropout=0, linear_after_pool=False)
Constructor
Parameters:
Name |
Type |
Description |
Default |
in_dim
|
int
|
|
required
|
num_classes
|
int
|
|
required
|
dim_list
|
list[int] | None
|
List with number of dimensions for each Linear
layer to be created. Defaults to None.
|
None
|
dropout
|
float
|
Dropout value to apply. Defaults to 0.
|
0
|
linear_after_pool
|
bool
|
Apply pooling first, then apply the linear layer. Defaults to False
|
False
|
Source code in terratorch/models/heads/classification_head.py
| def __init__(
self,
in_dim: int,
num_classes: int,
dim_list: list[int] | None = None,
dropout: float = 0,
linear_after_pool: bool = False,
) -> None:
"""Constructor
Args:
in_dim (int): Input dimensionality
num_classes (int): Number of output classes
dim_list (list[int] | None, optional): List with number of dimensions for each Linear
layer to be created. Defaults to None.
dropout (float, optional): Dropout value to apply. Defaults to 0.
linear_after_pool (bool, optional): Apply pooling first, then apply the linear layer. Defaults to False
"""
super().__init__()
self.num_classes = num_classes
self.linear_after_pool = linear_after_pool
if dim_list is None:
pre_head = nn.Identity()
else:
def block(in_dim, out_dim):
return nn.Sequential(nn.Linear(in_features=in_dim, out_features=out_dim), nn.ReLU())
dim_list = [in_dim, *dim_list]
pre_head = nn.Sequential(*[block(dim_list[i], dim_list[i + 1]) for i in range(len(dim_list) - 1)])
in_dim = dim_list[-1]
dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
self.head = nn.Sequential(
pre_head,
dropout,
nn.Linear(in_features=in_dim, out_features=num_classes),
)
|
terratorch.models.heads.regression_head.RegressionHead
Bases: Module
Regression head
Source code in terratorch/models/heads/regression_head.py
| class RegressionHead(nn.Module):
"""Regression head"""
def __init__(
self,
in_channels: int,
final_act: nn.Module | str | None = None,
learned_upscale_layers: int = 0,
channel_list: list[int] | None = None,
batch_norm: bool = True,
dropout: float = 0,
) -> None:
"""Constructor
Args:
in_channels (int): Number of input channels
final_act (nn.Module | None, optional): Final activation to be applied. Defaults to None.
learned_upscale_layers (int, optional): Number of Pixelshuffle layers to create. Each upscales 2x.
Defaults to 0.
channel_list (list[int] | None, optional): List with number of channels for each Conv
layer to be created. Defaults to None.
batch_norm (bool, optional): Whether to apply batch norm. Defaults to True.
dropout (float, optional): Dropout value to apply. Defaults to 0.
"""
super().__init__()
self.learned_upscale_layers = learned_upscale_layers
self.final_act = final_act if final_act else nn.Identity()
if isinstance(final_act, str):
module_name, class_name = final_act.rsplit(".", 1)
target_class = getattr(importlib.import_module(module_name), class_name)
self.final_act = target_class()
pre_layers = []
if learned_upscale_layers != 0:
learned_upscale = nn.Sequential(
*[PixelShuffleUpscale(in_channels) for _ in range(self.learned_upscale_layers)]
)
pre_layers.append(learned_upscale)
if channel_list is None:
pre_head = nn.Identity()
else:
def block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
channel_list = [in_channels, *channel_list]
pre_head = nn.Sequential(
*[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
)
in_channels = channel_list[-1]
pre_layers.append(pre_head)
dropout = nn.Dropout2d(dropout)
final_layer = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1)
self.head = nn.Sequential(*[*pre_layers, dropout, final_layer])
def forward(self, x):
output = self.head(x)
return self.final_act(output)
|
__init__(in_channels, final_act=None, learned_upscale_layers=0, channel_list=None, batch_norm=True, dropout=0)
Constructor
Parameters:
Name |
Type |
Description |
Default |
in_channels
|
int
|
|
required
|
final_act
|
Module | None
|
Final activation to be applied. Defaults to None.
|
None
|
learned_upscale_layers
|
int
|
Number of Pixelshuffle layers to create. Each upscales 2x.
Defaults to 0.
|
0
|
channel_list
|
list[int] | None
|
List with number of channels for each Conv
layer to be created. Defaults to None.
|
None
|
batch_norm
|
bool
|
Whether to apply batch norm. Defaults to True.
|
True
|
dropout
|
float
|
Dropout value to apply. Defaults to 0.
|
0
|
Source code in terratorch/models/heads/regression_head.py
| def __init__(
self,
in_channels: int,
final_act: nn.Module | str | None = None,
learned_upscale_layers: int = 0,
channel_list: list[int] | None = None,
batch_norm: bool = True,
dropout: float = 0,
) -> None:
"""Constructor
Args:
in_channels (int): Number of input channels
final_act (nn.Module | None, optional): Final activation to be applied. Defaults to None.
learned_upscale_layers (int, optional): Number of Pixelshuffle layers to create. Each upscales 2x.
Defaults to 0.
channel_list (list[int] | None, optional): List with number of channels for each Conv
layer to be created. Defaults to None.
batch_norm (bool, optional): Whether to apply batch norm. Defaults to True.
dropout (float, optional): Dropout value to apply. Defaults to 0.
"""
super().__init__()
self.learned_upscale_layers = learned_upscale_layers
self.final_act = final_act if final_act else nn.Identity()
if isinstance(final_act, str):
module_name, class_name = final_act.rsplit(".", 1)
target_class = getattr(importlib.import_module(module_name), class_name)
self.final_act = target_class()
pre_layers = []
if learned_upscale_layers != 0:
learned_upscale = nn.Sequential(
*[PixelShuffleUpscale(in_channels) for _ in range(self.learned_upscale_layers)]
)
pre_layers.append(learned_upscale)
if channel_list is None:
pre_head = nn.Identity()
else:
def block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
channel_list = [in_channels, *channel_list]
pre_head = nn.Sequential(
*[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
)
in_channels = channel_list[-1]
pre_layers.append(pre_head)
dropout = nn.Dropout2d(dropout)
final_layer = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1)
self.head = nn.Sequential(*[*pre_layers, dropout, final_layer])
|
terratorch.models.heads.segmentation_head.SegmentationHead
Bases: Module
Segmentation head
Source code in terratorch/models/heads/segmentation_head.py
| class SegmentationHead(nn.Module):
"""Segmentation head"""
def __init__(
self, in_channels: int, num_classes: int, channel_list: list[int] | None = None, dropout: float = 0
) -> None:
"""Constructor
Args:
in_channels (int): Number of input channels
num_classes (int): Number of output classes
channel_list (list[int] | None, optional): List with number of channels for each Conv
layer to be created. Defaults to None.
dropout (float, optional): Dropout value to apply. Defaults to 0.
"""
super().__init__()
self.num_classes = num_classes
if channel_list is None:
pre_head = nn.Identity()
else:
def block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1), nn.ReLU()
)
channel_list = [in_channels, *channel_list]
pre_head = nn.Sequential(
*[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
)
in_channels = channel_list[-1]
dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
self.head = nn.Sequential(
pre_head,
dropout,
nn.Conv2d(
in_channels=in_channels,
out_channels=num_classes,
kernel_size=1,
),
)
def forward(self, x):
return self.head(x)
|
__init__(in_channels, num_classes, channel_list=None, dropout=0)
Constructor
Parameters:
Name |
Type |
Description |
Default |
in_channels
|
int
|
|
required
|
num_classes
|
int
|
|
required
|
channel_list
|
list[int] | None
|
List with number of channels for each Conv
layer to be created. Defaults to None.
|
None
|
dropout
|
float
|
Dropout value to apply. Defaults to 0.
|
0
|
Source code in terratorch/models/heads/segmentation_head.py
| def __init__(
self, in_channels: int, num_classes: int, channel_list: list[int] | None = None, dropout: float = 0
) -> None:
"""Constructor
Args:
in_channels (int): Number of input channels
num_classes (int): Number of output classes
channel_list (list[int] | None, optional): List with number of channels for each Conv
layer to be created. Defaults to None.
dropout (float, optional): Dropout value to apply. Defaults to 0.
"""
super().__init__()
self.num_classes = num_classes
if channel_list is None:
pre_head = nn.Identity()
else:
def block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1), nn.ReLU()
)
channel_list = [in_channels, *channel_list]
pre_head = nn.Sequential(
*[block(channel_list[i], channel_list[i + 1]) for i in range(len(channel_list) - 1)]
)
in_channels = channel_list[-1]
dropout = nn.Identity() if dropout == 0 else nn.Dropout(dropout)
self.head = nn.Sequential(
pre_head,
dropout,
nn.Conv2d(
in_channels=in_channels,
out_channels=num_classes,
kernel_size=1,
),
)
|
Decoder compatibilities
Not all encoders and decoders are compatible. Below we include some caveats.
Some decoders expect pyramidal outputs, but some encoders do not produce such outputs (e.g. vanilla ViT models).
In this case, the InterpolateToPyramidal, MaxpoolToPyramidal and LearnedInterpolateToPyramidal necks may be particularly useful.
SMP decoders
Not all decoders are guaranteed to work with all encoders without additional necks.
Please check smp documentation to understand the embedding spatial dimensions expected by each decoder.
In particular, smp seems to assume the first feature in the passed feature list has the same spatial resolution
as the input, which may not always be true, and may break some decoders.
In addition, for some decoders, the final 2 features have the same spatial resolution.
Adding the AddBottleneckLayer neck will make this compatible.
Some smp decoders require additional parameters, such as decoder_channels
. These must be passed through the factory.
In the case of decoder_channels
, it would be passed as decoder_decoder_channels
(the first decoder_
routes the parameter to the decoder, where it is passed as decoder_channels
).
MMSegmentation decoders
MMSegmentation decoders are available through the BACKBONE_REGISTRY.
Warning
MMSegmentation currently requires mmcv==2.1.0
. Pre-built wheels for this only exist for torch==2.1.0
.
In order to use mmseg without building from source, you must downgrade your torch
to this version.
Install mmseg with:
pip install -U openmim
mim install mmengine
mim install mmcv==2.1.0
pip install regex ftfy mmsegmentation
We provide access to mmseg decoders as an external source of decoders, but are not directly responsible for the maintainence of that library.
Some mmseg decoders require the parameter in_index
, which performs the same function as the SelectIndices
neck.
For use for pixel wise regression, mmseg decoders should take num_classes=1
.