Skip to content

Generic Datamodules#

terratorch.datamodules.generic_pixel_wise_data_module.GenericNonGeoSegmentationDataModule #

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoSegmentationDatasets

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
class GenericNonGeoSegmentationDataModule(NonGeoDataModule):
    """
    This is a generic datamodule class for instantiating data modules at runtime.
    Composes several [GenericNonGeoSegmentationDatasets][terratorch.datasets.GenericNonGeoSegmentationDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        means: list[float] | str,
        stds: list[float] | str,
        num_classes: int,
        img_grep: str = "*",
        label_grep: str = "*",
        predict_data_root: Path | None = None,
        train_label_data_root: Path | None = None,
        val_label_data_root: Path | None = None,
        test_label_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        predict_output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        drop_last: bool = True,
        pin_memory: bool = False,
        check_stackability: bool = True,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            predict_data_root (Path): _description_
            img_grep (str): _description_
            label_grep (str): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            num_classes (int): _description_
            train_label_data_root (Path | None, optional): _description_. Defaults to None.
            val_label_data_root (Path | None, optional): _description_. Defaults to None.
            test_label_data_root (Path | None, optional): _description_. Defaults to None.
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
                Naming must match that of dataset_bands. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
                with this value at predict time.
                Defaults to None, which does not overwrite.
            predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
                with this value at predict time. Defaults to None, which does not overwrite.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
            pin_memory (bool): If ``True``, the data loader will copy Tensors
            into device/CUDA pinned memory before returning them. Defaults to False.
            check_stackability (bool): Check if all the files in the dataset has the same size and can be stacked.
        """
        super().__init__(GenericNonGeoSegmentationDataset, batch_size, num_workers, **kwargs)
        self.num_classes = num_classes
        self.img_grep = img_grep
        self.label_grep = label_grep
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.drop_last = drop_last
        self.pin_memory = pin_memory

        self.train_label_data_root = train_label_data_root
        self.val_label_data_root = val_label_data_root
        self.test_label_data_root = test_label_data_root

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices
        self.expand_temporal_dimension = expand_temporal_dimension
        self.reduce_zero_label = reduce_zero_label

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )
        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)

        # self.aug = Normalize(means, stds)
        # self.collate_fn = collate_fn_list_dicts

        self.check_stackability = check_stackability

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.train_label_data_root,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.val_label_data_root,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.test_label_data_root,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.predict_output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        if self.check_stackability:
            logger.info(f"Checking stackability for {split} split.")
            batch_size = check_dataset_stackability(dataset, batch_size)

        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
            pin_memory=self.pin_memory,
        )

__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, means, stds, num_classes, img_grep='*', label_grep='*', predict_data_root=None, train_label_data_root=None, val_label_data_root=None, test_label_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, output_bands=None, predict_dataset_bands=None, predict_output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, reduce_zero_label=False, no_data_replace=None, no_label_replace=None, drop_last=True, pin_memory=False, check_stackability=True, **kwargs) #

Constructor

Parameters:

Name Type Description Default
batch_size int

description

required
num_workers int

description

required
train_data_root Path

description

required
val_data_root Path

description

required
test_data_root Path

description

required
predict_data_root Path

description

None
img_grep str

description

'*'
label_grep str

description

'*'
means list[float]

description

required
stds list[float]

description

required
num_classes int

description

required
train_label_data_root Path | None

description. Defaults to None.

None
val_label_data_root Path | None

description. Defaults to None.

None
test_label_data_root Path | None

description. Defaults to None.

None
train_split Path | None

description. Defaults to None.

None
val_split Path | None

description. Defaults to None.

None
test_split Path | None

description. Defaults to None.

None
ignore_split_file_extensions bool

Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

True
allow_substring_split_file bool

Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

True
dataset_bands list[HLSBands | int] | None

Bands present in the dataset. Defaults to None.

None
output_bands list[HLSBands | int] | None

Bands that should be output by the dataset. Naming must match that of dataset_bands. Defaults to None.

None
predict_dataset_bands list[HLSBands | int] | None

Overwrites dataset_bands with this value at predict time. Defaults to None, which does not overwrite.

None
predict_output_bands list[HLSBands | int] | None

Overwrites output_bands with this value at predict time. Defaults to None, which does not overwrite.

None
constant_scale float

description. Defaults to 1.

1
rgb_indices list[int] | None

description. Defaults to None.

None
train_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
val_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
test_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
no_data_replace float | None

Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

None
no_label_replace int | None

Replace nan values in label with this value. If none, does no replacement. Defaults to None.

None
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

False
reduce_zero_label bool

Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

False
drop_last bool

Drop the last batch if it is not complete. Defaults to True.

True
pin_memory bool

If True, the data loader will copy Tensors

False
check_stackability bool

Check if all the files in the dataset has the same size and can be stacked.

True
Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    means: list[float] | str,
    stds: list[float] | str,
    num_classes: int,
    img_grep: str = "*",
    label_grep: str = "*",
    predict_data_root: Path | None = None,
    train_label_data_root: Path | None = None,
    val_label_data_root: Path | None = None,
    test_label_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    predict_output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    drop_last: bool = True,
    pin_memory: bool = False,
    check_stackability: bool = True,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        predict_data_root (Path): _description_
        img_grep (str): _description_
        label_grep (str): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        num_classes (int): _description_
        train_label_data_root (Path | None, optional): _description_. Defaults to None.
        val_label_data_root (Path | None, optional): _description_. Defaults to None.
        test_label_data_root (Path | None, optional): _description_. Defaults to None.
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            Naming must match that of dataset_bands. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
            with this value at predict time.
            Defaults to None, which does not overwrite.
        predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
            with this value at predict time. Defaults to None, which does not overwrite.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        pin_memory (bool): If ``True``, the data loader will copy Tensors
        into device/CUDA pinned memory before returning them. Defaults to False.
        check_stackability (bool): Check if all the files in the dataset has the same size and can be stacked.
    """
    super().__init__(GenericNonGeoSegmentationDataset, batch_size, num_workers, **kwargs)
    self.num_classes = num_classes
    self.img_grep = img_grep
    self.label_grep = label_grep
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.drop_last = drop_last
    self.pin_memory = pin_memory

    self.train_label_data_root = train_label_data_root
    self.val_label_data_root = val_label_data_root
    self.test_label_data_root = test_label_data_root

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices
    self.expand_temporal_dimension = expand_temporal_dimension
    self.reduce_zero_label = reduce_zero_label

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )
    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)

    # self.aug = Normalize(means, stds)
    # self.collate_fn = collate_fn_list_dicts

    self.check_stackability = check_stackability

terratorch.datamodules.generic_pixel_wise_data_module.GenericNonGeoPixelwiseRegressionDataModule #

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoPixelwiseRegressionDataset

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
class GenericNonGeoPixelwiseRegressionDataModule(NonGeoDataModule):
    """This is a generic datamodule class for instantiating data modules at runtime.
    Composes several
    [GenericNonGeoPixelwiseRegressionDataset][terratorch.datasets.GenericNonGeoPixelwiseRegressionDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        means: list[float] | str,
        stds: list[float] | str,
        predict_data_root: Path | None = None,
        img_grep: str | None = "*",
        label_grep: str | None = "*",
        train_label_data_root: Path | None = None,
        val_label_data_root: Path | None = None,
        test_label_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        predict_output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        drop_last: bool = True,
        pin_memory: bool = False,
        check_stackability: bool = True,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            predict_data_root (Path): _description_
            img_grep (str): _description_
            label_grep (str): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            train_label_data_root (Path | None, optional): _description_. Defaults to None.
            val_label_data_root (Path | None, optional): _description_. Defaults to None.
            test_label_data_root (Path | None, optional): _description_. Defaults to None.
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
                Naming must match that of dataset_bands. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
                with this value at predict time.
                Defaults to None, which does not overwrite.
            predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
                with this value at predict time. Defaults to None, which does not overwrite.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
            pin_memory (bool): If ``True``, the data loader will copy Tensors
            into device/CUDA pinned memory before returning them. Defaults to False.
            check_stackability (bool): Check if all the files in the dataset has the same size and can be stacked.
        """
        super().__init__(GenericNonGeoPixelwiseRegressionDataset, batch_size, num_workers, **kwargs)
        self.img_grep = img_grep
        self.label_grep = label_grep
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.drop_last = drop_last
        self.pin_memory = pin_memory
        self.expand_temporal_dimension = expand_temporal_dimension
        self.reduce_zero_label = reduce_zero_label

        self.train_label_data_root = train_label_data_root
        self.val_label_data_root = val_label_data_root
        self.test_label_data_root = test_label_data_root

        self.constant_scale = constant_scale

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )
        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

        self.check_stackability = check_stackability

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.train_label_data_root,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.val_label_data_root,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.test_label_data_root,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.predict_output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        if self.check_stackability:
            logger.info("Checking stackability.")
            batch_size = check_dataset_stackability(dataset, batch_size)

        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
            pin_memory=self.pin_memory,
        )

__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, means, stds, predict_data_root=None, img_grep='*', label_grep='*', train_label_data_root=None, val_label_data_root=None, test_label_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, output_bands=None, predict_dataset_bands=None, predict_output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, reduce_zero_label=False, no_data_replace=None, no_label_replace=None, drop_last=True, pin_memory=False, check_stackability=True, **kwargs) #

Constructor

Parameters:

Name Type Description Default
batch_size int

description

required
num_workers int

description

required
train_data_root Path

description

required
val_data_root Path

description

required
test_data_root Path

description

required
predict_data_root Path

description

None
img_grep str

description

'*'
label_grep str

description

'*'
means list[float]

description

required
stds list[float]

description

required
train_label_data_root Path | None

description. Defaults to None.

None
val_label_data_root Path | None

description. Defaults to None.

None
test_label_data_root Path | None

description. Defaults to None.

None
train_split Path | None

description. Defaults to None.

None
val_split Path | None

description. Defaults to None.

None
test_split Path | None

description. Defaults to None.

None
ignore_split_file_extensions bool

Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

True
allow_substring_split_file bool

Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

True
dataset_bands list[HLSBands | int] | None

Bands present in the dataset. Defaults to None.

None
output_bands list[HLSBands | int] | None

Bands that should be output by the dataset. Naming must match that of dataset_bands. Defaults to None.

None
predict_dataset_bands list[HLSBands | int] | None

Overwrites dataset_bands with this value at predict time. Defaults to None, which does not overwrite.

None
predict_output_bands list[HLSBands | int] | None

Overwrites output_bands with this value at predict time. Defaults to None, which does not overwrite.

None
constant_scale float

description. Defaults to 1.

1
rgb_indices list[int] | None

description. Defaults to None.

None
train_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
val_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
test_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
no_data_replace float | None

Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

None
no_label_replace int | None

Replace nan values in label with this value. If none, does no replacement. Defaults to None.

None
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

False
reduce_zero_label bool

Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

False
drop_last bool

Drop the last batch if it is not complete. Defaults to True.

True
pin_memory bool

If True, the data loader will copy Tensors

False
check_stackability bool

Check if all the files in the dataset has the same size and can be stacked.

True
Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    means: list[float] | str,
    stds: list[float] | str,
    predict_data_root: Path | None = None,
    img_grep: str | None = "*",
    label_grep: str | None = "*",
    train_label_data_root: Path | None = None,
    val_label_data_root: Path | None = None,
    test_label_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    predict_output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    drop_last: bool = True,
    pin_memory: bool = False,
    check_stackability: bool = True,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        predict_data_root (Path): _description_
        img_grep (str): _description_
        label_grep (str): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        train_label_data_root (Path | None, optional): _description_. Defaults to None.
        val_label_data_root (Path | None, optional): _description_. Defaults to None.
        test_label_data_root (Path | None, optional): _description_. Defaults to None.
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            Naming must match that of dataset_bands. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
            with this value at predict time.
            Defaults to None, which does not overwrite.
        predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
            with this value at predict time. Defaults to None, which does not overwrite.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        pin_memory (bool): If ``True``, the data loader will copy Tensors
        into device/CUDA pinned memory before returning them. Defaults to False.
        check_stackability (bool): Check if all the files in the dataset has the same size and can be stacked.
    """
    super().__init__(GenericNonGeoPixelwiseRegressionDataset, batch_size, num_workers, **kwargs)
    self.img_grep = img_grep
    self.label_grep = label_grep
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.drop_last = drop_last
    self.pin_memory = pin_memory
    self.expand_temporal_dimension = expand_temporal_dimension
    self.reduce_zero_label = reduce_zero_label

    self.train_label_data_root = train_label_data_root
    self.val_label_data_root = val_label_data_root
    self.test_label_data_root = test_label_data_root

    self.constant_scale = constant_scale

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )
    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)

    self.check_stackability = check_stackability

terratorch.datamodules.generic_scalar_label_data_module.GenericNonGeoClassificationDataModule #

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoClassificationDatasets

Source code in terratorch/datamodules/generic_scalar_label_data_module.py
class GenericNonGeoClassificationDataModule(NonGeoDataModule):
    """
    This is a generic datamodule class for instantiating data modules at runtime.
    Composes several [GenericNonGeoClassificationDatasets][terratorch.datasets.GenericNonGeoClassificationDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        means: list[float] | str,
        stds: list[float] | str,
        num_classes: int,
        predict_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int] | None = None,
        predict_dataset_bands: list[HLSBands | int] | None = None,
        output_bands: list[HLSBands | int] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        no_data_replace: float = 0,
        drop_last: bool = True,
        check_stackability: bool = True,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            num_classes (int): _description_
            predict_data_root (Path): _description_
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg".
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
            check_stackability (bool): Check if all the files in the dataset has the same size and can be stacked.
        """
        super().__init__(GenericNonGeoClassificationDataset, batch_size, num_workers, **kwargs)
        self.num_classes = num_classes
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.drop_last = drop_last

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices
        self.expand_temporal_dimension = expand_temporal_dimension

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )

        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)

        # self.aug = Normalize(means, stds)
        # self.collate_fn = collate_fn_list_dicts

        self.check_stackability = check_stackability

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                self.num_classes,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                self.num_classes,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                self.num_classes,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                self.num_classes,
                require_label=False,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        if self.check_stackability:
            logger.info("Checking stackability.")
            batch_size = check_dataset_stackability(dataset, batch_size)

        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )

__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, means, stds, num_classes, predict_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, predict_dataset_bands=None, output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, no_data_replace=0, drop_last=True, check_stackability=True, **kwargs) #

Constructor

Parameters:

Name Type Description Default
batch_size int

description

required
num_workers int

description

required
train_data_root Path

description

required
val_data_root Path

description

required
test_data_root Path

description

required
means list[float]

description

required
stds list[float]

description

required
num_classes int

description

required
predict_data_root Path

description

None
train_split Path | None

description. Defaults to None.

None
val_split Path | None

description. Defaults to None.

None
test_split Path | None

description. Defaults to None.

None
ignore_split_file_extensions bool

Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg".

True
allow_substring_split_file bool

Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

True
dataset_bands list[HLSBands | int] | None

description. Defaults to None.

None
predict_dataset_bands list[HLSBands | int] | None

description. Defaults to None.

None
output_bands list[HLSBands | int] | None

description. Defaults to None.

None
constant_scale float

description. Defaults to 1.

1
rgb_indices list[int] | None

description. Defaults to None.

None
train_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
val_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
test_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
no_data_replace float

Replace nan values in input images with this value. Defaults to 0.

0
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

False
drop_last bool

Drop the last batch if it is not complete. Defaults to True.

True
check_stackability bool

Check if all the files in the dataset has the same size and can be stacked.

True
Source code in terratorch/datamodules/generic_scalar_label_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    means: list[float] | str,
    stds: list[float] | str,
    num_classes: int,
    predict_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int] | None = None,
    predict_dataset_bands: list[HLSBands | int] | None = None,
    output_bands: list[HLSBands | int] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    no_data_replace: float = 0,
    drop_last: bool = True,
    check_stackability: bool = True,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        num_classes (int): _description_
        predict_data_root (Path): _description_
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg".
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        check_stackability (bool): Check if all the files in the dataset has the same size and can be stacked.
    """
    super().__init__(GenericNonGeoClassificationDataset, batch_size, num_workers, **kwargs)
    self.num_classes = num_classes
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.drop_last = drop_last

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices
    self.expand_temporal_dimension = expand_temporal_dimension

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )

    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)

    # self.aug = Normalize(means, stds)
    # self.collate_fn = collate_fn_list_dicts

    self.check_stackability = check_stackability

terratorch.datamodules.generic_multimodal_data_module.GenericMultiModalDataModule #

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoSegmentationDatasets

Source code in terratorch/datamodules/generic_multimodal_data_module.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
class GenericMultiModalDataModule(NonGeoDataModule):
    """
    This is a generic datamodule class for instantiating data modules at runtime.
    Composes several [GenericNonGeoSegmentationDatasets][terratorch.datasets.GenericNonGeoSegmentationDataset]
    """

    def __init__(
        self,
        batch_size: int,
        modalities: list[str],
        train_data_root: dict[str, Path],
        val_data_root: dict[str, Path],
        test_data_root: dict[str, Path],
        means: dict[str, list],
        stds: dict[str, list],
        task: str | None = None,
        num_classes: int | None = None,
        image_grep: str | dict[str, str] | None = None,
        label_grep: str | None = None,
        train_label_data_root: Path | str | None = None,
        val_label_data_root: Path | str | None = None,
        test_label_data_root: Path | str | None = None,
        predict_data_root: dict[str, Path] | str | None = None,
        train_split: Path | str | None = None,
        val_split: Path | str | None = None,
        test_split: Path| str | None = None,
        dataset_bands: dict[str, list] | None = None,
        output_bands: dict[str, list] | None = None,
        predict_dataset_bands: dict[str, list] | None = None,
        predict_output_bands: dict[str, list] | None = None,
        image_modalities: list[str] | None = None,
        rgb_modality: str | None = None,
        rgb_indices: list[int] | None = None,
        allow_substring_file_names: bool = True,
        class_names: list[str] | None = None,
        constant_scale: dict[float] = None,
        train_transform: dict | A.Compose | None | list[A.BasicTransform] = None,
        val_transform: dict | A.Compose | None | list[A.BasicTransform] = None,
        test_transform: dict | A.Compose | None | list[A.BasicTransform] = None,
        shared_transforms: list | bool = True,
        expand_temporal_dimension: bool = False,
        no_data_replace: float | None = None,
        no_label_replace: float | None = -1,
        reduce_zero_label: bool = False,
        drop_last: bool = True,
        num_workers: int = 0,
        pin_memory: bool = False,
        data_with_sample_dim: bool = False,
        allow_missing_modalities: bool = False,
        sample_num_modalities: int | None = None,
        sample_replace: bool = False,
        channel_position: int = -3,
        concat_bands: bool = False,
        check_stackability: bool = True,
        img_grep: str | dict[str, str] | None = None,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): Number of samples in per batch.
            modalities (list[str]): List of modalities.
            train_data_root (dict[Path]): Dictionary of paths to training data root directory or csv/parquet files with 
                image-level data, with modalities as keys.
            val_data_root (dict[Path]): Dictionary of paths to validation data root directory or csv/parquet files with 
                image-level data, with modalities as keys.
            test_data_root (dict[Path]): Dictionary of paths to test data root directory or csv/parquet files with 
                image-level data, with modalities as keys.
            means (dict[list]): Dictionary of mean values as lists with modalities as keys.
            stds (dict[list]): Dictionary of std values as lists with modalities as keys.
            task (str, optional): Selected task form segmentation, regression (pixel-wise), classification,
                multilabel_classification, scalar_regression, scalar (custom image-level task), or None (no targets).
                Defaults to None.
            num_classes (int, optional): Number of classes in classification or segmentation tasks.
            predict_data_root (dict[Path], optional): Dictionary of paths to data root directory or csv/parquet files
                with image-level data, with modalities as keys.
            image_grep (dict[str], optional): Dictionary with regular expression appended to data_root to find input
                images, with modalities as keys. Defaults to "*". Ignored when allow_substring_file_names is False.
            label_grep (str, optional): Regular expression appended to label_data_root to find labels or mask files.
                Defaults to "*". Ignored when allow_substring_file_names is False.
            train_label_data_root (Path | None, optional): Path to data root directory with training labels or
                csv/parquet files with labels. Required for supervised tasks.
            val_label_data_root (Path | None, optional): Path to data root directory with validation labels or
                csv/parquet files with labels. Required for supervised tasks.
            test_label_data_root (Path | None, optional): Path to data root directory with test labels or
                csv/parquet files with labels. Required for supervised tasks.
            train_split (Path, optional): Path to file containing training samples prefixes to be used for this split.
                The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated
                sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise,
                files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]).
                If not specified, search samples based on files in data_root. Defaults to None.
            val_split (Path, optional): Path to file containing validation samples prefixes to be used for this split.
                The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated
                sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise,
                files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]).
                If not specified, search samples based on files in data_root. Defaults to None.
            test_split (Path, optional): Path to file containing test samples prefixes to be used for this split.
                The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated
                sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise,
                files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]).
                If not specified, search samples based on files in data_root. Defaults to None.
            dataset_bands (dict[list], optional): Bands present in the dataset, provided in a dictionary with modalities
                as keys. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so
                that they can then be referred to by output_bands. Needs to be superset of output_bands. Can be a subset
                of all modalities. Defaults to None.
            output_bands (dict[list], optional): Bands that should be output by the dataset as named by dataset_bands,
                provided as a dictionary with modality keys. Can be subset of all modalities. Defaults to None.
            predict_dataset_bands (list[dict], optional): Overwrites dataset_bands with this value at predict time.
                Defaults to None, which does not overwrite.
            predict_output_bands (list[dict], optional): Overwrites output_bands with this value at predict time.
                Defaults to None, which does not overwrite.
            image_modalities(list[str], optional): List of pixel-level raster modalities. Defaults to data_root.keys().
                The difference between all modalities and image_modalities are non-image modalities which are treated
                differently during the transforms and are not modified but only converted into a tensor if possible.
            rgb_modality (str, optional): Modality used for RGB plots. Defaults to first modality in data_root.keys().
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            allow_substring_file_names (bool, optional): Allow substrings during sample identification by adding
                image or label grep to the sample prefixes. If False, treats sample prefixes as full file names.
                If True and no split file is provided, considers the file stem as prefix, otherwise the full file name.
                Defaults to True.
            class_names (list[str], optional): Names of the classes. Defaults to None.
            constant_scale (dict[float]): Factor to multiply data values by, provided as a dictionary with modalities as
                keys. Can be subset of all modalities. Defaults to None.
            train_transform (Albumentations.Compose | dict | None): Albumentations transform to be applied to all image
                modalities. Should end with ToTensorV2() and not include normalization. The transform is not applied to 
                non-image data, which is only converted to tensors if possible. If dict, can include separate transforms 
                per modality (no shared parameters between modalities). 
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | dict | None): Albumentations transform to be applied to all image
                modalities. Should end with ToTensorV2() and not include normalization. The transform is not applied to 
                non-image data, which is only converted to tensors if possible. If dict, can include separate transforms 
                per modality (no shared parameters between modalities). 
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | dict | None): Albumentations transform to be applied to all image
                modalities. Should end with ToTensorV2() and not include normalization. The transform is not applied to 
                non-image data, which is only converted to tensors if possible. If dict, can include separate transforms 
                per modality (no shared parameters between modalities). 
                Defaults to None, which simply applies ToTensorV2().
            shared_transforms (bool): transforms are shared between all image modalities (e.g., similar crop). 
                This setting is ignored if transforms are defined per modality. Defaults to True.  
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Only works with image modalities. Is only applied to modalities with defined dataset_bands.
                Defaults to False.
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no 
                replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. 
                Defaults to None.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
            num_workers (int): Number of parallel workers. Defaults to 0 for single threaded process.
            pin_memory (bool): If ``True``, the data loader will copy Tensors into device/CUDA pinned memory before 
                returning them. Defaults to False.
            data_with_sample_dim (bool): Use a specific collate function to concatenate samples along a existing sample
                dimension instead of stacking the samples. Defaults to False.
            allow_missing_modalities (bool): Experimental feature! Allow missing modalities during data loading.
                Defaults to False.
            sample_num_modalities (int, optional): Load only a subset of modalities per batch. Defaults to None.
            sample_replace (bool): If sample_num_modalities is set, sample modalities with replacement.
                Defaults to False.
            channel_position (int): Position of the channel dimension in the image modalities. Defaults to -3.
            concat_bands (bool): Concatenate all image modalities along the band dimension into a single "image", so
                that it can be processed by single-modal models. Concatenate in the order of provided modalities.
                Works with image modalities only. Does not work with allow_missing_modalities. Defaults to False.
            check_stackability (bool): Check if all the files in the dataset has the same size and can be stacked.
        """

        if task == "segmentation":
            dataset_class = GenericMultimodalSegmentationDataset
        elif task == "regression":
            dataset_class = GenericMultimodalPixelwiseRegressionDataset
        elif task in ["classification", "multilabel_classification", "scalar_regression", "scalar"]:
            dataset_class = GenericMultimodalScalarDataset
            task = "scalar"
        elif task is None:
            dataset_class = GenericMultimodalDataset
        else:
            raise ValueError(f"Unknown task {task}, only segmentation and regression are supported.")

        super().__init__(dataset_class, batch_size, num_workers)
        self.num_classes = num_classes
        self.class_names = class_names
        self.modalities = modalities
        self.image_modalities = image_modalities or modalities
        self.non_image_modalities = list(set(self.modalities) - set(self.image_modalities))
        if task == "scalar":
            self.non_image_modalities += ["label"]

        if img_grep is not None:
            warnings.warn(f"img_grep was renamed to image_grep and will be removed in a future version.",
                          DeprecationWarning)
            image_grep = img_grep

        if isinstance(image_grep, dict):
            # Check if image_grep is valid
            for key, grep in image_grep.items():
                if "*" not in grep:
                    warnings.warn(f"image_grep requires a wildcard with a suffix. "
                                  f"Adding '*' to image_grep[{key}]={grep}.")
                    image_grep[key] = "*" + grep
                if "*" in grep.strip("*/\\"):
                    raise ValueError(f"GenericMultiModalDataModule can only handle image_grep with suffixes "
                                     f"(e.g. '*_mod.tif'). Intermediate wildcards do not work, found {grep}.")
            self.image_grep = {m: image_grep[m] if m in image_grep else "*" for m in modalities}
        else:
            image_grep = image_grep or "*"  # Handle None
            if "*" not in image_grep:
                warnings.warn(f"image_grep requires a wildcard with a suffix. Adding '*' to image_grep={image_grep}.")
                image_grep = "*" + image_grep
            if "*" in image_grep.strip("*/\\"):
                raise ValueError(f"GenericMultiModalDataModule can only handle image_grep with suffixes "
                                 f"(e.g. '*_mod.tif'). Intermediate wildcards do not work, found {image_grep}.")
            self.image_grep = {m: image_grep for m in modalities}
        label_grep = label_grep or "*"  # Handle None
        # Check if label_grep is valid
        if '*' not in label_grep:
            warnings.warn(f"label_grep requires a wildcard with a suffix. Adding '*' to label_grep={label_grep}")
            label_grep = "*" + label_grep
        if "*" in label_grep.strip("*/\\"):
            raise ValueError(f"GenericMultiModalDataModule can only handle label_grep with suffixes "
                             f"(e.g. '*_mask.tif'). Intermediate wildcards do not work, found {label_grep}.")
        self.label_grep = label_grep
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.train_label_data_root = train_label_data_root
        self.val_label_data_root = val_label_data_root
        self.test_label_data_root = test_label_data_root
        self.predict_root = predict_data_root

        # Check paths and modalities
        for name, data_root in [("train", train_data_root), ("val", val_data_root), ("test", test_data_root),
                                ("predict", predict_data_root)]:
            if data_root is None:
                pass
            elif allow_missing_modalities:
                if not set(data_root.keys()) <= set(modalities):
                    raise ValueError(f"Modalities {modalities} do not match {name}_data_root: {data_root}")
            else:
                if not set(data_root.keys()) == set(modalities):
                    raise ValueError(f"Paths in {name}_data_root do not match modalities {modalities}: {data_root}")

        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.allow_substring_file_names = allow_substring_file_names
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.drop_last = drop_last
        self.pin_memory = pin_memory
        self.allow_missing_modalities = allow_missing_modalities
        self.sample_num_modalities = sample_num_modalities
        self.sample_replace = sample_replace
        if allow_missing_modalities and batch_size > 1:
            warnings.warn("allow_missing_modalities is set to True. This is an experimental feature."
                          "Stacking is currently not supported, setting batch_size to 1.")
            self.batch_size = 1

        self.dataset_bands = dataset_bands
        self.output_bands = output_bands
        self.predict_dataset_bands = predict_dataset_bands
        self.predict_output_bands = predict_output_bands

        self.rgb_modality = rgb_modality or modalities[0]
        self.rgb_indices = rgb_indices
        self.expand_temporal_dimension = expand_temporal_dimension
        self.reduce_zero_label = reduce_zero_label
        self.channel_position = channel_position
        self.concat_bands = concat_bands
        self.check_stackability = check_stackability

        if isinstance(train_transform, dict):
            self.train_transform = {m: wrap_in_compose_is_list(train_transform[m]) if m in train_transform else None
                                    for m in modalities}
        elif shared_transforms:
            self.train_transform = wrap_in_compose_is_list(train_transform,
                                                           image_modalities=self.image_modalities,
                                                           non_image_modalities=self.non_image_modalities)
        else:
            self.train_transform = {m: wrap_in_compose_is_list(train_transform)
                                    for m in modalities}

        if isinstance(val_transform, dict):
            self.val_transform = {m: wrap_in_compose_is_list(val_transform[m]) if m in val_transform else None
                                    for m in modalities}
        elif shared_transforms:
            self.val_transform = wrap_in_compose_is_list(val_transform,
                                                         image_modalities=self.image_modalities,
                                                         non_image_modalities=self.non_image_modalities)
        else:
            self.val_transform = {m: wrap_in_compose_is_list(val_transform)
                                    for m in modalities}

        if isinstance(test_transform, dict):
            self.test_transform = {m: wrap_in_compose_is_list(test_transform[m]) if m in test_transform else None
                                    for m in modalities}
        elif shared_transforms:
            self.test_transform = wrap_in_compose_is_list(test_transform,
                                                          image_modalities=self.image_modalities,
                                                          non_image_modalities=self.non_image_modalities,                                                          
                                                          )
        else:
            self.test_transform = {m: wrap_in_compose_is_list(test_transform)
                                    for m in modalities}

        if self.concat_bands:
            # Concatenate mean and std values
            means = load_from_file_or_attribute(np.concatenate([means[m] for m in self.image_modalities]).tolist())
            stds = load_from_file_or_attribute(np.concatenate([stds[m] for m in self.image_modalities]).tolist())

            self.aug = Normalize(means, stds)
        else:
            # Apply standardization per modality
            means = {m: load_from_file_or_attribute(means[m]) for m in means.keys()}
            stds = {m: load_from_file_or_attribute(stds[m]) for m in stds.keys()}

            self.aug = MultimodalNormalize(means, stds)

        self.data_with_sample_dim = data_with_sample_dim

        self.collate_fn = collate_chunk_dicts if data_with_sample_dim else collate_samples


    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                data_root=self.train_root,
                num_classes=self.num_classes,
                image_grep=self.image_grep,
                label_grep=self.label_grep,
                label_data_root=self.train_label_data_root,
                split=self.train_split,
                allow_missing_modalities=self.allow_missing_modalities,
                allow_substring_file_names=self.allow_substring_file_names,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                image_modalities=self.image_modalities,
                rgb_modality=self.rgb_modality,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
                channel_position=self.channel_position,
                data_with_sample_dim = self.data_with_sample_dim,
                concat_bands=self.concat_bands,
            )
            logger.info(f"Train dataset: {len(self.train_dataset)}")
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                data_root=self.val_root,
                num_classes=self.num_classes,
                image_grep=self.image_grep,
                label_grep=self.label_grep,
                label_data_root=self.val_label_data_root,
                split=self.val_split,
                allow_missing_modalities=self.allow_missing_modalities,
                allow_substring_file_names=self.allow_substring_file_names,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                image_modalities=self.image_modalities,
                rgb_modality=self.rgb_modality,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
                channel_position=self.channel_position,
                data_with_sample_dim = self.data_with_sample_dim,
                concat_bands=self.concat_bands,
            )
            logger.info(f"Val dataset: {len(self.val_dataset)}")
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                data_root=self.test_root,
                num_classes=self.num_classes,
                image_grep=self.image_grep,
                label_grep=self.label_grep,
                label_data_root=self.test_label_data_root,
                split=self.test_split,
                allow_missing_modalities=self.allow_missing_modalities,
                allow_substring_file_names=self.allow_substring_file_names,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                image_modalities=self.image_modalities,
                rgb_modality=self.rgb_modality,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
                channel_position=self.channel_position,
                data_with_sample_dim = self.data_with_sample_dim,
                concat_bands=self.concat_bands,
            )
            logger.info(f"Test dataset: {len(self.test_dataset)}")
        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                data_root=self.predict_root,
                label_data_root=None,  # Prediction mode
                num_classes=self.num_classes,
                image_grep=self.image_grep,
                label_grep=self.label_grep,
                allow_missing_modalities=self.allow_missing_modalities,
                allow_substring_file_names=self.allow_substring_file_names,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.predict_output_bands,
                constant_scale=self.constant_scale,
                image_modalities=self.image_modalities,
                rgb_modality=self.rgb_modality,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
                channel_position=self.channel_position,
                data_with_sample_dim=self.data_with_sample_dim,
                concat_bands=self.concat_bands,
            )
            logger.info(f"Predict dataset: {len(self.predict_dataset)}")

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either "train", "val", "test", or "predict".

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        if self.check_stackability and batch_size > 1:
            logger.info(f'Checking dataset stackability for {split} split')
            if self.concat_bands:
                batch_size = check_dataset_stackability(dataset, batch_size)
            else:
                batch_size = check_dataset_stackability_dict(dataset, batch_size)

        if self.sample_num_modalities:
            # Custom batch sampler for sampling modalities per batch
            batch_sampler = MultiModalBatchSampler(
                self.modalities, self.sample_num_modalities, self.sample_replace,
                RandomSampler(dataset) if split == "train" else SequentialSampler(dataset),
                batch_size=batch_size,
                drop_last=split == "train" and self.drop_last
            )
        else:
            batch_sampler = BatchSampler(
                RandomSampler(dataset) if split == "train" else SequentialSampler(dataset),
                batch_size=batch_size,
                drop_last=split == "train" and self.drop_last
            )

        return DataLoader(
            dataset=dataset,
            batch_sampler=batch_sampler,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            pin_memory=self.pin_memory,
        )

__init__(batch_size, modalities, train_data_root, val_data_root, test_data_root, means, stds, task=None, num_classes=None, image_grep=None, label_grep=None, train_label_data_root=None, val_label_data_root=None, test_label_data_root=None, predict_data_root=None, train_split=None, val_split=None, test_split=None, dataset_bands=None, output_bands=None, predict_dataset_bands=None, predict_output_bands=None, image_modalities=None, rgb_modality=None, rgb_indices=None, allow_substring_file_names=True, class_names=None, constant_scale=None, train_transform=None, val_transform=None, test_transform=None, shared_transforms=True, expand_temporal_dimension=False, no_data_replace=None, no_label_replace=-1, reduce_zero_label=False, drop_last=True, num_workers=0, pin_memory=False, data_with_sample_dim=False, allow_missing_modalities=False, sample_num_modalities=None, sample_replace=False, channel_position=-3, concat_bands=False, check_stackability=True, img_grep=None) #

Constructor

Parameters:

Name Type Description Default
batch_size int

Number of samples in per batch.

required
modalities list[str]

List of modalities.

required
train_data_root dict[Path]

Dictionary of paths to training data root directory or csv/parquet files with image-level data, with modalities as keys.

required
val_data_root dict[Path]

Dictionary of paths to validation data root directory or csv/parquet files with image-level data, with modalities as keys.

required
test_data_root dict[Path]

Dictionary of paths to test data root directory or csv/parquet files with image-level data, with modalities as keys.

required
means dict[list]

Dictionary of mean values as lists with modalities as keys.

required
stds dict[list]

Dictionary of std values as lists with modalities as keys.

required
task str

Selected task form segmentation, regression (pixel-wise), classification, multilabel_classification, scalar_regression, scalar (custom image-level task), or None (no targets). Defaults to None.

None
num_classes int

Number of classes in classification or segmentation tasks.

None
predict_data_root dict[Path]

Dictionary of paths to data root directory or csv/parquet files with image-level data, with modalities as keys.

None
image_grep dict[str]

Dictionary with regular expression appended to data_root to find input images, with modalities as keys. Defaults to "*". Ignored when allow_substring_file_names is False.

None
label_grep str

Regular expression appended to label_data_root to find labels or mask files. Defaults to "*". Ignored when allow_substring_file_names is False.

None
train_label_data_root Path | None

Path to data root directory with training labels or csv/parquet files with labels. Required for supervised tasks.

None
val_label_data_root Path | None

Path to data root directory with validation labels or csv/parquet files with labels. Required for supervised tasks.

None
test_label_data_root Path | None

Path to data root directory with test labels or csv/parquet files with labels. Required for supervised tasks.

None
train_split Path

Path to file containing training samples prefixes to be used for this split. The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise, files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]). If not specified, search samples based on files in data_root. Defaults to None.

None
val_split Path

Path to file containing validation samples prefixes to be used for this split. The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise, files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]). If not specified, search samples based on files in data_root. Defaults to None.

None
test_split Path

Path to file containing test samples prefixes to be used for this split. The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise, files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]). If not specified, search samples based on files in data_root. Defaults to None.

None
dataset_bands dict[list]

Bands present in the dataset, provided in a dictionary with modalities as keys. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be referred to by output_bands. Needs to be superset of output_bands. Can be a subset of all modalities. Defaults to None.

None
output_bands dict[list]

Bands that should be output by the dataset as named by dataset_bands, provided as a dictionary with modality keys. Can be subset of all modalities. Defaults to None.

None
predict_dataset_bands list[dict]

Overwrites dataset_bands with this value at predict time. Defaults to None, which does not overwrite.

None
predict_output_bands list[dict]

Overwrites output_bands with this value at predict time. Defaults to None, which does not overwrite.

None
image_modalities(list[str], optional

List of pixel-level raster modalities. Defaults to data_root.keys(). The difference between all modalities and image_modalities are non-image modalities which are treated differently during the transforms and are not modified but only converted into a tensor if possible.

required
rgb_modality str

Modality used for RGB plots. Defaults to first modality in data_root.keys().

None
rgb_indices list[int] | None

description. Defaults to None.

None
allow_substring_file_names bool

Allow substrings during sample identification by adding image or label grep to the sample prefixes. If False, treats sample prefixes as full file names. If True and no split file is provided, considers the file stem as prefix, otherwise the full file name. Defaults to True.

True
class_names list[str]

Names of the classes. Defaults to None.

None
constant_scale dict[float]

Factor to multiply data values by, provided as a dictionary with modalities as keys. Can be subset of all modalities. Defaults to None.

None
train_transform Compose | dict | None

Albumentations transform to be applied to all image modalities. Should end with ToTensorV2() and not include normalization. The transform is not applied to non-image data, which is only converted to tensors if possible. If dict, can include separate transforms per modality (no shared parameters between modalities). Defaults to None, which simply applies ToTensorV2().

None
val_transform Compose | dict | None

Albumentations transform to be applied to all image modalities. Should end with ToTensorV2() and not include normalization. The transform is not applied to non-image data, which is only converted to tensors if possible. If dict, can include separate transforms per modality (no shared parameters between modalities). Defaults to None, which simply applies ToTensorV2().

None
test_transform Compose | dict | None

Albumentations transform to be applied to all image modalities. Should end with ToTensorV2() and not include normalization. The transform is not applied to non-image data, which is only converted to tensors if possible. If dict, can include separate transforms per modality (no shared parameters between modalities). Defaults to None, which simply applies ToTensorV2().

None
shared_transforms bool

transforms are shared between all image modalities (e.g., similar crop). This setting is ignored if transforms are defined per modality. Defaults to True.

True
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Only works with image modalities. Is only applied to modalities with defined dataset_bands. Defaults to False.

False
no_data_replace float | None

Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

None
no_label_replace int | None

Replace nan values in label with this value. If none, does no replacement. Defaults to None.

-1
reduce_zero_label bool

Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

False
drop_last bool

Drop the last batch if it is not complete. Defaults to True.

True
num_workers int

Number of parallel workers. Defaults to 0 for single threaded process.

0
pin_memory bool

If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. Defaults to False.

False
data_with_sample_dim bool

Use a specific collate function to concatenate samples along a existing sample dimension instead of stacking the samples. Defaults to False.

False
allow_missing_modalities bool

Experimental feature! Allow missing modalities during data loading. Defaults to False.

False
sample_num_modalities int

Load only a subset of modalities per batch. Defaults to None.

None
sample_replace bool

If sample_num_modalities is set, sample modalities with replacement. Defaults to False.

False
channel_position int

Position of the channel dimension in the image modalities. Defaults to -3.

-3
concat_bands bool

Concatenate all image modalities along the band dimension into a single "image", so that it can be processed by single-modal models. Concatenate in the order of provided modalities. Works with image modalities only. Does not work with allow_missing_modalities. Defaults to False.

False
check_stackability bool

Check if all the files in the dataset has the same size and can be stacked.

True
Source code in terratorch/datamodules/generic_multimodal_data_module.py
def __init__(
    self,
    batch_size: int,
    modalities: list[str],
    train_data_root: dict[str, Path],
    val_data_root: dict[str, Path],
    test_data_root: dict[str, Path],
    means: dict[str, list],
    stds: dict[str, list],
    task: str | None = None,
    num_classes: int | None = None,
    image_grep: str | dict[str, str] | None = None,
    label_grep: str | None = None,
    train_label_data_root: Path | str | None = None,
    val_label_data_root: Path | str | None = None,
    test_label_data_root: Path | str | None = None,
    predict_data_root: dict[str, Path] | str | None = None,
    train_split: Path | str | None = None,
    val_split: Path | str | None = None,
    test_split: Path| str | None = None,
    dataset_bands: dict[str, list] | None = None,
    output_bands: dict[str, list] | None = None,
    predict_dataset_bands: dict[str, list] | None = None,
    predict_output_bands: dict[str, list] | None = None,
    image_modalities: list[str] | None = None,
    rgb_modality: str | None = None,
    rgb_indices: list[int] | None = None,
    allow_substring_file_names: bool = True,
    class_names: list[str] | None = None,
    constant_scale: dict[float] = None,
    train_transform: dict | A.Compose | None | list[A.BasicTransform] = None,
    val_transform: dict | A.Compose | None | list[A.BasicTransform] = None,
    test_transform: dict | A.Compose | None | list[A.BasicTransform] = None,
    shared_transforms: list | bool = True,
    expand_temporal_dimension: bool = False,
    no_data_replace: float | None = None,
    no_label_replace: float | None = -1,
    reduce_zero_label: bool = False,
    drop_last: bool = True,
    num_workers: int = 0,
    pin_memory: bool = False,
    data_with_sample_dim: bool = False,
    allow_missing_modalities: bool = False,
    sample_num_modalities: int | None = None,
    sample_replace: bool = False,
    channel_position: int = -3,
    concat_bands: bool = False,
    check_stackability: bool = True,
    img_grep: str | dict[str, str] | None = None,
) -> None:
    """Constructor

    Args:
        batch_size (int): Number of samples in per batch.
        modalities (list[str]): List of modalities.
        train_data_root (dict[Path]): Dictionary of paths to training data root directory or csv/parquet files with 
            image-level data, with modalities as keys.
        val_data_root (dict[Path]): Dictionary of paths to validation data root directory or csv/parquet files with 
            image-level data, with modalities as keys.
        test_data_root (dict[Path]): Dictionary of paths to test data root directory or csv/parquet files with 
            image-level data, with modalities as keys.
        means (dict[list]): Dictionary of mean values as lists with modalities as keys.
        stds (dict[list]): Dictionary of std values as lists with modalities as keys.
        task (str, optional): Selected task form segmentation, regression (pixel-wise), classification,
            multilabel_classification, scalar_regression, scalar (custom image-level task), or None (no targets).
            Defaults to None.
        num_classes (int, optional): Number of classes in classification or segmentation tasks.
        predict_data_root (dict[Path], optional): Dictionary of paths to data root directory or csv/parquet files
            with image-level data, with modalities as keys.
        image_grep (dict[str], optional): Dictionary with regular expression appended to data_root to find input
            images, with modalities as keys. Defaults to "*". Ignored when allow_substring_file_names is False.
        label_grep (str, optional): Regular expression appended to label_data_root to find labels or mask files.
            Defaults to "*". Ignored when allow_substring_file_names is False.
        train_label_data_root (Path | None, optional): Path to data root directory with training labels or
            csv/parquet files with labels. Required for supervised tasks.
        val_label_data_root (Path | None, optional): Path to data root directory with validation labels or
            csv/parquet files with labels. Required for supervised tasks.
        test_label_data_root (Path | None, optional): Path to data root directory with test labels or
            csv/parquet files with labels. Required for supervised tasks.
        train_split (Path, optional): Path to file containing training samples prefixes to be used for this split.
            The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated
            sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise,
            files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]).
            If not specified, search samples based on files in data_root. Defaults to None.
        val_split (Path, optional): Path to file containing validation samples prefixes to be used for this split.
            The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated
            sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise,
            files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]).
            If not specified, search samples based on files in data_root. Defaults to None.
        test_split (Path, optional): Path to file containing test samples prefixes to be used for this split.
            The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated
            sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise,
            files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]).
            If not specified, search samples based on files in data_root. Defaults to None.
        dataset_bands (dict[list], optional): Bands present in the dataset, provided in a dictionary with modalities
            as keys. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so
            that they can then be referred to by output_bands. Needs to be superset of output_bands. Can be a subset
            of all modalities. Defaults to None.
        output_bands (dict[list], optional): Bands that should be output by the dataset as named by dataset_bands,
            provided as a dictionary with modality keys. Can be subset of all modalities. Defaults to None.
        predict_dataset_bands (list[dict], optional): Overwrites dataset_bands with this value at predict time.
            Defaults to None, which does not overwrite.
        predict_output_bands (list[dict], optional): Overwrites output_bands with this value at predict time.
            Defaults to None, which does not overwrite.
        image_modalities(list[str], optional): List of pixel-level raster modalities. Defaults to data_root.keys().
            The difference between all modalities and image_modalities are non-image modalities which are treated
            differently during the transforms and are not modified but only converted into a tensor if possible.
        rgb_modality (str, optional): Modality used for RGB plots. Defaults to first modality in data_root.keys().
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        allow_substring_file_names (bool, optional): Allow substrings during sample identification by adding
            image or label grep to the sample prefixes. If False, treats sample prefixes as full file names.
            If True and no split file is provided, considers the file stem as prefix, otherwise the full file name.
            Defaults to True.
        class_names (list[str], optional): Names of the classes. Defaults to None.
        constant_scale (dict[float]): Factor to multiply data values by, provided as a dictionary with modalities as
            keys. Can be subset of all modalities. Defaults to None.
        train_transform (Albumentations.Compose | dict | None): Albumentations transform to be applied to all image
            modalities. Should end with ToTensorV2() and not include normalization. The transform is not applied to 
            non-image data, which is only converted to tensors if possible. If dict, can include separate transforms 
            per modality (no shared parameters between modalities). 
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | dict | None): Albumentations transform to be applied to all image
            modalities. Should end with ToTensorV2() and not include normalization. The transform is not applied to 
            non-image data, which is only converted to tensors if possible. If dict, can include separate transforms 
            per modality (no shared parameters between modalities). 
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | dict | None): Albumentations transform to be applied to all image
            modalities. Should end with ToTensorV2() and not include normalization. The transform is not applied to 
            non-image data, which is only converted to tensors if possible. If dict, can include separate transforms 
            per modality (no shared parameters between modalities). 
            Defaults to None, which simply applies ToTensorV2().
        shared_transforms (bool): transforms are shared between all image modalities (e.g., similar crop). 
            This setting is ignored if transforms are defined per modality. Defaults to True.  
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Only works with image modalities. Is only applied to modalities with defined dataset_bands.
            Defaults to False.
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no 
            replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. 
            Defaults to None.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        num_workers (int): Number of parallel workers. Defaults to 0 for single threaded process.
        pin_memory (bool): If ``True``, the data loader will copy Tensors into device/CUDA pinned memory before 
            returning them. Defaults to False.
        data_with_sample_dim (bool): Use a specific collate function to concatenate samples along a existing sample
            dimension instead of stacking the samples. Defaults to False.
        allow_missing_modalities (bool): Experimental feature! Allow missing modalities during data loading.
            Defaults to False.
        sample_num_modalities (int, optional): Load only a subset of modalities per batch. Defaults to None.
        sample_replace (bool): If sample_num_modalities is set, sample modalities with replacement.
            Defaults to False.
        channel_position (int): Position of the channel dimension in the image modalities. Defaults to -3.
        concat_bands (bool): Concatenate all image modalities along the band dimension into a single "image", so
            that it can be processed by single-modal models. Concatenate in the order of provided modalities.
            Works with image modalities only. Does not work with allow_missing_modalities. Defaults to False.
        check_stackability (bool): Check if all the files in the dataset has the same size and can be stacked.
    """

    if task == "segmentation":
        dataset_class = GenericMultimodalSegmentationDataset
    elif task == "regression":
        dataset_class = GenericMultimodalPixelwiseRegressionDataset
    elif task in ["classification", "multilabel_classification", "scalar_regression", "scalar"]:
        dataset_class = GenericMultimodalScalarDataset
        task = "scalar"
    elif task is None:
        dataset_class = GenericMultimodalDataset
    else:
        raise ValueError(f"Unknown task {task}, only segmentation and regression are supported.")

    super().__init__(dataset_class, batch_size, num_workers)
    self.num_classes = num_classes
    self.class_names = class_names
    self.modalities = modalities
    self.image_modalities = image_modalities or modalities
    self.non_image_modalities = list(set(self.modalities) - set(self.image_modalities))
    if task == "scalar":
        self.non_image_modalities += ["label"]

    if img_grep is not None:
        warnings.warn(f"img_grep was renamed to image_grep and will be removed in a future version.",
                      DeprecationWarning)
        image_grep = img_grep

    if isinstance(image_grep, dict):
        # Check if image_grep is valid
        for key, grep in image_grep.items():
            if "*" not in grep:
                warnings.warn(f"image_grep requires a wildcard with a suffix. "
                              f"Adding '*' to image_grep[{key}]={grep}.")
                image_grep[key] = "*" + grep
            if "*" in grep.strip("*/\\"):
                raise ValueError(f"GenericMultiModalDataModule can only handle image_grep with suffixes "
                                 f"(e.g. '*_mod.tif'). Intermediate wildcards do not work, found {grep}.")
        self.image_grep = {m: image_grep[m] if m in image_grep else "*" for m in modalities}
    else:
        image_grep = image_grep or "*"  # Handle None
        if "*" not in image_grep:
            warnings.warn(f"image_grep requires a wildcard with a suffix. Adding '*' to image_grep={image_grep}.")
            image_grep = "*" + image_grep
        if "*" in image_grep.strip("*/\\"):
            raise ValueError(f"GenericMultiModalDataModule can only handle image_grep with suffixes "
                             f"(e.g. '*_mod.tif'). Intermediate wildcards do not work, found {image_grep}.")
        self.image_grep = {m: image_grep for m in modalities}
    label_grep = label_grep or "*"  # Handle None
    # Check if label_grep is valid
    if '*' not in label_grep:
        warnings.warn(f"label_grep requires a wildcard with a suffix. Adding '*' to label_grep={label_grep}")
        label_grep = "*" + label_grep
    if "*" in label_grep.strip("*/\\"):
        raise ValueError(f"GenericMultiModalDataModule can only handle label_grep with suffixes "
                         f"(e.g. '*_mask.tif'). Intermediate wildcards do not work, found {label_grep}.")
    self.label_grep = label_grep
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.train_label_data_root = train_label_data_root
    self.val_label_data_root = val_label_data_root
    self.test_label_data_root = test_label_data_root
    self.predict_root = predict_data_root

    # Check paths and modalities
    for name, data_root in [("train", train_data_root), ("val", val_data_root), ("test", test_data_root),
                            ("predict", predict_data_root)]:
        if data_root is None:
            pass
        elif allow_missing_modalities:
            if not set(data_root.keys()) <= set(modalities):
                raise ValueError(f"Modalities {modalities} do not match {name}_data_root: {data_root}")
        else:
            if not set(data_root.keys()) == set(modalities):
                raise ValueError(f"Paths in {name}_data_root do not match modalities {modalities}: {data_root}")

    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.allow_substring_file_names = allow_substring_file_names
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.drop_last = drop_last
    self.pin_memory = pin_memory
    self.allow_missing_modalities = allow_missing_modalities
    self.sample_num_modalities = sample_num_modalities
    self.sample_replace = sample_replace
    if allow_missing_modalities and batch_size > 1:
        warnings.warn("allow_missing_modalities is set to True. This is an experimental feature."
                      "Stacking is currently not supported, setting batch_size to 1.")
        self.batch_size = 1

    self.dataset_bands = dataset_bands
    self.output_bands = output_bands
    self.predict_dataset_bands = predict_dataset_bands
    self.predict_output_bands = predict_output_bands

    self.rgb_modality = rgb_modality or modalities[0]
    self.rgb_indices = rgb_indices
    self.expand_temporal_dimension = expand_temporal_dimension
    self.reduce_zero_label = reduce_zero_label
    self.channel_position = channel_position
    self.concat_bands = concat_bands
    self.check_stackability = check_stackability

    if isinstance(train_transform, dict):
        self.train_transform = {m: wrap_in_compose_is_list(train_transform[m]) if m in train_transform else None
                                for m in modalities}
    elif shared_transforms:
        self.train_transform = wrap_in_compose_is_list(train_transform,
                                                       image_modalities=self.image_modalities,
                                                       non_image_modalities=self.non_image_modalities)
    else:
        self.train_transform = {m: wrap_in_compose_is_list(train_transform)
                                for m in modalities}

    if isinstance(val_transform, dict):
        self.val_transform = {m: wrap_in_compose_is_list(val_transform[m]) if m in val_transform else None
                                for m in modalities}
    elif shared_transforms:
        self.val_transform = wrap_in_compose_is_list(val_transform,
                                                     image_modalities=self.image_modalities,
                                                     non_image_modalities=self.non_image_modalities)
    else:
        self.val_transform = {m: wrap_in_compose_is_list(val_transform)
                                for m in modalities}

    if isinstance(test_transform, dict):
        self.test_transform = {m: wrap_in_compose_is_list(test_transform[m]) if m in test_transform else None
                                for m in modalities}
    elif shared_transforms:
        self.test_transform = wrap_in_compose_is_list(test_transform,
                                                      image_modalities=self.image_modalities,
                                                      non_image_modalities=self.non_image_modalities,                                                          
                                                      )
    else:
        self.test_transform = {m: wrap_in_compose_is_list(test_transform)
                                for m in modalities}

    if self.concat_bands:
        # Concatenate mean and std values
        means = load_from_file_or_attribute(np.concatenate([means[m] for m in self.image_modalities]).tolist())
        stds = load_from_file_or_attribute(np.concatenate([stds[m] for m in self.image_modalities]).tolist())

        self.aug = Normalize(means, stds)
    else:
        # Apply standardization per modality
        means = {m: load_from_file_or_attribute(means[m]) for m in means.keys()}
        stds = {m: load_from_file_or_attribute(stds[m]) for m in stds.keys()}

        self.aug = MultimodalNormalize(means, stds)

    self.data_with_sample_dim = data_with_sample_dim

    self.collate_fn = collate_chunk_dicts if data_with_sample_dim else collate_samples