Registries

TerraTorch keeps a set of registries which map strings to instances of those strings. They can be imported from terratorch.registry.

Info

If you are using tasks with existing models, you may never have to interact with registries directly. The model factory will handle interactions with registries.

Registries behave like python sets, exposing the usual contains and iter operations. This means you can easily operate on them in a pythonic way, such as "model" in registry or list(registry).

To create the desired instance, registries expose a build method, which accepts the name and the arguments to be passed to the constructor.

Using registries
from terratorch import BACKBONE_REGISTRY

# find available prithvi models
print([model_name for model_name in BACKBONE_REGISTRY if "prithvi" in model_name])
>>> ['timm_prithvi_swin_B', 'timm_prithvi_swin_L', 'timm_prithvi_vit_100', 'timm_prithvi_vit_300', 'timm_prithvi_vit_tiny']

# show all models with list(BACKBONE_REGISTRY)

# check a model is in the registry
"timm_prithvi_swin_B" in BACKBONE_REGISTRY
>>> True

# without the prefix, all internal registries will be searched until the first match is found
"prithvi_swin_B" in BACKBONE_REGISTRY
>>> True

# instantiate your desired model
# the backbone registry prefix (in this case 'timm') is optional
# in this case, the underlying registry is timm, so we can pass timm arguments to it
model = BACKBONE_REGISTRY.build("prithvi_vit_100", num_frames=1, pretrained=True)

# instantiate your model with more options, for instance, passing weights of your own through timm
model = BACKBONE_REGISTRY.build(
    "prithvi_vit_100", num_frames=1, pretrained=True, pretrained_cfg_overlay={"file": "<path to weights>"}
)
# Rest of your PyTorch / PyTorchLightning code

MultiSourceRegistries

BACKBONE_REGISTRY and DECODER_REGISTRY are special registries which dynamically aggregate multiple registries. They behave as if they were a single large registry by searching over multiple registries.

For instance, the DECODER_REGISTRY holds the TERRATORCH_DECODER_REGISTRY, which is responsible for decoders implemented in terratorch, as well as the SMP_DECODER_REGISTRY and the MMSEG_DECODER_REGISTRY (if mmseg is installed).

To make sure you access the object from a particular registry, you may prepend your string with the prefix from that registry.

from terratorch import DECODER_REGISTRY

# decoder registries always take at least one extra argument, the channel list with the channel dimension of each embedding passed to it
DECODER_REGISTRY.build("FCNDecoder", [32, 64, 128])

DECODER_REGISTRY.build("terratorch_FCNDecoder", [32, 64, 128])

# Find all prefixes
DECODER_REGISTRY.keys()
>>> odict_keys(['terratorch', 'smp', 'mmseg'])

If a prefix is not added, the MultiSourceRegistry will search each registry in the order it was added (starting with the TERRATORCH_ registry) until it finds the first match.

For both of these registries, only TERRATORCH_X_REGISTRY is mutable. To register backbones or decoders to terratorch, you should decorate the constructor function (or the model class itself) with @TERRATORCH_DECODER_REGISTRY.register or @TERRATORCH_BACKBONE_REGISTRY.register.

To add a new registry to these top level registries, you should use the .register method, taking the register and the prefix that will be used for it.

terratorch.registry.registry.MultiSourceRegistry

Bases: Mapping[str, T], Generic[T]

Registry that searches in multiple sources

Correct functioning of this class depends on registries raising a KeyError when the model is not found.

Source code in terratorch/registry/registry.py
class MultiSourceRegistry(Mapping[str, T], typing.Generic[T]):
    """Registry that searches in multiple sources

        Correct functioning of this class depends on registries raising a KeyError when the model is not found.
    """
    def __init__(self, **sources) -> None:
        self._sources: OrderedDict[str, T] = OrderedDict(sources)

    def _parse_prefix(self, name) -> tuple[str, str] | None:
        split = name.split("_")
        if len(split) > 1 and split[0] in self._sources:
            prefix = split[0]
            name_without_prefix = "_".join(split[1:])
            return prefix, name_without_prefix
        return None

    def find_registry(self, name: str) -> T:
        parsed_prefix = self._parse_prefix(name)
        if parsed_prefix:
            prefix, name_without_prefix = parsed_prefix
            registry = self._sources[prefix]
            return registry

        # if no prefix is given, go through all sources in order
        for registry in self._sources.values():
            if name in registry:
                return registry
        msg = f"Model {name} not found in any registry"
        raise KeyError(msg)

    def build(self, name: str, *constructor_args, **constructor_kwargs):
        parsed_prefix = self._parse_prefix(name)
        if parsed_prefix:
            prefix, name_without_prefix = parsed_prefix
            registry = self._sources[prefix]
            return registry.build(name_without_prefix, *constructor_args, **constructor_kwargs)

        # if no prefix, try to build in order
        for source in self._sources.values():
            with suppress(KeyError):
                return source.build(name, *constructor_args, **constructor_kwargs)

        msg = f"Could not instantiate model {name} not from any source."
        raise KeyError(msg)

    def register_source(self, prefix: str, registry: T) -> None:
        """Register a source in the registry"""
        if prefix in self._sources:
            msg = f"Source for prefix {prefix} already exists."
            raise KeyError(msg)
        self._sources[prefix] = registry

    def __iter__(self):
        for prefix in self._sources:
            for element in self._sources[prefix]:
                yield prefix + "_" + element

    def __len__(self):
        return sum(len(source) for source in self._sources.values())

    # def __getitem__(self, name):
    #     parsed_prefix = self._parse_prefix(name)
    #     if parsed_prefix:
    #         prefix, name_without_prefix = parsed_prefix
    #         registry = self._sources[prefix]
    #         return registry[name_without_prefix]

    #     # if no prefix is given, go through all sources in order
    #     for source in self._sources.values():
    #         try:
    #             return source[name]
    #         except Exception as e:
    #             logging.debug(e)

    #     msg = f"Could not find Model {name} not from any source."
    #     raise KeyError(msg)

    def __getitem__(self, name):
        return self._sources[name]

    def __contains__(self, name):
        parsed_prefix = self._parse_prefix(name)
        if parsed_prefix:
            prefix, name_without_prefix = parsed_prefix
            return name_without_prefix in self._sources[prefix]
        return any(name in source for source in self._sources.values())

    @_recursive_repr()
    def __repr__(self):
        args = [f"{name}={source!r}" for name, source in self._sources.items()]
        return f'{self.__class__.__name__}({", ".join(args)})'

    def __str__(self):
        sources_str = str(" | ".join([f"{prefix}: {source!s}" for prefix, source in self._sources.items()]))
        return f"Multi source registry with {len(self)} items: {sources_str}"

    def keys(self):
        return self._sources.keys()

register_source(prefix, registry)

Register a source in the registry

Source code in terratorch/registry/registry.py
def register_source(self, prefix: str, registry: T) -> None:
    """Register a source in the registry"""
    if prefix in self._sources:
        msg = f"Source for prefix {prefix} already exists."
        raise KeyError(msg)
    self._sources[prefix] = registry

terratorch.registry.registry.Registry

Bases: Set

Registry holding model constructors and multiple additional sources.

This registry behaves as a set of strings, which are model names, to model classes or functions which instantiate model classes.

In addition, it can instantiate models with the build method.

Add constructors to the registry by annotating them with @registry.register.

registry = Registry() @registry.register ... def model(args, *kwargs): ... return object() "model" in registry True model_instance = registry.build("model")

Source code in terratorch/registry/registry.py
class Registry(Set):
    """Registry holding model constructors and multiple additional sources.

    This registry behaves as a set of strings, which are model names,
    to model classes or functions which instantiate model classes.

    In addition, it can instantiate models with the build method.

    Add constructors to the registry by annotating them with @registry.register.

    >>> registry = Registry()
    >>> @registry.register
    ... def model(*args, **kwargs):
    ...     return object()
    >>> "model" in registry
    True
    >>> model_instance = registry.build("model")
    """

    def __init__(self, **elements) -> None:
        self._registry: dict[str, Callable] = dict(elements)

    def register(self, constructor: Callable | type) -> Callable:
        """Register a component in the registry. Used as a decorator.

        Args:
            constructor (Callable | type): Function or class to be decorated with @register.
        """
        if not callable(constructor):
            msg = f"Invalid argument. Decorate a function or class with @{self.__class__.__name__}.register"
            raise TypeError(msg)
        self._registry[constructor.__name__] = constructor
        return constructor

    def build(self, name: str, *constructor_args, **constructor_kwargs):
        """Build and return the component.
        Use prefixes ending with _ to forward to a specific source
        """
        return self._registry[name](*constructor_args, **constructor_kwargs)

    def __iter__(self):
        return iter(self._registry)

    # def __getitem__(self, key):
    #     return self._registry[key]

    def __len__(self):
        return len(self._registry)

    def __contains__(self, key):
        return key in self._registry

    def __repr__(self):
        return f"{self.__class__.__name__}({self._registry!r})"

    def __str__(self):
        return f"Registry with {len(self)} registered items"

build(name, *constructor_args, **constructor_kwargs)

Build and return the component. Use prefixes ending with _ to forward to a specific source

Source code in terratorch/registry/registry.py
def build(self, name: str, *constructor_args, **constructor_kwargs):
    """Build and return the component.
    Use prefixes ending with _ to forward to a specific source
    """
    return self._registry[name](*constructor_args, **constructor_kwargs)

register(constructor)

Register a component in the registry. Used as a decorator.

Parameters:
  • constructor (Callable | type) –

    Function or class to be decorated with @register.

Source code in terratorch/registry/registry.py
def register(self, constructor: Callable | type) -> Callable:
    """Register a component in the registry. Used as a decorator.

    Args:
        constructor (Callable | type): Function or class to be decorated with @register.
    """
    if not callable(constructor):
        msg = f"Invalid argument. Decorate a function or class with @{self.__class__.__name__}.register"
        raise TypeError(msg)
    self._registry[constructor.__name__] = constructor
    return constructor

Other Registries

Additionally, terratorch has the NECK_REGISTRY, where all necks must be registered, and the MODEL_FACTORY_REGISTRY, where all model factories must be registered.