Skip to content

Model Factories

Model factories build the model that is fine-tuned by TerraTorch. Specifically, a backbone is used an encoder and combined with a task-specific decoder and head. Necks are using to reshape the encoder output to be compatible with the decoder input.

Tip

The EncoderDecoderFactory is the default factory for segmentation, pixel-wise regression, and classification tasks.

Other commonly used factories are the ObjectDetectionModelFactory for object detection tasks and sometimes the FullModelFactory if a model is registered in the FULL_MODEL_REGISTRY and can be directly applied to a specific task.

terratorch.models.encoder_decoder_factory.EncoderDecoderFactory #

Bases: ModelFactory

Source code in terratorch/models/encoder_decoder_factory.py
@MODEL_FACTORY_REGISTRY.register
class EncoderDecoderFactory(ModelFactory):
    def build_model(
        self,
        task: str,
        backbone: str | nn.Module,
        decoder: str | nn.Module,
        backbone_kwargs: dict | None = None,
        decoder_kwargs: dict | None = None,
        head_kwargs: dict | None = None,
        num_classes: int | None = None,
        necks: list[dict] | None = None,
        aux_decoders: list[AuxiliaryHead] | None = None,
        rescale: bool = True,  # noqa: FBT002, FBT001,
        peft_config: dict | None = None,
        **kwargs,
    ) -> Model:
        """Generic model factory that combines an encoder and decoder, together with a head, for a specific task.

        Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
        `backbone_`, `decoder_` and `head_` respectively.

        Args:
            task (str): Task to be performed. Currently supports "segmentation", "regression" and "classification".
            backbone (str, nn.Module): Backbone to be used. If a string, will look for such models in the different
                registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it
                directly. The backbone should have and `out_channels` attribute and its `forward` should return a list[Tensor].
            decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
                    If a string, will look for such decoders in the different
                    registries supported (internal terratorch registry, smp, ...).
                    If an nn.Module, we expect it to expose a property `decoder.out_channels`.
                    Pixel wise tasks will be concatenated with a Conv2d for the final convolution.
                    Defaults to "FCNDecoder".
            backbone_kwargs (dict, optional) : Arguments to be passed to instantiate the backbone.
            decoder_kwargs (dict, optional) : Arguments to be passed to instantiate the decoder.
            head_kwargs (dict, optional) : Arguments to be passed to the head network. 
            num_classes (int, optional): Number of classes. None for regression tasks.
            necks (list[dict]): nn.Modules to be called in succession on encoder features
                before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry.
                Expects each one to have a key "name" and subsequent keys for arguments, if any.
                Defaults to None, which applies the identity function.
            aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead decoders to be added to the model.
                These decoders take the input from the encoder as well.
            rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
                is different from the ground truth. Only applicable to pixel wise models
                (e.g. segmentation, pixel wise regression). Defaults to True.
            peft_config (dict): Configuration options for using [PEFT](https://huggingface.co/docs/peft/index).
                The dictionary should have the following keys:

                - "method": Which PEFT method to use. Should be one implemented in PEFT, a list is available [here](https://huggingface.co/docs/peft/package_reference/peft_types#peft.PeftType).
                - "replace_qkv": String containing a substring of the name of the submodules to replace with QKVSep.
                  This should be used when the qkv matrices are merged together in a single linear layer and the PEFT
                  method should be applied separately to query, key and value matrices (e.g. if LoRA is only desired in
                  Q and V matrices). e.g. If using Prithvi this should be "qkv"
                - "peft_config_kwargs": Dictionary containing keyword arguments which will be passed to [PeftConfig](https://huggingface.co/docs/peft/package_reference/config#peft.PeftConfig)


        Returns:
            nn.Module: Full model with encoder, decoder and head.
        """
        task = task.lower()
        if task not in SUPPORTED_TASKS:
            msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
            raise NotImplementedError(msg)

        if not backbone_kwargs:
            backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")

        backbone = _get_backbone(backbone, **backbone_kwargs)


        # If patch size is not provided in the config or by the model, it might lead to errors due to irregular images.
        patch_size = backbone_kwargs.get("patch_size", None)

        if patch_size is None:
            # Infer patch size from model by checking all backbone modules
            for module in backbone.modules():
                if hasattr(module, "patch_size"):
                    patch_size = module.patch_size
                    break
        padding = backbone_kwargs.get("padding", "reflect")

        if peft_config is not None:
            if not backbone_kwargs.get("pretrained", False):
                msg = (
                    "You are using PEFT without a pretrained backbone. If you are loading a checkpoint afterwards "
                    "this is probably fine, but if you are training a model check the backbone_pretrained parameter."
                )
                warnings.warn(msg, stacklevel=1)

            backbone = get_peft_backbone(peft_config, backbone)

        try:
            out_channels = backbone.out_channels
        except AttributeError as e:
            msg = "backbone must have out_channels attribute"
            raise AttributeError(msg) from e

        if necks is None:
            necks = []
        neck_list, channel_list = build_neck_list(necks, out_channels)

        # some decoders already include a head
        # for these, we pass the num_classes to them
        # others dont include a head
        # for those, we dont pass num_classes
        if not decoder_kwargs:
            decoder_kwargs, kwargs = extract_prefix_keys(kwargs, "decoder_")

        if not head_kwargs:
            head_kwargs, kwargs = extract_prefix_keys(kwargs, "head_")

        decoder, head_kwargs, decoder_includes_head = _get_decoder_and_head_kwargs(
            decoder, channel_list, decoder_kwargs, head_kwargs, num_classes=num_classes
        )

        if aux_decoders is None:
            _check_all_args_used(kwargs)
            return _build_appropriate_model(
                task,
                backbone,
                decoder,
                head_kwargs,
                patch_size=patch_size,
                padding=padding,
                necks=neck_list,
                decoder_includes_head=decoder_includes_head,
                rescale=rescale,
            )

        to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
        for aux_decoder in aux_decoders:
            args = aux_decoder.decoder_args if aux_decoder.decoder_args else {}
            aux_decoder_kwargs, args = extract_prefix_keys(args, "decoder_")
            aux_head_kwargs, args = extract_prefix_keys(args, "head_")
            aux_decoder_instance, aux_head_kwargs, aux_decoder_includes_head = _get_decoder_and_head_kwargs(
                aux_decoder.decoder, channel_list, aux_decoder_kwargs, aux_head_kwargs, num_classes=num_classes
            )
            to_be_aux_decoders.append(
                AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
            )
            _check_all_args_used(args)

        _check_all_args_used(kwargs)

        return _build_appropriate_model(
            task,
            backbone,
            decoder,
            head_kwargs,
            patch_size=patch_size,
            padding=padding,
            necks=neck_list,
            decoder_includes_head=decoder_includes_head,
            rescale=rescale,
            auxiliary_heads=to_be_aux_decoders,
        )

build_model(task, backbone, decoder, backbone_kwargs=None, decoder_kwargs=None, head_kwargs=None, num_classes=None, necks=None, aux_decoders=None, rescale=True, peft_config=None, **kwargs) #

Generic model factory that combines an encoder and decoder, together with a head, for a specific task.

Further arguments to be passed to the backbone, decoder or head. They should be prefixed with backbone_, decoder_ and head_ respectively.

Parameters:

Name Type Description Default
task str

Task to be performed. Currently supports "segmentation", "regression" and "classification".

required
backbone (str, Module)

Backbone to be used. If a string, will look for such models in the different registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it directly. The backbone should have and out_channels attribute and its forward should return a list[Tensor].

required
decoder Union[str, Module]

Decoder to be used for the segmentation model. If a string, will look for such decoders in the different registries supported (internal terratorch registry, smp, ...). If an nn.Module, we expect it to expose a property decoder.out_channels. Pixel wise tasks will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".

required
backbone_kwargs dict, optional)

Arguments to be passed to instantiate the backbone.

None
decoder_kwargs dict, optional)

Arguments to be passed to instantiate the decoder.

None
head_kwargs dict, optional)

Arguments to be passed to the head network.

None
num_classes int

Number of classes. None for regression tasks.

None
necks list[dict]

nn.Modules to be called in succession on encoder features before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry. Expects each one to have a key "name" and subsequent keys for arguments, if any. Defaults to None, which applies the identity function.

None
aux_decoders list[AuxiliaryHead] | None

List of AuxiliaryHead decoders to be added to the model. These decoders take the input from the encoder as well.

None
rescale bool

Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

True
peft_config dict

Configuration options for using PEFT. The dictionary should have the following keys:

  • "method": Which PEFT method to use. Should be one implemented in PEFT, a list is available here.
  • "replace_qkv": String containing a substring of the name of the submodules to replace with QKVSep. This should be used when the qkv matrices are merged together in a single linear layer and the PEFT method should be applied separately to query, key and value matrices (e.g. if LoRA is only desired in Q and V matrices). e.g. If using Prithvi this should be "qkv"
  • "peft_config_kwargs": Dictionary containing keyword arguments which will be passed to PeftConfig
None

Returns:

Type Description
Model

nn.Module: Full model with encoder, decoder and head.

Source code in terratorch/models/encoder_decoder_factory.py
def build_model(
    self,
    task: str,
    backbone: str | nn.Module,
    decoder: str | nn.Module,
    backbone_kwargs: dict | None = None,
    decoder_kwargs: dict | None = None,
    head_kwargs: dict | None = None,
    num_classes: int | None = None,
    necks: list[dict] | None = None,
    aux_decoders: list[AuxiliaryHead] | None = None,
    rescale: bool = True,  # noqa: FBT002, FBT001,
    peft_config: dict | None = None,
    **kwargs,
) -> Model:
    """Generic model factory that combines an encoder and decoder, together with a head, for a specific task.

    Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
    `backbone_`, `decoder_` and `head_` respectively.

    Args:
        task (str): Task to be performed. Currently supports "segmentation", "regression" and "classification".
        backbone (str, nn.Module): Backbone to be used. If a string, will look for such models in the different
            registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it
            directly. The backbone should have and `out_channels` attribute and its `forward` should return a list[Tensor].
        decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
                If a string, will look for such decoders in the different
                registries supported (internal terratorch registry, smp, ...).
                If an nn.Module, we expect it to expose a property `decoder.out_channels`.
                Pixel wise tasks will be concatenated with a Conv2d for the final convolution.
                Defaults to "FCNDecoder".
        backbone_kwargs (dict, optional) : Arguments to be passed to instantiate the backbone.
        decoder_kwargs (dict, optional) : Arguments to be passed to instantiate the decoder.
        head_kwargs (dict, optional) : Arguments to be passed to the head network. 
        num_classes (int, optional): Number of classes. None for regression tasks.
        necks (list[dict]): nn.Modules to be called in succession on encoder features
            before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry.
            Expects each one to have a key "name" and subsequent keys for arguments, if any.
            Defaults to None, which applies the identity function.
        aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead decoders to be added to the model.
            These decoders take the input from the encoder as well.
        rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
            is different from the ground truth. Only applicable to pixel wise models
            (e.g. segmentation, pixel wise regression). Defaults to True.
        peft_config (dict): Configuration options for using [PEFT](https://huggingface.co/docs/peft/index).
            The dictionary should have the following keys:

            - "method": Which PEFT method to use. Should be one implemented in PEFT, a list is available [here](https://huggingface.co/docs/peft/package_reference/peft_types#peft.PeftType).
            - "replace_qkv": String containing a substring of the name of the submodules to replace with QKVSep.
              This should be used when the qkv matrices are merged together in a single linear layer and the PEFT
              method should be applied separately to query, key and value matrices (e.g. if LoRA is only desired in
              Q and V matrices). e.g. If using Prithvi this should be "qkv"
            - "peft_config_kwargs": Dictionary containing keyword arguments which will be passed to [PeftConfig](https://huggingface.co/docs/peft/package_reference/config#peft.PeftConfig)


    Returns:
        nn.Module: Full model with encoder, decoder and head.
    """
    task = task.lower()
    if task not in SUPPORTED_TASKS:
        msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
        raise NotImplementedError(msg)

    if not backbone_kwargs:
        backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")

    backbone = _get_backbone(backbone, **backbone_kwargs)


    # If patch size is not provided in the config or by the model, it might lead to errors due to irregular images.
    patch_size = backbone_kwargs.get("patch_size", None)

    if patch_size is None:
        # Infer patch size from model by checking all backbone modules
        for module in backbone.modules():
            if hasattr(module, "patch_size"):
                patch_size = module.patch_size
                break
    padding = backbone_kwargs.get("padding", "reflect")

    if peft_config is not None:
        if not backbone_kwargs.get("pretrained", False):
            msg = (
                "You are using PEFT without a pretrained backbone. If you are loading a checkpoint afterwards "
                "this is probably fine, but if you are training a model check the backbone_pretrained parameter."
            )
            warnings.warn(msg, stacklevel=1)

        backbone = get_peft_backbone(peft_config, backbone)

    try:
        out_channels = backbone.out_channels
    except AttributeError as e:
        msg = "backbone must have out_channels attribute"
        raise AttributeError(msg) from e

    if necks is None:
        necks = []
    neck_list, channel_list = build_neck_list(necks, out_channels)

    # some decoders already include a head
    # for these, we pass the num_classes to them
    # others dont include a head
    # for those, we dont pass num_classes
    if not decoder_kwargs:
        decoder_kwargs, kwargs = extract_prefix_keys(kwargs, "decoder_")

    if not head_kwargs:
        head_kwargs, kwargs = extract_prefix_keys(kwargs, "head_")

    decoder, head_kwargs, decoder_includes_head = _get_decoder_and_head_kwargs(
        decoder, channel_list, decoder_kwargs, head_kwargs, num_classes=num_classes
    )

    if aux_decoders is None:
        _check_all_args_used(kwargs)
        return _build_appropriate_model(
            task,
            backbone,
            decoder,
            head_kwargs,
            patch_size=patch_size,
            padding=padding,
            necks=neck_list,
            decoder_includes_head=decoder_includes_head,
            rescale=rescale,
        )

    to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
    for aux_decoder in aux_decoders:
        args = aux_decoder.decoder_args if aux_decoder.decoder_args else {}
        aux_decoder_kwargs, args = extract_prefix_keys(args, "decoder_")
        aux_head_kwargs, args = extract_prefix_keys(args, "head_")
        aux_decoder_instance, aux_head_kwargs, aux_decoder_includes_head = _get_decoder_and_head_kwargs(
            aux_decoder.decoder, channel_list, aux_decoder_kwargs, aux_head_kwargs, num_classes=num_classes
        )
        to_be_aux_decoders.append(
            AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
        )
        _check_all_args_used(args)

    _check_all_args_used(kwargs)

    return _build_appropriate_model(
        task,
        backbone,
        decoder,
        head_kwargs,
        patch_size=patch_size,
        padding=padding,
        necks=neck_list,
        decoder_includes_head=decoder_includes_head,
        rescale=rescale,
        auxiliary_heads=to_be_aux_decoders,
    )

terratorch.models.object_detection_model_factory.ObjectDetectionModelFactory #

Bases: ModelFactory

Source code in terratorch/models/object_detection_model_factory.py
@MODEL_FACTORY_REGISTRY.register
class ObjectDetectionModelFactory(ModelFactory):
    def build_model(
        self,
        task: str,
        backbone: str | nn.Module,
        framework: str,
        num_classes: int | None = None,
        necks: list[dict] | None = None,
        **kwargs,
    ) -> Model:
        """
        Generic model factory that combines an encoder and necks with the detection models, called framework, in torchvision.detection.

        Further arguments to be passed to the backbone_ and framework_.

        Args:
            task (str): Task to be performed. Currently supports "object_detection".
            backbone (str, nn.Module): Backbone to be used. If a string, will look for such models in the different
                registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it
                directly. The backbone should have and `out_channels` attribute and its `forward` should return a list[Tensor].
            framework (str): object detection framework to be used between "faster-rcnn", "fcos", "retinanet" for object detection and "mask-rcnn" for instance segmentation.
            num_classes (int, optional): Number of classes. None for regression tasks.
            necks (list[dict]): nn.Modules to be called in succession on encoder features
                before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry.
                Expects each one to have a key "name" and subsequent keys for arguments, if any.
                Defaults to None, which applies the identity function.

        Returns:
            nn.Module: Full torchvision detection model.
        """
        task = task.lower()
        if task not in SUPPORTED_TASKS:
            msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
            raise NotImplementedError(msg)
        backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
        framework_kwargs, kwargs = extract_prefix_keys(kwargs, "framework_")

        backbone = _get_backbone(backbone, **backbone_kwargs)
        if 'in_channels' in kwargs.keys():
            in_channels = kwargs['in_channels']
        else:
            in_channels = len(backbone_kwargs["model_bands"]) if "model_bands" in backbone_kwargs.keys() else len(backbone_kwargs["bands"])

        try:
            out_channels = backbone.out_channels
        except AttributeError as e:
            msg = "backbone must have out_channels attribute"
            raise AttributeError(msg) from e
        # pdb.set_trace()
        if necks is None:
            necks = []
        neck_list, channel_list = build_neck_list(necks, out_channels)

        neck_module = nn.Sequential(*neck_list)

        combined_backbone = BackboneWrapper(backbone, neck_module, channel_list)
        # pdb.set_trace()

        if framework == 'faster-rcnn':

            sizes = ((32), (64), (128), (256), (512))
            sizes = sizes[:len(combined_backbone.channel_list)]
            aspect_ratios = ((0.5, 1.0, 2.0),) * len(sizes)
            anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios)

            roi_pooler = MultiScaleRoIAlign(
                featmap_names=['feat0', 'feat1', 'feat2', 'feat3'], output_size=7, sampling_ratio=2
            )

            model = torchvision.models.detection.FasterRCNN(
                combined_backbone,
                num_classes,
                rpn_anchor_generator=anchor_generator,
                box_roi_pool=roi_pooler,
                _skip_resize=True,
                image_mean = np.repeat(0, in_channels),
                image_std = np.repeat(1, in_channels),
                **framework_kwargs
            )
        elif framework == 'fcos':

            sizes = ((8,), (16,), (32,), (64,), (128,), (256,))
            sizes=sizes[:len(combined_backbone.channel_list)]
            aspect_ratios = ((1.0,), (1.0,), (1.0,), (1.0,), (1.0,), (1.0,)) * len(sizes)
            anchor_generator = AnchorGenerator(
                sizes=sizes,
                aspect_ratios=aspect_ratios,
            )

            model = torchvision.models.detection.FCOS(
                combined_backbone, 
                num_classes,
                anchor_generator=anchor_generator, 
                _skip_resize=True,
                image_mean = np.repeat(0, in_channels),
                image_std = np.repeat(1, in_channels),
                **framework_kwargs

            )
        elif framework == 'retinanet':

            sizes = (
                (16, 20, 25),
                (32, 40, 50),
                (64, 80, 101),
                (128, 161, 203),
                (256, 322, 406),
                (512, 645, 812),
            )
            sizes=sizes[:len(combined_backbone.channel_list)]
            aspect_ratios = ((0.5, 1.0, 2.0),) * len(sizes)
            anchor_generator = AnchorGenerator(sizes, aspect_ratios)
            head = RetinaNetHead(
                combined_backbone.out_channels,
                anchor_generator.num_anchors_per_location()[0],
                num_classes,
                norm_layer=partial(torch.nn.GroupNorm, 32),
            )

            model = torchvision.models.detection.RetinaNet(
                combined_backbone,
                num_classes,
                anchor_generator=anchor_generator,
                head=head,
                _skip_resize=True,
                image_mean=np.repeat(0, in_channels),
                image_std=np.repeat(1, in_channels),
                **framework_kwargs
            )

        elif framework == 'mask-rcnn':

            sizes = ((32), (64), (128), (256), (512))
            sizes = sizes[:len(combined_backbone.channel_list)]
            aspect_ratios = ((0.5, 1.0, 2.0),) * len(sizes)
            anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios)

            rpn_head = torchvision.models.detection.faster_rcnn.RPNHead(combined_backbone.out_channels, anchor_generator.num_anchors_per_location()[0], conv_depth=2)
            box_head = torchvision.models.detection.faster_rcnn.FastRCNNConvFCHead(
                (combined_backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
            )
            mask_head = torchvision.models.detection.mask_rcnn.MaskRCNNHeads(combined_backbone.out_channels, [256, 256, 256, 256], 1, norm_layer=nn.BatchNorm2d)
            roi_pooler = MultiScaleRoIAlign(
                featmap_names=['feat0', 'feat1', 'feat2', 'feat3'], output_size=7, sampling_ratio=2
            )

            model = torchvision.models.detection.MaskRCNN(
                combined_backbone,
                num_classes=num_classes,
                rpn_anchor_generator=anchor_generator,
                rpn_head=rpn_head,
                box_head=box_head,
                box_roi_pool=roi_pooler,
                mask_roi_pool=roi_pooler,
                mask_head=mask_head,
                _skip_resize=True,
                image_mean=np.repeat(0, in_channels),
                image_std=np.repeat(1, in_channels),
                **framework_kwargs
            )

        else:
            raise ValueError(f"Framework type '{framework}' is not valid.")

        # some decoders already include a head
        # for these, we pass the num_classes to them
        # others dont include a head
        # for those, we dont pass num_classes
        # model.transform = IdentityTransform()

        return ObjectDetectionModel(model, framework)

build_model(task, backbone, framework, num_classes=None, necks=None, **kwargs) #

Generic model factory that combines an encoder and necks with the detection models, called framework, in torchvision.detection.

Further arguments to be passed to the backbone_ and framework_.

Parameters:

Name Type Description Default
task str

Task to be performed. Currently supports "object_detection".

required
backbone (str, Module)

Backbone to be used. If a string, will look for such models in the different registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it directly. The backbone should have and out_channels attribute and its forward should return a list[Tensor].

required
framework str

object detection framework to be used between "faster-rcnn", "fcos", "retinanet" for object detection and "mask-rcnn" for instance segmentation.

required
num_classes int

Number of classes. None for regression tasks.

None
necks list[dict]

nn.Modules to be called in succession on encoder features before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry. Expects each one to have a key "name" and subsequent keys for arguments, if any. Defaults to None, which applies the identity function.

None

Returns:

Type Description
Model

nn.Module: Full torchvision detection model.

Source code in terratorch/models/object_detection_model_factory.py
def build_model(
    self,
    task: str,
    backbone: str | nn.Module,
    framework: str,
    num_classes: int | None = None,
    necks: list[dict] | None = None,
    **kwargs,
) -> Model:
    """
    Generic model factory that combines an encoder and necks with the detection models, called framework, in torchvision.detection.

    Further arguments to be passed to the backbone_ and framework_.

    Args:
        task (str): Task to be performed. Currently supports "object_detection".
        backbone (str, nn.Module): Backbone to be used. If a string, will look for such models in the different
            registries supported (internal terratorch registry, timm, ...). If a torch nn.Module, will use it
            directly. The backbone should have and `out_channels` attribute and its `forward` should return a list[Tensor].
        framework (str): object detection framework to be used between "faster-rcnn", "fcos", "retinanet" for object detection and "mask-rcnn" for instance segmentation.
        num_classes (int, optional): Number of classes. None for regression tasks.
        necks (list[dict]): nn.Modules to be called in succession on encoder features
            before passing them to the decoder. Should be registered in the NECKS_REGISTRY registry.
            Expects each one to have a key "name" and subsequent keys for arguments, if any.
            Defaults to None, which applies the identity function.

    Returns:
        nn.Module: Full torchvision detection model.
    """
    task = task.lower()
    if task not in SUPPORTED_TASKS:
        msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
        raise NotImplementedError(msg)
    backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
    framework_kwargs, kwargs = extract_prefix_keys(kwargs, "framework_")

    backbone = _get_backbone(backbone, **backbone_kwargs)
    if 'in_channels' in kwargs.keys():
        in_channels = kwargs['in_channels']
    else:
        in_channels = len(backbone_kwargs["model_bands"]) if "model_bands" in backbone_kwargs.keys() else len(backbone_kwargs["bands"])

    try:
        out_channels = backbone.out_channels
    except AttributeError as e:
        msg = "backbone must have out_channels attribute"
        raise AttributeError(msg) from e
    # pdb.set_trace()
    if necks is None:
        necks = []
    neck_list, channel_list = build_neck_list(necks, out_channels)

    neck_module = nn.Sequential(*neck_list)

    combined_backbone = BackboneWrapper(backbone, neck_module, channel_list)
    # pdb.set_trace()

    if framework == 'faster-rcnn':

        sizes = ((32), (64), (128), (256), (512))
        sizes = sizes[:len(combined_backbone.channel_list)]
        aspect_ratios = ((0.5, 1.0, 2.0),) * len(sizes)
        anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios)

        roi_pooler = MultiScaleRoIAlign(
            featmap_names=['feat0', 'feat1', 'feat2', 'feat3'], output_size=7, sampling_ratio=2
        )

        model = torchvision.models.detection.FasterRCNN(
            combined_backbone,
            num_classes,
            rpn_anchor_generator=anchor_generator,
            box_roi_pool=roi_pooler,
            _skip_resize=True,
            image_mean = np.repeat(0, in_channels),
            image_std = np.repeat(1, in_channels),
            **framework_kwargs
        )
    elif framework == 'fcos':

        sizes = ((8,), (16,), (32,), (64,), (128,), (256,))
        sizes=sizes[:len(combined_backbone.channel_list)]
        aspect_ratios = ((1.0,), (1.0,), (1.0,), (1.0,), (1.0,), (1.0,)) * len(sizes)
        anchor_generator = AnchorGenerator(
            sizes=sizes,
            aspect_ratios=aspect_ratios,
        )

        model = torchvision.models.detection.FCOS(
            combined_backbone, 
            num_classes,
            anchor_generator=anchor_generator, 
            _skip_resize=True,
            image_mean = np.repeat(0, in_channels),
            image_std = np.repeat(1, in_channels),
            **framework_kwargs

        )
    elif framework == 'retinanet':

        sizes = (
            (16, 20, 25),
            (32, 40, 50),
            (64, 80, 101),
            (128, 161, 203),
            (256, 322, 406),
            (512, 645, 812),
        )
        sizes=sizes[:len(combined_backbone.channel_list)]
        aspect_ratios = ((0.5, 1.0, 2.0),) * len(sizes)
        anchor_generator = AnchorGenerator(sizes, aspect_ratios)
        head = RetinaNetHead(
            combined_backbone.out_channels,
            anchor_generator.num_anchors_per_location()[0],
            num_classes,
            norm_layer=partial(torch.nn.GroupNorm, 32),
        )

        model = torchvision.models.detection.RetinaNet(
            combined_backbone,
            num_classes,
            anchor_generator=anchor_generator,
            head=head,
            _skip_resize=True,
            image_mean=np.repeat(0, in_channels),
            image_std=np.repeat(1, in_channels),
            **framework_kwargs
        )

    elif framework == 'mask-rcnn':

        sizes = ((32), (64), (128), (256), (512))
        sizes = sizes[:len(combined_backbone.channel_list)]
        aspect_ratios = ((0.5, 1.0, 2.0),) * len(sizes)
        anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios)

        rpn_head = torchvision.models.detection.faster_rcnn.RPNHead(combined_backbone.out_channels, anchor_generator.num_anchors_per_location()[0], conv_depth=2)
        box_head = torchvision.models.detection.faster_rcnn.FastRCNNConvFCHead(
            (combined_backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
        )
        mask_head = torchvision.models.detection.mask_rcnn.MaskRCNNHeads(combined_backbone.out_channels, [256, 256, 256, 256], 1, norm_layer=nn.BatchNorm2d)
        roi_pooler = MultiScaleRoIAlign(
            featmap_names=['feat0', 'feat1', 'feat2', 'feat3'], output_size=7, sampling_ratio=2
        )

        model = torchvision.models.detection.MaskRCNN(
            combined_backbone,
            num_classes=num_classes,
            rpn_anchor_generator=anchor_generator,
            rpn_head=rpn_head,
            box_head=box_head,
            box_roi_pool=roi_pooler,
            mask_roi_pool=roi_pooler,
            mask_head=mask_head,
            _skip_resize=True,
            image_mean=np.repeat(0, in_channels),
            image_std=np.repeat(1, in_channels),
            **framework_kwargs
        )

    else:
        raise ValueError(f"Framework type '{framework}' is not valid.")

    # some decoders already include a head
    # for these, we pass the num_classes to them
    # others dont include a head
    # for those, we dont pass num_classes
    # model.transform = IdentityTransform()

    return ObjectDetectionModel(model, framework)

terratorch.models.full_model_factory.FullModelFactory #

Bases: ModelFactory

Source code in terratorch/models/full_model_factory.py
@MODEL_FACTORY_REGISTRY.register
class FullModelFactory(ModelFactory):
    def build_model(
        self,
        model: str | nn.Module,
        rescale: bool = True,  # noqa: FBT002, FBT001
        padding: str = "reflect",
        peft_config: dict | None = None,
        **kwargs,
    ) -> nn.Module:
        """Generic model factory that wraps any model.

        All kwargs are passed to the model.

        Args:
            task (str): Task to be performed. Currently supports "segmentation" and "regression".
            model (str, nn.Module): Model to be used. If a string, will look for such models in the different
                registries supported (internal terratorch registry, ...). If a torch nn.Module, will use it
                directly.
            rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
                is different from the ground truth. Only applicable to pixel wise models
                (e.g. segmentation, pixel wise regression, reconstruction). Defaults to True.
            padding (str): Padding method used if images are not divisible by the patch size. Defaults to "reflect".
            peft_config (dict): Configuration options for using [PEFT](https://huggingface.co/docs/peft/index).
                The dictionary should have the following keys:
                - "method": Which PEFT method to use. Should be one implemented in PEFT, a list is available [here](https://huggingface.co/docs/peft/package_reference/peft_types#peft.PeftType).
                - "replace_qkv": String containing a substring of the name of the submodules to replace with QKVSep.
                  This should be used when the qkv matrices are merged together in a single linear layer and the PEFT
                  method should be applied separately to query, key and value matrices (e.g. if LoRA is only desired in
                  Q and V matrices). e.g. If using Prithvi this should be "qkv"
                - "peft_config_kwargs": Dictionary containing keyword arguments which will be passed to [PeftConfig](https://huggingface.co/docs/peft/package_reference/config#peft.PeftConfig)


        Returns:
            nn.Module: Full model.
        """

        model = _get_model(model, **kwargs)

        # If patch size is not provided in the config or by the model, it might lead to errors due to irregular images.
        patch_size = kwargs.get("patch_size", None)

        if patch_size is None:
            # Infer patch size from model by checking all backbone modules
            for module in model.modules():
                if hasattr(module, "patch_size"):
                    patch_size = module.patch_size
                    break

        if peft_config is not None:
            if not kwargs.get("pretrained", False):
                msg = (
                    "You are using PEFT without a pretrained backbone. If you are loading a checkpoint afterwards "
                    "this is probably fine, but if you are training a model check the backbone_pretrained parameter."
                )
                warnings.warn(msg, stacklevel=1)

            model = get_peft_backbone(peft_config, model)

        return model

build_model(model, rescale=True, padding='reflect', peft_config=None, **kwargs) #

Generic model factory that wraps any model.

All kwargs are passed to the model.

Parameters:

Name Type Description Default
task str

Task to be performed. Currently supports "segmentation" and "regression".

required
model (str, Module)

Model to be used. If a string, will look for such models in the different registries supported (internal terratorch registry, ...). If a torch nn.Module, will use it directly.

required
rescale bool

Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression, reconstruction). Defaults to True.

True
padding str

Padding method used if images are not divisible by the patch size. Defaults to "reflect".

'reflect'
peft_config dict

Configuration options for using PEFT. The dictionary should have the following keys: - "method": Which PEFT method to use. Should be one implemented in PEFT, a list is available here. - "replace_qkv": String containing a substring of the name of the submodules to replace with QKVSep. This should be used when the qkv matrices are merged together in a single linear layer and the PEFT method should be applied separately to query, key and value matrices (e.g. if LoRA is only desired in Q and V matrices). e.g. If using Prithvi this should be "qkv" - "peft_config_kwargs": Dictionary containing keyword arguments which will be passed to PeftConfig

None

Returns:

Type Description
Module

nn.Module: Full model.

Source code in terratorch/models/full_model_factory.py
def build_model(
    self,
    model: str | nn.Module,
    rescale: bool = True,  # noqa: FBT002, FBT001
    padding: str = "reflect",
    peft_config: dict | None = None,
    **kwargs,
) -> nn.Module:
    """Generic model factory that wraps any model.

    All kwargs are passed to the model.

    Args:
        task (str): Task to be performed. Currently supports "segmentation" and "regression".
        model (str, nn.Module): Model to be used. If a string, will look for such models in the different
            registries supported (internal terratorch registry, ...). If a torch nn.Module, will use it
            directly.
        rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
            is different from the ground truth. Only applicable to pixel wise models
            (e.g. segmentation, pixel wise regression, reconstruction). Defaults to True.
        padding (str): Padding method used if images are not divisible by the patch size. Defaults to "reflect".
        peft_config (dict): Configuration options for using [PEFT](https://huggingface.co/docs/peft/index).
            The dictionary should have the following keys:
            - "method": Which PEFT method to use. Should be one implemented in PEFT, a list is available [here](https://huggingface.co/docs/peft/package_reference/peft_types#peft.PeftType).
            - "replace_qkv": String containing a substring of the name of the submodules to replace with QKVSep.
              This should be used when the qkv matrices are merged together in a single linear layer and the PEFT
              method should be applied separately to query, key and value matrices (e.g. if LoRA is only desired in
              Q and V matrices). e.g. If using Prithvi this should be "qkv"
            - "peft_config_kwargs": Dictionary containing keyword arguments which will be passed to [PeftConfig](https://huggingface.co/docs/peft/package_reference/config#peft.PeftConfig)


    Returns:
        nn.Module: Full model.
    """

    model = _get_model(model, **kwargs)

    # If patch size is not provided in the config or by the model, it might lead to errors due to irregular images.
    patch_size = kwargs.get("patch_size", None)

    if patch_size is None:
        # Infer patch size from model by checking all backbone modules
        for module in model.modules():
            if hasattr(module, "patch_size"):
                patch_size = module.patch_size
                break

    if peft_config is not None:
        if not kwargs.get("pretrained", False):
            msg = (
                "You are using PEFT without a pretrained backbone. If you are loading a checkpoint afterwards "
                "this is probably fine, but if you are training a model check the backbone_pretrained parameter."
            )
            warnings.warn(msg, stacklevel=1)

        model = get_peft_backbone(peft_config, model)

    return model

terratorch.models.smp_model_factory.SMPModelFactory #

Bases: ModelFactory

Source code in terratorch/models/smp_model_factory.py
@MODEL_FACTORY_REGISTRY.register
class SMPModelFactory(ModelFactory):
    def build_model(
        self,
        task: str,
        backbone: str,
        model: str,
        bands: list[HLSBands | int],
        in_channels: int | None = None,
        num_classes: int = 1,
        pretrained: str | bool | None = True,  # noqa: FBT002
        prepare_features_for_image_model: Callable | None = None,
        regression_relu: bool = False,  # noqa: FBT001, FBT002
        **kwargs,
    ) -> Model:
        """
        Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.

        This factory handles the instantiation of segmentation and regression models using specified
        encoders and decoders from the SMP library, along with custom modifications and extensions such
        as auxiliary decoders or modified encoders.

        Attributes:
            task (str): Specifies the task for which the model is being built. Supported tasks are
                        "segmentation".
            backbone (str): Specifies the backbone model to be used.
            decoder (str): Specifies the decoder to be used for constructing the
                        segmentation model.
            bands (list[terratorch.datasets.HLSBands | int]): A list specifying the bands that the model
                        will operate on. These are expected to be from terratorch.datasets.HLSBands.
            in_channels (int, optional): Specifies the number of input channels. Defaults to None.
            num_classes (int, optional): The number of output classes for the model.
            pretrained (bool | Path, optional): Indicates whether to load pretrained weights for the
                        backbone. Can also specify a path to weights. Defaults to True.
            num_frames (int, optional): Specifies the number of timesteps the model should handle. Useful
                        for temporal models.
            regression_relu (bool): Whether to apply ReLU activation in the case of regression tasks.
            **kwargs: Additional arguments that might be passed to further customize the backbone, decoder,
                        or any auxiliary heads. These should be prefixed appropriately

        Raises:
            ValueError: If the specified decoder is not supported by SMP.
            Exception: If the specified task is not "segmentation"

        Returns:
            nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified
                    parameters and tasks.
        """
        if task != "segmentation":
            msg = f"SMP models can only perform segmentation, but got task {task}"
            raise Exception(msg)

        bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands]
        if in_channels is None:
            in_channels = len(bands)

        # Gets decoder module.
        model_module = getattr(smp, model, None)
        if model_module is None:
            msg = f"Decoder {model} is not supported in SMP."
            raise ValueError(msg)

        backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")  # Encoder params should be prefixed backbone_
        smp_kwargs, kwargs = extract_prefix_keys(backbone_kwargs, "smp_")  # Smp model params should be prefixed smp_
        aux_params, kwargs = extract_prefix_keys(backbone_kwargs, "aux_")  # Auxiliary head params should be prefixed aux_
        aux_params = None if aux_params == {} else aux_params

        if isinstance(pretrained, bool):
            if pretrained:
                pretrained = "imagenet"
            else:
                pretrained = None

        # If encoder not currently supported by SMP (custom encoder).
        if backbone not in smp_encoders:
            if backbone.startswith("tu-"):
                #for timm encoders
                timm_encoder = backbone[3:]
                if timm_encoder not in timm.list_models(pretrained=True):
                    raise ValueError(f"Backbone {timm_encoder} is not a valid pretrained timm model.")

                model_args = {
                    "encoder_name": backbone,
                    "encoder_weights": pretrained,
                    "in_channels": in_channels,
                    "classes": num_classes,
                    **smp_kwargs,
                }
            else:
                # These params must be included in the config file with appropriate prefix.
                required_params = {
                    "encoder_depth": smp_kwargs,
                    "out_channels": backbone_kwargs,
                    "output_stride": backbone_kwargs,
                }

                for param, config_dict in required_params.items():
                    if param not in config_dict:
                        msg = f"Config must include the '{param}' parameter"
                        raise ValueError(msg)

                # Using new encoder.
                backbone_class = make_smp_encoder(backbone)
                backbone_kwargs["prepare_features_for_image_model"] = prepare_features_for_image_model
                # Registering custom encoder into SMP.
                register_custom_encoder(backbone_class, backbone_kwargs, pretrained)

                model_args = {
                    "encoder_name": "SMPEncoderWrapperWithPFFIM",
                    "encoder_weights": pretrained,
                    "in_channels": in_channels,
                    "classes": num_classes,
                    **smp_kwargs,
                }
        # Using SMP encoder.
        else:
            model_args = {
                "encoder_name": backbone,
                "encoder_weights": pretrained,
                "in_channels": in_channels,
                "classes": num_classes,
                **smp_kwargs,
            }

        model = model_module(**model_args, aux_params=aux_params)

        return SMPModelWrapper(
            model, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression"
        )

build_model(task, backbone, model, bands, in_channels=None, num_classes=1, pretrained=True, prepare_features_for_image_model=None, regression_relu=False, **kwargs) #

Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.

This factory handles the instantiation of segmentation and regression models using specified encoders and decoders from the SMP library, along with custom modifications and extensions such as auxiliary decoders or modified encoders.

Attributes:

Name Type Description
task str

Specifies the task for which the model is being built. Supported tasks are "segmentation".

backbone str

Specifies the backbone model to be used.

decoder str

Specifies the decoder to be used for constructing the segmentation model.

bands list[HLSBands | int]

A list specifying the bands that the model will operate on. These are expected to be from terratorch.datasets.HLSBands.

in_channels int

Specifies the number of input channels. Defaults to None.

num_classes int

The number of output classes for the model.

pretrained bool | Path

Indicates whether to load pretrained weights for the backbone. Can also specify a path to weights. Defaults to True.

num_frames int

Specifies the number of timesteps the model should handle. Useful for temporal models.

regression_relu bool

Whether to apply ReLU activation in the case of regression tasks.

**kwargs bool

Additional arguments that might be passed to further customize the backbone, decoder, or any auxiliary heads. These should be prefixed appropriately

Raises:

Type Description
ValueError

If the specified decoder is not supported by SMP.

Exception

If the specified task is not "segmentation"

Returns:

Type Description
Model

nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified parameters and tasks.

Source code in terratorch/models/smp_model_factory.py
def build_model(
    self,
    task: str,
    backbone: str,
    model: str,
    bands: list[HLSBands | int],
    in_channels: int | None = None,
    num_classes: int = 1,
    pretrained: str | bool | None = True,  # noqa: FBT002
    prepare_features_for_image_model: Callable | None = None,
    regression_relu: bool = False,  # noqa: FBT001, FBT002
    **kwargs,
) -> Model:
    """
    Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization.

    This factory handles the instantiation of segmentation and regression models using specified
    encoders and decoders from the SMP library, along with custom modifications and extensions such
    as auxiliary decoders or modified encoders.

    Attributes:
        task (str): Specifies the task for which the model is being built. Supported tasks are
                    "segmentation".
        backbone (str): Specifies the backbone model to be used.
        decoder (str): Specifies the decoder to be used for constructing the
                    segmentation model.
        bands (list[terratorch.datasets.HLSBands | int]): A list specifying the bands that the model
                    will operate on. These are expected to be from terratorch.datasets.HLSBands.
        in_channels (int, optional): Specifies the number of input channels. Defaults to None.
        num_classes (int, optional): The number of output classes for the model.
        pretrained (bool | Path, optional): Indicates whether to load pretrained weights for the
                    backbone. Can also specify a path to weights. Defaults to True.
        num_frames (int, optional): Specifies the number of timesteps the model should handle. Useful
                    for temporal models.
        regression_relu (bool): Whether to apply ReLU activation in the case of regression tasks.
        **kwargs: Additional arguments that might be passed to further customize the backbone, decoder,
                    or any auxiliary heads. These should be prefixed appropriately

    Raises:
        ValueError: If the specified decoder is not supported by SMP.
        Exception: If the specified task is not "segmentation"

    Returns:
        nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified
                parameters and tasks.
    """
    if task != "segmentation":
        msg = f"SMP models can only perform segmentation, but got task {task}"
        raise Exception(msg)

    bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands]
    if in_channels is None:
        in_channels = len(bands)

    # Gets decoder module.
    model_module = getattr(smp, model, None)
    if model_module is None:
        msg = f"Decoder {model} is not supported in SMP."
        raise ValueError(msg)

    backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")  # Encoder params should be prefixed backbone_
    smp_kwargs, kwargs = extract_prefix_keys(backbone_kwargs, "smp_")  # Smp model params should be prefixed smp_
    aux_params, kwargs = extract_prefix_keys(backbone_kwargs, "aux_")  # Auxiliary head params should be prefixed aux_
    aux_params = None if aux_params == {} else aux_params

    if isinstance(pretrained, bool):
        if pretrained:
            pretrained = "imagenet"
        else:
            pretrained = None

    # If encoder not currently supported by SMP (custom encoder).
    if backbone not in smp_encoders:
        if backbone.startswith("tu-"):
            #for timm encoders
            timm_encoder = backbone[3:]
            if timm_encoder not in timm.list_models(pretrained=True):
                raise ValueError(f"Backbone {timm_encoder} is not a valid pretrained timm model.")

            model_args = {
                "encoder_name": backbone,
                "encoder_weights": pretrained,
                "in_channels": in_channels,
                "classes": num_classes,
                **smp_kwargs,
            }
        else:
            # These params must be included in the config file with appropriate prefix.
            required_params = {
                "encoder_depth": smp_kwargs,
                "out_channels": backbone_kwargs,
                "output_stride": backbone_kwargs,
            }

            for param, config_dict in required_params.items():
                if param not in config_dict:
                    msg = f"Config must include the '{param}' parameter"
                    raise ValueError(msg)

            # Using new encoder.
            backbone_class = make_smp_encoder(backbone)
            backbone_kwargs["prepare_features_for_image_model"] = prepare_features_for_image_model
            # Registering custom encoder into SMP.
            register_custom_encoder(backbone_class, backbone_kwargs, pretrained)

            model_args = {
                "encoder_name": "SMPEncoderWrapperWithPFFIM",
                "encoder_weights": pretrained,
                "in_channels": in_channels,
                "classes": num_classes,
                **smp_kwargs,
            }
    # Using SMP encoder.
    else:
        model_args = {
            "encoder_name": backbone,
            "encoder_weights": pretrained,
            "in_channels": in_channels,
            "classes": num_classes,
            **smp_kwargs,
        }

    model = model_module(**model_args, aux_params=aux_params)

    return SMPModelWrapper(
        model, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression"
    )

terratorch.models.timm_model_factory.TimmModelFactory #

Bases: ModelFactory

Source code in terratorch/models/timm_model_factory.py
@MODEL_FACTORY_REGISTRY.register
class TimmModelFactory(ModelFactory):
    def build_model(
        self,
        task: str,
        backbone: str,
        in_channels: int,
        num_classes: int,
        pretrained: str | bool = True,
        **kwargs,
    ) -> Model:
        """Build a classifier from timm

        Args:
            task (str): Must be "classification".
            backbone (str): Name of the backbone in timm.
            in_channels (int): Number of input channels.
            num_classes (int): Number of classes.

        Returns:
            Model: Timm model wrapped in TimmModelWrapper.
        """
        if task != "classification":
            msg = f"timm models can only perform classification, but got task {task}"
            raise Exception(msg)
        backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
        if isinstance(pretrained, bool):
            model = create_model(
                backbone, pretrained=pretrained, num_classes=num_classes, in_chans=in_channels, **backbone_kwargs
            )
        else:
            model = create_model(backbone, num_classes=num_classes, in_chans=in_channels, **backbone_kwargs)

        # Load weights
        # Code adapted from geobench
        if pretrained and pretrained is not True:
            try:
                weights = WeightsEnum(pretrained)
                state_dict = weights.get_state_dict(progress=True)
            except ValueError:
                if os.path.exists(pretrained):
                    _, state_dict = utils.extract_backbone(pretrained)
                else:
                    state_dict = get_weight(pretrained).get_state_dict(progress=True)
            model = utils.load_state_dict(model, state_dict)

        return TimmModelWrapper(model)

build_model(task, backbone, in_channels, num_classes, pretrained=True, **kwargs) #

Build a classifier from timm

Parameters:

Name Type Description Default
task str

Must be "classification".

required
backbone str

Name of the backbone in timm.

required
in_channels int

Number of input channels.

required
num_classes int

Number of classes.

required

Returns:

Name Type Description
Model Model

Timm model wrapped in TimmModelWrapper.

Source code in terratorch/models/timm_model_factory.py
def build_model(
    self,
    task: str,
    backbone: str,
    in_channels: int,
    num_classes: int,
    pretrained: str | bool = True,
    **kwargs,
) -> Model:
    """Build a classifier from timm

    Args:
        task (str): Must be "classification".
        backbone (str): Name of the backbone in timm.
        in_channels (int): Number of input channels.
        num_classes (int): Number of classes.

    Returns:
        Model: Timm model wrapped in TimmModelWrapper.
    """
    if task != "classification":
        msg = f"timm models can only perform classification, but got task {task}"
        raise Exception(msg)
    backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
    if isinstance(pretrained, bool):
        model = create_model(
            backbone, pretrained=pretrained, num_classes=num_classes, in_chans=in_channels, **backbone_kwargs
        )
    else:
        model = create_model(backbone, num_classes=num_classes, in_chans=in_channels, **backbone_kwargs)

    # Load weights
    # Code adapted from geobench
    if pretrained and pretrained is not True:
        try:
            weights = WeightsEnum(pretrained)
            state_dict = weights.get_state_dict(progress=True)
        except ValueError:
            if os.path.exists(pretrained):
                _, state_dict = utils.extract_backbone(pretrained)
            else:
                state_dict = get_weight(pretrained).get_state_dict(progress=True)
        model = utils.load_state_dict(model, state_dict)

    return TimmModelWrapper(model)

terratorch.models.generic_model_factory.GenericModelFactory #

Bases: ModelFactory

Source code in terratorch/models/generic_model_factory.py
@MODEL_FACTORY_REGISTRY.register
class GenericModelFactory(ModelFactory):

    def build_model(
        self,
        backbone: str | None = None,
        in_channels: int = 6,
        pretrained: str | bool | None = True,
        **kwargs,
    ) -> Model:
        """Factory to create models from any custom module.

        Args:
            model (str): The name for the model class.
            in_channels (int): Number of input channels.
            pretrained(str | bool): Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True.

        Returns:
            Model: A wrapped generic model.
        """

        model_kwargs = _extract_prefix_keys(kwargs, "backbone_")

        try:
            model = BACKBONE_REGISTRY.build(backbone, **model_kwargs)

        except KeyError:
            raise KeyError(f"Model {backbone} not found in the registry.")

        return GenericModelWrapper(model)

build_model(backbone=None, in_channels=6, pretrained=True, **kwargs) #

Factory to create models from any custom module.

Parameters:

Name Type Description Default
model str

The name for the model class.

required
in_channels int

Number of input channels.

6
pretrained(str | bool

Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True.

required

Returns:

Name Type Description
Model Model

A wrapped generic model.

Source code in terratorch/models/generic_model_factory.py
def build_model(
    self,
    backbone: str | None = None,
    in_channels: int = 6,
    pretrained: str | bool | None = True,
    **kwargs,
) -> Model:
    """Factory to create models from any custom module.

    Args:
        model (str): The name for the model class.
        in_channels (int): Number of input channels.
        pretrained(str | bool): Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True.

    Returns:
        Model: A wrapped generic model.
    """

    model_kwargs = _extract_prefix_keys(kwargs, "backbone_")

    try:
        model = BACKBONE_REGISTRY.build(backbone, **model_kwargs)

    except KeyError:
        raise KeyError(f"Model {backbone} not found in the registry.")

    return GenericModelWrapper(model)

terratorch.models.clay_model_factory.ClayModelFactory #

Bases: ModelFactory

Source code in terratorch/models/clay_model_factory.py
@MODEL_FACTORY_REGISTRY.register
class ClayModelFactory(ModelFactory):
    def build_model(
        self,
        task: str,
        backbone: str | nn.Module,
        decoder: str | nn.Module,
        in_channels: int,
        bands: list[int] = [],
        num_classes: int | None = None,
        pretrained: bool = True,  # noqa: FBT001, FBT002
        num_frames: int = 1,
        prepare_features_for_image_model: Callable | None = None,
        aux_decoders: list[AuxiliaryHead] | None = None,
        rescale: bool = True,  # noqa: FBT002, FBT001
        checkpoint_path: str = None,
        **kwargs,
    ) -> Model:
        """Model factory for Clay models.

        Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
        `backbone_`, `decoder_` and `head_` respectively.

        Args:
            task (str): Task to be performed. Currently supports "segmentation" and "regression".
            backbone (str, nn.Module): Backbone to be used. If string, should be able to be parsed
                by the specified factory. Defaults to "prithvi_100".
            decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
                    If a string, it will be created from a class exposed in decoder.__init__.py with the same name.
                    If an nn.Module, we expect it to expose a property `decoder.out_channels`.
                    Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".
            in_channels (int, optional): Number of input channels. Defaults to 3.
            num_classes (int, optional): Number of classes. None for regression tasks.
            pretrained (Union[bool, Path], optional): Whether to load pretrained weights for the backbone, if available.
                Defaults to True.
            num_frames (int, optional): Number of timesteps for the model to handle. Defaults to 1.
            prepare_features_for_image_model (Callable | None): Function to be called on encoder features
                before passing them to the decoder. Defaults to None, which applies the identity function.
            aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model.
                These decoders take the input from the encoder as well.
            rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
                is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

        Raises:
            NotImplementedError: _description_
            DecoderNotFoundException: _description_

        Returns:
            nn.Module: _description_
        """
        if not torch.cuda.is_available():
            self.CPU_ONLY = True
        else:
            self.CPU_ONLY = False

        # Path for accessing the model source code.
        self.syspath_kwarg = "model_sys_path"
        backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")

        # TODO: support auxiliary heads
        if not isinstance(backbone, nn.Module):
            if not "clay" in backbone:
                msg = "This class only handles models for `Clay` encoders"
                raise NotImplementedError(msg)

            task = task.lower()
            if task not in SUPPORTED_TASKS:
                msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
                raise NotImplementedError(msg)

            # Trying to find the model on HuggingFace.
            try:
                backbone: nn.Module = timm.create_model(
                    backbone,
                    pretrained=pretrained,
                    in_chans=in_channels,
                    bands=bands,
                    num_frames=num_frames,
                    features_only=True,
                    **backbone_kwargs,
                )
            except Exception as e:
                print(e, "Error loading from HF. Trying to instantiate locally ...")

        else:
            if checkpoint_path is None:
                raise ValueError("A checkpoint (checkpoint_path) must be provided to restore the model.")

            backbone: nn.Module = Embedder(ckpt_path=checkpoint_path, **backbone_kwargs)
            print("Model Clay was successfully restored.")

        # If patch size is not provided in the config or by the model, it might lead to errors due to irregular images.
        patch_size = backbone_kwargs.get("patch_size", None)
        if patch_size is None:
            # Infer patch size from model by checking all backbone modules
            for module in backbone.modules():
                if hasattr(module, "patch_size"):
                    patch_size = module.patch_size
                    break
        padding = backbone_kwargs.get("padding", "reflect")

        # allow decoder to be a module passed directly
        decoder_cls = _get_decoder(decoder)
        decoder_kwargs, kwargs = extract_prefix_keys(kwargs, "decoder_")

        # TODO: remove this
        decoder: nn.Module = decoder_cls(
            backbone.feature_info.channels(), **decoder_kwargs)
        # decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs)

        head_kwargs, kwargs = extract_prefix_keys(kwargs, "head_")
        if num_classes:
            head_kwargs["num_classes"] = num_classes
        if aux_decoders is None:
            return _build_appropriate_model(
                task, backbone, decoder, head_kwargs, prepare_features_for_image_model, patch_size=patch_size, padding=padding, rescale=rescale
            )

        to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []

        for aux_decoder in aux_decoders:
            args = aux_decoder.decoder_args if aux_decoder.decoder_args else {}
            aux_decoder_cls: nn.Module = _get_decoder(aux_decoder.decoder)

            aux_decoder_kwargs, kwargs = extract_prefix_keys(args, "decoder_")
            aux_decoder_instance = aux_decoder_cls(backbone.feature_info.channels(), **aux_decoder_kwargs)
            # aux_decoder_instance = aux_decoder_cls([128, 256, 512, 1024], **decoder_kwargs)

            aux_head_kwargs, kwargs = extract_prefix_keys(args, "head_")
            if num_classes:
                aux_head_kwargs["num_classes"] = num_classes
            # aux_head: nn.Module = _get_head(task, aux_decoder_instance, num_classes=num_classes, **head_kwargs)
            # aux_decoder.decoder = nn.Sequential(aux_decoder_instance, aux_head)
            to_be_aux_decoders.append(
                AuxiliaryHeadWithDecoderWithoutInstantiatedHead(
                    aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
            )

        return _build_appropriate_model(
            task,
            backbone,
            decoder,
            head_kwargs,
            prepare_features_for_image_model,
            patch_size=patch_size,
            padding=padding,
            rescale=rescale,
            auxiliary_heads=to_be_aux_decoders,
        )

build_model(task, backbone, decoder, in_channels, bands=[], num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, checkpoint_path=None, **kwargs) #

Model factory for Clay models.

Further arguments to be passed to the backbone, decoder or head. They should be prefixed with backbone_, decoder_ and head_ respectively.

Parameters:

Name Type Description Default
task str

Task to be performed. Currently supports "segmentation" and "regression".

required
backbone (str, Module)

Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100".

required
decoder Union[str, Module]

Decoder to be used for the segmentation model. If a string, it will be created from a class exposed in decoder.init.py with the same name. If an nn.Module, we expect it to expose a property decoder.out_channels. Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".

required
in_channels int

Number of input channels. Defaults to 3.

required
num_classes int

Number of classes. None for regression tasks.

None
pretrained Union[bool, Path]

Whether to load pretrained weights for the backbone, if available. Defaults to True.

True
num_frames int

Number of timesteps for the model to handle. Defaults to 1.

1
prepare_features_for_image_model Callable | None

Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function.

None
aux_decoders list[AuxiliaryHead] | None

List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well.

None
rescale bool

Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

True

Raises:

Type Description
NotImplementedError

description

DecoderNotFoundException

description

Returns:

Type Description
Model

nn.Module: description

Source code in terratorch/models/clay_model_factory.py
def build_model(
    self,
    task: str,
    backbone: str | nn.Module,
    decoder: str | nn.Module,
    in_channels: int,
    bands: list[int] = [],
    num_classes: int | None = None,
    pretrained: bool = True,  # noqa: FBT001, FBT002
    num_frames: int = 1,
    prepare_features_for_image_model: Callable | None = None,
    aux_decoders: list[AuxiliaryHead] | None = None,
    rescale: bool = True,  # noqa: FBT002, FBT001
    checkpoint_path: str = None,
    **kwargs,
) -> Model:
    """Model factory for Clay models.

    Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
    `backbone_`, `decoder_` and `head_` respectively.

    Args:
        task (str): Task to be performed. Currently supports "segmentation" and "regression".
        backbone (str, nn.Module): Backbone to be used. If string, should be able to be parsed
            by the specified factory. Defaults to "prithvi_100".
        decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
                If a string, it will be created from a class exposed in decoder.__init__.py with the same name.
                If an nn.Module, we expect it to expose a property `decoder.out_channels`.
                Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".
        in_channels (int, optional): Number of input channels. Defaults to 3.
        num_classes (int, optional): Number of classes. None for regression tasks.
        pretrained (Union[bool, Path], optional): Whether to load pretrained weights for the backbone, if available.
            Defaults to True.
        num_frames (int, optional): Number of timesteps for the model to handle. Defaults to 1.
        prepare_features_for_image_model (Callable | None): Function to be called on encoder features
            before passing them to the decoder. Defaults to None, which applies the identity function.
        aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model.
            These decoders take the input from the encoder as well.
        rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
            is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

    Raises:
        NotImplementedError: _description_
        DecoderNotFoundException: _description_

    Returns:
        nn.Module: _description_
    """
    if not torch.cuda.is_available():
        self.CPU_ONLY = True
    else:
        self.CPU_ONLY = False

    # Path for accessing the model source code.
    self.syspath_kwarg = "model_sys_path"
    backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")

    # TODO: support auxiliary heads
    if not isinstance(backbone, nn.Module):
        if not "clay" in backbone:
            msg = "This class only handles models for `Clay` encoders"
            raise NotImplementedError(msg)

        task = task.lower()
        if task not in SUPPORTED_TASKS:
            msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
            raise NotImplementedError(msg)

        # Trying to find the model on HuggingFace.
        try:
            backbone: nn.Module = timm.create_model(
                backbone,
                pretrained=pretrained,
                in_chans=in_channels,
                bands=bands,
                num_frames=num_frames,
                features_only=True,
                **backbone_kwargs,
            )
        except Exception as e:
            print(e, "Error loading from HF. Trying to instantiate locally ...")

    else:
        if checkpoint_path is None:
            raise ValueError("A checkpoint (checkpoint_path) must be provided to restore the model.")

        backbone: nn.Module = Embedder(ckpt_path=checkpoint_path, **backbone_kwargs)
        print("Model Clay was successfully restored.")

    # If patch size is not provided in the config or by the model, it might lead to errors due to irregular images.
    patch_size = backbone_kwargs.get("patch_size", None)
    if patch_size is None:
        # Infer patch size from model by checking all backbone modules
        for module in backbone.modules():
            if hasattr(module, "patch_size"):
                patch_size = module.patch_size
                break
    padding = backbone_kwargs.get("padding", "reflect")

    # allow decoder to be a module passed directly
    decoder_cls = _get_decoder(decoder)
    decoder_kwargs, kwargs = extract_prefix_keys(kwargs, "decoder_")

    # TODO: remove this
    decoder: nn.Module = decoder_cls(
        backbone.feature_info.channels(), **decoder_kwargs)
    # decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs)

    head_kwargs, kwargs = extract_prefix_keys(kwargs, "head_")
    if num_classes:
        head_kwargs["num_classes"] = num_classes
    if aux_decoders is None:
        return _build_appropriate_model(
            task, backbone, decoder, head_kwargs, prepare_features_for_image_model, patch_size=patch_size, padding=padding, rescale=rescale
        )

    to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []

    for aux_decoder in aux_decoders:
        args = aux_decoder.decoder_args if aux_decoder.decoder_args else {}
        aux_decoder_cls: nn.Module = _get_decoder(aux_decoder.decoder)

        aux_decoder_kwargs, kwargs = extract_prefix_keys(args, "decoder_")
        aux_decoder_instance = aux_decoder_cls(backbone.feature_info.channels(), **aux_decoder_kwargs)
        # aux_decoder_instance = aux_decoder_cls([128, 256, 512, 1024], **decoder_kwargs)

        aux_head_kwargs, kwargs = extract_prefix_keys(args, "head_")
        if num_classes:
            aux_head_kwargs["num_classes"] = num_classes
        # aux_head: nn.Module = _get_head(task, aux_decoder_instance, num_classes=num_classes, **head_kwargs)
        # aux_decoder.decoder = nn.Sequential(aux_decoder_instance, aux_head)
        to_be_aux_decoders.append(
            AuxiliaryHeadWithDecoderWithoutInstantiatedHead(
                aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
        )

    return _build_appropriate_model(
        task,
        backbone,
        decoder,
        head_kwargs,
        prepare_features_for_image_model,
        patch_size=patch_size,
        padding=padding,
        rescale=rescale,
        auxiliary_heads=to_be_aux_decoders,
    )

terratorch.models.generic_unet_model_factory.GenericUnetModelFactory #

Bases: ModelFactory

Source code in terratorch/models/generic_unet_model_factory.py
@MODEL_FACTORY_REGISTRY.register
class GenericUnetModelFactory(ModelFactory):
    def _check_model_availability(self, model, builtin_engine, engine, **model_kwargs):

        try:
            print(f"Using module {model} from terratorch.")
            if builtin_engine:
                model_class = getattr(builtin_engine, model)
            else:
                model_class = None
        except:
            if _has_mmseg:
                print("Module not available on terratorch.")
                print(f"Using module {model} from mmseg.")
                if engine:
                    model_class = getattr(engine, model)
                else:
                    model_class = None
            else:
                raise Exception("mmseg is not installed.")

        if model_class:
            model = model_class(
               **model_kwargs,
            )
        else:
            model = None

        return model 

    def build_model(
        self,
        task: str = "segmentation",
        backbone: str | None = None,
        decoder: str | None = None,
        dilations: tuple[int] = (1, 6, 12, 18),
        in_channels: int = 6,
        pretrained: str | bool | None = True,
        regression_relu: bool = False,
        **kwargs,
    ) -> Model:
        """Factory to create model based on mmseg.

        Args:
            task (str): Must be "segmentation".
            model (str): Decoder architecture. Currently only supports "unet".
            in_channels (int): Number of input channels.
            pretrained(str | bool): Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True.
            regression_relu (bool). Whether to apply a ReLU if task is regression. Defaults to False.

        Returns:
            Model: UNet model.
        """
        if task not in ["segmentation", "regression"]:
            msg = f"This model can only perform pixel wise tasks, but got task {task}"
            raise Exception(msg)

        builtin_engine_decoders = importlib.import_module("terratorch.models.decoders")
        builtin_engine_encoders = importlib.import_module("terratorch.models.backbones")

        # Default values
        backbone_builtin_engine = None
        decoder_builtin_engine = None
        backbone_engine = None 
        decoder_engine = None 
        backbone_model_kwargs = {}
        decoder_model_kwargs = {}

        try:
            engine_decoders = importlib.import_module("mmseg.models.decode_heads")
            engine_encoders = importlib.import_module("mmseg.models.backbones")
            _has_mmseg = True
        except:
            engine_decoders = None
            engine_encoders = None
            _has_mmseg = False
            print("mmseg is not installed.")

        if backbone:
            backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_")
            backbone_model_kwargs = backbone_kwargs
            backbone_engine = engine_encoders
            backbone_builtin_engine = builtin_engine_encoders
        else:
            backbone=None

        if decoder: 
            decoder_kwargs = _extract_prefix_keys(kwargs, "decoder_")
            decoder_model_kwargs = decoder_kwargs
            decoder_engine = engine_decoders
            decoder_builtin_engine = builtin_engine_decoders
        else:
            decoder = None 

        if not backbone and not decoder:
            print("It is necessary to define a backbone and/or a decoder.")

        # Instantianting backbone and decoder 
        backbone = self._check_model_availability(backbone, backbone_builtin_engine, backbone_engine, **backbone_model_kwargs) 
        decoder = self._check_model_availability(decoder, decoder_builtin_engine, decoder_engine, **decoder_model_kwargs) 

        return GenericUnetModelWrapper(
            backbone, decoder=decoder, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression"
        )

build_model(task='segmentation', backbone=None, decoder=None, dilations=(1, 6, 12, 18), in_channels=6, pretrained=True, regression_relu=False, **kwargs) #

Factory to create model based on mmseg.

Parameters:

Name Type Description Default
task str

Must be "segmentation".

'segmentation'
model str

Decoder architecture. Currently only supports "unet".

required
in_channels int

Number of input channels.

6
pretrained(str | bool

Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True.

required

Returns:

Name Type Description
Model Model

UNet model.

Source code in terratorch/models/generic_unet_model_factory.py
def build_model(
    self,
    task: str = "segmentation",
    backbone: str | None = None,
    decoder: str | None = None,
    dilations: tuple[int] = (1, 6, 12, 18),
    in_channels: int = 6,
    pretrained: str | bool | None = True,
    regression_relu: bool = False,
    **kwargs,
) -> Model:
    """Factory to create model based on mmseg.

    Args:
        task (str): Must be "segmentation".
        model (str): Decoder architecture. Currently only supports "unet".
        in_channels (int): Number of input channels.
        pretrained(str | bool): Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True.
        regression_relu (bool). Whether to apply a ReLU if task is regression. Defaults to False.

    Returns:
        Model: UNet model.
    """
    if task not in ["segmentation", "regression"]:
        msg = f"This model can only perform pixel wise tasks, but got task {task}"
        raise Exception(msg)

    builtin_engine_decoders = importlib.import_module("terratorch.models.decoders")
    builtin_engine_encoders = importlib.import_module("terratorch.models.backbones")

    # Default values
    backbone_builtin_engine = None
    decoder_builtin_engine = None
    backbone_engine = None 
    decoder_engine = None 
    backbone_model_kwargs = {}
    decoder_model_kwargs = {}

    try:
        engine_decoders = importlib.import_module("mmseg.models.decode_heads")
        engine_encoders = importlib.import_module("mmseg.models.backbones")
        _has_mmseg = True
    except:
        engine_decoders = None
        engine_encoders = None
        _has_mmseg = False
        print("mmseg is not installed.")

    if backbone:
        backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_")
        backbone_model_kwargs = backbone_kwargs
        backbone_engine = engine_encoders
        backbone_builtin_engine = builtin_engine_encoders
    else:
        backbone=None

    if decoder: 
        decoder_kwargs = _extract_prefix_keys(kwargs, "decoder_")
        decoder_model_kwargs = decoder_kwargs
        decoder_engine = engine_decoders
        decoder_builtin_engine = builtin_engine_decoders
    else:
        decoder = None 

    if not backbone and not decoder:
        print("It is necessary to define a backbone and/or a decoder.")

    # Instantianting backbone and decoder 
    backbone = self._check_model_availability(backbone, backbone_builtin_engine, backbone_engine, **backbone_model_kwargs) 
    decoder = self._check_model_availability(decoder, decoder_builtin_engine, decoder_engine, **decoder_model_kwargs) 

    return GenericUnetModelWrapper(
        backbone, decoder=decoder, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression"
    )

terratorch.models.satmae_model_factory.SatMAEModelFactory #

Bases: ModelFactory

Source code in terratorch/models/satmae_model_factory.py
@MODEL_FACTORY_REGISTRY.register
class SatMAEModelFactory(ModelFactory):
    def build_model(
        self,
        task: str,
        backbone: str | nn.Module,
        decoder: str | nn.Module,
        in_channels: int,
        bands: list[HLSBands | int],
        num_classes: int | None = None,
        pretrained: bool = True,  # noqa: FBT001, FBT002
        num_frames: int = 1,
        prepare_features_for_image_model: Callable | None = None,
        aux_decoders: list[AuxiliaryHead] | None = None,
        rescale: bool = True,  # noqa: FBT002, FBT001
        checkpoint_path: str = None,
        **kwargs,
    ) -> Model:
        """Model factory for SatMAE  models.

        Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
        `backbone_`, `decoder_` and `head_` respectively.

        Args:
            task (str): Task to be performed. Currently supports "segmentation" and "regression".
            backbone (str, nn.Module): Backbone to be used. If string, should be able to be parsed
                by the specified factory. Defaults to "prithvi_100".
            decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
                    If a string, it will be created from a class exposed in decoder.__init__.py with the same name.
                    If an nn.Module, we expect it to expose a property `decoder.out_channels`.
                    Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".
            in_channels (int, optional): Number of input channels. Defaults to 3.
            bands (list[terratorch.datasets.HLSBands], optional): Bands the model will be trained on.
                    Should be a list of terratorch.datasets.HLSBands.
                    Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE].
            num_classes (int, optional): Number of classes. None for regression tasks.
            pretrained (Union[bool, Path], optional): Whether to load pretrained weights for the backbone, if available.
                Defaults to True.
            num_frames (int, optional): Number of timesteps for the model to handle. Defaults to 1.
            prepare_features_for_image_model (Callable | None): Function to be called on encoder features
                before passing them to the decoder. Defaults to None, which applies the identity function.
            aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model.
                These decoders take the input from the encoder as well.
            rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
                is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

        Raises:
            NotImplementedError: _description_
            DecoderNotFoundException: _description_

        Returns:
            nn.Module: _description_
        """

        self.possible_modules = None 

        if not torch.cuda.is_available():
            self.CPU_ONLY = True
        else:
            self.CPU_ONLY = False

        # Path for accessing the model source code.
        self.syspath_kwarg = "model_sys_path"

        bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands]

        # TODO: support auxiliary heads
        if not isinstance(backbone, nn.Module):
            if not 'SatMAE' in kwargs[self.syspath_kwarg]:
                msg = "This class only handles models for `SatMAE` encoders"
                raise NotImplementedError(msg)


            task = task.lower()
            if task not in SUPPORTED_TASKS:
                msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
                raise NotImplementedError(msg)

            backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
            backbone_name = backbone

            # Trying to find the model on HuggingFace.
            try:
                backbone: nn.Module = timm.create_model(
                    backbone,
                    pretrained=pretrained,
                    in_chans=in_channels,
                    num_frames=num_frames,
                    bands=bands,
                    features_only=True,
                    **backbone_kwargs,
                )
            except Exception:

                # When the model is not on HG, it needs be restored locally.
                print("This model is not available on HuggingFace. Trying to instantiate locally ...")

                assert checkpoint_path, "A checkpoint must be provided to restore the model."

                # The SatMAE source code must be installed or available via PYTHONPATH.
                try:  
                    if self.syspath_kwarg in kwargs:
                        syspath_value = kwargs.get(self.syspath_kwarg)

                    else:

                        Exception(f"It is necessary to define the variable {self.syspath_kwarg} on yaml"
                                                           "config for restoring local model.")

                    sys.path.insert(0, syspath_value)

                    # There are dozens of classes in the SatMAE repo, but it seems to be the right open_generic_torch_model
                    backbone_template = None

                    self.possible_modules = [importlib.import_module(mod) for mod in ["models_mae", "models_vit"]]

                    for backbone_module in self.possible_modules:

                        backbone_template_ = getattr(backbone_module, backbone_name, None)
                        if not backbone_template_ :
                            pass
                        else:
                            backbone_template = backbone_template_

                except ModuleNotFoundError:

                    print(f"It is better to review the field {self.syspath_kwarg} in the yaml file.")

                # Is it a ViT or a ViT-MAE ?
                backbone_kind = check_the_kind_of_vit(name=backbone_name)

                backbone: nn.Module = ModelWrapper(model=backbone_template(**backbone_kwargs), kind=backbone_kind)

                if self.CPU_ONLY:
                    model_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
                else:
                    model_dict = torch.load(checkpoint_path, weights_only=True)


                # Filtering parameters from the model state_dict (when necessary)
                model_dict = filter_cefficients_when_necessary(model_state_dict=model_dict, kind=backbone_kind)

                if backbone_kind == "vit":
                    backbone.model.fc_norm = nn.Identity()
                    backbone.model.head_drop = nn.Identity()
                    backbone.model.head = nn.Identity()
                    backbone.model.pos_embed = None # TODO It needs be corrected from source

                # Load saved model when it exists
                if  pretrained: 
                    backbone.model.load_state_dict(model_dict['model'], strict=False)

                # Print the general architecture
                backbone.summary()

                print("Model SatMAE was successfully restored.")

        # allow decoder to be a module passed directly
        decoder_cls = _get_decoder(decoder)

        decoder_kwargs, kwargs = extract_prefix_keys(kwargs, "decoder_")

        # If backabone is a ViT-MAE, the attribute "num_patches" will be necessary
        if hasattr(backbone, "num_patches"):
            decoder_kwargs["num_patches"] = backbone.num_patches

        # TODO: remove this
        if "SatMAEHead" in decoder:
            decoder: nn.Module = decoder_cls(**decoder_kwargs)
        else:
            decoder: nn.Module = decoder_cls(backbone.channels(), **decoder_kwargs)

        head_kwargs, kwargs = extract_prefix_keys(kwargs, "head_")
        if num_classes:
            head_kwargs["num_classes"] = num_classes
        if aux_decoders is None:
            return _build_appropriate_model(
                task, backbone, decoder, head_kwargs, prepare_features_for_image_model, rescale=rescale
            )

        to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
        for aux_decoder in aux_decoders:
            args = aux_decoder.decoder_args if aux_decoder.decoder_args else {}
            aux_decoder_cls: nn.Module = _get_decoder(aux_decoder.decoder)
            aux_decoder_kwargs, kwargs = extract_prefix_keys(args, "decoder_")
            aux_decoder_instance = aux_decoder_cls(backbone.feature_info.channels(), **aux_decoder_kwargs)
            # aux_decoder_instance = aux_decoder_cls([128, 256, 512, 1024], **decoder_kwargs)

            aux_head_kwargs, kwargs = extract_prefix_keys(args, "head_")
            if num_classes:
                aux_head_kwargs["num_classes"] = num_classes
            # aux_head: nn.Module = _get_head(task, aux_decoder_instance, num_classes=num_classes, **head_kwargs)
            # aux_decoder.decoder = nn.Sequential(aux_decoder_instance, aux_head)
            to_be_aux_decoders.append(
                AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
            )

        return _build_appropriate_model(
            task,
            backbone,
            decoder,
            head_kwargs,
            prepare_features_for_image_model,
            rescale=rescale,
            auxiliary_heads=to_be_aux_decoders,
        )

build_model(task, backbone, decoder, in_channels, bands, num_classes=None, pretrained=True, num_frames=1, prepare_features_for_image_model=None, aux_decoders=None, rescale=True, checkpoint_path=None, **kwargs) #

Model factory for SatMAE models.

Further arguments to be passed to the backbone, decoder or head. They should be prefixed with backbone_, decoder_ and head_ respectively.

Parameters:

Name Type Description Default
task str

Task to be performed. Currently supports "segmentation" and "regression".

required
backbone (str, Module)

Backbone to be used. If string, should be able to be parsed by the specified factory. Defaults to "prithvi_100".

required
decoder Union[str, Module]

Decoder to be used for the segmentation model. If a string, it will be created from a class exposed in decoder.init.py with the same name. If an nn.Module, we expect it to expose a property decoder.out_channels. Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".

required
in_channels int

Number of input channels. Defaults to 3.

required
bands list[HLSBands]

Bands the model will be trained on. Should be a list of terratorch.datasets.HLSBands. Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE].

required
num_classes int

Number of classes. None for regression tasks.

None
pretrained Union[bool, Path]

Whether to load pretrained weights for the backbone, if available. Defaults to True.

True
num_frames int

Number of timesteps for the model to handle. Defaults to 1.

1
prepare_features_for_image_model Callable | None

Function to be called on encoder features before passing them to the decoder. Defaults to None, which applies the identity function.

None
aux_decoders list[AuxiliaryHead] | None

List of AuxiliaryHead deciders to be added to the model. These decoders take the input from the encoder as well.

None
rescale bool

Whether to apply bilinear interpolation to rescale the model output if its size is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

True

Raises:

Type Description
NotImplementedError

description

DecoderNotFoundException

description

Returns:

Type Description
Model

nn.Module: description

Source code in terratorch/models/satmae_model_factory.py
def build_model(
    self,
    task: str,
    backbone: str | nn.Module,
    decoder: str | nn.Module,
    in_channels: int,
    bands: list[HLSBands | int],
    num_classes: int | None = None,
    pretrained: bool = True,  # noqa: FBT001, FBT002
    num_frames: int = 1,
    prepare_features_for_image_model: Callable | None = None,
    aux_decoders: list[AuxiliaryHead] | None = None,
    rescale: bool = True,  # noqa: FBT002, FBT001
    checkpoint_path: str = None,
    **kwargs,
) -> Model:
    """Model factory for SatMAE  models.

    Further arguments to be passed to the backbone, decoder or head. They should be prefixed with
    `backbone_`, `decoder_` and `head_` respectively.

    Args:
        task (str): Task to be performed. Currently supports "segmentation" and "regression".
        backbone (str, nn.Module): Backbone to be used. If string, should be able to be parsed
            by the specified factory. Defaults to "prithvi_100".
        decoder (Union[str, nn.Module], optional): Decoder to be used for the segmentation model.
                If a string, it will be created from a class exposed in decoder.__init__.py with the same name.
                If an nn.Module, we expect it to expose a property `decoder.out_channels`.
                Will be concatenated with a Conv2d for the final convolution. Defaults to "FCNDecoder".
        in_channels (int, optional): Number of input channels. Defaults to 3.
        bands (list[terratorch.datasets.HLSBands], optional): Bands the model will be trained on.
                Should be a list of terratorch.datasets.HLSBands.
                Defaults to [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE].
        num_classes (int, optional): Number of classes. None for regression tasks.
        pretrained (Union[bool, Path], optional): Whether to load pretrained weights for the backbone, if available.
            Defaults to True.
        num_frames (int, optional): Number of timesteps for the model to handle. Defaults to 1.
        prepare_features_for_image_model (Callable | None): Function to be called on encoder features
            before passing them to the decoder. Defaults to None, which applies the identity function.
        aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model.
            These decoders take the input from the encoder as well.
        rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
            is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.

    Raises:
        NotImplementedError: _description_
        DecoderNotFoundException: _description_

    Returns:
        nn.Module: _description_
    """

    self.possible_modules = None 

    if not torch.cuda.is_available():
        self.CPU_ONLY = True
    else:
        self.CPU_ONLY = False

    # Path for accessing the model source code.
    self.syspath_kwarg = "model_sys_path"

    bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands]

    # TODO: support auxiliary heads
    if not isinstance(backbone, nn.Module):
        if not 'SatMAE' in kwargs[self.syspath_kwarg]:
            msg = "This class only handles models for `SatMAE` encoders"
            raise NotImplementedError(msg)


        task = task.lower()
        if task not in SUPPORTED_TASKS:
            msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
            raise NotImplementedError(msg)

        backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
        backbone_name = backbone

        # Trying to find the model on HuggingFace.
        try:
            backbone: nn.Module = timm.create_model(
                backbone,
                pretrained=pretrained,
                in_chans=in_channels,
                num_frames=num_frames,
                bands=bands,
                features_only=True,
                **backbone_kwargs,
            )
        except Exception:

            # When the model is not on HG, it needs be restored locally.
            print("This model is not available on HuggingFace. Trying to instantiate locally ...")

            assert checkpoint_path, "A checkpoint must be provided to restore the model."

            # The SatMAE source code must be installed or available via PYTHONPATH.
            try:  
                if self.syspath_kwarg in kwargs:
                    syspath_value = kwargs.get(self.syspath_kwarg)

                else:

                    Exception(f"It is necessary to define the variable {self.syspath_kwarg} on yaml"
                                                       "config for restoring local model.")

                sys.path.insert(0, syspath_value)

                # There are dozens of classes in the SatMAE repo, but it seems to be the right open_generic_torch_model
                backbone_template = None

                self.possible_modules = [importlib.import_module(mod) for mod in ["models_mae", "models_vit"]]

                for backbone_module in self.possible_modules:

                    backbone_template_ = getattr(backbone_module, backbone_name, None)
                    if not backbone_template_ :
                        pass
                    else:
                        backbone_template = backbone_template_

            except ModuleNotFoundError:

                print(f"It is better to review the field {self.syspath_kwarg} in the yaml file.")

            # Is it a ViT or a ViT-MAE ?
            backbone_kind = check_the_kind_of_vit(name=backbone_name)

            backbone: nn.Module = ModelWrapper(model=backbone_template(**backbone_kwargs), kind=backbone_kind)

            if self.CPU_ONLY:
                model_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
            else:
                model_dict = torch.load(checkpoint_path, weights_only=True)


            # Filtering parameters from the model state_dict (when necessary)
            model_dict = filter_cefficients_when_necessary(model_state_dict=model_dict, kind=backbone_kind)

            if backbone_kind == "vit":
                backbone.model.fc_norm = nn.Identity()
                backbone.model.head_drop = nn.Identity()
                backbone.model.head = nn.Identity()
                backbone.model.pos_embed = None # TODO It needs be corrected from source

            # Load saved model when it exists
            if  pretrained: 
                backbone.model.load_state_dict(model_dict['model'], strict=False)

            # Print the general architecture
            backbone.summary()

            print("Model SatMAE was successfully restored.")

    # allow decoder to be a module passed directly
    decoder_cls = _get_decoder(decoder)

    decoder_kwargs, kwargs = extract_prefix_keys(kwargs, "decoder_")

    # If backabone is a ViT-MAE, the attribute "num_patches" will be necessary
    if hasattr(backbone, "num_patches"):
        decoder_kwargs["num_patches"] = backbone.num_patches

    # TODO: remove this
    if "SatMAEHead" in decoder:
        decoder: nn.Module = decoder_cls(**decoder_kwargs)
    else:
        decoder: nn.Module = decoder_cls(backbone.channels(), **decoder_kwargs)

    head_kwargs, kwargs = extract_prefix_keys(kwargs, "head_")
    if num_classes:
        head_kwargs["num_classes"] = num_classes
    if aux_decoders is None:
        return _build_appropriate_model(
            task, backbone, decoder, head_kwargs, prepare_features_for_image_model, rescale=rescale
        )

    to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
    for aux_decoder in aux_decoders:
        args = aux_decoder.decoder_args if aux_decoder.decoder_args else {}
        aux_decoder_cls: nn.Module = _get_decoder(aux_decoder.decoder)
        aux_decoder_kwargs, kwargs = extract_prefix_keys(args, "decoder_")
        aux_decoder_instance = aux_decoder_cls(backbone.feature_info.channels(), **aux_decoder_kwargs)
        # aux_decoder_instance = aux_decoder_cls([128, 256, 512, 1024], **decoder_kwargs)

        aux_head_kwargs, kwargs = extract_prefix_keys(args, "head_")
        if num_classes:
            aux_head_kwargs["num_classes"] = num_classes
        # aux_head: nn.Module = _get_head(task, aux_decoder_instance, num_classes=num_classes, **head_kwargs)
        # aux_decoder.decoder = nn.Sequential(aux_decoder_instance, aux_head)
        to_be_aux_decoders.append(
            AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
        )

    return _build_appropriate_model(
        task,
        backbone,
        decoder,
        head_kwargs,
        prepare_features_for_image_model,
        rescale=rescale,
        auxiliary_heads=to_be_aux_decoders,
    )

Base factory:

terratorch.models.model.ModelFactory #

Bases: Protocol

Source code in terratorch/models/model.py
class ModelFactory(typing.Protocol):
    def build_model(self, *args, **kwargs) -> Model:...