Skip to content

Data

We rely on TorchGeo for the implementation of datasets and data modules.

Check out the TorchGeo tutorials on datasets for more in depth information.

In general, it is reccomended you create a TorchGeo dataset specifically for your dataset. This gives you complete control and flexibility on how data is loaded, what transforms are done over it, and even how it is plotted if you log with tools like TensorBoard.

TorchGeo provides GeoDataset and NonGeoDataset.

  • If your data is already nicely tiled and ready for consumption by a neural network, you can inherit from NonGeoDataset. This is essentially a wrapper of a regular torch dataset.
  • If your data consists of large GeoTiffs you would like to sample from during training, you can leverage the powerful GeoDataset from torch. This will automatically align your input data and labels and enable a variety of geo-aware samplers.

Using Datasets already implemented in TorchGeo

Using existing TorchGeo DataModules is very easy! Just plug them in! For instance, to use the EuroSATDataModule, in your config file, set the data as:

data:
  class_path: torchgeo.datamodules.EuroSATDataModule
  init_args:
    batch_size: 32
    num_workers: 8
  dict_kwargs:
    root: /dccstor/geofm-pre/EuroSat
    download: True
    bands:
      - B02
      - B03
      - B04
      - B08A
      - B09
      - B10
Modifying each parameter as you see fit.

You can also do this outside of config files! Simply instantiate the data module as normal and plug it in.

Warning

To define transforms to be passed to DataModules from TorchGeo from config files, you must use the following format:

data:
class_path: terratorch.datamodules.TorchNonGeoDataModule
init_args:
  cls: torchgeo.datamodules.EuroSATDataModule
  transforms:
    - class_path: albumentations.augmentations.geometric.resize.Resize
      init_args:
        height: 224
        width: 224
    - class_path: ToTensorV2
Note the class_path is TorchNonGeoDataModule and the class to be used is passed through cls (there is also a TorchGeoDataModule for geo modules). This has to be done as the transforms argument is passed through **kwargs in TorchGeo, making it difficult to instantiate with LightningCLI. See more details below.

terratorch.datamodules.torchgeo_data_module

Ugly proxy objects so parsing config file works with transforms.

These are necessary since, for LightningCLI to instantiate arguments as objects from the config, they must have type annotations

In TorchGeo, transforms is passed in **kwargs, so it has no type annotations! To get around that, we create these wrappers that have transforms type annotated. They create the transforms and forward all method and attribute calls to the original TorchGeo datamodule.

Additionally, TorchGeo datasets pass the data to the transforms callable as a dict, and as a tensor.

Albumentations expects this data not as a dict but as different key-value arguments, and as numpy. We handle that conversion here.

TorchGeoDataModule

Bases: GeoDataModule

Proxy object for using Geo data modules defined by TorchGeo.

Allows for transforms to be defined and passed using config files. The only reason this class exists is so that we can annotate the transforms argument with a type. This is required for lightningcli and config files. As such, all getattr and setattr will be redirected to the underlying class.

Source code in terratorch/datamodules/torchgeo_data_module.py
class TorchGeoDataModule(GeoDataModule):
    """Proxy object for using Geo data modules defined by TorchGeo.

    Allows for transforms to be defined and passed using config files.
    The only reason this class exists is so that we can annotate the transforms argument with a type.
    This is required for lightningcli and config files.
    As such, all getattr and setattr will be redirected to the underlying class.
    """

    def __init__(
        self,
        cls: type[GeoDataModule],
        batch_size: int | None = None,
        num_workers: int = 0,
        transforms: None | list[BasicTransform] = None,
        **kwargs: Any,
    ):
        """Constructor

        Args:
            cls (type[GeoDataModule]): TorchGeo DataModule class to be instantiated
            batch_size (int | None, optional): batch_size. Defaults to None.
            num_workers (int, optional): num_workers. Defaults to 0.
            transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
                Should enc with ToTensorV2. Defaults to None.
            **kwargs (Any): Arguments passed to instantiate `cls`.
        """
        if batch_size is not None:
            kwargs["batch_size"] = batch_size
        if transforms is not None:
            transforms_as_callable = albumentations_to_callable_with_dict(transforms)
            kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
        # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
        self._proxy = cls(num_workers=num_workers, **kwargs)
        super().__init__(self._proxy.dataset_class)  # dummy arg

    @property
    def collate_fn(self):
        return self._proxy.collate_fn

    @collate_fn.setter
    def collate_fn(self, value):
        self._proxy.collate_fn = value

    @property
    def patch_size(self):
        return self._proxy.patch_size

    @property
    def length(self):
        return self._proxy.length

    def setup(self, stage: str):
        return self._proxy.setup(stage)

    def train_dataloader(self):
        return self._proxy.train_dataloader()

    def val_dataloader(self):
        return self._proxy.val_dataloader()

    def test_dataloader(self):
        return self._proxy.test_dataloader()

    def predict_dataloader(self):
        return self._proxy.predict_dataloader()

    def transfer_batch_to_device(self, batch, device, dataloader_idx):
        return self._proxy.predict_dataloader(batch, device, dataloader_idx)

__init__(cls, batch_size=None, num_workers=0, transforms=None, **kwargs)

Constructor

Parameters:

Name Type Description Default
cls type[GeoDataModule]

TorchGeo DataModule class to be instantiated

required
batch_size int | None

batch_size. Defaults to None.

None
num_workers int

num_workers. Defaults to 0.

0
transforms None | list[BasicTransform]

List of Albumentations Transforms. Should enc with ToTensorV2. Defaults to None.

None
**kwargs Any

Arguments passed to instantiate cls.

{}
Source code in terratorch/datamodules/torchgeo_data_module.py
def __init__(
    self,
    cls: type[GeoDataModule],
    batch_size: int | None = None,
    num_workers: int = 0,
    transforms: None | list[BasicTransform] = None,
    **kwargs: Any,
):
    """Constructor

    Args:
        cls (type[GeoDataModule]): TorchGeo DataModule class to be instantiated
        batch_size (int | None, optional): batch_size. Defaults to None.
        num_workers (int, optional): num_workers. Defaults to 0.
        transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
            Should enc with ToTensorV2. Defaults to None.
        **kwargs (Any): Arguments passed to instantiate `cls`.
    """
    if batch_size is not None:
        kwargs["batch_size"] = batch_size
    if transforms is not None:
        transforms_as_callable = albumentations_to_callable_with_dict(transforms)
        kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
    # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
    self._proxy = cls(num_workers=num_workers, **kwargs)
    super().__init__(self._proxy.dataset_class)  # dummy arg

TorchNonGeoDataModule

Bases: NonGeoDataModule

Proxy object for using NonGeo data modules defined by TorchGeo.

Allows for transforms to be defined and passed using config files. The only reason this class exists is so that we can annotate the transforms argument with a type. This is required for lightningcli and config files. As such, all getattr and setattr will be redirected to the underlying class.

Source code in terratorch/datamodules/torchgeo_data_module.py
class TorchNonGeoDataModule(NonGeoDataModule):
    """Proxy object for using NonGeo data modules defined by TorchGeo.

    Allows for transforms to be defined and passed using config files.
    The only reason this class exists is so that we can annotate the transforms argument with a type.
    This is required for lightningcli and config files.
    As such, all getattr and setattr will be redirected to the underlying class.
    """

    def __init__(
        self,
        cls: type[NonGeoDataModule],
        batch_size: int | None = None,
        num_workers: int = 0,
        transforms: None | list[BasicTransform] = None,
        **kwargs: Any,
    ):
        """Constructor

        Args:
            cls (type[NonGeoDataModule]): TorchGeo DataModule class to be instantiated
            batch_size (int | None, optional): batch_size. Defaults to None.
            num_workers (int, optional): num_workers. Defaults to 0.
            transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
                Should enc with ToTensorV2. Defaults to None.
            **kwargs (Any): Arguments passed to instantiate `cls`.
        """
        if batch_size is not None:
            kwargs["batch_size"] = batch_size
        if transforms is not None:
            transforms_as_callable = albumentations_to_callable_with_dict(transforms)
            kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
        # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
        self._proxy = cls(num_workers=num_workers, **kwargs)
        super().__init__(self._proxy.dataset_class)  # dummy arg

    @property
    def collate_fn(self):
        return self._proxy.collate_fn

    @collate_fn.setter
    def collate_fn(self, value):
        self._proxy.collate_fn = value

    def setup(self, stage: str):
        return self._proxy.setup(stage)

    def train_dataloader(self):
        return self._proxy.train_dataloader()

    def val_dataloader(self):
        return self._proxy.val_dataloader()

    def test_dataloader(self):
        return self._proxy.test_dataloader()

    def predict_dataloader(self):
        return self._proxy.predict_dataloader()

__init__(cls, batch_size=None, num_workers=0, transforms=None, **kwargs)

Constructor

Parameters:

Name Type Description Default
cls type[NonGeoDataModule]

TorchGeo DataModule class to be instantiated

required
batch_size int | None

batch_size. Defaults to None.

None
num_workers int

num_workers. Defaults to 0.

0
transforms None | list[BasicTransform]

List of Albumentations Transforms. Should enc with ToTensorV2. Defaults to None.

None
**kwargs Any

Arguments passed to instantiate cls.

{}
Source code in terratorch/datamodules/torchgeo_data_module.py
def __init__(
    self,
    cls: type[NonGeoDataModule],
    batch_size: int | None = None,
    num_workers: int = 0,
    transforms: None | list[BasicTransform] = None,
    **kwargs: Any,
):
    """Constructor

    Args:
        cls (type[NonGeoDataModule]): TorchGeo DataModule class to be instantiated
        batch_size (int | None, optional): batch_size. Defaults to None.
        num_workers (int, optional): num_workers. Defaults to 0.
        transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
            Should enc with ToTensorV2. Defaults to None.
        **kwargs (Any): Arguments passed to instantiate `cls`.
    """
    if batch_size is not None:
        kwargs["batch_size"] = batch_size
    if transforms is not None:
        transforms_as_callable = albumentations_to_callable_with_dict(transforms)
        kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
    # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
    self._proxy = cls(num_workers=num_workers, **kwargs)
    super().__init__(self._proxy.dataset_class)  # dummy arg

Generic datasets and data modules

For the NonGeoDataset case, we also provide "generic" datasets and datamodules. These can be used when you would like to load data from given directories, in a style similar to the MMLab libraries.

Generic Datasets

terratorch.datasets.generic_pixel_wise_dataset

Module containing generic dataset classes

GenericNonGeoPixelwiseRegressionDataset

Bases: GenericPixelWiseDataset

GenericNonGeoPixelwiseRegressionDataset

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
class GenericNonGeoPixelwiseRegressionDataset(GenericPixelWiseDataset):
    """GenericNonGeoPixelwiseRegressionDataset"""

    def __init__(
        self,
        data_root: Path,
        label_data_root: Path | None = None,
        image_grep: str | None = "*",
        label_grep: str | None = "*",
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[int] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            label_data_root (Path, optional): Path to data root directory with labels.
                If not specified, will use the same as for images.
            image_grep (str, optional): Regular expression appended to data_root to find input images.
                Defaults to "*".
            label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
                Defaults to "*".
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
        """
        super().__init__(
            data_root,
            label_data_root=label_data_root,
            image_grep=image_grep,
            label_grep=label_grep,
            split=split,
            ignore_split_file_extensions=ignore_split_file_extensions,
            allow_substring_split_file=allow_substring_split_file,
            rgb_indices=rgb_indices,
            dataset_bands=dataset_bands,
            output_bands=output_bands,
            constant_scale=constant_scale,
            transform=transform,
            no_data_replace=no_data_replace,
            no_label_replace=no_label_replace,
            expand_temporal_dimension=expand_temporal_dimension,
            reduce_zero_label=reduce_zero_label,
        )

    def __getitem__(self, index: int) -> dict[str, Any]:
        item = super().__getitem__(index)
        item["mask"] = item["mask"].float()
        return item

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, Tensor]): a sample returned by :meth:`__getitem__`
            suptitle (str|None): optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample

        .. versionadded:: 0.2
        """
        image = sample["image"]
        if len(image.shape) == 5:
            return
        if isinstance(image, Tensor):
            image = image.numpy()
        image = image.take(self.rgb_indices, axis=0)
        image = np.transpose(image, (1, 2, 0))
        image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
        image = np.clip(image, 0, 1)

        label_mask = sample["mask"]
        if isinstance(label_mask, Tensor):
            label_mask = label_mask.numpy()

        showing_predictions = "prediction" in sample
        if showing_predictions:
            prediction_mask = sample["prediction"]
            if isinstance(prediction_mask, Tensor):
                prediction_mask = prediction_mask.numpy()

        return self._plot_sample(
            image,
            label_mask,
            prediction=prediction_mask if showing_predictions else None,
            suptitle=suptitle,
        )

    @staticmethod
    def _plot_sample(image, label, prediction=None, suptitle=None):
        num_images = 4 if prediction is not None else 3
        fig, ax = plt.subplots(1, num_images, figsize=(12, 10), layout="compressed")

        norm = mpl.colors.Normalize(vmin=label.min(), vmax=label.max())
        ax[0].axis("off")
        ax[0].title.set_text("Image")
        ax[0].imshow(image)

        ax[1].axis("off")
        ax[1].title.set_text("Ground Truth Mask")
        ax[1].imshow(label, cmap="Greens", norm=norm)

        ax[2].axis("off")
        ax[2].title.set_text("GT Mask on Image")
        ax[2].imshow(image)
        ax[2].imshow(label, cmap="Greens", alpha=0.3, norm=norm)
        # ax[2].legend()

        if prediction is not None:
            ax[3].title.set_text("Predicted Mask")
            ax[3].imshow(prediction, cmap="Greens", norm=norm)

        if suptitle is not None:
            plt.suptitle(suptitle)
        return fig
__init__(data_root, label_data_root=None, image_grep='*', label_grep='*', split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, constant_scale=1, transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=False, reduce_zero_label=False)

Constructor

Parameters:

Name Type Description Default
data_root Path

Path to data root directory

required
label_data_root Path

Path to data root directory with labels. If not specified, will use the same as for images.

None
image_grep str

Regular expression appended to data_root to find input images. Defaults to "*".

'*'
label_grep str

Regular expression appended to data_root to find ground truth masks. Defaults to "*".

'*'
split Path

Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

None
ignore_split_file_extensions bool

Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

True
allow_substring_split_file bool

Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

True
rgb_indices list[str]

Indices of RGB channels. Defaults to [0, 1, 2].

None
dataset_bands list[HLSBands | int] | None

Bands present in the dataset.

None
output_bands list[HLSBands | int] | None

Bands that should be output by the dataset.

None
constant_scale float

Factor to multiply image values by. Defaults to 1.

1
transform Compose | None

Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
no_data_replace float | None

Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

None
no_label_replace int | None

Replace nan values in label with this value. If none, does no replacement. Defaults to None.

None
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

False
reduce_zero_label bool

Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

False
Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def __init__(
    self,
    data_root: Path,
    label_data_root: Path | None = None,
    image_grep: str | None = "*",
    label_grep: str | None = "*",
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[int] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        label_data_root (Path, optional): Path to data root directory with labels.
            If not specified, will use the same as for images.
        image_grep (str, optional): Regular expression appended to data_root to find input images.
            Defaults to "*".
        label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
            Defaults to "*".
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
    """
    super().__init__(
        data_root,
        label_data_root=label_data_root,
        image_grep=image_grep,
        label_grep=label_grep,
        split=split,
        ignore_split_file_extensions=ignore_split_file_extensions,
        allow_substring_split_file=allow_substring_split_file,
        rgb_indices=rgb_indices,
        dataset_bands=dataset_bands,
        output_bands=output_bands,
        constant_scale=constant_scale,
        transform=transform,
        no_data_replace=no_data_replace,
        no_label_replace=no_label_replace,
        expand_temporal_dimension=expand_temporal_dimension,
        reduce_zero_label=reduce_zero_label,
    )
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

a sample returned by :meth:__getitem__

required
suptitle str | None

optional string to use as a suptitle

None

Returns:

Type Description
Figure

a matplotlib Figure with the rendered sample

.. versionadded:: 0.2

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, Tensor]): a sample returned by :meth:`__getitem__`
        suptitle (str|None): optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample

    .. versionadded:: 0.2
    """
    image = sample["image"]
    if len(image.shape) == 5:
        return
    if isinstance(image, Tensor):
        image = image.numpy()
    image = image.take(self.rgb_indices, axis=0)
    image = np.transpose(image, (1, 2, 0))
    image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
    image = np.clip(image, 0, 1)

    label_mask = sample["mask"]
    if isinstance(label_mask, Tensor):
        label_mask = label_mask.numpy()

    showing_predictions = "prediction" in sample
    if showing_predictions:
        prediction_mask = sample["prediction"]
        if isinstance(prediction_mask, Tensor):
            prediction_mask = prediction_mask.numpy()

    return self._plot_sample(
        image,
        label_mask,
        prediction=prediction_mask if showing_predictions else None,
        suptitle=suptitle,
    )
GenericNonGeoSegmentationDataset

Bases: GenericPixelWiseDataset

GenericNonGeoSegmentationDataset

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
    """GenericNonGeoSegmentationDataset"""

    def __init__(
        self,
        data_root: Path,
        num_classes: int,
        label_data_root: Path | None = None,
        image_grep: str | None = "*",
        label_grep: str | None = "*",
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[str] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        class_names: list[str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            num_classes (int): Number of classes in the dataset
            label_data_root (Path, optional): Path to data root directory with labels.
                If not specified, will use the same as for images.
            image_grep (str, optional): Regular expression appended to data_root to find input images.
                Defaults to "*".
            label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
                Defaults to "*".
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            class_names (list[str], optional): Class names. Defaults to None.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
        """
        super().__init__(
            data_root,
            label_data_root=label_data_root,
            image_grep=image_grep,
            label_grep=label_grep,
            split=split,
            ignore_split_file_extensions=ignore_split_file_extensions,
            allow_substring_split_file=allow_substring_split_file,
            rgb_indices=rgb_indices,
            dataset_bands=dataset_bands,
            output_bands=output_bands,
            constant_scale=constant_scale,
            transform=transform,
            no_data_replace=no_data_replace,
            no_label_replace=no_label_replace,
            expand_temporal_dimension=expand_temporal_dimension,
            reduce_zero_label=reduce_zero_label,
        )
        self.num_classes = num_classes
        self.class_names = class_names

    def __getitem__(self, index: int) -> dict[str, Any]:
        item = super().__getitem__(index)
        item["mask"] = item["mask"].long()
        return item

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            suptitle: optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample

        .. versionadded:: 0.2
        """
        image = sample["image"]
        if len(image.shape) == 5:
            return
        if isinstance(image, Tensor):
            image = image.numpy()
        image = image.take(self.rgb_indices, axis=0)
        image = np.transpose(image, (1, 2, 0))
        image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
        image = np.clip(image, 0, 1)

        label_mask = sample["mask"]
        if isinstance(label_mask, Tensor):
            label_mask = label_mask.numpy()

        showing_predictions = "prediction" in sample
        if showing_predictions:
            prediction_mask = sample["prediction"]
            if isinstance(prediction_mask, Tensor):
                prediction_mask = prediction_mask.numpy()

        return self._plot_sample(
            image,
            label_mask,
            self.num_classes,
            prediction=prediction_mask if showing_predictions else None,
            suptitle=suptitle,
            class_names=self.class_names,
        )

    @staticmethod
    def _plot_sample(image, label, num_classes, prediction=None, suptitle=None, class_names=None):
        num_images = 5 if prediction is not None else 4
        fig, ax = plt.subplots(1, num_images, figsize=(12, 10), layout="compressed")

        # for legend
        ax[0].axis("off")

        norm = mpl.colors.Normalize(vmin=0, vmax=num_classes - 1)
        ax[1].axis("off")
        ax[1].title.set_text("Image")
        ax[1].imshow(image)

        ax[2].axis("off")
        ax[2].title.set_text("Ground Truth Mask")
        ax[2].imshow(label, cmap="jet", norm=norm)

        ax[3].axis("off")
        ax[3].title.set_text("GT Mask on Image")
        ax[3].imshow(image)
        ax[3].imshow(label, cmap="jet", alpha=0.3, norm=norm)

        if prediction is not None:
            ax[4].title.set_text("Predicted Mask")
            ax[4].imshow(prediction, cmap="jet", norm=norm)

        cmap = plt.get_cmap("jet")
        legend_data = []
        for i, _ in enumerate(range(num_classes)):
            class_name = class_names[i] if class_names else str(i)
            data = [i, cmap(norm(i)), class_name]
            legend_data.append(data)
        handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
        labels = [n for k, c, n in legend_data]
        ax[0].legend(handles, labels, loc="center")
        if suptitle is not None:
            plt.suptitle(suptitle)
        return fig
__init__(data_root, num_classes, label_data_root=None, image_grep='*', label_grep='*', split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, class_names=None, constant_scale=1, transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=False, reduce_zero_label=False)

Constructor

Parameters:

Name Type Description Default
data_root Path

Path to data root directory

required
num_classes int

Number of classes in the dataset

required
label_data_root Path

Path to data root directory with labels. If not specified, will use the same as for images.

None
image_grep str

Regular expression appended to data_root to find input images. Defaults to "*".

'*'
label_grep str

Regular expression appended to data_root to find ground truth masks. Defaults to "*".

'*'
split Path

Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

None
ignore_split_file_extensions bool

Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True

True
allow_substring_split_file bool

Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

True
rgb_indices list[str]

Indices of RGB channels. Defaults to [0, 1, 2].

None
dataset_bands list[HLSBands | int] | None

Bands present in the dataset.

None
output_bands list[HLSBands | int] | None

Bands that should be output by the dataset.

None
class_names list[str]

Class names. Defaults to None.

None
constant_scale float

Factor to multiply image values by. Defaults to 1.

1
transform Compose | None

Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
no_data_replace float | None

Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

None
no_label_replace int | None

Replace nan values in label with this value. If none, does no replacement. Defaults to None.

None
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

False
reduce_zero_label bool

Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

False
Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def __init__(
    self,
    data_root: Path,
    num_classes: int,
    label_data_root: Path | None = None,
    image_grep: str | None = "*",
    label_grep: str | None = "*",
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[str] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    class_names: list[str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        num_classes (int): Number of classes in the dataset
        label_data_root (Path, optional): Path to data root directory with labels.
            If not specified, will use the same as for images.
        image_grep (str, optional): Regular expression appended to data_root to find input images.
            Defaults to "*".
        label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
            Defaults to "*".
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
        class_names (list[str], optional): Class names. Defaults to None.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
    """
    super().__init__(
        data_root,
        label_data_root=label_data_root,
        image_grep=image_grep,
        label_grep=label_grep,
        split=split,
        ignore_split_file_extensions=ignore_split_file_extensions,
        allow_substring_split_file=allow_substring_split_file,
        rgb_indices=rgb_indices,
        dataset_bands=dataset_bands,
        output_bands=output_bands,
        constant_scale=constant_scale,
        transform=transform,
        no_data_replace=no_data_replace,
        no_label_replace=no_label_replace,
        expand_temporal_dimension=expand_temporal_dimension,
        reduce_zero_label=reduce_zero_label,
    )
    self.num_classes = num_classes
    self.class_names = class_names
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

a sample returned by :meth:__getitem__

required
suptitle str | None

optional string to use as a suptitle

None

Returns:

Type Description
Figure

a matplotlib Figure with the rendered sample

.. versionadded:: 0.2

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        suptitle: optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample

    .. versionadded:: 0.2
    """
    image = sample["image"]
    if len(image.shape) == 5:
        return
    if isinstance(image, Tensor):
        image = image.numpy()
    image = image.take(self.rgb_indices, axis=0)
    image = np.transpose(image, (1, 2, 0))
    image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
    image = np.clip(image, 0, 1)

    label_mask = sample["mask"]
    if isinstance(label_mask, Tensor):
        label_mask = label_mask.numpy()

    showing_predictions = "prediction" in sample
    if showing_predictions:
        prediction_mask = sample["prediction"]
        if isinstance(prediction_mask, Tensor):
            prediction_mask = prediction_mask.numpy()

    return self._plot_sample(
        image,
        label_mask,
        self.num_classes,
        prediction=prediction_mask if showing_predictions else None,
        suptitle=suptitle,
        class_names=self.class_names,
    )
GenericPixelWiseDataset

Bases: NonGeoDataset, ABC

This is a generic dataset class to be used for instantiating datasets from arguments. Ideally, one would create a dataset class specific to a dataset.

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
class GenericPixelWiseDataset(NonGeoDataset, ABC):
    """
    This is a generic dataset class to be used for instantiating datasets from arguments.
    Ideally, one would create a dataset class specific to a dataset.
    """

    def __init__(
        self,
        data_root: Path,
        label_data_root: Path | None = None,
        image_grep: str | None = "*",
        label_grep: str | None = "*",
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[int] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            label_data_root (Path, optional): Path to data root directory with labels.
                If not specified, will use the same as for images.
            image_grep (str, optional): Regular expression appended to data_root to find input images.
                Defaults to "*".
            label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
                Defaults to "*".
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
            output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to -1.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
        """
        super().__init__()

        self.split_file = split

        label_data_root = label_data_root if label_data_root is not None else data_root
        self.image_files = sorted(glob.glob(os.path.join(data_root, image_grep)))
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_data_root, label_grep)))
        self.reduce_zero_label = reduce_zero_label
        self.expand_temporal_dimension = expand_temporal_dimension

        if self.expand_temporal_dimension and output_bands is None:
            msg = "Please provide output_bands when expand_temporal_dimension is True"
            raise Exception(msg)
        if self.split_file is not None:
            with open(self.split_file) as f:
                split = f.readlines()
            valid_files = {rf"{substring.strip()}" for substring in split}
            self.image_files = filter_valid_files(
                self.image_files,
                valid_files=valid_files,
                ignore_extensions=ignore_split_file_extensions,
                allow_substring=allow_substring_split_file,
            )
            self.segmentation_mask_files = filter_valid_files(
                self.segmentation_mask_files,
                valid_files=valid_files,
                ignore_extensions=ignore_split_file_extensions,
                allow_substring=allow_substring_split_file,
            )
        self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

        self.dataset_bands = self._generate_bands_intervals(dataset_bands)
        self.output_bands = self._generate_bands_intervals(output_bands)

        if self.output_bands and not self.dataset_bands:
            msg = "If output bands provided, dataset_bands must also be provided"
            return Exception(msg)  # noqa: PLE0101

        # There is a special condition if the bands are defined as simple strings.
        if self.output_bands:
            if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
                msg = "Output bands must be a subset of dataset bands"
                raise Exception(msg)

            self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

        else:
            self.filter_indices = None

        # If no transform is given, apply only to transform to torch tensor
        self.transform = transform if transform else default_transform
        # self.transform = transform if transform else ToTensorV2()

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace).to_numpy()
        # to channels last
        if self.expand_temporal_dimension:
            image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.output_bands))
        image = np.moveaxis(image, 0, -1)

        if self.filter_indices:
            image = image[..., self.filter_indices]
        output = {
            "image": image.astype(np.float32) * self.constant_scale,
            "mask": self._load_file(self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[
                0
            ]
        }

        if self.reduce_zero_label:
            output["mask"] -= 1
        if self.transform:
            output = self.transform(**output)
        output["filename"] = self.image_files[index]

        return output

    def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data

    def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None):
        if bands_intervals is None:
            return None
        bands = []
        for element in bands_intervals:
            # if its an interval
            if isinstance(element, tuple):
                if len(element) != 2:  # noqa: PLR2004
                    msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive"
                    raise Exception(msg)
                expanded_element = list(range(element[0], element[1] + 1))
                bands.extend(expanded_element)
            else:
                bands.append(element)
        return bands
__init__(data_root, label_data_root=None, image_grep='*', label_grep='*', split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, constant_scale=1, transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=False, reduce_zero_label=False)

Constructor

Parameters:

Name Type Description Default
data_root Path

Path to data root directory

required
label_data_root Path

Path to data root directory with labels. If not specified, will use the same as for images.

None
image_grep str

Regular expression appended to data_root to find input images. Defaults to "*".

'*'
label_grep str

Regular expression appended to data_root to find ground truth masks. Defaults to "*".

'*'
split Path

Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

None
ignore_split_file_extensions bool

Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

True
allow_substring_split_file bool

Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

True
rgb_indices list[str]

Indices of RGB channels. Defaults to [0, 1, 2].

None
dataset_bands list[HLSBands | int | tuple[int, int] | str] | None

Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.

None
output_bands list[HLSBands | int | tuple[int, int] | str] | None

Bands that should be output by the dataset as named by dataset_bands.

None
constant_scale float

Factor to multiply image values by. Defaults to 1.

1
transform Compose | None

Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
no_data_replace float | None

Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

None
no_label_replace int | None

Replace nan values in label with this value. If none, does no replacement. Defaults to -1.

None
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

False
reduce_zero_label bool

Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

False
Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def __init__(
    self,
    data_root: Path,
    label_data_root: Path | None = None,
    image_grep: str | None = "*",
    label_grep: str | None = "*",
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[int] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        label_data_root (Path, optional): Path to data root directory with labels.
            If not specified, will use the same as for images.
        image_grep (str, optional): Regular expression appended to data_root to find input images.
            Defaults to "*".
        label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
            Defaults to "*".
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
        output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to -1.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
    """
    super().__init__()

    self.split_file = split

    label_data_root = label_data_root if label_data_root is not None else data_root
    self.image_files = sorted(glob.glob(os.path.join(data_root, image_grep)))
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_data_root, label_grep)))
    self.reduce_zero_label = reduce_zero_label
    self.expand_temporal_dimension = expand_temporal_dimension

    if self.expand_temporal_dimension and output_bands is None:
        msg = "Please provide output_bands when expand_temporal_dimension is True"
        raise Exception(msg)
    if self.split_file is not None:
        with open(self.split_file) as f:
            split = f.readlines()
        valid_files = {rf"{substring.strip()}" for substring in split}
        self.image_files = filter_valid_files(
            self.image_files,
            valid_files=valid_files,
            ignore_extensions=ignore_split_file_extensions,
            allow_substring=allow_substring_split_file,
        )
        self.segmentation_mask_files = filter_valid_files(
            self.segmentation_mask_files,
            valid_files=valid_files,
            ignore_extensions=ignore_split_file_extensions,
            allow_substring=allow_substring_split_file,
        )
    self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

    self.dataset_bands = self._generate_bands_intervals(dataset_bands)
    self.output_bands = self._generate_bands_intervals(output_bands)

    if self.output_bands and not self.dataset_bands:
        msg = "If output bands provided, dataset_bands must also be provided"
        return Exception(msg)  # noqa: PLE0101

    # There is a special condition if the bands are defined as simple strings.
    if self.output_bands:
        if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
            msg = "Output bands must be a subset of dataset bands"
            raise Exception(msg)

        self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

    else:
        self.filter_indices = None

    # If no transform is given, apply only to transform to torch tensor
    self.transform = transform if transform else default_transform

terratorch.datasets.generic_scalar_label_dataset

Module containing generic dataset classes

GenericNonGeoClassificationDataset

Bases: GenericScalarLabelDataset

GenericNonGeoClassificationDataset

Source code in terratorch/datasets/generic_scalar_label_dataset.py
class GenericNonGeoClassificationDataset(GenericScalarLabelDataset):
    """GenericNonGeoClassificationDataset"""

    def __init__(
        self,
        data_root: Path,
        num_classes: int,
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[str] | None = None,
        dataset_bands: list[HLSBands | int] | None = None,
        output_bands: list[HLSBands | int] | None = None,
        class_names: list[str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float = 0,
        expand_temporal_dimension: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            num_classes (int): Number of classes in the dataset
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            class_names (list[str], optional): Class names. Defaults to None.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
        """
        super().__init__(
            data_root,
            split=split,
            ignore_split_file_extensions=ignore_split_file_extensions,
            allow_substring_split_file=allow_substring_split_file,
            rgb_indices=rgb_indices,
            dataset_bands=dataset_bands,
            output_bands=output_bands,
            constant_scale=constant_scale,
            transform=transform,
            no_data_replace=no_data_replace,
            expand_temporal_dimension=expand_temporal_dimension,
        )
        self.num_classes = num_classes
        self.class_names = class_names

    def __getitem__(self, index: int) -> dict[str, Any]:
        item = super().__getitem__(index)
        item["label"] = torch.tensor(item["label"]).long()
        return item

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        pass
__init__(data_root, num_classes, split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, class_names=None, constant_scale=1, transform=None, no_data_replace=0, expand_temporal_dimension=False)

Constructor

Parameters:

Name Type Description Default
data_root Path

Path to data root directory

required
num_classes int

Number of classes in the dataset

required
split Path

Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

None
ignore_split_file_extensions bool

Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

True
allow_substring_split_file bool

Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

True
rgb_indices list[str]

Indices of RGB channels. Defaults to [0, 1, 2].

None
dataset_bands list[HLSBands | int] | None

Bands present in the dataset.

None
output_bands list[HLSBands | int] | None

Bands that should be output by the dataset.

None
class_names list[str]

Class names. Defaults to None.

None
constant_scale float

Factor to multiply image values by. Defaults to 1.

1
transform Compose | None

Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
no_data_replace float

Replace nan values in input images with this value. Defaults to 0.

0
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

False
Source code in terratorch/datasets/generic_scalar_label_dataset.py
def __init__(
    self,
    data_root: Path,
    num_classes: int,
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[str] | None = None,
    dataset_bands: list[HLSBands | int] | None = None,
    output_bands: list[HLSBands | int] | None = None,
    class_names: list[str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float = 0,
    expand_temporal_dimension: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        num_classes (int): Number of classes in the dataset
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
        class_names (list[str], optional): Class names. Defaults to None.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
    """
    super().__init__(
        data_root,
        split=split,
        ignore_split_file_extensions=ignore_split_file_extensions,
        allow_substring_split_file=allow_substring_split_file,
        rgb_indices=rgb_indices,
        dataset_bands=dataset_bands,
        output_bands=output_bands,
        constant_scale=constant_scale,
        transform=transform,
        no_data_replace=no_data_replace,
        expand_temporal_dimension=expand_temporal_dimension,
    )
    self.num_classes = num_classes
    self.class_names = class_names
GenericScalarLabelDataset

Bases: NonGeoDataset, ImageFolder, ABC

This is a generic dataset class to be used for instantiating datasets from arguments. Ideally, one would create a dataset class specific to a dataset.

Source code in terratorch/datasets/generic_scalar_label_dataset.py
class GenericScalarLabelDataset(NonGeoDataset, ImageFolder, ABC):
    """
    This is a generic dataset class to be used for instantiating datasets from arguments.
    Ideally, one would create a dataset class specific to a dataset.
    """

    def __init__(
        self,
        data_root: Path,
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[int] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float = 0,
        expand_temporal_dimension: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
            output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
        """
        self.split_file = split

        self.image_files = sorted(glob.glob(os.path.join(data_root, "**"), recursive=True))
        self.image_files = [f for f in self.image_files if not os.path.isdir(f)]
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.expand_temporal_dimension = expand_temporal_dimension
        if self.expand_temporal_dimension and output_bands is None:
            msg = "Please provide output_bands when expand_temporal_dimension is True"
            raise Exception(msg)
        if self.split_file is not None:
            with open(self.split_file) as f:
                split = f.readlines()
            valid_files = {rf"{substring.strip()}" for substring in split}
            self.image_files = filter_valid_files(
                self.image_files,
                valid_files=valid_files,
                ignore_extensions=ignore_split_file_extensions,
                allow_substring=allow_substring_split_file,
            )

            def is_valid_file(x):
                return x in self.image_files

        else:

            def is_valid_file(x):
                return True

        super().__init__(
            root=data_root, transform=None, target_transform=None, loader=rasterio_loader, is_valid_file=is_valid_file
        )

        self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

        self.dataset_bands = self._generate_bands_intervals(dataset_bands)
        self.output_bands = self._generate_bands_intervals(output_bands)

        if self.output_bands and not self.dataset_bands:
            msg = "If output bands provided, dataset_bands must also be provided"
            return Exception(msg)  # noqa: PLE0101

        # There is a special condition if the bands are defined as simple strings.
        if self.output_bands:
            if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
                msg = "Output bands must be a subset of dataset bands"
                raise Exception(msg)

            self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

        else:
            self.filter_indices = None
        # If no transform is given, apply only to transform to torch tensor
        self.transforms = transform if transform else default_transform
        # self.transform = transform if transform else ToTensorV2()

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image, label = ImageFolder.__getitem__(self, index)
        if self.expand_temporal_dimension:
            image = rearrange(image, "h w (channels time) -> time h w channels", channels=len(self.output_bands))
        if self.filter_indices:
            image = image[..., self.filter_indices]

        output = {
            "image": image.astype(np.float32) * self.constant_scale,
            "label": label,  # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target)
        }
        if self.transforms:
            output = self.transforms(**output)
        output["filename"] = self.image_files[index]

        return output

    def _load_file(self, path) -> xr.DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        data = data.fillna(self.no_data_replace)
        return data
__init__(data_root, split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, constant_scale=1, transform=None, no_data_replace=0, expand_temporal_dimension=False)

Constructor

Parameters:

Name Type Description Default
data_root Path

Path to data root directory

required
split Path

Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

None
ignore_split_file_extensions bool

Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

True
allow_substring_split_file bool

Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

True
rgb_indices list[str]

Indices of RGB channels. Defaults to [0, 1, 2].

None
dataset_bands list[HLSBands | int | tuple[int, int] | str] | None

Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.

None
output_bands list[HLSBands | int | tuple[int, int] | str] | None

Bands that should be output by the dataset as named by dataset_bands.

None
constant_scale float

Factor to multiply image values by. Defaults to 1.

1
transform Compose | None

Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
no_data_replace float

Replace nan values in input images with this value. Defaults to 0.

0
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

False
Source code in terratorch/datasets/generic_scalar_label_dataset.py
def __init__(
    self,
    data_root: Path,
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[int] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float = 0,
    expand_temporal_dimension: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
        output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
    """
    self.split_file = split

    self.image_files = sorted(glob.glob(os.path.join(data_root, "**"), recursive=True))
    self.image_files = [f for f in self.image_files if not os.path.isdir(f)]
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.expand_temporal_dimension = expand_temporal_dimension
    if self.expand_temporal_dimension and output_bands is None:
        msg = "Please provide output_bands when expand_temporal_dimension is True"
        raise Exception(msg)
    if self.split_file is not None:
        with open(self.split_file) as f:
            split = f.readlines()
        valid_files = {rf"{substring.strip()}" for substring in split}
        self.image_files = filter_valid_files(
            self.image_files,
            valid_files=valid_files,
            ignore_extensions=ignore_split_file_extensions,
            allow_substring=allow_substring_split_file,
        )

        def is_valid_file(x):
            return x in self.image_files

    else:

        def is_valid_file(x):
            return True

    super().__init__(
        root=data_root, transform=None, target_transform=None, loader=rasterio_loader, is_valid_file=is_valid_file
    )

    self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

    self.dataset_bands = self._generate_bands_intervals(dataset_bands)
    self.output_bands = self._generate_bands_intervals(output_bands)

    if self.output_bands and not self.dataset_bands:
        msg = "If output bands provided, dataset_bands must also be provided"
        return Exception(msg)  # noqa: PLE0101

    # There is a special condition if the bands are defined as simple strings.
    if self.output_bands:
        if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
            msg = "Output bands must be a subset of dataset bands"
            raise Exception(msg)

        self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

    else:
        self.filter_indices = None
    # If no transform is given, apply only to transform to torch tensor
    self.transforms = transform if transform else default_transform

Generic Data Modules

terratorch.datamodules.generic_pixel_wise_data_module

This module contains generic data modules for instantiation at runtime.

GenericNonGeoPixelwiseRegressionDataModule

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoPixelwiseRegressionDataset

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
class GenericNonGeoPixelwiseRegressionDataModule(NonGeoDataModule):
    """This is a generic datamodule class for instantiating data modules at runtime.
    Composes several
    [GenericNonGeoPixelwiseRegressionDataset][terratorch.datasets.GenericNonGeoPixelwiseRegressionDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        means: list[float] | str,
        stds: list[float] | str,
        predict_data_root: Path | None = None,
        img_grep: str | None = "*",
        label_grep: str | None = "*",
        train_label_data_root: Path | None = None,
        val_label_data_root: Path | None = None,
        test_label_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        drop_last: bool = True,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            predict_data_root (Path): _description_
            img_grep (str): _description_
            label_grep (str): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            train_label_data_root (Path | None, optional): _description_. Defaults to None.
            val_label_data_root (Path | None, optional): _description_. Defaults to None.
            test_label_data_root (Path | None, optional): _description_. Defaults to None.
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        """
        super().__init__(GenericNonGeoPixelwiseRegressionDataset, batch_size, num_workers, **kwargs)
        self.img_grep = img_grep
        self.label_grep = label_grep
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.drop_last = drop_last
        self.expand_temporal_dimension = expand_temporal_dimension
        self.reduce_zero_label = reduce_zero_label

        self.train_label_data_root = train_label_data_root
        self.val_label_data_root = val_label_data_root
        self.test_label_data_root = test_label_data_root

        self.constant_scale = constant_scale

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )
        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.train_label_data_root,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.val_label_data_root,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.test_label_data_root,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.predict_output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )
__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, means, stds, predict_data_root=None, img_grep='*', label_grep='*', train_label_data_root=None, val_label_data_root=None, test_label_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, predict_dataset_bands=None, predict_output_bands=None, output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, reduce_zero_label=False, no_data_replace=None, no_label_replace=None, drop_last=True, **kwargs)

Constructor

Parameters:

Name Type Description Default
batch_size int

description

required
num_workers int

description

required
train_data_root Path

description

required
val_data_root Path

description

required
test_data_root Path

description

required
predict_data_root Path

description

None
img_grep str

description

'*'
label_grep str

description

'*'
means list[float]

description

required
stds list[float]

description

required
train_label_data_root Path | None

description. Defaults to None.

None
val_label_data_root Path | None

description. Defaults to None.

None
test_label_data_root Path | None

description. Defaults to None.

None
train_split Path | None

description. Defaults to None.

None
val_split Path | None

description. Defaults to None.

None
test_split Path | None

description. Defaults to None.

None
ignore_split_file_extensions bool

Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

True
allow_substring_split_file bool

Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

True
dataset_bands list[HLSBands | int] | None

description. Defaults to None.

None
predict_dataset_bands list[HLSBands | int] | None

description. Defaults to None.

None
output_bands list[HLSBands | int] | None

description. Defaults to None.

None
constant_scale float

description. Defaults to 1.

1
rgb_indices list[int] | None

description. Defaults to None.

None
train_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
val_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
test_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
no_data_replace float | None

Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

None
no_label_replace int | None

Replace nan values in label with this value. If none, does no replacement. Defaults to None.

None
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

False
reduce_zero_label bool

Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

False
drop_last bool

Drop the last batch if it is not complete. Defaults to True.

True
Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    means: list[float] | str,
    stds: list[float] | str,
    predict_data_root: Path | None = None,
    img_grep: str | None = "*",
    label_grep: str | None = "*",
    train_label_data_root: Path | None = None,
    val_label_data_root: Path | None = None,
    test_label_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    drop_last: bool = True,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        predict_data_root (Path): _description_
        img_grep (str): _description_
        label_grep (str): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        train_label_data_root (Path | None, optional): _description_. Defaults to None.
        val_label_data_root (Path | None, optional): _description_. Defaults to None.
        test_label_data_root (Path | None, optional): _description_. Defaults to None.
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
    """
    super().__init__(GenericNonGeoPixelwiseRegressionDataset, batch_size, num_workers, **kwargs)
    self.img_grep = img_grep
    self.label_grep = label_grep
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.drop_last = drop_last
    self.expand_temporal_dimension = expand_temporal_dimension
    self.reduce_zero_label = reduce_zero_label

    self.train_label_data_root = train_label_data_root
    self.val_label_data_root = val_label_data_root
    self.test_label_data_root = test_label_data_root

    self.constant_scale = constant_scale

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )
    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
GenericNonGeoSegmentationDataModule

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoSegmentationDatasets

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
class GenericNonGeoSegmentationDataModule(NonGeoDataModule):
    """
    This is a generic datamodule class for instantiating data modules at runtime.
    Composes several [GenericNonGeoSegmentationDatasets][terratorch.datasets.GenericNonGeoSegmentationDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        img_grep: str,
        label_grep: str,
        means: list[float] | str,
        stds: list[float] | str,
        num_classes: int,
        predict_data_root: Path | None = None,
        train_label_data_root: Path | None = None,
        val_label_data_root: Path | None = None,
        test_label_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        drop_last: bool = True,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            predict_data_root (Path): _description_
            img_grep (str): _description_
            label_grep (str): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            num_classes (int): _description_
            train_label_data_root (Path | None, optional): _description_. Defaults to None.
            val_label_data_root (Path | None, optional): _description_. Defaults to None.
            test_label_data_root (Path | None, optional): _description_. Defaults to None.
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        """
        super().__init__(GenericNonGeoSegmentationDataset, batch_size, num_workers, **kwargs)
        self.num_classes = num_classes
        self.img_grep = img_grep
        self.label_grep = label_grep
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.drop_last = drop_last

        self.train_label_data_root = train_label_data_root
        self.val_label_data_root = val_label_data_root
        self.test_label_data_root = test_label_data_root

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices
        self.expand_temporal_dimension = expand_temporal_dimension
        self.reduce_zero_label = reduce_zero_label

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )
        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)

        # self.aug = Normalize(means, stds)
        # self.collate_fn = collate_fn_list_dicts

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.train_label_data_root,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.val_label_data_root,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.test_label_data_root,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                self.num_classes,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.predict_output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )
__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, img_grep, label_grep, means, stds, num_classes, predict_data_root=None, train_label_data_root=None, val_label_data_root=None, test_label_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, predict_dataset_bands=None, output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, reduce_zero_label=False, no_data_replace=None, no_label_replace=None, drop_last=True, **kwargs)

Constructor

Parameters:

Name Type Description Default
batch_size int

description

required
num_workers int

description

required
train_data_root Path

description

required
val_data_root Path

description

required
test_data_root Path

description

required
predict_data_root Path

description

None
img_grep str

description

required
label_grep str

description

required
means list[float]

description

required
stds list[float]

description

required
num_classes int

description

required
train_label_data_root Path | None

description. Defaults to None.

None
val_label_data_root Path | None

description. Defaults to None.

None
test_label_data_root Path | None

description. Defaults to None.

None
train_split Path | None

description. Defaults to None.

None
val_split Path | None

description. Defaults to None.

None
test_split Path | None

description. Defaults to None.

None
ignore_split_file_extensions bool

Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

True
allow_substring_split_file bool

Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

True
dataset_bands list[HLSBands | int] | None

description. Defaults to None.

None
predict_dataset_bands list[HLSBands | int] | None

description. Defaults to None.

None
output_bands list[HLSBands | int] | None

description. Defaults to None.

None
constant_scale float

description. Defaults to 1.

1
rgb_indices list[int] | None

description. Defaults to None.

None
train_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
val_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
test_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
no_data_replace float | None

Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

None
no_label_replace int | None

Replace nan values in label with this value. If none, does no replacement. Defaults to None.

None
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

False
reduce_zero_label bool

Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

False
drop_last bool

Drop the last batch if it is not complete. Defaults to True.

True
Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    img_grep: str,
    label_grep: str,
    means: list[float] | str,
    stds: list[float] | str,
    num_classes: int,
    predict_data_root: Path | None = None,
    train_label_data_root: Path | None = None,
    val_label_data_root: Path | None = None,
    test_label_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    drop_last: bool = True,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        predict_data_root (Path): _description_
        img_grep (str): _description_
        label_grep (str): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        num_classes (int): _description_
        train_label_data_root (Path | None, optional): _description_. Defaults to None.
        val_label_data_root (Path | None, optional): _description_. Defaults to None.
        test_label_data_root (Path | None, optional): _description_. Defaults to None.
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
    """
    super().__init__(GenericNonGeoSegmentationDataset, batch_size, num_workers, **kwargs)
    self.num_classes = num_classes
    self.img_grep = img_grep
    self.label_grep = label_grep
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.drop_last = drop_last

    self.train_label_data_root = train_label_data_root
    self.val_label_data_root = val_label_data_root
    self.test_label_data_root = test_label_data_root

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices
    self.expand_temporal_dimension = expand_temporal_dimension
    self.reduce_zero_label = reduce_zero_label

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )
    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)

terratorch.datamodules.generic_scalar_label_data_module

This module contains generic data modules for instantiation at runtime.

GenericNonGeoClassificationDataModule

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoClassificationDatasets

Source code in terratorch/datamodules/generic_scalar_label_data_module.py
class GenericNonGeoClassificationDataModule(NonGeoDataModule):
    """
    This is a generic datamodule class for instantiating data modules at runtime.
    Composes several [GenericNonGeoClassificationDatasets][terratorch.datasets.GenericNonGeoClassificationDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        means: list[float] | str,
        stds: list[float] | str,
        num_classes: int,
        predict_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int] | None = None,
        predict_dataset_bands: list[HLSBands | int] | None = None,
        output_bands: list[HLSBands | int] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        no_data_replace: float = 0,
        drop_last: bool = True,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            num_classes (int): _description_
            predict_data_root (Path): _description_
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg".
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        """
        super().__init__(GenericNonGeoClassificationDataset, batch_size, num_workers, **kwargs)
        self.num_classes = num_classes
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.drop_last = drop_last

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices
        self.expand_temporal_dimension = expand_temporal_dimension

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )

        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)

        # self.aug = Normalize(means, stds)
        # self.collate_fn = collate_fn_list_dicts

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                self.num_classes,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                self.num_classes,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                self.num_classes,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                self.num_classes,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )
__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, means, stds, num_classes, predict_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, predict_dataset_bands=None, output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, no_data_replace=0, drop_last=True, **kwargs)

Constructor

Parameters:

Name Type Description Default
batch_size int

description

required
num_workers int

description

required
train_data_root Path

description

required
val_data_root Path

description

required
test_data_root Path

description

required
means list[float]

description

required
stds list[float]

description

required
num_classes int

description

required
predict_data_root Path

description

None
train_split Path | None

description. Defaults to None.

None
val_split Path | None

description. Defaults to None.

None
test_split Path | None

description. Defaults to None.

None
ignore_split_file_extensions bool

Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg".

True
allow_substring_split_file bool

Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

True
dataset_bands list[HLSBands | int] | None

description. Defaults to None.

None
predict_dataset_bands list[HLSBands | int] | None

description. Defaults to None.

None
output_bands list[HLSBands | int] | None

description. Defaults to None.

None
constant_scale float

description. Defaults to 1.

1
rgb_indices list[int] | None

description. Defaults to None.

None
train_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
val_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
test_transform Compose | None

Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

None
no_data_replace float

Replace nan values in input images with this value. Defaults to 0.

0
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

False
drop_last bool

Drop the last batch if it is not complete. Defaults to True.

True
Source code in terratorch/datamodules/generic_scalar_label_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    means: list[float] | str,
    stds: list[float] | str,
    num_classes: int,
    predict_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int] | None = None,
    predict_dataset_bands: list[HLSBands | int] | None = None,
    output_bands: list[HLSBands | int] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    no_data_replace: float = 0,
    drop_last: bool = True,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        num_classes (int): _description_
        predict_data_root (Path): _description_
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg".
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
    """
    super().__init__(GenericNonGeoClassificationDataset, batch_size, num_workers, **kwargs)
    self.num_classes = num_classes
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.drop_last = drop_last

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices
    self.expand_temporal_dimension = expand_temporal_dimension

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )

    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)

Custom datasets and data modules

Below is a documented example of how a custom dataset and data module class can be implemented.

terratorch.datasets.fire_scars

FireScarsHLS

Bases: RasterDataset

RasterDataset implementation for fire scars input images.

Source code in terratorch/datasets/fire_scars.py
class FireScarsHLS(RasterDataset):
    """RasterDataset implementation for fire scars input images."""

    filename_glob = "subsetted*_merged.tif"
    filename_regex = r"subsetted_512x512_HLS\..30\..{6}\.(?P<date>[0-9]*)\.v1.4_merged.tif"
    date_format = "%Y%j"
    is_image = True
    separate_files = False
    all_bands = dataclasses.field(default_factory=["B02", "B03", "B04", "B8A", "B11", "B12"])
    rgb_bands = dataclasses.field(default_factory=["B04", "B03", "B02"])
FireScarsNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for fire scars.

Source code in terratorch/datasets/fire_scars.py
class FireScarsNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for fire scars."""

    def __init__(self, data_root: Path) -> None:
        super().__init__()
        self.image_files = sorted(glob.glob(os.path.join(data_root, "subsetted*_merged.tif")))
        self.segmentation_mask_files = sorted(glob.glob(os.path.join(data_root, "subsetted*.mask.tif")))
        self.rgb_indices = [0, 1, 2]

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, Any]:
        output = {
            "image": self._load_file(self.image_files[index]).astype(np.float32),
            "mask": self._load_file(self.segmentation_mask_files[index]).astype(np.int64),
        }
        return output

    def _load_file(self, path: Path):
        data = rioxarray.open_rasterio(path)
        return data.to_numpy()

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            suptitle: optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample
        """
        image = sample["image"].take(self.rgb_indices, axis=0)
        image = np.transpose(image, (1, 2, 0))
        image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
        image = np.clip(image, 0, 1)

        label_mask = sample["mask"]
        label_mask = np.transpose(label_mask, (1, 2, 0))

        showing_predictions = "prediction" in sample
        if showing_predictions:
            prediction_mask = sample["prediction"]

        return self._plot_sample(
            image,
            label_mask,
            prediction=prediction_mask if showing_predictions else None,
            suptitle=suptitle,
        )

    @staticmethod
    def _plot_sample(image, label, num_classes, prediction=None, suptitle=None, class_names=None):
        num_images = 5 if prediction else 4
        fig, ax = plt.subplots(1, num_images, figsize=(8, 6))

        # for legend
        ax[0].axis("off")

        norm = mpl.colors.Normalize(vmin=0, vmax=num_classes - 1)
        ax[1].axis("off")
        ax[1].title.set_text("Image")
        ax[1].imshow(image)

        ax[2].axis("off")
        ax[2].title.set_text("Ground Truth Mask")
        ax[2].imshow(label, cmap="jet", norm=norm)

        ax[3].axis("off")
        ax[3].title.set_text("GT Mask on Image")
        ax[3].imshow(image)
        ax[3].imshow(label, cmap="jet", alpha=0.3, norm=norm)

        if prediction:
            ax[4].title.set_text("Predicted Mask")
            ax[4].imshow(prediction, cmap="jet", norm=norm)

        cmap = plt.get_cmap("jet")
        legend_data = []
        for i, _ in enumerate(range(num_classes)):
            class_name = class_names[i] if class_names else str(i)
            data = [i, cmap(norm(i)), class_name]
            legend_data.append(data)
        handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
        labels = [n for k, c, n in legend_data]
        ax[0].legend(handles, labels, loc="center")
        if suptitle is not None:
            plt.suptitle(suptitle)
        return fig
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

a sample returned by :meth:__getitem__

required
suptitle str | None

optional string to use as a suptitle

None

Returns:

Type Description
Figure

a matplotlib Figure with the rendered sample

Source code in terratorch/datasets/fire_scars.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        suptitle: optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample
    """
    image = sample["image"].take(self.rgb_indices, axis=0)
    image = np.transpose(image, (1, 2, 0))
    image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
    image = np.clip(image, 0, 1)

    label_mask = sample["mask"]
    label_mask = np.transpose(label_mask, (1, 2, 0))

    showing_predictions = "prediction" in sample
    if showing_predictions:
        prediction_mask = sample["prediction"]

    return self._plot_sample(
        image,
        label_mask,
        prediction=prediction_mask if showing_predictions else None,
        suptitle=suptitle,
    )
FireScarsSegmentationMask

Bases: RasterDataset

RasterDataset implementation for fire scars segmentation mask. Can be easily merged with input images using the & operator.

Source code in terratorch/datasets/fire_scars.py
class FireScarsSegmentationMask(RasterDataset):
    """RasterDataset implementation for fire scars segmentation mask.
    Can be easily merged with input images using the & operator.
    """

    filename_glob = "subsetted*.mask.tif"
    filename_regex = r"subsetted_512x512_HLS\..30\..{6}\.(?P<date>[0-9]*)\.v1.4.mask.tif"
    date_format = "%Y%j"
    is_image = False
    separate_files = False

terratorch.datamodules.fire_scars

FireScarsDataModule

Bases: GeoDataModule

Geo Fire Scars data module implementation that merges input data with ground truth segmentation masks.

Source code in terratorch/datamodules/fire_scars.py
class FireScarsDataModule(GeoDataModule):
    """Geo Fire Scars data module implementation that merges input data with ground truth segmentation masks."""

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(FireScarsSegmentationMask, 4, 224, 100, 0, **kwargs)
        self.train_aug = AugmentationSequential(K.RandomCrop(224, 224), K.Normalize(MEANS, STDS))
        self.aug = AugmentationSequential(K.Normalize(MEANS, STDS))

    def setup(self, stage: str) -> None:
        self.images = FireScarsHLS(
            "/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training/"
        )
        self.labels = FireScarsSegmentationMask(
            "/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training/"
        )
        self.dataset = self.images & self.labels
        self.train_aug = AugmentationSequential(K.RandomCrop(224, 224), K.normalize())

        self.images_test = FireScarsHLS(
            "/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation/"
        )
        self.labels_test = FireScarsSegmentationMask(
            "/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation/"
        )
        self.val_dataset = self.images_test & self.labels_test

        if stage in ["fit"]:
            self.train_batch_sampler = RandomBatchGeoSampler(self.dataset, self.patch_size, self.batch_size, None)
        if stage in ["fit", "validate"]:
            self.val_sampler = GridGeoSampler(self.val_dataset, self.patch_size, self.patch_size)
        if stage in ["test"]:
            self.test_sampler = GridGeoSampler(self.val_dataset, self.patch_size, self.patch_size)
FireScarsNonGeoDataModule

Bases: NonGeoDataModule

NonGeo Fire Scars data module implementation

Source code in terratorch/datamodules/fire_scars.py
class FireScarsNonGeoDataModule(NonGeoDataModule):
    """NonGeo Fire Scars data module implementation"""

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(FireScarsNonGeo, 16, 8, **kwargs)
        # applied for training
        self.train_aug = AugmentationSequential(
            K.Normalize(MEANS, STDS),
            K.RandomCrop((224, 224)),
            data_keys=["image", "mask"],
        )
        self.aug = AugmentationSequential(K.Normalize(MEANS, STDS), data_keys=["image", "mask"])

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                "/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/training/"
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                "/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation/"
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                "/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended/validation/"
            )