Skip to content

Extra Model Structures#

terratorch.models.model.Model #

Bases: ABC, Module

Source code in terratorch/models/model.py
class Model(ABC, nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
    @abstractmethod
    def freeze_encoder(self):
        pass

    @abstractmethod
    def freeze_decoder(self):
        pass

    @abstractmethod
    def forward(self, *args, **kwargs) -> ModelOutput:
        pass

terratorch.models.model.AuxiliaryHead dataclass #

Class containing all information to create auxiliary heads.

Parameters:

Name Type Description Default
name str

Name of the head. Should match the name given to the auxiliary loss.

required
decoder str

Name of the decoder class to be used.

required
decoder_args dict | None

parameters to be passed to the decoder constructor. Parameters for the decoder should be prefixed with decoder_. Parameters for the head should be prefixed with head_.

required
Source code in terratorch/models/model.py
@dataclass
class AuxiliaryHead:
    """Class containing all information to create auxiliary heads.

    Args:
        name (str): Name of the head. Should match the name given to the auxiliary loss.
        decoder (str): Name of the decoder class to be used.
        decoder_args (dict | None): parameters to be passed to the decoder constructor.
            Parameters for the decoder should be prefixed with `decoder_`.
            Parameters for the head should be prefixed with `head_`.
    """

    name: str
    decoder: str
    decoder_args: dict | None

terratorch.models.model.ModelOutput dataclass #

Source code in terratorch/models/model.py
@dataclass
class ModelOutput:
    output: Tensor
    auxiliary_heads: dict[str, Tensor] = None

terratorch.registry.registry.MultiSourceRegistry #

Bases: Mapping[str, T], Generic[T]

Registry that searches in multiple sources

Correct functioning of this class depends on registries raising a KeyError when the model is not found.

Source code in terratorch/registry/registry.py
class MultiSourceRegistry(Mapping[str, T], typing.Generic[T]):
    """Registry that searches in multiple sources

    Correct functioning of this class depends on registries raising a KeyError when the model is not found.
    """

    def __init__(self, **sources) -> None:
        self._sources: OrderedDict[str, T] = OrderedDict(sources)

    def _parse_prefix(self, name) -> tuple[str, str] | None:
        split = name.split("_")
        if len(split) > 1 and split[0] in self._sources:
            prefix = split[0]
            name_without_prefix = "_".join(split[1:])
            return prefix, name_without_prefix
        return None

    def find_registry(self, name: str) -> T:
        parsed_prefix = self._parse_prefix(name)
        if parsed_prefix:
            prefix, name_without_prefix = parsed_prefix
            registry = self._sources[prefix]
            return registry

        # if no prefix is given, go through all sources in order
        for registry in self._sources.values():
            if name in registry:
                return registry
        msg = f"Model {name} not found in any registry"
        raise KeyError(msg)

    def find_class(self, name: str) -> type:
        parsed_prefix = self._parse_prefix(name)
        registry = self.find_registry(name)
        if parsed_prefix:
            prefix, name_without_prefix = parsed_prefix
            return registry[name_without_prefix]
        return registry[name]

    def build(self, name: str, *constructor_args, **constructor_kwargs):
        parsed_prefix = self._parse_prefix(name)
        if parsed_prefix:
            prefix, name_without_prefix = parsed_prefix
            registry = self._sources[prefix]
            return registry.build(name_without_prefix, *constructor_args, **constructor_kwargs)

        # if no prefix, try to build in order
        for source in self._sources.values():
            with suppress(KeyError):
                return source.build(name, *constructor_args, **constructor_kwargs)

        msg = f"Could not instantiate model {name} not from any source."
        raise KeyError(msg)

    def register_source(self, prefix: str, registry: T) -> None:
        """Register a source in the registry"""
        if prefix in self._sources:
            msg = f"Source for prefix {prefix} already exists."
            raise KeyError(msg)
        self._sources[prefix] = registry

    def __iter__(self):
        for prefix in self._sources:
            for element in self._sources[prefix]:
                yield prefix + "_" + element

    def __len__(self):
        return sum(len(source) for source in self._sources.values())

    def __getitem__(self, name):
        return self._sources[name]

    def __contains__(self, name):
        parsed_prefix = self._parse_prefix(name)
        if parsed_prefix:
            prefix, name_without_prefix = parsed_prefix
            return name_without_prefix in self._sources[prefix]
        return any(name in source for source in self._sources.values())

    @_recursive_repr()
    def __repr__(self):
        args = [f"{name}={source!r}" for name, source in self._sources.items()]
        return f"{self.__class__.__name__}({', '.join(args)})"

    def __str__(self):
        sources_str = str(" | ".join([f"{prefix}: {source!s}" for prefix, source in self._sources.items()]))
        return f"Multi source registry with {len(self)} items: {sources_str}"

    def keys(self):
        return self._sources.keys()

register_source(prefix, registry) #

Register a source in the registry

Source code in terratorch/registry/registry.py
def register_source(self, prefix: str, registry: T) -> None:
    """Register a source in the registry"""
    if prefix in self._sources:
        msg = f"Source for prefix {prefix} already exists."
        raise KeyError(msg)
    self._sources[prefix] = registry

terratorch.registry.registry.Registry #

Bases: Set

Registry holding model constructors and multiple additional sources.

This registry behaves as a set of strings, which are model names, to model classes or functions which instantiate model classes.

In addition, it can instantiate models with the build method.

Add constructors to the registry by annotating them with @registry.register.

>>> registry = Registry()
>>> @registry.register
... def model(*args, **kwargs):
...     return object()
>>> "model" in registry
True
>>> model_instance = registry.build("model")

Source code in terratorch/registry/registry.py
class Registry(Set):
    """Registry holding model constructors and multiple additional sources.

    This registry behaves as a set of strings, which are model names,
    to model classes or functions which instantiate model classes.

    In addition, it can instantiate models with the build method.

    Add constructors to the registry by annotating them with @registry.register.
    ```
    >>> registry = Registry()
    >>> @registry.register
    ... def model(*args, **kwargs):
    ...     return object()
    >>> "model" in registry
    True
    >>> model_instance = registry.build("model")
    ```
    """

    def __init__(self, **elements) -> None:
        self._registry: dict[str, Callable] = dict(elements)

    def register(self, constructor: Callable | type) -> Callable:
        """Register a component in the registry. Used as a decorator.

        Args:
            constructor (Callable | type): Function or class to be decorated with @register.
        """
        if not callable(constructor):
            msg = f"Invalid argument. Decorate a function or class with @{self.__class__.__name__}.register"
            raise TypeError(msg)
        self._registry[constructor.__name__] = constructor
        return constructor

    def build(self, name: str, *constructor_args, **constructor_kwargs):
        """Build and return the component.
        Use prefixes ending with _ to forward to a specific source
        """
        return self._registry[name](*constructor_args, **constructor_kwargs)

    def __iter__(self):
        return iter(self._registry)

    def __getitem__(self, key):
        return self._registry[key]

    def __len__(self):
        return len(self._registry)

    def __contains__(self, key):
        return key in self._registry

    def __repr__(self):
        return f"{self.__class__.__name__}({self._registry!r})"

    def __str__(self):
        return f"Registry with {len(self)} registered items"

build(name, *constructor_args, **constructor_kwargs) #

Build and return the component. Use prefixes ending with _ to forward to a specific source

Source code in terratorch/registry/registry.py
def build(self, name: str, *constructor_args, **constructor_kwargs):
    """Build and return the component.
    Use prefixes ending with _ to forward to a specific source
    """
    return self._registry[name](*constructor_args, **constructor_kwargs)

register(constructor) #

Register a component in the registry. Used as a decorator.

Parameters:

Name Type Description Default
constructor Callable | type

Function or class to be decorated with @register.

required
Source code in terratorch/registry/registry.py
def register(self, constructor: Callable | type) -> Callable:
    """Register a component in the registry. Used as a decorator.

    Args:
        constructor (Callable | type): Function or class to be decorated with @register.
    """
    if not callable(constructor):
        msg = f"Invalid argument. Decorate a function or class with @{self.__class__.__name__}.register"
        raise TypeError(msg)
    self._registry[constructor.__name__] = constructor
    return constructor