Data

We rely on TorchGeo for the implementation of datasets and data modules.

Check out the TorchGeo tutorials on datasets for more in depth information.

In general, it is reccomended you create a TorchGeo dataset specifically for your dataset. This gives you complete control and flexibility on how data is loaded, what transforms are done over it, and even how it is plotted if you log with tools like TensorBoard.

TorchGeo provides GeoDataset and NonGeoDataset.

  • If your data is already nicely tiled and ready for consumption by a neural network, you can inherit from NonGeoDataset. This is essentially a wrapper of a regular torch dataset.
  • If your data consists of large GeoTiffs you would like to sample from during training, you can leverage the powerful GeoDataset from torch. This will automatically align your input data and labels and enable a variety of geo-aware samplers.

Using Datasets already implemented in TorchGeo

Using existing TorchGeo DataModules is very easy! Just plug them in! For instance, to use the EuroSATDataModule, in your config file, set the data as:

data:
  class_path: torchgeo.datamodules.EuroSATDataModule
  init_args:
    batch_size: 32
    num_workers: 8
  dict_kwargs:
    root: /dccstor/geofm-pre/EuroSat
    download: True
    bands:
      - B02
      - B03
      - B04
      - B08A
      - B09
      - B10
Modifying each parameter as you see fit.

You can also do this outside of config files! Simply instantiate the data module as normal and plug it in.

Warning

To define transforms to be passed to DataModules from TorchGeo from config files, you must use the following format:

data:
class_path: terratorch.datamodules.TorchNonGeoDataModule
init_args:
  cls: torchgeo.datamodules.EuroSATDataModule
  transforms:
    - class_path: albumentations.augmentations.geometric.resize.Resize
      init_args:
        height: 224
        width: 224
    - class_path: ToTensorV2
Note the class_path is TorchNonGeoDataModule and the class to be used is passed through cls (there is also a TorchGeoDataModule for geo modules). This has to be done as the transforms argument is passed through **kwargs in TorchGeo, making it difficult to instantiate with LightningCLI. See more details below.

terratorch.datamodules.torchgeo_data_module

Ugly proxy objects so parsing config file works with transforms.

These are necessary since, for LightningCLI to instantiate arguments as objects from the config, they must have type annotations

In TorchGeo, transforms is passed in **kwargs, so it has no type annotations! To get around that, we create these wrappers that have transforms type annotated. They create the transforms and forward all method and attribute calls to the original TorchGeo datamodule.

Additionally, TorchGeo datasets pass the data to the transforms callable as a dict, and as a tensor.

Albumentations expects this data not as a dict but as different key-value arguments, and as numpy. We handle that conversion here.

TorchGeoDataModule

Bases: GeoDataModule

Proxy object for using Geo data modules defined by TorchGeo.

Allows for transforms to be defined and passed using config files. The only reason this class exists is so that we can annotate the transforms argument with a type. This is required for lightningcli and config files. As such, all getattr and setattr will be redirected to the underlying class.

Source code in terratorch/datamodules/torchgeo_data_module.py
class TorchGeoDataModule(GeoDataModule):
    """Proxy object for using Geo data modules defined by TorchGeo.

    Allows for transforms to be defined and passed using config files.
    The only reason this class exists is so that we can annotate the transforms argument with a type.
    This is required for lightningcli and config files.
    As such, all getattr and setattr will be redirected to the underlying class.
    """

    def __init__(
        self,
        cls: type[GeoDataModule],
        batch_size: int | None = None,
        num_workers: int = 0,
        transforms: None | list[BasicTransform] = None,
        **kwargs: Any,
    ):
        """Constructor

        Args:
            cls (type[GeoDataModule]): TorchGeo DataModule class to be instantiated
            batch_size (int | None, optional): batch_size. Defaults to None.
            num_workers (int, optional): num_workers. Defaults to 0.
            transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
                Should enc with ToTensorV2. Defaults to None.
            **kwargs (Any): Arguments passed to instantiate `cls`.
        """
        if batch_size is not None:
            kwargs["batch_size"] = batch_size
        if transforms is not None:
            transforms_as_callable = albumentations_to_callable_with_dict(transforms)
            kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
        # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
        self._proxy = cls(num_workers=num_workers, **kwargs)
        super().__init__(self._proxy.dataset_class)  # dummy arg

    @property
    def collate_fn(self):
        return self._proxy.collate_fn

    @collate_fn.setter
    def collate_fn(self, value):
        self._proxy.collate_fn = value

    @property
    def patch_size(self):
        return self._proxy.patch_size

    @property
    def length(self):
        return self._proxy.length

    def setup(self, stage: str):
        return self._proxy.setup(stage)

    def train_dataloader(self):
        return self._proxy.train_dataloader()

    def val_dataloader(self):
        return self._proxy.val_dataloader()

    def test_dataloader(self):
        return self._proxy.test_dataloader()

    def predict_dataloader(self):
        return self._proxy.predict_dataloader()

    def transfer_batch_to_device(self, batch, device, dataloader_idx):
        return self._proxy.predict_dataloader(batch, device, dataloader_idx)

__init__(cls, batch_size=None, num_workers=0, transforms=None, **kwargs)

Constructor

Parameters:
  • cls (type[GeoDataModule]) –

    TorchGeo DataModule class to be instantiated

  • batch_size (int | None, default: None ) –

    batch_size. Defaults to None.

  • num_workers (int, default: 0 ) –

    num_workers. Defaults to 0.

  • transforms (None | list[BasicTransform], default: None ) –

    List of Albumentations Transforms. Should enc with ToTensorV2. Defaults to None.

  • **kwargs (Any, default: {} ) –

    Arguments passed to instantiate cls.

Source code in terratorch/datamodules/torchgeo_data_module.py
def __init__(
    self,
    cls: type[GeoDataModule],
    batch_size: int | None = None,
    num_workers: int = 0,
    transforms: None | list[BasicTransform] = None,
    **kwargs: Any,
):
    """Constructor

    Args:
        cls (type[GeoDataModule]): TorchGeo DataModule class to be instantiated
        batch_size (int | None, optional): batch_size. Defaults to None.
        num_workers (int, optional): num_workers. Defaults to 0.
        transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
            Should enc with ToTensorV2. Defaults to None.
        **kwargs (Any): Arguments passed to instantiate `cls`.
    """
    if batch_size is not None:
        kwargs["batch_size"] = batch_size
    if transforms is not None:
        transforms_as_callable = albumentations_to_callable_with_dict(transforms)
        kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
    # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
    self._proxy = cls(num_workers=num_workers, **kwargs)
    super().__init__(self._proxy.dataset_class)  # dummy arg

TorchNonGeoDataModule

Bases: NonGeoDataModule

Proxy object for using NonGeo data modules defined by TorchGeo.

Allows for transforms to be defined and passed using config files. The only reason this class exists is so that we can annotate the transforms argument with a type. This is required for lightningcli and config files. As such, all getattr and setattr will be redirected to the underlying class.

Source code in terratorch/datamodules/torchgeo_data_module.py
class TorchNonGeoDataModule(NonGeoDataModule):
    """Proxy object for using NonGeo data modules defined by TorchGeo.

    Allows for transforms to be defined and passed using config files.
    The only reason this class exists is so that we can annotate the transforms argument with a type.
    This is required for lightningcli and config files.
    As such, all getattr and setattr will be redirected to the underlying class.
    """

    def __init__(
        self,
        cls: type[NonGeoDataModule],
        batch_size: int | None = None,
        num_workers: int = 0,
        transforms: None | list[BasicTransform] = None,
        **kwargs: Any,
    ):
        """Constructor

        Args:
            cls (type[NonGeoDataModule]): TorchGeo DataModule class to be instantiated
            batch_size (int | None, optional): batch_size. Defaults to None.
            num_workers (int, optional): num_workers. Defaults to 0.
            transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
                Should enc with ToTensorV2. Defaults to None.
            **kwargs (Any): Arguments passed to instantiate `cls`.
        """
        if batch_size is not None:
            kwargs["batch_size"] = batch_size
        if transforms is not None:
            transforms_as_callable = albumentations_to_callable_with_dict(transforms)
            kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
        # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
        self._proxy = cls(num_workers=num_workers, **kwargs)
        super().__init__(self._proxy.dataset_class)  # dummy arg

    @property
    def collate_fn(self):
        return self._proxy.collate_fn

    @collate_fn.setter
    def collate_fn(self, value):
        self._proxy.collate_fn = value

    def setup(self, stage: str):
        return self._proxy.setup(stage)

    def train_dataloader(self):
        return self._proxy.train_dataloader()

    def val_dataloader(self):
        return self._proxy.val_dataloader()

    def test_dataloader(self):
        return self._proxy.test_dataloader()

    def predict_dataloader(self):
        return self._proxy.predict_dataloader()

__init__(cls, batch_size=None, num_workers=0, transforms=None, **kwargs)

Constructor

Parameters:
  • cls (type[NonGeoDataModule]) –

    TorchGeo DataModule class to be instantiated

  • batch_size (int | None, default: None ) –

    batch_size. Defaults to None.

  • num_workers (int, default: 0 ) –

    num_workers. Defaults to 0.

  • transforms (None | list[BasicTransform], default: None ) –

    List of Albumentations Transforms. Should enc with ToTensorV2. Defaults to None.

  • **kwargs (Any, default: {} ) –

    Arguments passed to instantiate cls.

Source code in terratorch/datamodules/torchgeo_data_module.py
def __init__(
    self,
    cls: type[NonGeoDataModule],
    batch_size: int | None = None,
    num_workers: int = 0,
    transforms: None | list[BasicTransform] = None,
    **kwargs: Any,
):
    """Constructor

    Args:
        cls (type[NonGeoDataModule]): TorchGeo DataModule class to be instantiated
        batch_size (int | None, optional): batch_size. Defaults to None.
        num_workers (int, optional): num_workers. Defaults to 0.
        transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
            Should enc with ToTensorV2. Defaults to None.
        **kwargs (Any): Arguments passed to instantiate `cls`.
    """
    if batch_size is not None:
        kwargs["batch_size"] = batch_size
    if transforms is not None:
        transforms_as_callable = albumentations_to_callable_with_dict(transforms)
        kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
    # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
    self._proxy = cls(num_workers=num_workers, **kwargs)
    super().__init__(self._proxy.dataset_class)  # dummy arg

Generic datasets and data modules

For the NonGeoDataset case, we also provide "generic" datasets and datamodules. These can be used when you would like to load data from given directories, in a style similar to the MMLab libraries.

Generic Datasets

terratorch.datasets.generic_pixel_wise_dataset

Module containing generic dataset classes

GenericNonGeoPixelwiseRegressionDataset

Bases: GenericPixelWiseDataset

GenericNonGeoPixelwiseRegressionDataset

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
class GenericNonGeoPixelwiseRegressionDataset(GenericPixelWiseDataset):
    """GenericNonGeoPixelwiseRegressionDataset"""

    def __init__(
        self,
        data_root: Path,
        label_data_root: Path | None = None,
        image_grep: str | None = "*",
        label_grep: str | None = "*",
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[int] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            label_data_root (Path, optional): Path to data root directory with labels.
                If not specified, will use the same as for images.
            image_grep (str, optional): Regular expression appended to data_root to find input images.
                Defaults to "*".
            label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
                Defaults to "*".
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
        """
        super().__init__(
            data_root,
            label_data_root=label_data_root,
            image_grep=image_grep,
            label_grep=label_grep,
            split=split,
            ignore_split_file_extensions=ignore_split_file_extensions,
            allow_substring_split_file=allow_substring_split_file,
            rgb_indices=rgb_indices,
            dataset_bands=dataset_bands,
            output_bands=output_bands,
            constant_scale=constant_scale,
            transform=transform,
            no_data_replace=no_data_replace,
            no_label_replace=no_label_replace,
            expand_temporal_dimension=expand_temporal_dimension,
            reduce_zero_label=reduce_zero_label,
        )

    def __getitem__(self, index: int) -> dict[str, Any]:
        item = super().__getitem__(index)
        item["mask"] = item["mask"].float()
        return item

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, Tensor]): a sample returned by :meth:`__getitem__`
            suptitle (str|None): optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample

        .. versionadded:: 0.2
        """
        image = sample["image"]
        if len(image.shape) == 5:
            return
        if isinstance(image, Tensor):
            image = image.numpy()
        image = image.take(self.rgb_indices, axis=0)
        image = np.transpose(image, (1, 2, 0))
        image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
        image = np.clip(image, 0, 1)

        label_mask = sample["mask"]
        if isinstance(label_mask, Tensor):
            label_mask = label_mask.numpy()

        showing_predictions = "prediction" in sample
        if showing_predictions:
            prediction_mask = sample["prediction"]
            if isinstance(prediction_mask, Tensor):
                prediction_mask = prediction_mask.numpy()

        return self._plot_sample(
            image,
            label_mask,
            prediction=prediction_mask if showing_predictions else None,
            suptitle=suptitle,
        )

    @staticmethod
    def _plot_sample(image, label, prediction=None, suptitle=None):
        num_images = 4 if prediction is not None else 3
        fig, ax = plt.subplots(1, num_images, figsize=(12, 10), layout="compressed")

        norm = mpl.colors.Normalize(vmin=label.min(), vmax=label.max())
        ax[0].axis("off")
        ax[0].title.set_text("Image")
        ax[0].imshow(image)

        ax[1].axis("off")
        ax[1].title.set_text("Ground Truth Mask")
        ax[1].imshow(label, cmap="Greens", norm=norm)

        ax[2].axis("off")
        ax[2].title.set_text("GT Mask on Image")
        ax[2].imshow(image)
        ax[2].imshow(label, cmap="Greens", alpha=0.3, norm=norm)
        # ax[2].legend()

        if prediction is not None:
            ax[3].title.set_text("Predicted Mask")
            ax[3].imshow(prediction, cmap="Greens", norm=norm)

        if suptitle is not None:
            plt.suptitle(suptitle)
        return fig
__init__(data_root, label_data_root=None, image_grep='*', label_grep='*', split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, constant_scale=1, transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=False, reduce_zero_label=False)

Constructor

Parameters:
  • data_root (Path) –

    Path to data root directory

  • label_data_root (Path, default: None ) –

    Path to data root directory with labels. If not specified, will use the same as for images.

  • image_grep (str, default: '*' ) –

    Regular expression appended to data_root to find input images. Defaults to "*".

  • label_grep (str, default: '*' ) –

    Regular expression appended to data_root to find ground truth masks. Defaults to "*".

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to None.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def __init__(
    self,
    data_root: Path,
    label_data_root: Path | None = None,
    image_grep: str | None = "*",
    label_grep: str | None = "*",
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[int] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        label_data_root (Path, optional): Path to data root directory with labels.
            If not specified, will use the same as for images.
        image_grep (str, optional): Regular expression appended to data_root to find input images.
            Defaults to "*".
        label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
            Defaults to "*".
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
    """
    super().__init__(
        data_root,
        label_data_root=label_data_root,
        image_grep=image_grep,
        label_grep=label_grep,
        split=split,
        ignore_split_file_extensions=ignore_split_file_extensions,
        allow_substring_split_file=allow_substring_split_file,
        rgb_indices=rgb_indices,
        dataset_bands=dataset_bands,
        output_bands=output_bands,
        constant_scale=constant_scale,
        transform=transform,
        no_data_replace=no_data_replace,
        no_label_replace=no_label_replace,
        expand_temporal_dimension=expand_temporal_dimension,
        reduce_zero_label=reduce_zero_label,
    )
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    a sample returned by :meth:__getitem__

  • suptitle (str | None, default: None ) –

    optional string to use as a suptitle

Returns:
  • Figure

    a matplotlib Figure with the rendered sample

.. versionadded:: 0.2

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, Tensor]): a sample returned by :meth:`__getitem__`
        suptitle (str|None): optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample

    .. versionadded:: 0.2
    """
    image = sample["image"]
    if len(image.shape) == 5:
        return
    if isinstance(image, Tensor):
        image = image.numpy()
    image = image.take(self.rgb_indices, axis=0)
    image = np.transpose(image, (1, 2, 0))
    image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
    image = np.clip(image, 0, 1)

    label_mask = sample["mask"]
    if isinstance(label_mask, Tensor):
        label_mask = label_mask.numpy()

    showing_predictions = "prediction" in sample
    if showing_predictions:
        prediction_mask = sample["prediction"]
        if isinstance(prediction_mask, Tensor):
            prediction_mask = prediction_mask.numpy()

    return self._plot_sample(
        image,
        label_mask,
        prediction=prediction_mask if showing_predictions else None,
        suptitle=suptitle,
    )
GenericNonGeoSegmentationDataset

Bases: GenericPixelWiseDataset

GenericNonGeoSegmentationDataset

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
    """GenericNonGeoSegmentationDataset"""

    def __init__(
        self,
        data_root: Path,
        num_classes: int,
        label_data_root: Path | None = None,
        image_grep: str | None = "*",
        label_grep: str | None = "*",
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[str] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        class_names: list[str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            num_classes (int): Number of classes in the dataset
            label_data_root (Path, optional): Path to data root directory with labels.
                If not specified, will use the same as for images.
            image_grep (str, optional): Regular expression appended to data_root to find input images.
                Defaults to "*".
            label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
                Defaults to "*".
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            class_names (list[str], optional): Class names. Defaults to None.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
        """
        super().__init__(
            data_root,
            label_data_root=label_data_root,
            image_grep=image_grep,
            label_grep=label_grep,
            split=split,
            ignore_split_file_extensions=ignore_split_file_extensions,
            allow_substring_split_file=allow_substring_split_file,
            rgb_indices=rgb_indices,
            dataset_bands=dataset_bands,
            output_bands=output_bands,
            constant_scale=constant_scale,
            transform=transform,
            no_data_replace=no_data_replace,
            no_label_replace=no_label_replace,
            expand_temporal_dimension=expand_temporal_dimension,
            reduce_zero_label=reduce_zero_label,
        )
        self.num_classes = num_classes
        self.class_names = class_names

    def __getitem__(self, index: int) -> dict[str, Any]:
        item = super().__getitem__(index)
        item["mask"] = item["mask"].long()
        return item

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            suptitle: optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample

        .. versionadded:: 0.2
        """
        image = sample["image"]
        if len(image.shape) == 5:
            return
        if isinstance(image, Tensor):
            image = image.numpy()
        image = image.take(self.rgb_indices, axis=0)
        image = np.transpose(image, (1, 2, 0))
        image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
        image = np.clip(image, 0, 1)

        label_mask = sample["mask"]
        if isinstance(label_mask, Tensor):
            label_mask = label_mask.numpy()

        showing_predictions = "prediction" in sample
        if showing_predictions:
            prediction_mask = sample["prediction"]
            if isinstance(prediction_mask, Tensor):
                prediction_mask = prediction_mask.numpy()

        return self._plot_sample(
            image,
            label_mask,
            self.num_classes,
            prediction=prediction_mask if showing_predictions else None,
            suptitle=suptitle,
            class_names=self.class_names,
        )

    @staticmethod
    def _plot_sample(image, label, num_classes, prediction=None, suptitle=None, class_names=None):
        num_images = 5 if prediction is not None else 4
        fig, ax = plt.subplots(1, num_images, figsize=(12, 10), layout="compressed")

        # for legend
        ax[0].axis("off")

        norm = mpl.colors.Normalize(vmin=0, vmax=num_classes - 1)
        ax[1].axis("off")
        ax[1].title.set_text("Image")
        ax[1].imshow(image)

        ax[2].axis("off")
        ax[2].title.set_text("Ground Truth Mask")
        ax[2].imshow(label, cmap="jet", norm=norm)

        ax[3].axis("off")
        ax[3].title.set_text("GT Mask on Image")
        ax[3].imshow(image)
        ax[3].imshow(label, cmap="jet", alpha=0.3, norm=norm)

        if prediction is not None:
            ax[4].title.set_text("Predicted Mask")
            ax[4].imshow(prediction, cmap="jet", norm=norm)

        cmap = plt.get_cmap("jet")
        legend_data = []
        for i, _ in enumerate(range(num_classes)):
            class_name = class_names[i] if class_names else str(i)
            data = [i, cmap(norm(i)), class_name]
            legend_data.append(data)
        handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
        labels = [n for k, c, n in legend_data]
        ax[0].legend(handles, labels, loc="center")
        if suptitle is not None:
            plt.suptitle(suptitle)
        return fig
__init__(data_root, num_classes, label_data_root=None, image_grep='*', label_grep='*', split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, class_names=None, constant_scale=1, transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=False, reduce_zero_label=False)

Constructor

Parameters:
  • data_root (Path) –

    Path to data root directory

  • num_classes (int) –

    Number of classes in the dataset

  • label_data_root (Path, default: None ) –

    Path to data root directory with labels. If not specified, will use the same as for images.

  • image_grep (str, default: '*' ) –

    Regular expression appended to data_root to find input images. Defaults to "*".

  • label_grep (str, default: '*' ) –

    Regular expression appended to data_root to find ground truth masks. Defaults to "*".

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset.

  • class_names (list[str], default: None ) –

    Class names. Defaults to None.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to None.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def __init__(
    self,
    data_root: Path,
    num_classes: int,
    label_data_root: Path | None = None,
    image_grep: str | None = "*",
    label_grep: str | None = "*",
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[str] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    class_names: list[str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        num_classes (int): Number of classes in the dataset
        label_data_root (Path, optional): Path to data root directory with labels.
            If not specified, will use the same as for images.
        image_grep (str, optional): Regular expression appended to data_root to find input images.
            Defaults to "*".
        label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
            Defaults to "*".
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
        class_names (list[str], optional): Class names. Defaults to None.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
    """
    super().__init__(
        data_root,
        label_data_root=label_data_root,
        image_grep=image_grep,
        label_grep=label_grep,
        split=split,
        ignore_split_file_extensions=ignore_split_file_extensions,
        allow_substring_split_file=allow_substring_split_file,
        rgb_indices=rgb_indices,
        dataset_bands=dataset_bands,
        output_bands=output_bands,
        constant_scale=constant_scale,
        transform=transform,
        no_data_replace=no_data_replace,
        no_label_replace=no_label_replace,
        expand_temporal_dimension=expand_temporal_dimension,
        reduce_zero_label=reduce_zero_label,
    )
    self.num_classes = num_classes
    self.class_names = class_names
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    a sample returned by :meth:__getitem__

  • suptitle (str | None, default: None ) –

    optional string to use as a suptitle

Returns:
  • Figure

    a matplotlib Figure with the rendered sample

.. versionadded:: 0.2

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        suptitle: optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample

    .. versionadded:: 0.2
    """
    image = sample["image"]
    if len(image.shape) == 5:
        return
    if isinstance(image, Tensor):
        image = image.numpy()
    image = image.take(self.rgb_indices, axis=0)
    image = np.transpose(image, (1, 2, 0))
    image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
    image = np.clip(image, 0, 1)

    label_mask = sample["mask"]
    if isinstance(label_mask, Tensor):
        label_mask = label_mask.numpy()

    showing_predictions = "prediction" in sample
    if showing_predictions:
        prediction_mask = sample["prediction"]
        if isinstance(prediction_mask, Tensor):
            prediction_mask = prediction_mask.numpy()

    return self._plot_sample(
        image,
        label_mask,
        self.num_classes,
        prediction=prediction_mask if showing_predictions else None,
        suptitle=suptitle,
        class_names=self.class_names,
    )
GenericPixelWiseDataset

Bases: NonGeoDataset, ABC

This is a generic dataset class to be used for instantiating datasets from arguments. Ideally, one would create a dataset class specific to a dataset.

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
class GenericPixelWiseDataset(NonGeoDataset, ABC):
    """
    This is a generic dataset class to be used for instantiating datasets from arguments.
    Ideally, one would create a dataset class specific to a dataset.
    """

    def __init__(
        self,
        data_root: Path,
        label_data_root: Path | None = None,
        image_grep: str | None = "*",
        label_grep: str | None = "*",
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[int] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            label_data_root (Path, optional): Path to data root directory with labels.
                If not specified, will use the same as for images.
            image_grep (str, optional): Regular expression appended to data_root to find input images.
                Defaults to "*".
            label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
                Defaults to "*".
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
            output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to -1.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
        """
        super().__init__()

        self.split_file = split

        label_data_root = label_data_root if label_data_root is not None else data_root
        self.image_files = sorted(glob.glob(os.path.join(data_root, image_grep)))
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_data_root, label_grep)))
        self.reduce_zero_label = reduce_zero_label
        self.expand_temporal_dimension = expand_temporal_dimension

        if self.expand_temporal_dimension and output_bands is None:
            msg = "Please provide output_bands when expand_temporal_dimension is True"
            raise Exception(msg)
        if self.split_file is not None:
            with open(self.split_file) as f:
                split = f.readlines()
            valid_files = {rf"{substring.strip()}" for substring in split}
            self.image_files = filter_valid_files(
                self.image_files,
                valid_files=valid_files,
                ignore_extensions=ignore_split_file_extensions,
                allow_substring=allow_substring_split_file,
            )
            self.segmentation_mask_files = filter_valid_files(
                self.segmentation_mask_files,
                valid_files=valid_files,
                ignore_extensions=ignore_split_file_extensions,
                allow_substring=allow_substring_split_file,
            )
        self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

        self.dataset_bands = generate_bands_intervals(dataset_bands)
        self.output_bands = generate_bands_intervals(output_bands)

        if self.output_bands and not self.dataset_bands:
            msg = "If output bands provided, dataset_bands must also be provided"
            return Exception(msg)  # noqa: PLE0101

        # There is a special condition if the bands are defined as simple strings.
        if self.output_bands:
            if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
                msg = "Output bands must be a subset of dataset bands"
                raise Exception(msg)

            self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

        else:
            self.filter_indices = None

        # If no transform is given, apply only to transform to torch tensor
        self.transform = transform if transform else default_transform
        # self.transform = transform if transform else ToTensorV2()

        import warnings

        import rasterio

        warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace).to_numpy()
        # to channels last
        if self.expand_temporal_dimension:
            image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.output_bands))
        image = np.moveaxis(image, 0, -1)

        if self.filter_indices:
            image = image[..., self.filter_indices]
        output = {
            "image": image.astype(np.float32) * self.constant_scale,
            "mask": self._load_file(self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[
                0
            ]
        }

        if self.reduce_zero_label:
            output["mask"] -= 1
        if self.transform:
            output = self.transform(**output)
        output["filename"] = self.image_files[index]

        return output

    def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data
__init__(data_root, label_data_root=None, image_grep='*', label_grep='*', split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, constant_scale=1, transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=False, reduce_zero_label=False)

Constructor

Parameters:
  • data_root (Path) –

    Path to data root directory

  • label_data_root (Path, default: None ) –

    Path to data root directory with labels. If not specified, will use the same as for images.

  • image_grep (str, default: '*' ) –

    Regular expression appended to data_root to find input images. Defaults to "*".

  • label_grep (str, default: '*' ) –

    Regular expression appended to data_root to find ground truth masks. Defaults to "*".

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None, default: None ) –

    Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.

  • output_bands (list[HLSBands | int | tuple[int, int] | str] | None, default: None ) –

    Bands that should be output by the dataset as named by dataset_bands.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to -1.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def __init__(
    self,
    data_root: Path,
    label_data_root: Path | None = None,
    image_grep: str | None = "*",
    label_grep: str | None = "*",
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[int] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        label_data_root (Path, optional): Path to data root directory with labels.
            If not specified, will use the same as for images.
        image_grep (str, optional): Regular expression appended to data_root to find input images.
            Defaults to "*".
        label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
            Defaults to "*".
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
        output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to -1.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
    """
    super().__init__()

    self.split_file = split

    label_data_root = label_data_root if label_data_root is not None else data_root
    self.image_files = sorted(glob.glob(os.path.join(data_root, image_grep)))
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_data_root, label_grep)))
    self.reduce_zero_label = reduce_zero_label
    self.expand_temporal_dimension = expand_temporal_dimension

    if self.expand_temporal_dimension and output_bands is None:
        msg = "Please provide output_bands when expand_temporal_dimension is True"
        raise Exception(msg)
    if self.split_file is not None:
        with open(self.split_file) as f:
            split = f.readlines()
        valid_files = {rf"{substring.strip()}" for substring in split}
        self.image_files = filter_valid_files(
            self.image_files,
            valid_files=valid_files,
            ignore_extensions=ignore_split_file_extensions,
            allow_substring=allow_substring_split_file,
        )
        self.segmentation_mask_files = filter_valid_files(
            self.segmentation_mask_files,
            valid_files=valid_files,
            ignore_extensions=ignore_split_file_extensions,
            allow_substring=allow_substring_split_file,
        )
    self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

    self.dataset_bands = generate_bands_intervals(dataset_bands)
    self.output_bands = generate_bands_intervals(output_bands)

    if self.output_bands and not self.dataset_bands:
        msg = "If output bands provided, dataset_bands must also be provided"
        return Exception(msg)  # noqa: PLE0101

    # There is a special condition if the bands are defined as simple strings.
    if self.output_bands:
        if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
            msg = "Output bands must be a subset of dataset bands"
            raise Exception(msg)

        self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

    else:
        self.filter_indices = None

    # If no transform is given, apply only to transform to torch tensor
    self.transform = transform if transform else default_transform
    # self.transform = transform if transform else ToTensorV2()

    import warnings

    import rasterio

    warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

terratorch.datasets.generic_scalar_label_dataset

Module containing generic dataset classes

GenericNonGeoClassificationDataset

Bases: GenericScalarLabelDataset

GenericNonGeoClassificationDataset

Source code in terratorch/datasets/generic_scalar_label_dataset.py
class GenericNonGeoClassificationDataset(GenericScalarLabelDataset):
    """GenericNonGeoClassificationDataset"""

    def __init__(
        self,
        data_root: Path,
        num_classes: int,
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,  # noqa: FBT001, FBT002
        allow_substring_split_file: bool = True,  # noqa: FBT001, FBT002
        rgb_indices: list[str] | None = None,
        dataset_bands: list[HLSBands | int] | None = None,
        output_bands: list[HLSBands | int] | None = None,
        class_names: list[str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float = 0,
        expand_temporal_dimension: bool = False,  # noqa: FBT001, FBT002
    ) -> None:
        """A generic Non-Geo dataset for classification.

        Args:
            data_root (Path): Path to data root directory
            num_classes (int): Number of classes in the dataset
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            class_names (list[str], optional): Class names. Defaults to None.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
        """
        super().__init__(
            data_root,
            split=split,
            ignore_split_file_extensions=ignore_split_file_extensions,
            allow_substring_split_file=allow_substring_split_file,
            rgb_indices=rgb_indices,
            dataset_bands=dataset_bands,
            output_bands=output_bands,
            constant_scale=constant_scale,
            transform=transform,
            no_data_replace=no_data_replace,
            expand_temporal_dimension=expand_temporal_dimension,
        )
        self.num_classes = num_classes
        self.class_names = class_names

    def __getitem__(self, index: int) -> dict[str, Any]:
        item = super().__getitem__(index)
        item["label"] = torch.tensor(item["label"]).long()
        return item

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        pass
__init__(data_root, num_classes, split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, class_names=None, constant_scale=1, transform=None, no_data_replace=0, expand_temporal_dimension=False)

A generic Non-Geo dataset for classification.

Parameters:
  • data_root (Path) –

    Path to data root directory

  • num_classes (int) –

    Number of classes in the dataset

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset.

  • class_names (list[str], default: None ) –

    Class names. Defaults to None.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float, default: 0 ) –

    Replace nan values in input images with this value. Defaults to 0.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

Source code in terratorch/datasets/generic_scalar_label_dataset.py
def __init__(
    self,
    data_root: Path,
    num_classes: int,
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,  # noqa: FBT001, FBT002
    allow_substring_split_file: bool = True,  # noqa: FBT001, FBT002
    rgb_indices: list[str] | None = None,
    dataset_bands: list[HLSBands | int] | None = None,
    output_bands: list[HLSBands | int] | None = None,
    class_names: list[str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float = 0,
    expand_temporal_dimension: bool = False,  # noqa: FBT001, FBT002
) -> None:
    """A generic Non-Geo dataset for classification.

    Args:
        data_root (Path): Path to data root directory
        num_classes (int): Number of classes in the dataset
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
        class_names (list[str], optional): Class names. Defaults to None.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
    """
    super().__init__(
        data_root,
        split=split,
        ignore_split_file_extensions=ignore_split_file_extensions,
        allow_substring_split_file=allow_substring_split_file,
        rgb_indices=rgb_indices,
        dataset_bands=dataset_bands,
        output_bands=output_bands,
        constant_scale=constant_scale,
        transform=transform,
        no_data_replace=no_data_replace,
        expand_temporal_dimension=expand_temporal_dimension,
    )
    self.num_classes = num_classes
    self.class_names = class_names
GenericScalarLabelDataset

Bases: NonGeoDataset, ImageFolder, ABC

This is a generic dataset class to be used for instantiating datasets from arguments. Ideally, one would create a dataset class specific to a dataset.

Source code in terratorch/datasets/generic_scalar_label_dataset.py
class GenericScalarLabelDataset(NonGeoDataset, ImageFolder, ABC):
    """
    This is a generic dataset class to be used for instantiating datasets from arguments.
    Ideally, one would create a dataset class specific to a dataset.
    """

    def __init__(
        self,
        data_root: Path,
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,  # noqa: FBT001, FBT002
        allow_substring_split_file: bool = True,  # noqa: FBT001, FBT002
        rgb_indices: list[int] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float = 0,
        expand_temporal_dimension: bool = False,  # noqa: FBT001, FBT002
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This
                parameter gives identifiers to input channels (bands) so that they can then be refered to by
                output_bands. Can use the HLSBands enum, ints, int ranges, or strings. Defaults to None.
            output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the
                dataset as named by dataset_bands.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
        """
        self.split_file = split

        self.image_files = sorted(glob.glob(os.path.join(data_root, "**"), recursive=True))
        self.image_files = [f for f in self.image_files if not os.path.isdir(f)]
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.expand_temporal_dimension = expand_temporal_dimension
        if self.expand_temporal_dimension and output_bands is None:
            msg = "Please provide output_bands when expand_temporal_dimension is True"
            raise Exception(msg)
        if self.split_file is not None:
            with open(self.split_file) as f:
                split = f.readlines()
            valid_files = {rf"{substring.strip()}" for substring in split}
            self.image_files = filter_valid_files(
                self.image_files,
                valid_files=valid_files,
                ignore_extensions=ignore_split_file_extensions,
                allow_substring=allow_substring_split_file,
            )

            def is_valid_file(x):
                return x in self.image_files

        else:

            def is_valid_file(x):
                return True

        super().__init__(
            root=data_root, transform=None, target_transform=None, loader=rasterio_loader, is_valid_file=is_valid_file
        )

        self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

        self.dataset_bands = generate_bands_intervals(dataset_bands)
        self.output_bands = generate_bands_intervals(output_bands)

        if self.output_bands and not self.dataset_bands:
            msg = "If output bands provided, dataset_bands must also be provided"
            return Exception(msg)  # noqa: PLE0101

        # There is a special condition if the bands are defined as simple strings.
        if self.output_bands:
            if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
                msg = "Output bands must be a subset of dataset bands"
                raise Exception(msg)

            self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

        else:
            self.filter_indices = None
        # If no transform is given, apply only to transform to torch tensor
        self.transforms = transform if transform else default_transform
        # self.transform = transform if transform else ToTensorV2()

        import warnings

        import rasterio
        warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image, label = ImageFolder.__getitem__(self, index)
        if self.expand_temporal_dimension:
            image = rearrange(image, "h w (channels time) -> time h w channels", channels=len(self.output_bands))
        if self.filter_indices:
            image = image[..., self.filter_indices]

        image = image.astype(np.float32) * self.constant_scale

        if self.transforms:
            image = self.transforms(image=image)["image"]  # albumentations returns dict

        output = {
            "image": image,
            "label": label,  # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target)
            "filename": self.image_files[index]
        }

        return output

    def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None):
        if bands_intervals is None:
            return None
        bands = []
        for element in bands_intervals:
            # if its an interval
            if isinstance(element, tuple):
                if len(element) != 2:  # noqa: PLR2004
                    msg = "When defining an interval, a tuple of two integers should be passed,\
                    defining start and end indices inclusive"
                    raise Exception(msg)
                expanded_element = list(range(element[0], element[1] + 1))
                bands.extend(expanded_element)
            else:
                bands.append(element)
        return bands

    def _load_file(self, path) -> xr.DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        data = data.fillna(self.no_data_replace)
        return data
__init__(data_root, split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, constant_scale=1, transform=None, no_data_replace=0, expand_temporal_dimension=False)

Constructor

Parameters:
  • data_root (Path) –

    Path to data root directory

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None, default: None ) –

    Bands present in the dataset. This parameter gives identifiers to input channels (bands) so that they can then be refered to by output_bands. Can use the HLSBands enum, ints, int ranges, or strings. Defaults to None.

  • output_bands (list[HLSBands | int | tuple[int, int] | str] | None, default: None ) –

    Bands that should be output by the dataset as named by dataset_bands.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float, default: 0 ) –

    Replace nan values in input images with this value. Defaults to 0.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

Source code in terratorch/datasets/generic_scalar_label_dataset.py
def __init__(
    self,
    data_root: Path,
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,  # noqa: FBT001, FBT002
    allow_substring_split_file: bool = True,  # noqa: FBT001, FBT002
    rgb_indices: list[int] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float = 0,
    expand_temporal_dimension: bool = False,  # noqa: FBT001, FBT002
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This
            parameter gives identifiers to input channels (bands) so that they can then be refered to by
            output_bands. Can use the HLSBands enum, ints, int ranges, or strings. Defaults to None.
        output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the
            dataset as named by dataset_bands.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
    """
    self.split_file = split

    self.image_files = sorted(glob.glob(os.path.join(data_root, "**"), recursive=True))
    self.image_files = [f for f in self.image_files if not os.path.isdir(f)]
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.expand_temporal_dimension = expand_temporal_dimension
    if self.expand_temporal_dimension and output_bands is None:
        msg = "Please provide output_bands when expand_temporal_dimension is True"
        raise Exception(msg)
    if self.split_file is not None:
        with open(self.split_file) as f:
            split = f.readlines()
        valid_files = {rf"{substring.strip()}" for substring in split}
        self.image_files = filter_valid_files(
            self.image_files,
            valid_files=valid_files,
            ignore_extensions=ignore_split_file_extensions,
            allow_substring=allow_substring_split_file,
        )

        def is_valid_file(x):
            return x in self.image_files

    else:

        def is_valid_file(x):
            return True

    super().__init__(
        root=data_root, transform=None, target_transform=None, loader=rasterio_loader, is_valid_file=is_valid_file
    )

    self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

    self.dataset_bands = generate_bands_intervals(dataset_bands)
    self.output_bands = generate_bands_intervals(output_bands)

    if self.output_bands and not self.dataset_bands:
        msg = "If output bands provided, dataset_bands must also be provided"
        return Exception(msg)  # noqa: PLE0101

    # There is a special condition if the bands are defined as simple strings.
    if self.output_bands:
        if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
            msg = "Output bands must be a subset of dataset bands"
            raise Exception(msg)

        self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

    else:
        self.filter_indices = None
    # If no transform is given, apply only to transform to torch tensor
    self.transforms = transform if transform else default_transform
    # self.transform = transform if transform else ToTensorV2()

    import warnings

    import rasterio
    warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

Generic Data Modules

terratorch.datamodules.generic_pixel_wise_data_module

This module contains generic data modules for instantiation at runtime.

GenericNonGeoPixelwiseRegressionDataModule

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoPixelwiseRegressionDataset

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
class GenericNonGeoPixelwiseRegressionDataModule(NonGeoDataModule):
    """This is a generic datamodule class for instantiating data modules at runtime.
    Composes several
    [GenericNonGeoPixelwiseRegressionDataset][terratorch.datasets.GenericNonGeoPixelwiseRegressionDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        means: list[float] | str,
        stds: list[float] | str,
        predict_data_root: Path | None = None,
        img_grep: str | None = "*",
        label_grep: str | None = "*",
        train_label_data_root: Path | None = None,
        val_label_data_root: Path | None = None,
        test_label_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        drop_last: bool = True,
        pin_memory: bool = False,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            predict_data_root (Path): _description_
            img_grep (str): _description_
            label_grep (str): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            train_label_data_root (Path | None, optional): _description_. Defaults to None.
            val_label_data_root (Path | None, optional): _description_. Defaults to None.
            test_label_data_root (Path | None, optional): _description_. Defaults to None.
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
                Naming must match that of dataset_bands. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
                with this value at predict time.
                Defaults to None, which does not overwrite.
            predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
                with this value at predict time. Defaults to None, which does not overwrite.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
            pin_memory (bool): If ``True``, the data loader will copy Tensors
            into device/CUDA pinned memory before returning them. Defaults to False.

        """
        super().__init__(GenericNonGeoPixelwiseRegressionDataset, batch_size, num_workers, **kwargs)
        self.img_grep = img_grep
        self.label_grep = label_grep
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.drop_last = drop_last
        self.pin_memory = pin_memory
        self.expand_temporal_dimension = expand_temporal_dimension
        self.reduce_zero_label = reduce_zero_label

        self.train_label_data_root = train_label_data_root
        self.val_label_data_root = val_label_data_root
        self.test_label_data_root = test_label_data_root

        self.constant_scale = constant_scale

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )
        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.train_label_data_root,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.val_label_data_root,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.test_label_data_root,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                image_grep=self.img_grep,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.predict_output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        batch_size = check_dataset_stackability(dataset, batch_size)

        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
            pin_memory=self.pin_memory,
        )
__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, means, stds, predict_data_root=None, img_grep='*', label_grep='*', train_label_data_root=None, val_label_data_root=None, test_label_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, output_bands=None, predict_dataset_bands=None, predict_output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, reduce_zero_label=False, no_data_replace=None, no_label_replace=None, drop_last=True, pin_memory=False, **kwargs)

Constructor

Parameters:
  • batch_size (int) –

    description

  • num_workers (int) –

    description

  • train_data_root (Path) –

    description

  • val_data_root (Path) –

    description

  • test_data_root (Path) –

    description

  • predict_data_root (Path, default: None ) –

    description

  • img_grep (str, default: '*' ) –

    description

  • label_grep (str, default: '*' ) –

    description

  • means (list[float]) –

    description

  • stds (list[float]) –

    description

  • train_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • val_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • test_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • train_split (Path | None, default: None ) –

    description. Defaults to None.

  • val_split (Path | None, default: None ) –

    description. Defaults to None.

  • test_split (Path | None, default: None ) –

    description. Defaults to None.

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset. Defaults to None.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset. Naming must match that of dataset_bands. Defaults to None.

  • predict_dataset_bands (list[HLSBands | int] | None, default: None ) –

    Overwrites dataset_bands with this value at predict time. Defaults to None, which does not overwrite.

  • predict_output_bands (list[HLSBands | int] | None, default: None ) –

    Overwrites output_bands with this value at predict time. Defaults to None, which does not overwrite.

  • constant_scale (float, default: 1 ) –

    description. Defaults to 1.

  • rgb_indices (list[int] | None, default: None ) –

    description. Defaults to None.

  • train_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • val_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • test_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to None.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

  • drop_last (bool, default: True ) –

    Drop the last batch if it is not complete. Defaults to True.

  • pin_memory (bool, default: False ) –

    If True, the data loader will copy Tensors

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    means: list[float] | str,
    stds: list[float] | str,
    predict_data_root: Path | None = None,
    img_grep: str | None = "*",
    label_grep: str | None = "*",
    train_label_data_root: Path | None = None,
    val_label_data_root: Path | None = None,
    test_label_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    drop_last: bool = True,
    pin_memory: bool = False,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        predict_data_root (Path): _description_
        img_grep (str): _description_
        label_grep (str): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        train_label_data_root (Path | None, optional): _description_. Defaults to None.
        val_label_data_root (Path | None, optional): _description_. Defaults to None.
        test_label_data_root (Path | None, optional): _description_. Defaults to None.
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            Naming must match that of dataset_bands. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
            with this value at predict time.
            Defaults to None, which does not overwrite.
        predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
            with this value at predict time. Defaults to None, which does not overwrite.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        pin_memory (bool): If ``True``, the data loader will copy Tensors
        into device/CUDA pinned memory before returning them. Defaults to False.

    """
    super().__init__(GenericNonGeoPixelwiseRegressionDataset, batch_size, num_workers, **kwargs)
    self.img_grep = img_grep
    self.label_grep = label_grep
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.drop_last = drop_last
    self.pin_memory = pin_memory
    self.expand_temporal_dimension = expand_temporal_dimension
    self.reduce_zero_label = reduce_zero_label

    self.train_label_data_root = train_label_data_root
    self.val_label_data_root = val_label_data_root
    self.test_label_data_root = test_label_data_root

    self.constant_scale = constant_scale

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )
    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
GenericNonGeoSegmentationDataModule

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoSegmentationDatasets

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
class GenericNonGeoSegmentationDataModule(NonGeoDataModule):
    """
    This is a generic datamodule class for instantiating data modules at runtime.
    Composes several [GenericNonGeoSegmentationDatasets][terratorch.datasets.GenericNonGeoSegmentationDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        img_grep: str,
        label_grep: str,
        means: list[float] | str,
        stds: list[float] | str,
        num_classes: int,
        predict_data_root: Path | None = None,
        train_label_data_root: Path | None = None,
        val_label_data_root: Path | None = None,
        test_label_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        drop_last: bool = True,
        pin_memory: bool = False,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            predict_data_root (Path): _description_
            img_grep (str): _description_
            label_grep (str): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            num_classes (int): _description_
            train_label_data_root (Path | None, optional): _description_. Defaults to None.
            val_label_data_root (Path | None, optional): _description_. Defaults to None.
            test_label_data_root (Path | None, optional): _description_. Defaults to None.
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
                Naming must match that of dataset_bands. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
                with this value at predict time.
                Defaults to None, which does not overwrite.
            predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
                with this value at predict time. Defaults to None, which does not overwrite.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
            pin_memory (bool): If ``True``, the data loader will copy Tensors
            into device/CUDA pinned memory before returning them. Defaults to False.
        """
        super().__init__(GenericNonGeoSegmentationDataset, batch_size, num_workers, **kwargs)
        self.num_classes = num_classes
        self.img_grep = img_grep
        self.label_grep = label_grep
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.drop_last = drop_last
        self.pin_memory = pin_memory

        self.train_label_data_root = train_label_data_root
        self.val_label_data_root = val_label_data_root
        self.test_label_data_root = test_label_data_root

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices
        self.expand_temporal_dimension = expand_temporal_dimension
        self.reduce_zero_label = reduce_zero_label

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )
        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)

        # self.aug = Normalize(means, stds)
        # self.collate_fn = collate_fn_list_dicts

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.train_label_data_root,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.val_label_data_root,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.test_label_data_root,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                self.num_classes,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.predict_output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        batch_size = check_dataset_stackability(dataset, batch_size)

        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
            pin_memory=self.pin_memory,
        )
__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, img_grep, label_grep, means, stds, num_classes, predict_data_root=None, train_label_data_root=None, val_label_data_root=None, test_label_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, output_bands=None, predict_dataset_bands=None, predict_output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, reduce_zero_label=False, no_data_replace=None, no_label_replace=None, drop_last=True, pin_memory=False, **kwargs)

Constructor

Parameters:
  • batch_size (int) –

    description

  • num_workers (int) –

    description

  • train_data_root (Path) –

    description

  • val_data_root (Path) –

    description

  • test_data_root (Path) –

    description

  • predict_data_root (Path, default: None ) –

    description

  • img_grep (str) –

    description

  • label_grep (str) –

    description

  • means (list[float]) –

    description

  • stds (list[float]) –

    description

  • num_classes (int) –

    description

  • train_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • val_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • test_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • train_split (Path | None, default: None ) –

    description. Defaults to None.

  • val_split (Path | None, default: None ) –

    description. Defaults to None.

  • test_split (Path | None, default: None ) –

    description. Defaults to None.

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset. Defaults to None.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset. Naming must match that of dataset_bands. Defaults to None.

  • predict_dataset_bands (list[HLSBands | int] | None, default: None ) –

    Overwrites dataset_bands with this value at predict time. Defaults to None, which does not overwrite.

  • predict_output_bands (list[HLSBands | int] | None, default: None ) –

    Overwrites output_bands with this value at predict time. Defaults to None, which does not overwrite.

  • constant_scale (float, default: 1 ) –

    description. Defaults to 1.

  • rgb_indices (list[int] | None, default: None ) –

    description. Defaults to None.

  • train_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • val_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • test_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to None.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

  • drop_last (bool, default: True ) –

    Drop the last batch if it is not complete. Defaults to True.

  • pin_memory (bool, default: False ) –

    If True, the data loader will copy Tensors

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    img_grep: str,
    label_grep: str,
    means: list[float] | str,
    stds: list[float] | str,
    num_classes: int,
    predict_data_root: Path | None = None,
    train_label_data_root: Path | None = None,
    val_label_data_root: Path | None = None,
    test_label_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    drop_last: bool = True,
    pin_memory: bool = False,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        predict_data_root (Path): _description_
        img_grep (str): _description_
        label_grep (str): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        num_classes (int): _description_
        train_label_data_root (Path | None, optional): _description_. Defaults to None.
        val_label_data_root (Path | None, optional): _description_. Defaults to None.
        test_label_data_root (Path | None, optional): _description_. Defaults to None.
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            Naming must match that of dataset_bands. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
            with this value at predict time.
            Defaults to None, which does not overwrite.
        predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
            with this value at predict time. Defaults to None, which does not overwrite.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        pin_memory (bool): If ``True``, the data loader will copy Tensors
        into device/CUDA pinned memory before returning them. Defaults to False.
    """
    super().__init__(GenericNonGeoSegmentationDataset, batch_size, num_workers, **kwargs)
    self.num_classes = num_classes
    self.img_grep = img_grep
    self.label_grep = label_grep
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.drop_last = drop_last
    self.pin_memory = pin_memory

    self.train_label_data_root = train_label_data_root
    self.val_label_data_root = val_label_data_root
    self.test_label_data_root = test_label_data_root

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices
    self.expand_temporal_dimension = expand_temporal_dimension
    self.reduce_zero_label = reduce_zero_label

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )
    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)

terratorch.datamodules.generic_scalar_label_data_module

This module contains generic data modules for instantiation at runtime.

GenericNonGeoClassificationDataModule

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoClassificationDatasets

Source code in terratorch/datamodules/generic_scalar_label_data_module.py
class GenericNonGeoClassificationDataModule(NonGeoDataModule):
    """
    This is a generic datamodule class for instantiating data modules at runtime.
    Composes several [GenericNonGeoClassificationDatasets][terratorch.datasets.GenericNonGeoClassificationDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        means: list[float] | str,
        stds: list[float] | str,
        num_classes: int,
        predict_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int] | None = None,
        predict_dataset_bands: list[HLSBands | int] | None = None,
        output_bands: list[HLSBands | int] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        no_data_replace: float = 0,
        drop_last: bool = True,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            num_classes (int): _description_
            predict_data_root (Path): _description_
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg".
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        """
        super().__init__(GenericNonGeoClassificationDataset, batch_size, num_workers, **kwargs)
        self.num_classes = num_classes
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.drop_last = drop_last

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices
        self.expand_temporal_dimension = expand_temporal_dimension

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )

        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)

        # self.aug = Normalize(means, stds)
        # self.collate_fn = collate_fn_list_dicts

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                self.num_classes,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                self.num_classes,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                self.num_classes,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                self.num_classes,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        batch_size = check_dataset_stackability(dataset, batch_size)

        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )
__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, means, stds, num_classes, predict_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, predict_dataset_bands=None, output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, no_data_replace=0, drop_last=True, **kwargs)

Constructor

Parameters:
  • batch_size (int) –

    description

  • num_workers (int) –

    description

  • train_data_root (Path) –

    description

  • val_data_root (Path) –

    description

  • test_data_root (Path) –

    description

  • means (list[float]) –

    description

  • stds (list[float]) –

    description

  • num_classes (int) –

    description

  • predict_data_root (Path, default: None ) –

    description

  • train_split (Path | None, default: None ) –

    description. Defaults to None.

  • val_split (Path | None, default: None ) –

    description. Defaults to None.

  • test_split (Path | None, default: None ) –

    description. Defaults to None.

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg".

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    description. Defaults to None.

  • predict_dataset_bands (list[HLSBands | int] | None, default: None ) –

    description. Defaults to None.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    description. Defaults to None.

  • constant_scale (float, default: 1 ) –

    description. Defaults to 1.

  • rgb_indices (list[int] | None, default: None ) –

    description. Defaults to None.

  • train_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • val_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • test_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float, default: 0 ) –

    Replace nan values in input images with this value. Defaults to 0.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • drop_last (bool, default: True ) –

    Drop the last batch if it is not complete. Defaults to True.

Source code in terratorch/datamodules/generic_scalar_label_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    means: list[float] | str,
    stds: list[float] | str,
    num_classes: int,
    predict_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int] | None = None,
    predict_dataset_bands: list[HLSBands | int] | None = None,
    output_bands: list[HLSBands | int] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    no_data_replace: float = 0,
    drop_last: bool = True,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        num_classes (int): _description_
        predict_data_root (Path): _description_
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg".
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
    """
    super().__init__(GenericNonGeoClassificationDataset, batch_size, num_workers, **kwargs)
    self.num_classes = num_classes
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.drop_last = drop_last

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices
    self.expand_temporal_dimension = expand_temporal_dimension

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )

    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)

Custom datasets and data modules

Below is a documented example of how a custom dataset and data module class can be implemented.

terratorch.datasets.fire_scars

FireScarsHLS

Bases: RasterDataset

RasterDataset implementation for fire scars input images.

Source code in terratorch/datasets/fire_scars.py
class FireScarsHLS(RasterDataset):
    """RasterDataset implementation for fire scars input images."""

    filename_glob = "subsetted*_merged.tif"
    filename_regex = r"subsetted_512x512_HLS\..30\..{6}\.(?P<date>[0-9]*)\.v1.4_merged.tif"
    date_format = "%Y%j"
    is_image = True
    separate_files = False
    all_bands = dataclasses.field(default_factory=["B02", "B03", "B04", "B8A", "B11", "B12"])
    rgb_bands = dataclasses.field(default_factory=["B04", "B03", "B02"])

FireScarsNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for fire scars.

Source code in terratorch/datasets/fire_scars.py
class FireScarsNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for fire scars."""
    all_band_names = (
        "BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    num_classes = 2
    splits = {"train": "training", "val": "validation"}   # Only train and val splits available

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (str): Path to the data root directory.
            bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the corresponding data module,
                should not include normalization. Defaults to None, which applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value.
                If None, does no replacement. Defaults to 0.
            no_label_replace (int | None): Replace nan values in label with this value.
                If none, does no replacement. Defaults to -1.
            use_metadata (bool): whether to return metadata info (time and location).
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {self.splits}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
        self.data_root = Path(data_root)

        input_dir = self.data_root / split_name
        self.image_files = sorted(glob.glob(os.path.join(input_dir, "*_merged.tif")))
        self.segmentation_mask_files = sorted(glob.glob(os.path.join(input_dir, "*.mask.tif")))

        self.use_metadata = use_metadata
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace

        # If no transform is given, apply only to transform to torch tensor
        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def _get_date(self, index: int) -> torch.Tensor:
        file_name = self.image_files[index]
        base_filename = os.path.basename(file_name)

        filename_regex = r"subsetted_512x512_HLS\.S30\.T[0-9A-Z]{5}\.(?P<date>[0-9]+)\.v1\.4_merged\.tif"
        match = re.match(filename_regex, base_filename)
        date_str = match.group("date")
        year = int(date_str[:4])
        julian_day = int(date_str[4:])

        return torch.tensor([[year, julian_day]], dtype=torch.float32)

    def _get_coords(self, image: DataArray) -> torch.Tensor:
        px = image.x.shape[0] // 2
        py = image.y.shape[0] // 2

        # get center point to reproject to lat/lon
        point = image.isel(band=0, x=slice(px, px + 1), y=slice(py, py + 1))
        point = point.rio.reproject("epsg:4326")

        lat_lon = np.asarray([point.y[0], point.x[0]])

        return torch.tensor(lat_lon, dtype=torch.float32)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace)

        location_coords, temporal_coords = None, None
        if self.use_metadata:
            location_coords = self._get_coords(image)
            temporal_coords = self._get_date(index)

        # to channels last
        image = image.to_numpy()
        image = np.moveaxis(image, 0, -1)

        # filter bands
        image = image[..., self.band_indices]

        output = {
            "image": image.astype(np.float32),
            "mask": self._load_file(
                self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[0],
        }
        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        if self.use_metadata:
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        return output

    def _load_file(self, path: Path, nan_replace: int | float | None = None) -> DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            suptitle: optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample
        """
        num_images = 4

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        # RGB -> channels-last
        image = sample["image"][rgb_indices, ...].permute(1, 2, 0).numpy()
        mask = sample["mask"].numpy()

        image = clip_image_percentile(image)

        if "prediction" in sample:
            prediction = sample["prediction"]
            num_images += 1
        else:
            prediction = None

        fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")

        ax[0].axis("off")

        norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
        ax[1].axis("off")
        ax[1].title.set_text("Image")
        ax[1].imshow(image)

        ax[2].axis("off")
        ax[2].title.set_text("Ground Truth Mask")
        ax[2].imshow(mask, cmap="jet", norm=norm)

        ax[3].axis("off")
        ax[3].title.set_text("GT Mask on Image")
        ax[3].imshow(image)
        ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

        if "prediction" in sample:
            ax[4].title.set_text("Predicted Mask")
            ax[4].imshow(prediction, cmap="jet", norm=norm)

        cmap = plt.get_cmap("jet")
        legend_data = [[i, cmap(norm(i)), str(i)] for i in range(self.num_classes)]
        handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
        labels = [n for k, c, n in legend_data]
        ax[0].legend(handles, labels, loc="center")
        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, no_data_replace=0, no_label_replace=-1, use_metadata=False)

Constructor

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • bands (list[str], default: BAND_SETS['all'] ) –

    Bands that should be output by the dataset. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the corresponding data module, should not include normalization. Defaults to None, which applies ToTensorV2().

  • no_data_replace (float | None, default: 0 ) –

    Replace nan values in input images with this value. If None, does no replacement. Defaults to 0.

  • no_label_replace (int | None, default: -1 ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to -1.

  • use_metadata (bool, default: False ) –

    whether to return metadata info (time and location).

Source code in terratorch/datasets/fire_scars.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    no_data_replace: float | None = 0,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (str): Path to the data root directory.
        bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the corresponding data module,
            should not include normalization. Defaults to None, which applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value.
            If None, does no replacement. Defaults to 0.
        no_label_replace (int | None): Replace nan values in label with this value.
            If none, does no replacement. Defaults to -1.
        use_metadata (bool): whether to return metadata info (time and location).
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {self.splits}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
    self.data_root = Path(data_root)

    input_dir = self.data_root / split_name
    self.image_files = sorted(glob.glob(os.path.join(input_dir, "*_merged.tif")))
    self.segmentation_mask_files = sorted(glob.glob(os.path.join(input_dir, "*.mask.tif")))

    self.use_metadata = use_metadata
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace

    # If no transform is given, apply only to transform to torch tensor
    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    a sample returned by :meth:__getitem__

  • suptitle (str | None, default: None ) –

    optional string to use as a suptitle

Returns:
  • Figure

    a matplotlib Figure with the rendered sample

Source code in terratorch/datasets/fire_scars.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        suptitle: optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample
    """
    num_images = 4

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    # RGB -> channels-last
    image = sample["image"][rgb_indices, ...].permute(1, 2, 0).numpy()
    mask = sample["mask"].numpy()

    image = clip_image_percentile(image)

    if "prediction" in sample:
        prediction = sample["prediction"]
        num_images += 1
    else:
        prediction = None

    fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")

    ax[0].axis("off")

    norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
    ax[1].axis("off")
    ax[1].title.set_text("Image")
    ax[1].imshow(image)

    ax[2].axis("off")
    ax[2].title.set_text("Ground Truth Mask")
    ax[2].imshow(mask, cmap="jet", norm=norm)

    ax[3].axis("off")
    ax[3].title.set_text("GT Mask on Image")
    ax[3].imshow(image)
    ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

    if "prediction" in sample:
        ax[4].title.set_text("Predicted Mask")
        ax[4].imshow(prediction, cmap="jet", norm=norm)

    cmap = plt.get_cmap("jet")
    legend_data = [[i, cmap(norm(i)), str(i)] for i in range(self.num_classes)]
    handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
    labels = [n for k, c, n in legend_data]
    ax[0].legend(handles, labels, loc="center")
    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig

FireScarsSegmentationMask

Bases: RasterDataset

RasterDataset implementation for fire scars segmentation mask. Can be easily merged with input images using the & operator.

Source code in terratorch/datasets/fire_scars.py
class FireScarsSegmentationMask(RasterDataset):
    """RasterDataset implementation for fire scars segmentation mask.
    Can be easily merged with input images using the & operator.
    """

    filename_glob = "subsetted*.mask.tif"
    filename_regex = r"subsetted_512x512_HLS\..30\..{6}\.(?P<date>[0-9]*)\.v1.4.mask.tif"
    date_format = "%Y%j"
    is_image = False
    separate_files = False

terratorch.datamodules.fire_scars

FireScarsDataModule

Bases: GeoDataModule

Geo Fire Scars data module implementation that merges input data with ground truth segmentation masks.

Source code in terratorch/datamodules/fire_scars.py
class FireScarsDataModule(GeoDataModule):
    """Geo Fire Scars data module implementation that merges input data with ground truth segmentation masks."""

    def __init__(self, data_root: str, **kwargs: Any) -> None:
        super().__init__(FireScarsSegmentationMask, 4, 224, 100, 0, **kwargs)
        means = list(MEANS.values())
        stds = list(STDS.values())
        self.train_aug = AugmentationSequential(K.RandomCrop(224, 224), K.Normalize(means, stds))
        self.aug = AugmentationSequential(K.Normalize(means, stds))
        self.data_root = data_root

    def setup(self, stage: str) -> None:
        self.images = FireScarsHLS(
            os.path.join(self.data_root, "training/")
        )
        self.labels = FireScarsSegmentationMask(
            os.path.join(self.data_root, "training/")
        )
        self.dataset = self.images & self.labels
        self.train_aug = AugmentationSequential(K.RandomCrop(224, 224), K.normalize())

        self.images_test = FireScarsHLS(
            os.path.join(self.data_root, "validation/")
        )
        self.labels_test = FireScarsSegmentationMask(
            os.path.join(self.data_root, "validation/")
        )
        self.val_dataset = self.images_test & self.labels_test

        if stage in ["fit"]:
            self.train_batch_sampler = RandomBatchGeoSampler(self.dataset, self.patch_size, self.batch_size, None)
        if stage in ["fit", "validate"]:
            self.val_sampler = GridGeoSampler(self.val_dataset, self.patch_size, self.patch_size)
        if stage in ["test"]:
            self.test_sampler = GridGeoSampler(self.val_dataset, self.patch_size, self.patch_size)

FireScarsNonGeoDataModule

Bases: NonGeoDataModule

NonGeo datamodule implementation for Fire Scars

Source code in terratorch/datamodules/fire_scars.py
class FireScarsNonGeoDataModule(NonGeoDataModule):
    """NonGeo datamodule implementation for Fire Scars"""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = FireScarsNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        drop_last: bool = True,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(FireScarsNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        means = [MEANS[b] for b in bands]
        stds = [STDS[b] for b in bands]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=["image"])
        self.drop_last = drop_last
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )