Data Processing

We rely on TorchGeo for the implementation of datasets and data modules.

Check out the TorchGeo tutorials on datasets for more in depth information.

In general, it is reccomended you create a TorchGeo dataset specifically for your dataset. This gives you complete control and flexibility on how data is loaded, what transforms are done over it, and even how it is plotted if you log with tools like TensorBoard.

TorchGeo provides GeoDataset and NonGeoDataset.

  • If your data is already nicely tiled and ready for consumption by a neural network, you can inherit from NonGeoDataset. This is essentially a wrapper of a regular torch dataset.
  • If your data consists of large GeoTiffs you would like to sample from during training, you can leverage the powerful GeoDataset from torch. This will automatically align your input data and labels and enable a variety of geo-aware samplers.

Using Datasets already implemented in TorchGeo

Using existing TorchGeo DataModules is very easy! Just plug them in! For instance, to use the EuroSATDataModule, in your config file, set the data as:

data:
  class_path: torchgeo.datamodules.EuroSATDataModule
  init_args:
    batch_size: 32
    num_workers: 8
  dict_kwargs:
    root: /dccstor/geofm-pre/EuroSat
    download: True
    bands:
      - B02
      - B03
      - B04
      - B08A
      - B09
      - B10
Modifying each parameter as you see fit.

You can also do this outside of config files! Simply instantiate the data module as normal and plug it in.

Warning

To define transforms to be passed to DataModules from TorchGeo from config files, you must use the following format:

data:
class_path: terratorch.datamodules.TorchNonGeoDataModule
init_args:
  cls: torchgeo.datamodules.EuroSATDataModule
  transforms:
    - class_path: albumentations.augmentations.geometric.resize.Resize
      init_args:
        height: 224
        width: 224
    - class_path: ToTensorV2
Note the class_path is TorchNonGeoDataModule and the class to be used is passed through cls (there is also a TorchGeoDataModule for geo modules). This has to be done as the transforms argument is passed through **kwargs in TorchGeo, making it difficult to instantiate with LightningCLI. See more details below.

terratorch.datamodules.torchgeo_data_module

Ugly proxy objects so parsing config file works with transforms.

These are necessary since, for LightningCLI to instantiate arguments as objects from the config, they must have type annotations

In TorchGeo, transforms is passed in **kwargs, so it has no type annotations! To get around that, we create these wrappers that have transforms type annotated. They create the transforms and forward all method and attribute calls to the original TorchGeo datamodule.

Additionally, TorchGeo datasets pass the data to the transforms callable as a dict, and as a tensor.

Albumentations expects this data not as a dict but as different key-value arguments, and as numpy. We handle that conversion here.

TorchGeoDataModule

Bases: GeoDataModule

Proxy object for using Geo data modules defined by TorchGeo.

Allows for transforms to be defined and passed using config files. The only reason this class exists is so that we can annotate the transforms argument with a type. This is required for lightningcli and config files. As such, all getattr and setattr will be redirected to the underlying class.

Source code in terratorch/datamodules/torchgeo_data_module.py
class TorchGeoDataModule(GeoDataModule):
    """Proxy object for using Geo data modules defined by TorchGeo.

    Allows for transforms to be defined and passed using config files.
    The only reason this class exists is so that we can annotate the transforms argument with a type.
    This is required for lightningcli and config files.
    As such, all getattr and setattr will be redirected to the underlying class.
    """

    def __init__(
        self,
        cls: type[GeoDataModule],
        batch_size: int | None = None,
        num_workers: int = 0,
        transforms: None | list[BasicTransform] = None,
        **kwargs: Any,
    ):
        """Constructor

        Args:
            cls (type[GeoDataModule]): TorchGeo DataModule class to be instantiated
            batch_size (int | None, optional): batch_size. Defaults to None.
            num_workers (int, optional): num_workers. Defaults to 0.
            transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
                Should enc with ToTensorV2. Defaults to None.
            **kwargs (Any): Arguments passed to instantiate `cls`.
        """
        if batch_size is not None:
            kwargs["batch_size"] = batch_size
        if transforms is not None:
            transforms_as_callable = albumentations_to_callable_with_dict(transforms)
            kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
        # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
        self._proxy = cls(num_workers=num_workers, **kwargs)
        super().__init__(self._proxy.dataset_class)  # dummy arg

    @property
    def collate_fn(self):
        return self._proxy.collate_fn

    @collate_fn.setter
    def collate_fn(self, value):
        self._proxy.collate_fn = value

    @property
    def patch_size(self):
        return self._proxy.patch_size

    @property
    def length(self):
        return self._proxy.length

    def setup(self, stage: str):
        return self._proxy.setup(stage)

    def train_dataloader(self):
        return self._proxy.train_dataloader()

    def val_dataloader(self):
        return self._proxy.val_dataloader()

    def test_dataloader(self):
        return self._proxy.test_dataloader()

    def predict_dataloader(self):
        return self._proxy.predict_dataloader()

    def transfer_batch_to_device(self, batch, device, dataloader_idx):
        return self._proxy.predict_dataloader(batch, device, dataloader_idx)

__init__(cls, batch_size=None, num_workers=0, transforms=None, **kwargs)

Constructor

Parameters:
  • cls (type[GeoDataModule]) –

    TorchGeo DataModule class to be instantiated

  • batch_size (int | None, default: None ) –

    batch_size. Defaults to None.

  • num_workers (int, default: 0 ) –

    num_workers. Defaults to 0.

  • transforms (None | list[BasicTransform], default: None ) –

    List of Albumentations Transforms. Should enc with ToTensorV2. Defaults to None.

  • **kwargs (Any, default: {} ) –

    Arguments passed to instantiate cls.

Source code in terratorch/datamodules/torchgeo_data_module.py
def __init__(
    self,
    cls: type[GeoDataModule],
    batch_size: int | None = None,
    num_workers: int = 0,
    transforms: None | list[BasicTransform] = None,
    **kwargs: Any,
):
    """Constructor

    Args:
        cls (type[GeoDataModule]): TorchGeo DataModule class to be instantiated
        batch_size (int | None, optional): batch_size. Defaults to None.
        num_workers (int, optional): num_workers. Defaults to 0.
        transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
            Should enc with ToTensorV2. Defaults to None.
        **kwargs (Any): Arguments passed to instantiate `cls`.
    """
    if batch_size is not None:
        kwargs["batch_size"] = batch_size
    if transforms is not None:
        transforms_as_callable = albumentations_to_callable_with_dict(transforms)
        kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
    # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
    self._proxy = cls(num_workers=num_workers, **kwargs)
    super().__init__(self._proxy.dataset_class)  # dummy arg

TorchNonGeoDataModule

Bases: NonGeoDataModule

Proxy object for using NonGeo data modules defined by TorchGeo.

Allows for transforms to be defined and passed using config files. The only reason this class exists is so that we can annotate the transforms argument with a type. This is required for lightningcli and config files. As such, all getattr and setattr will be redirected to the underlying class.

Source code in terratorch/datamodules/torchgeo_data_module.py
class TorchNonGeoDataModule(NonGeoDataModule):
    """Proxy object for using NonGeo data modules defined by TorchGeo.

    Allows for transforms to be defined and passed using config files.
    The only reason this class exists is so that we can annotate the transforms argument with a type.
    This is required for lightningcli and config files.
    As such, all getattr and setattr will be redirected to the underlying class.
    """

    def __init__(
        self,
        cls: type[NonGeoDataModule],
        batch_size: int | None = None,
        num_workers: int = 0,
        transforms: None | list[BasicTransform] = None,
        **kwargs: Any,
    ):
        """Constructor

        Args:
            cls (type[NonGeoDataModule]): TorchGeo DataModule class to be instantiated
            batch_size (int | None, optional): batch_size. Defaults to None.
            num_workers (int, optional): num_workers. Defaults to 0.
            transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
                Should enc with ToTensorV2. Defaults to None.
            **kwargs (Any): Arguments passed to instantiate `cls`.
        """
        if batch_size is not None:
            kwargs["batch_size"] = batch_size
        if transforms is not None:
            transforms_as_callable = albumentations_to_callable_with_dict(transforms)
            kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
        # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
        self._proxy = cls(num_workers=num_workers, **kwargs)
        super().__init__(self._proxy.dataset_class)  # dummy arg

    @property
    def collate_fn(self):
        return self._proxy.collate_fn

    @collate_fn.setter
    def collate_fn(self, value):
        self._proxy.collate_fn = value

    def setup(self, stage: str):
        return self._proxy.setup(stage)

    def train_dataloader(self):
        return self._proxy.train_dataloader()

    def val_dataloader(self):
        return self._proxy.val_dataloader()

    def test_dataloader(self):
        return self._proxy.test_dataloader()

    def predict_dataloader(self):
        return self._proxy.predict_dataloader()

__init__(cls, batch_size=None, num_workers=0, transforms=None, **kwargs)

Constructor

Parameters:
  • cls (type[NonGeoDataModule]) –

    TorchGeo DataModule class to be instantiated

  • batch_size (int | None, default: None ) –

    batch_size. Defaults to None.

  • num_workers (int, default: 0 ) –

    num_workers. Defaults to 0.

  • transforms (None | list[BasicTransform], default: None ) –

    List of Albumentations Transforms. Should enc with ToTensorV2. Defaults to None.

  • **kwargs (Any, default: {} ) –

    Arguments passed to instantiate cls.

Source code in terratorch/datamodules/torchgeo_data_module.py
def __init__(
    self,
    cls: type[NonGeoDataModule],
    batch_size: int | None = None,
    num_workers: int = 0,
    transforms: None | list[BasicTransform] = None,
    **kwargs: Any,
):
    """Constructor

    Args:
        cls (type[NonGeoDataModule]): TorchGeo DataModule class to be instantiated
        batch_size (int | None, optional): batch_size. Defaults to None.
        num_workers (int, optional): num_workers. Defaults to 0.
        transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
            Should enc with ToTensorV2. Defaults to None.
        **kwargs (Any): Arguments passed to instantiate `cls`.
    """
    if batch_size is not None:
        kwargs["batch_size"] = batch_size
    if transforms is not None:
        transforms_as_callable = albumentations_to_callable_with_dict(transforms)
        kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
    # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
    self._proxy = cls(num_workers=num_workers, **kwargs)
    super().__init__(self._proxy.dataset_class)  # dummy arg

Generic datasets and data modules

For the NonGeoDataset case, we also provide "generic" datasets and datamodules. These can be used when you would like to load data from given directories, in a style similar to the MMLab libraries.

Generic Datasets

terratorch.datasets.generic_pixel_wise_dataset

Module containing generic dataset classes

GenericNonGeoPixelwiseRegressionDataset

Bases: GenericPixelWiseDataset

GenericNonGeoPixelwiseRegressionDataset

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
class GenericNonGeoPixelwiseRegressionDataset(GenericPixelWiseDataset):
    """GenericNonGeoPixelwiseRegressionDataset"""

    def __init__(
        self,
        data_root: Path,
        label_data_root: Path | None = None,
        image_grep: str | None = "*",
        label_grep: str | None = "*",
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[int] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            label_data_root (Path, optional): Path to data root directory with labels.
                If not specified, will use the same as for images.
            image_grep (str, optional): Regular expression appended to data_root to find input images.
                Defaults to "*".
            label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
                Defaults to "*".
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
        """
        super().__init__(
            data_root,
            label_data_root=label_data_root,
            image_grep=image_grep,
            label_grep=label_grep,
            split=split,
            ignore_split_file_extensions=ignore_split_file_extensions,
            allow_substring_split_file=allow_substring_split_file,
            rgb_indices=rgb_indices,
            dataset_bands=dataset_bands,
            output_bands=output_bands,
            constant_scale=constant_scale,
            transform=transform,
            no_data_replace=no_data_replace,
            no_label_replace=no_label_replace,
            expand_temporal_dimension=expand_temporal_dimension,
            reduce_zero_label=reduce_zero_label,
        )

    def __getitem__(self, index: int) -> dict[str, Any]:
        item = super().__getitem__(index)
        item["mask"] = item["mask"].float()
        return item

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, Tensor]): a sample returned by :meth:`__getitem__`
            suptitle (str|None): optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample

        .. versionadded:: 0.2
        """
        image = sample["image"]
        if len(image.shape) == 5:
            return
        if isinstance(image, Tensor):
            image = image.numpy()
        image = image.take(self.rgb_indices, axis=0)
        image = np.transpose(image, (1, 2, 0))
        image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
        image = np.clip(image, 0, 1)

        label_mask = sample["mask"]
        if isinstance(label_mask, Tensor):
            label_mask = label_mask.numpy()

        showing_predictions = "prediction" in sample
        if showing_predictions:
            prediction_mask = sample["prediction"]
            if isinstance(prediction_mask, Tensor):
                prediction_mask = prediction_mask.numpy()

        return self._plot_sample(
            image,
            label_mask,
            prediction=prediction_mask if showing_predictions else None,
            suptitle=suptitle,
        )

    @staticmethod
    def _plot_sample(image, label, prediction=None, suptitle=None):
        num_images = 4 if prediction is not None else 3
        fig, ax = plt.subplots(1, num_images, figsize=(12, 10), layout="compressed")

        norm = mpl.colors.Normalize(vmin=label.min(), vmax=label.max())
        ax[0].axis("off")
        ax[0].title.set_text("Image")
        ax[0].imshow(image)

        ax[1].axis("off")
        ax[1].title.set_text("Ground Truth Mask")
        ax[1].imshow(label, cmap="Greens", norm=norm)

        ax[2].axis("off")
        ax[2].title.set_text("GT Mask on Image")
        ax[2].imshow(image)
        ax[2].imshow(label, cmap="Greens", alpha=0.3, norm=norm)
        # ax[2].legend()

        if prediction is not None:
            ax[3].title.set_text("Predicted Mask")
            ax[3].imshow(prediction, cmap="Greens", norm=norm)

        if suptitle is not None:
            plt.suptitle(suptitle)
        return fig
__init__(data_root, label_data_root=None, image_grep='*', label_grep='*', split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, constant_scale=1, transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=False, reduce_zero_label=False)

Constructor

Parameters:
  • data_root (Path) –

    Path to data root directory

  • label_data_root (Path, default: None ) –

    Path to data root directory with labels. If not specified, will use the same as for images.

  • image_grep (str, default: '*' ) –

    Regular expression appended to data_root to find input images. Defaults to "*".

  • label_grep (str, default: '*' ) –

    Regular expression appended to data_root to find ground truth masks. Defaults to "*".

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to None.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def __init__(
    self,
    data_root: Path,
    label_data_root: Path | None = None,
    image_grep: str | None = "*",
    label_grep: str | None = "*",
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[int] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        label_data_root (Path, optional): Path to data root directory with labels.
            If not specified, will use the same as for images.
        image_grep (str, optional): Regular expression appended to data_root to find input images.
            Defaults to "*".
        label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
            Defaults to "*".
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
    """
    super().__init__(
        data_root,
        label_data_root=label_data_root,
        image_grep=image_grep,
        label_grep=label_grep,
        split=split,
        ignore_split_file_extensions=ignore_split_file_extensions,
        allow_substring_split_file=allow_substring_split_file,
        rgb_indices=rgb_indices,
        dataset_bands=dataset_bands,
        output_bands=output_bands,
        constant_scale=constant_scale,
        transform=transform,
        no_data_replace=no_data_replace,
        no_label_replace=no_label_replace,
        expand_temporal_dimension=expand_temporal_dimension,
        reduce_zero_label=reduce_zero_label,
    )
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    a sample returned by :meth:__getitem__

  • suptitle (str | None, default: None ) –

    optional string to use as a suptitle

Returns:
  • Figure

    a matplotlib Figure with the rendered sample

.. versionadded:: 0.2

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, Tensor]): a sample returned by :meth:`__getitem__`
        suptitle (str|None): optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample

    .. versionadded:: 0.2
    """
    image = sample["image"]
    if len(image.shape) == 5:
        return
    if isinstance(image, Tensor):
        image = image.numpy()
    image = image.take(self.rgb_indices, axis=0)
    image = np.transpose(image, (1, 2, 0))
    image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
    image = np.clip(image, 0, 1)

    label_mask = sample["mask"]
    if isinstance(label_mask, Tensor):
        label_mask = label_mask.numpy()

    showing_predictions = "prediction" in sample
    if showing_predictions:
        prediction_mask = sample["prediction"]
        if isinstance(prediction_mask, Tensor):
            prediction_mask = prediction_mask.numpy()

    return self._plot_sample(
        image,
        label_mask,
        prediction=prediction_mask if showing_predictions else None,
        suptitle=suptitle,
    )
GenericNonGeoSegmentationDataset

Bases: GenericPixelWiseDataset

GenericNonGeoSegmentationDataset

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
    """GenericNonGeoSegmentationDataset"""

    def __init__(
        self,
        data_root: Path,
        num_classes: int,
        label_data_root: Path | None = None,
        image_grep: str | None = "*",
        label_grep: str | None = "*",
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[str] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        class_names: list[str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            num_classes (int): Number of classes in the dataset
            label_data_root (Path, optional): Path to data root directory with labels.
                If not specified, will use the same as for images.
            image_grep (str, optional): Regular expression appended to data_root to find input images.
                Defaults to "*".
            label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
                Defaults to "*".
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            class_names (list[str], optional): Class names. Defaults to None.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
        """
        super().__init__(
            data_root,
            label_data_root=label_data_root,
            image_grep=image_grep,
            label_grep=label_grep,
            split=split,
            ignore_split_file_extensions=ignore_split_file_extensions,
            allow_substring_split_file=allow_substring_split_file,
            rgb_indices=rgb_indices,
            dataset_bands=dataset_bands,
            output_bands=output_bands,
            constant_scale=constant_scale,
            transform=transform,
            no_data_replace=no_data_replace,
            no_label_replace=no_label_replace,
            expand_temporal_dimension=expand_temporal_dimension,
            reduce_zero_label=reduce_zero_label,
        )
        self.num_classes = num_classes
        self.class_names = class_names

    def __getitem__(self, index: int) -> dict[str, Any]:
        item = super().__getitem__(index)
        item["mask"] = item["mask"].long()
        return item

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            suptitle: optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample

        .. versionadded:: 0.2
        """
        image = sample["image"]
        if len(image.shape) == 5:
            return
        if isinstance(image, Tensor):
            image = image.numpy()
        image = image.take(self.rgb_indices, axis=0)
        image = np.transpose(image, (1, 2, 0))
        image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
        image = np.clip(image, 0, 1)

        label_mask = sample["mask"]
        if isinstance(label_mask, Tensor):
            label_mask = label_mask.numpy()

        showing_predictions = "prediction" in sample
        if showing_predictions:
            prediction_mask = sample["prediction"]
            if isinstance(prediction_mask, Tensor):
                prediction_mask = prediction_mask.numpy()

        return self._plot_sample(
            image,
            label_mask,
            self.num_classes,
            prediction=prediction_mask if showing_predictions else None,
            suptitle=suptitle,
            class_names=self.class_names,
        )

    @staticmethod
    def _plot_sample(image, label, num_classes, prediction=None, suptitle=None, class_names=None):
        num_images = 5 if prediction is not None else 4
        fig, ax = plt.subplots(1, num_images, figsize=(12, 10), layout="compressed")

        # for legend
        ax[0].axis("off")

        norm = mpl.colors.Normalize(vmin=0, vmax=num_classes - 1)
        ax[1].axis("off")
        ax[1].title.set_text("Image")
        ax[1].imshow(image)

        ax[2].axis("off")
        ax[2].title.set_text("Ground Truth Mask")
        ax[2].imshow(label, cmap="jet", norm=norm)

        ax[3].axis("off")
        ax[3].title.set_text("GT Mask on Image")
        ax[3].imshow(image)
        ax[3].imshow(label, cmap="jet", alpha=0.3, norm=norm)

        if prediction is not None:
            ax[4].title.set_text("Predicted Mask")
            ax[4].imshow(prediction, cmap="jet", norm=norm)

        cmap = plt.get_cmap("jet")
        legend_data = []
        for i, _ in enumerate(range(num_classes)):
            class_name = class_names[i] if class_names else str(i)
            data = [i, cmap(norm(i)), class_name]
            legend_data.append(data)
        handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
        labels = [n for k, c, n in legend_data]
        ax[0].legend(handles, labels, loc="center")
        if suptitle is not None:
            plt.suptitle(suptitle)
        return fig
__init__(data_root, num_classes, label_data_root=None, image_grep='*', label_grep='*', split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, class_names=None, constant_scale=1, transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=False, reduce_zero_label=False)

Constructor

Parameters:
  • data_root (Path) –

    Path to data root directory

  • num_classes (int) –

    Number of classes in the dataset

  • label_data_root (Path, default: None ) –

    Path to data root directory with labels. If not specified, will use the same as for images.

  • image_grep (str, default: '*' ) –

    Regular expression appended to data_root to find input images. Defaults to "*".

  • label_grep (str, default: '*' ) –

    Regular expression appended to data_root to find ground truth masks. Defaults to "*".

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset.

  • class_names (list[str], default: None ) –

    Class names. Defaults to None.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to None.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def __init__(
    self,
    data_root: Path,
    num_classes: int,
    label_data_root: Path | None = None,
    image_grep: str | None = "*",
    label_grep: str | None = "*",
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[str] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    class_names: list[str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        num_classes (int): Number of classes in the dataset
        label_data_root (Path, optional): Path to data root directory with labels.
            If not specified, will use the same as for images.
        image_grep (str, optional): Regular expression appended to data_root to find input images.
            Defaults to "*".
        label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
            Defaults to "*".
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
        class_names (list[str], optional): Class names. Defaults to None.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
    """
    super().__init__(
        data_root,
        label_data_root=label_data_root,
        image_grep=image_grep,
        label_grep=label_grep,
        split=split,
        ignore_split_file_extensions=ignore_split_file_extensions,
        allow_substring_split_file=allow_substring_split_file,
        rgb_indices=rgb_indices,
        dataset_bands=dataset_bands,
        output_bands=output_bands,
        constant_scale=constant_scale,
        transform=transform,
        no_data_replace=no_data_replace,
        no_label_replace=no_label_replace,
        expand_temporal_dimension=expand_temporal_dimension,
        reduce_zero_label=reduce_zero_label,
    )
    self.num_classes = num_classes
    self.class_names = class_names
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    a sample returned by :meth:__getitem__

  • suptitle (str | None, default: None ) –

    optional string to use as a suptitle

Returns:
  • Figure

    a matplotlib Figure with the rendered sample

.. versionadded:: 0.2

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        suptitle: optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample

    .. versionadded:: 0.2
    """
    image = sample["image"]
    if len(image.shape) == 5:
        return
    if isinstance(image, Tensor):
        image = image.numpy()
    image = image.take(self.rgb_indices, axis=0)
    image = np.transpose(image, (1, 2, 0))
    image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
    image = np.clip(image, 0, 1)

    label_mask = sample["mask"]
    if isinstance(label_mask, Tensor):
        label_mask = label_mask.numpy()

    showing_predictions = "prediction" in sample
    if showing_predictions:
        prediction_mask = sample["prediction"]
        if isinstance(prediction_mask, Tensor):
            prediction_mask = prediction_mask.numpy()

    return self._plot_sample(
        image,
        label_mask,
        self.num_classes,
        prediction=prediction_mask if showing_predictions else None,
        suptitle=suptitle,
        class_names=self.class_names,
    )
GenericPixelWiseDataset

Bases: NonGeoDataset, ABC

This is a generic dataset class to be used for instantiating datasets from arguments. Ideally, one would create a dataset class specific to a dataset.

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
class GenericPixelWiseDataset(NonGeoDataset, ABC):
    """
    This is a generic dataset class to be used for instantiating datasets from arguments.
    Ideally, one would create a dataset class specific to a dataset.
    """

    def __init__(
        self,
        data_root: Path,
        label_data_root: Path | None = None,
        image_grep: str | None = "*",
        label_grep: str | None = "*",
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[int] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            label_data_root (Path, optional): Path to data root directory with labels.
                If not specified, will use the same as for images.
            image_grep (str, optional): Regular expression appended to data_root to find input images.
                Defaults to "*".
            label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
                Defaults to "*".
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
            output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to -1.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
        """
        super().__init__()

        self.split_file = split

        label_data_root = label_data_root if label_data_root is not None else data_root
        self.image_files = sorted(glob.glob(os.path.join(data_root, image_grep)))
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_data_root, label_grep)))
        self.reduce_zero_label = reduce_zero_label
        self.expand_temporal_dimension = expand_temporal_dimension

        if self.expand_temporal_dimension and output_bands is None:
            msg = "Please provide output_bands when expand_temporal_dimension is True"
            raise Exception(msg)
        if self.split_file is not None:
            with open(self.split_file) as f:
                split = f.readlines()
            valid_files = {rf"{substring.strip()}" for substring in split}
            self.image_files = filter_valid_files(
                self.image_files,
                valid_files=valid_files,
                ignore_extensions=ignore_split_file_extensions,
                allow_substring=allow_substring_split_file,
            )
            self.segmentation_mask_files = filter_valid_files(
                self.segmentation_mask_files,
                valid_files=valid_files,
                ignore_extensions=ignore_split_file_extensions,
                allow_substring=allow_substring_split_file,
            )
        self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

        self.dataset_bands = generate_bands_intervals(dataset_bands)
        self.output_bands = generate_bands_intervals(output_bands)

        if self.output_bands and not self.dataset_bands:
            msg = "If output bands provided, dataset_bands must also be provided"
            return Exception(msg)  # noqa: PLE0101

        # There is a special condition if the bands are defined as simple strings.
        if self.output_bands:
            if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
                msg = "Output bands must be a subset of dataset bands"
                raise Exception(msg)

            self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

        else:
            self.filter_indices = None

        # If no transform is given, apply only to transform to torch tensor
        self.transform = transform if transform else default_transform
        # self.transform = transform if transform else ToTensorV2()

        import warnings

        import rasterio

        warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace).to_numpy()
        # to channels last
        if self.expand_temporal_dimension:
            image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.output_bands))
        image = np.moveaxis(image, 0, -1)

        if self.filter_indices:
            image = image[..., self.filter_indices]
        output = {
            "image": image.astype(np.float32) * self.constant_scale,
            "mask": self._load_file(self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[
                0
            ]
        }

        if self.reduce_zero_label:
            output["mask"] -= 1
        if self.transform:
            output = self.transform(**output)
        output["filename"] = self.image_files[index]

        return output

    def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data
__init__(data_root, label_data_root=None, image_grep='*', label_grep='*', split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, constant_scale=1, transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=False, reduce_zero_label=False)

Constructor

Parameters:
  • data_root (Path) –

    Path to data root directory

  • label_data_root (Path, default: None ) –

    Path to data root directory with labels. If not specified, will use the same as for images.

  • image_grep (str, default: '*' ) –

    Regular expression appended to data_root to find input images. Defaults to "*".

  • label_grep (str, default: '*' ) –

    Regular expression appended to data_root to find ground truth masks. Defaults to "*".

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None, default: None ) –

    Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.

  • output_bands (list[HLSBands | int | tuple[int, int] | str] | None, default: None ) –

    Bands that should be output by the dataset as named by dataset_bands.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to -1.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def __init__(
    self,
    data_root: Path,
    label_data_root: Path | None = None,
    image_grep: str | None = "*",
    label_grep: str | None = "*",
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[int] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        label_data_root (Path, optional): Path to data root directory with labels.
            If not specified, will use the same as for images.
        image_grep (str, optional): Regular expression appended to data_root to find input images.
            Defaults to "*".
        label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
            Defaults to "*".
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
        output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to -1.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
    """
    super().__init__()

    self.split_file = split

    label_data_root = label_data_root if label_data_root is not None else data_root
    self.image_files = sorted(glob.glob(os.path.join(data_root, image_grep)))
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_data_root, label_grep)))
    self.reduce_zero_label = reduce_zero_label
    self.expand_temporal_dimension = expand_temporal_dimension

    if self.expand_temporal_dimension and output_bands is None:
        msg = "Please provide output_bands when expand_temporal_dimension is True"
        raise Exception(msg)
    if self.split_file is not None:
        with open(self.split_file) as f:
            split = f.readlines()
        valid_files = {rf"{substring.strip()}" for substring in split}
        self.image_files = filter_valid_files(
            self.image_files,
            valid_files=valid_files,
            ignore_extensions=ignore_split_file_extensions,
            allow_substring=allow_substring_split_file,
        )
        self.segmentation_mask_files = filter_valid_files(
            self.segmentation_mask_files,
            valid_files=valid_files,
            ignore_extensions=ignore_split_file_extensions,
            allow_substring=allow_substring_split_file,
        )
    self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

    self.dataset_bands = generate_bands_intervals(dataset_bands)
    self.output_bands = generate_bands_intervals(output_bands)

    if self.output_bands and not self.dataset_bands:
        msg = "If output bands provided, dataset_bands must also be provided"
        return Exception(msg)  # noqa: PLE0101

    # There is a special condition if the bands are defined as simple strings.
    if self.output_bands:
        if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
            msg = "Output bands must be a subset of dataset bands"
            raise Exception(msg)

        self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

    else:
        self.filter_indices = None

    # If no transform is given, apply only to transform to torch tensor
    self.transform = transform if transform else default_transform
    # self.transform = transform if transform else ToTensorV2()

    import warnings

    import rasterio

    warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

terratorch.datasets.generic_scalar_label_dataset

Module containing generic dataset classes

GenericNonGeoClassificationDataset

Bases: GenericScalarLabelDataset

GenericNonGeoClassificationDataset

Source code in terratorch/datasets/generic_scalar_label_dataset.py
class GenericNonGeoClassificationDataset(GenericScalarLabelDataset):
    """GenericNonGeoClassificationDataset"""

    def __init__(
        self,
        data_root: Path,
        num_classes: int,
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,  # noqa: FBT001, FBT002
        allow_substring_split_file: bool = True,  # noqa: FBT001, FBT002
        rgb_indices: list[str] | None = None,
        dataset_bands: list[HLSBands | int] | None = None,
        output_bands: list[HLSBands | int] | None = None,
        class_names: list[str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float = 0,
        expand_temporal_dimension: bool = False,  # noqa: FBT001, FBT002
    ) -> None:
        """A generic Non-Geo dataset for classification.

        Args:
            data_root (Path): Path to data root directory
            num_classes (int): Number of classes in the dataset
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            class_names (list[str], optional): Class names. Defaults to None.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
        """
        super().__init__(
            data_root,
            split=split,
            ignore_split_file_extensions=ignore_split_file_extensions,
            allow_substring_split_file=allow_substring_split_file,
            rgb_indices=rgb_indices,
            dataset_bands=dataset_bands,
            output_bands=output_bands,
            constant_scale=constant_scale,
            transform=transform,
            no_data_replace=no_data_replace,
            expand_temporal_dimension=expand_temporal_dimension,
        )
        self.num_classes = num_classes
        self.class_names = class_names

    def __getitem__(self, index: int) -> dict[str, Any]:
        item = super().__getitem__(index)
        item["label"] = torch.tensor(item["label"]).long()
        return item

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        pass
__init__(data_root, num_classes, split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, class_names=None, constant_scale=1, transform=None, no_data_replace=0, expand_temporal_dimension=False)

A generic Non-Geo dataset for classification.

Parameters:
  • data_root (Path) –

    Path to data root directory

  • num_classes (int) –

    Number of classes in the dataset

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset.

  • class_names (list[str], default: None ) –

    Class names. Defaults to None.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float, default: 0 ) –

    Replace nan values in input images with this value. Defaults to 0.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

Source code in terratorch/datasets/generic_scalar_label_dataset.py
def __init__(
    self,
    data_root: Path,
    num_classes: int,
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,  # noqa: FBT001, FBT002
    allow_substring_split_file: bool = True,  # noqa: FBT001, FBT002
    rgb_indices: list[str] | None = None,
    dataset_bands: list[HLSBands | int] | None = None,
    output_bands: list[HLSBands | int] | None = None,
    class_names: list[str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float = 0,
    expand_temporal_dimension: bool = False,  # noqa: FBT001, FBT002
) -> None:
    """A generic Non-Geo dataset for classification.

    Args:
        data_root (Path): Path to data root directory
        num_classes (int): Number of classes in the dataset
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
        class_names (list[str], optional): Class names. Defaults to None.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
    """
    super().__init__(
        data_root,
        split=split,
        ignore_split_file_extensions=ignore_split_file_extensions,
        allow_substring_split_file=allow_substring_split_file,
        rgb_indices=rgb_indices,
        dataset_bands=dataset_bands,
        output_bands=output_bands,
        constant_scale=constant_scale,
        transform=transform,
        no_data_replace=no_data_replace,
        expand_temporal_dimension=expand_temporal_dimension,
    )
    self.num_classes = num_classes
    self.class_names = class_names
GenericScalarLabelDataset

Bases: NonGeoDataset, ImageFolder, ABC

This is a generic dataset class to be used for instantiating datasets from arguments. Ideally, one would create a dataset class specific to a dataset.

Source code in terratorch/datasets/generic_scalar_label_dataset.py
class GenericScalarLabelDataset(NonGeoDataset, ImageFolder, ABC):
    """
    This is a generic dataset class to be used for instantiating datasets from arguments.
    Ideally, one would create a dataset class specific to a dataset.
    """

    def __init__(
        self,
        data_root: Path,
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,  # noqa: FBT001, FBT002
        allow_substring_split_file: bool = True,  # noqa: FBT001, FBT002
        rgb_indices: list[int] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float = 0,
        expand_temporal_dimension: bool = False,  # noqa: FBT001, FBT002
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This
                parameter gives identifiers to input channels (bands) so that they can then be refered to by
                output_bands. Can use the HLSBands enum, ints, int ranges, or strings. Defaults to None.
            output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the
                dataset as named by dataset_bands.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
        """
        self.split_file = split

        self.image_files = sorted(glob.glob(os.path.join(data_root, "**"), recursive=True))
        self.image_files = [f for f in self.image_files if not os.path.isdir(f)]
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.expand_temporal_dimension = expand_temporal_dimension
        if self.expand_temporal_dimension and output_bands is None:
            msg = "Please provide output_bands when expand_temporal_dimension is True"
            raise Exception(msg)
        if self.split_file is not None:
            with open(self.split_file) as f:
                split = f.readlines()
            valid_files = {rf"{substring.strip()}" for substring in split}
            self.image_files = filter_valid_files(
                self.image_files,
                valid_files=valid_files,
                ignore_extensions=ignore_split_file_extensions,
                allow_substring=allow_substring_split_file,
            )

            def is_valid_file(x):
                return x in self.image_files

        else:

            def is_valid_file(x):
                return True

        super().__init__(
            root=data_root, transform=None, target_transform=None, loader=rasterio_loader, is_valid_file=is_valid_file
        )

        self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

        self.dataset_bands = generate_bands_intervals(dataset_bands)
        self.output_bands = generate_bands_intervals(output_bands)

        if self.output_bands and not self.dataset_bands:
            msg = "If output bands provided, dataset_bands must also be provided"
            return Exception(msg)  # noqa: PLE0101

        # There is a special condition if the bands are defined as simple strings.
        if self.output_bands:
            if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
                msg = "Output bands must be a subset of dataset bands"
                raise Exception(msg)

            self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

        else:
            self.filter_indices = None
        # If no transform is given, apply only to transform to torch tensor
        self.transforms = transform if transform else default_transform
        # self.transform = transform if transform else ToTensorV2()

        import warnings

        import rasterio
        warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image, label = ImageFolder.__getitem__(self, index)
        if self.expand_temporal_dimension:
            image = rearrange(image, "h w (channels time) -> time h w channels", channels=len(self.output_bands))
        if self.filter_indices:
            image = image[..., self.filter_indices]

        image = image.astype(np.float32) * self.constant_scale

        if self.transforms:
            image = self.transforms(image=image)["image"]  # albumentations returns dict

        output = {
            "image": image,
            "label": label,  # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target)
            "filename": self.image_files[index]
        }

        return output

    def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None):
        if bands_intervals is None:
            return None
        bands = []
        for element in bands_intervals:
            # if its an interval
            if isinstance(element, tuple):
                if len(element) != 2:  # noqa: PLR2004
                    msg = "When defining an interval, a tuple of two integers should be passed,\
                    defining start and end indices inclusive"
                    raise Exception(msg)
                expanded_element = list(range(element[0], element[1] + 1))
                bands.extend(expanded_element)
            else:
                bands.append(element)
        return bands

    def _load_file(self, path) -> xr.DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        data = data.fillna(self.no_data_replace)
        return data
__init__(data_root, split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, constant_scale=1, transform=None, no_data_replace=0, expand_temporal_dimension=False)

Constructor

Parameters:
  • data_root (Path) –

    Path to data root directory

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None, default: None ) –

    Bands present in the dataset. This parameter gives identifiers to input channels (bands) so that they can then be refered to by output_bands. Can use the HLSBands enum, ints, int ranges, or strings. Defaults to None.

  • output_bands (list[HLSBands | int | tuple[int, int] | str] | None, default: None ) –

    Bands that should be output by the dataset as named by dataset_bands.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float, default: 0 ) –

    Replace nan values in input images with this value. Defaults to 0.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

Source code in terratorch/datasets/generic_scalar_label_dataset.py
def __init__(
    self,
    data_root: Path,
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,  # noqa: FBT001, FBT002
    allow_substring_split_file: bool = True,  # noqa: FBT001, FBT002
    rgb_indices: list[int] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float = 0,
    expand_temporal_dimension: bool = False,  # noqa: FBT001, FBT002
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This
            parameter gives identifiers to input channels (bands) so that they can then be refered to by
            output_bands. Can use the HLSBands enum, ints, int ranges, or strings. Defaults to None.
        output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the
            dataset as named by dataset_bands.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
    """
    self.split_file = split

    self.image_files = sorted(glob.glob(os.path.join(data_root, "**"), recursive=True))
    self.image_files = [f for f in self.image_files if not os.path.isdir(f)]
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.expand_temporal_dimension = expand_temporal_dimension
    if self.expand_temporal_dimension and output_bands is None:
        msg = "Please provide output_bands when expand_temporal_dimension is True"
        raise Exception(msg)
    if self.split_file is not None:
        with open(self.split_file) as f:
            split = f.readlines()
        valid_files = {rf"{substring.strip()}" for substring in split}
        self.image_files = filter_valid_files(
            self.image_files,
            valid_files=valid_files,
            ignore_extensions=ignore_split_file_extensions,
            allow_substring=allow_substring_split_file,
        )

        def is_valid_file(x):
            return x in self.image_files

    else:

        def is_valid_file(x):
            return True

    super().__init__(
        root=data_root, transform=None, target_transform=None, loader=rasterio_loader, is_valid_file=is_valid_file
    )

    self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

    self.dataset_bands = generate_bands_intervals(dataset_bands)
    self.output_bands = generate_bands_intervals(output_bands)

    if self.output_bands and not self.dataset_bands:
        msg = "If output bands provided, dataset_bands must also be provided"
        return Exception(msg)  # noqa: PLE0101

    # There is a special condition if the bands are defined as simple strings.
    if self.output_bands:
        if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
            msg = "Output bands must be a subset of dataset bands"
            raise Exception(msg)

        self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

    else:
        self.filter_indices = None
    # If no transform is given, apply only to transform to torch tensor
    self.transforms = transform if transform else default_transform
    # self.transform = transform if transform else ToTensorV2()

    import warnings

    import rasterio
    warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

Generic Data Modules

terratorch.datamodules.generic_pixel_wise_data_module

This module contains generic data modules for instantiation at runtime.

GenericNonGeoPixelwiseRegressionDataModule

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoPixelwiseRegressionDataset

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
class GenericNonGeoPixelwiseRegressionDataModule(NonGeoDataModule):
    """This is a generic datamodule class for instantiating data modules at runtime.
    Composes several
    [GenericNonGeoPixelwiseRegressionDataset][terratorch.datasets.GenericNonGeoPixelwiseRegressionDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        means: list[float] | str,
        stds: list[float] | str,
        predict_data_root: Path | None = None,
        img_grep: str | None = "*",
        label_grep: str | None = "*",
        train_label_data_root: Path | None = None,
        val_label_data_root: Path | None = None,
        test_label_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        drop_last: bool = True,
        pin_memory: bool = False,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            predict_data_root (Path): _description_
            img_grep (str): _description_
            label_grep (str): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            train_label_data_root (Path | None, optional): _description_. Defaults to None.
            val_label_data_root (Path | None, optional): _description_. Defaults to None.
            test_label_data_root (Path | None, optional): _description_. Defaults to None.
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
                Naming must match that of dataset_bands. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
                with this value at predict time.
                Defaults to None, which does not overwrite.
            predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
                with this value at predict time. Defaults to None, which does not overwrite.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
            pin_memory (bool): If ``True``, the data loader will copy Tensors
            into device/CUDA pinned memory before returning them. Defaults to False.

        """
        super().__init__(GenericNonGeoPixelwiseRegressionDataset, batch_size, num_workers, **kwargs)
        self.img_grep = img_grep
        self.label_grep = label_grep
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.drop_last = drop_last
        self.pin_memory = pin_memory
        self.expand_temporal_dimension = expand_temporal_dimension
        self.reduce_zero_label = reduce_zero_label

        self.train_label_data_root = train_label_data_root
        self.val_label_data_root = val_label_data_root
        self.test_label_data_root = test_label_data_root

        self.constant_scale = constant_scale

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )
        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.train_label_data_root,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.val_label_data_root,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.test_label_data_root,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                image_grep=self.img_grep,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.predict_output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        batch_size = check_dataset_stackability(dataset, batch_size)

        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
            pin_memory=self.pin_memory,
        )
__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, means, stds, predict_data_root=None, img_grep='*', label_grep='*', train_label_data_root=None, val_label_data_root=None, test_label_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, output_bands=None, predict_dataset_bands=None, predict_output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, reduce_zero_label=False, no_data_replace=None, no_label_replace=None, drop_last=True, pin_memory=False, **kwargs)

Constructor

Parameters:
  • batch_size (int) –

    description

  • num_workers (int) –

    description

  • train_data_root (Path) –

    description

  • val_data_root (Path) –

    description

  • test_data_root (Path) –

    description

  • predict_data_root (Path, default: None ) –

    description

  • img_grep (str, default: '*' ) –

    description

  • label_grep (str, default: '*' ) –

    description

  • means (list[float]) –

    description

  • stds (list[float]) –

    description

  • train_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • val_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • test_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • train_split (Path | None, default: None ) –

    description. Defaults to None.

  • val_split (Path | None, default: None ) –

    description. Defaults to None.

  • test_split (Path | None, default: None ) –

    description. Defaults to None.

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset. Defaults to None.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset. Naming must match that of dataset_bands. Defaults to None.

  • predict_dataset_bands (list[HLSBands | int] | None, default: None ) –

    Overwrites dataset_bands with this value at predict time. Defaults to None, which does not overwrite.

  • predict_output_bands (list[HLSBands | int] | None, default: None ) –

    Overwrites output_bands with this value at predict time. Defaults to None, which does not overwrite.

  • constant_scale (float, default: 1 ) –

    description. Defaults to 1.

  • rgb_indices (list[int] | None, default: None ) –

    description. Defaults to None.

  • train_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • val_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • test_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to None.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

  • drop_last (bool, default: True ) –

    Drop the last batch if it is not complete. Defaults to True.

  • pin_memory (bool, default: False ) –

    If True, the data loader will copy Tensors

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    means: list[float] | str,
    stds: list[float] | str,
    predict_data_root: Path | None = None,
    img_grep: str | None = "*",
    label_grep: str | None = "*",
    train_label_data_root: Path | None = None,
    val_label_data_root: Path | None = None,
    test_label_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    drop_last: bool = True,
    pin_memory: bool = False,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        predict_data_root (Path): _description_
        img_grep (str): _description_
        label_grep (str): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        train_label_data_root (Path | None, optional): _description_. Defaults to None.
        val_label_data_root (Path | None, optional): _description_. Defaults to None.
        test_label_data_root (Path | None, optional): _description_. Defaults to None.
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            Naming must match that of dataset_bands. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
            with this value at predict time.
            Defaults to None, which does not overwrite.
        predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
            with this value at predict time. Defaults to None, which does not overwrite.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        pin_memory (bool): If ``True``, the data loader will copy Tensors
        into device/CUDA pinned memory before returning them. Defaults to False.

    """
    super().__init__(GenericNonGeoPixelwiseRegressionDataset, batch_size, num_workers, **kwargs)
    self.img_grep = img_grep
    self.label_grep = label_grep
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.drop_last = drop_last
    self.pin_memory = pin_memory
    self.expand_temporal_dimension = expand_temporal_dimension
    self.reduce_zero_label = reduce_zero_label

    self.train_label_data_root = train_label_data_root
    self.val_label_data_root = val_label_data_root
    self.test_label_data_root = test_label_data_root

    self.constant_scale = constant_scale

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )
    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
GenericNonGeoSegmentationDataModule

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoSegmentationDatasets

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
class GenericNonGeoSegmentationDataModule(NonGeoDataModule):
    """
    This is a generic datamodule class for instantiating data modules at runtime.
    Composes several [GenericNonGeoSegmentationDatasets][terratorch.datasets.GenericNonGeoSegmentationDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        img_grep: str,
        label_grep: str,
        means: list[float] | str,
        stds: list[float] | str,
        num_classes: int,
        predict_data_root: Path | None = None,
        train_label_data_root: Path | None = None,
        val_label_data_root: Path | None = None,
        test_label_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        drop_last: bool = True,
        pin_memory: bool = False,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            predict_data_root (Path): _description_
            img_grep (str): _description_
            label_grep (str): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            num_classes (int): _description_
            train_label_data_root (Path | None, optional): _description_. Defaults to None.
            val_label_data_root (Path | None, optional): _description_. Defaults to None.
            test_label_data_root (Path | None, optional): _description_. Defaults to None.
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
                Naming must match that of dataset_bands. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
                with this value at predict time.
                Defaults to None, which does not overwrite.
            predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
                with this value at predict time. Defaults to None, which does not overwrite.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
            pin_memory (bool): If ``True``, the data loader will copy Tensors
            into device/CUDA pinned memory before returning them. Defaults to False.
        """
        super().__init__(GenericNonGeoSegmentationDataset, batch_size, num_workers, **kwargs)
        self.num_classes = num_classes
        self.img_grep = img_grep
        self.label_grep = label_grep
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.drop_last = drop_last
        self.pin_memory = pin_memory

        self.train_label_data_root = train_label_data_root
        self.val_label_data_root = val_label_data_root
        self.test_label_data_root = test_label_data_root

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices
        self.expand_temporal_dimension = expand_temporal_dimension
        self.reduce_zero_label = reduce_zero_label

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )
        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)

        # self.aug = Normalize(means, stds)
        # self.collate_fn = collate_fn_list_dicts

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.train_label_data_root,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.val_label_data_root,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.test_label_data_root,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                self.num_classes,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.predict_output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        batch_size = check_dataset_stackability(dataset, batch_size)

        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
            pin_memory=self.pin_memory,
        )
__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, img_grep, label_grep, means, stds, num_classes, predict_data_root=None, train_label_data_root=None, val_label_data_root=None, test_label_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, output_bands=None, predict_dataset_bands=None, predict_output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, reduce_zero_label=False, no_data_replace=None, no_label_replace=None, drop_last=True, pin_memory=False, **kwargs)

Constructor

Parameters:
  • batch_size (int) –

    description

  • num_workers (int) –

    description

  • train_data_root (Path) –

    description

  • val_data_root (Path) –

    description

  • test_data_root (Path) –

    description

  • predict_data_root (Path, default: None ) –

    description

  • img_grep (str) –

    description

  • label_grep (str) –

    description

  • means (list[float]) –

    description

  • stds (list[float]) –

    description

  • num_classes (int) –

    description

  • train_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • val_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • test_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • train_split (Path | None, default: None ) –

    description. Defaults to None.

  • val_split (Path | None, default: None ) –

    description. Defaults to None.

  • test_split (Path | None, default: None ) –

    description. Defaults to None.

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset. Defaults to None.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset. Naming must match that of dataset_bands. Defaults to None.

  • predict_dataset_bands (list[HLSBands | int] | None, default: None ) –

    Overwrites dataset_bands with this value at predict time. Defaults to None, which does not overwrite.

  • predict_output_bands (list[HLSBands | int] | None, default: None ) –

    Overwrites output_bands with this value at predict time. Defaults to None, which does not overwrite.

  • constant_scale (float, default: 1 ) –

    description. Defaults to 1.

  • rgb_indices (list[int] | None, default: None ) –

    description. Defaults to None.

  • train_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • val_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • test_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to None.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

  • drop_last (bool, default: True ) –

    Drop the last batch if it is not complete. Defaults to True.

  • pin_memory (bool, default: False ) –

    If True, the data loader will copy Tensors

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    img_grep: str,
    label_grep: str,
    means: list[float] | str,
    stds: list[float] | str,
    num_classes: int,
    predict_data_root: Path | None = None,
    train_label_data_root: Path | None = None,
    val_label_data_root: Path | None = None,
    test_label_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    drop_last: bool = True,
    pin_memory: bool = False,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        predict_data_root (Path): _description_
        img_grep (str): _description_
        label_grep (str): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        num_classes (int): _description_
        train_label_data_root (Path | None, optional): _description_. Defaults to None.
        val_label_data_root (Path | None, optional): _description_. Defaults to None.
        test_label_data_root (Path | None, optional): _description_. Defaults to None.
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            Naming must match that of dataset_bands. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
            with this value at predict time.
            Defaults to None, which does not overwrite.
        predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
            with this value at predict time. Defaults to None, which does not overwrite.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        pin_memory (bool): If ``True``, the data loader will copy Tensors
        into device/CUDA pinned memory before returning them. Defaults to False.
    """
    super().__init__(GenericNonGeoSegmentationDataset, batch_size, num_workers, **kwargs)
    self.num_classes = num_classes
    self.img_grep = img_grep
    self.label_grep = label_grep
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.drop_last = drop_last
    self.pin_memory = pin_memory

    self.train_label_data_root = train_label_data_root
    self.val_label_data_root = val_label_data_root
    self.test_label_data_root = test_label_data_root

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices
    self.expand_temporal_dimension = expand_temporal_dimension
    self.reduce_zero_label = reduce_zero_label

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )
    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)

terratorch.datamodules.generic_scalar_label_data_module

This module contains generic data modules for instantiation at runtime.

GenericNonGeoClassificationDataModule

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoClassificationDatasets

Source code in terratorch/datamodules/generic_scalar_label_data_module.py
class GenericNonGeoClassificationDataModule(NonGeoDataModule):
    """
    This is a generic datamodule class for instantiating data modules at runtime.
    Composes several [GenericNonGeoClassificationDatasets][terratorch.datasets.GenericNonGeoClassificationDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        means: list[float] | str,
        stds: list[float] | str,
        num_classes: int,
        predict_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int] | None = None,
        predict_dataset_bands: list[HLSBands | int] | None = None,
        output_bands: list[HLSBands | int] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        no_data_replace: float = 0,
        drop_last: bool = True,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            num_classes (int): _description_
            predict_data_root (Path): _description_
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg".
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        """
        super().__init__(GenericNonGeoClassificationDataset, batch_size, num_workers, **kwargs)
        self.num_classes = num_classes
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.drop_last = drop_last

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices
        self.expand_temporal_dimension = expand_temporal_dimension

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )

        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)

        # self.aug = Normalize(means, stds)
        # self.collate_fn = collate_fn_list_dicts

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                self.num_classes,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                self.num_classes,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                self.num_classes,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                self.num_classes,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        batch_size = check_dataset_stackability(dataset, batch_size)

        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )
__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, means, stds, num_classes, predict_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, predict_dataset_bands=None, output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, no_data_replace=0, drop_last=True, **kwargs)

Constructor

Parameters:
  • batch_size (int) –

    description

  • num_workers (int) –

    description

  • train_data_root (Path) –

    description

  • val_data_root (Path) –

    description

  • test_data_root (Path) –

    description

  • means (list[float]) –

    description

  • stds (list[float]) –

    description

  • num_classes (int) –

    description

  • predict_data_root (Path, default: None ) –

    description

  • train_split (Path | None, default: None ) –

    description. Defaults to None.

  • val_split (Path | None, default: None ) –

    description. Defaults to None.

  • test_split (Path | None, default: None ) –

    description. Defaults to None.

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg".

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    description. Defaults to None.

  • predict_dataset_bands (list[HLSBands | int] | None, default: None ) –

    description. Defaults to None.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    description. Defaults to None.

  • constant_scale (float, default: 1 ) –

    description. Defaults to 1.

  • rgb_indices (list[int] | None, default: None ) –

    description. Defaults to None.

  • train_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • val_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • test_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float, default: 0 ) –

    Replace nan values in input images with this value. Defaults to 0.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • drop_last (bool, default: True ) –

    Drop the last batch if it is not complete. Defaults to True.

Source code in terratorch/datamodules/generic_scalar_label_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    means: list[float] | str,
    stds: list[float] | str,
    num_classes: int,
    predict_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int] | None = None,
    predict_dataset_bands: list[HLSBands | int] | None = None,
    output_bands: list[HLSBands | int] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    no_data_replace: float = 0,
    drop_last: bool = True,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        num_classes (int): _description_
        predict_data_root (Path): _description_
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg".
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
    """
    super().__init__(GenericNonGeoClassificationDataset, batch_size, num_workers, **kwargs)
    self.num_classes = num_classes
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.drop_last = drop_last

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices
    self.expand_temporal_dimension = expand_temporal_dimension

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )

    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)

Custom datasets and data modules

Below is a documented example of how a custom dataset and data module class can be implemented.

Datasets

terratorch.datasets.fire_scars

FireScarsHLS

Bases: RasterDataset

RasterDataset implementation for fire scars input images.

Source code in terratorch/datasets/fire_scars.py
class FireScarsHLS(RasterDataset):
    """RasterDataset implementation for fire scars input images."""

    filename_glob = "subsetted*_merged.tif"
    filename_regex = r"subsetted_512x512_HLS\..30\..{6}\.(?P<date>[0-9]*)\.v1.4_merged.tif"
    date_format = "%Y%j"
    is_image = True
    separate_files = False
    all_bands = dataclasses.field(default_factory=["B02", "B03", "B04", "B8A", "B11", "B12"])
    rgb_bands = dataclasses.field(default_factory=["B04", "B03", "B02"])
FireScarsNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for fire scars.

Source code in terratorch/datasets/fire_scars.py
class FireScarsNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for fire scars."""
    all_band_names = (
        "BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    num_classes = 2
    splits = {"train": "training", "val": "validation"}   # Only train and val splits available

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (str): Path to the data root directory.
            bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the corresponding data module,
                should not include normalization. Defaults to None, which applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value.
                If None, does no replacement. Defaults to 0.
            no_label_replace (int | None): Replace nan values in label with this value.
                If none, does no replacement. Defaults to -1.
            use_metadata (bool): whether to return metadata info (time and location).
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {self.splits}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
        self.data_root = Path(data_root)

        input_dir = self.data_root / split_name
        self.image_files = sorted(glob.glob(os.path.join(input_dir, "*_merged.tif")))
        self.segmentation_mask_files = sorted(glob.glob(os.path.join(input_dir, "*.mask.tif")))

        self.use_metadata = use_metadata
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace

        # If no transform is given, apply only to transform to torch tensor
        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def _get_date(self, index: int) -> torch.Tensor:
        file_name = self.image_files[index]
        base_filename = os.path.basename(file_name)

        filename_regex = r"subsetted_512x512_HLS\.S30\.T[0-9A-Z]{5}\.(?P<date>[0-9]+)\.v1\.4_merged\.tif"
        match = re.match(filename_regex, base_filename)
        date_str = match.group("date")
        year = int(date_str[:4])
        julian_day = int(date_str[4:])

        return torch.tensor([[year, julian_day]], dtype=torch.float32)

    def _get_coords(self, image: DataArray) -> torch.Tensor:
        px = image.x.shape[0] // 2
        py = image.y.shape[0] // 2

        # get center point to reproject to lat/lon
        point = image.isel(band=0, x=slice(px, px + 1), y=slice(py, py + 1))
        point = point.rio.reproject("epsg:4326")

        lat_lon = np.asarray([point.y[0], point.x[0]])

        return torch.tensor(lat_lon, dtype=torch.float32)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace)

        location_coords, temporal_coords = None, None
        if self.use_metadata:
            location_coords = self._get_coords(image)
            temporal_coords = self._get_date(index)

        # to channels last
        image = image.to_numpy()
        image = np.moveaxis(image, 0, -1)

        # filter bands
        image = image[..., self.band_indices]

        output = {
            "image": image.astype(np.float32),
            "mask": self._load_file(
                self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[0],
        }
        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        if self.use_metadata:
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        return output

    def _load_file(self, path: Path, nan_replace: int | float | None = None) -> DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            suptitle: optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample
        """
        num_images = 4

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        # RGB -> channels-last
        image = sample["image"][rgb_indices, ...].permute(1, 2, 0).numpy()
        mask = sample["mask"].numpy()

        image = clip_image_percentile(image)

        if "prediction" in sample:
            prediction = sample["prediction"]
            num_images += 1
        else:
            prediction = None

        fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")

        ax[0].axis("off")

        norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
        ax[1].axis("off")
        ax[1].title.set_text("Image")
        ax[1].imshow(image)

        ax[2].axis("off")
        ax[2].title.set_text("Ground Truth Mask")
        ax[2].imshow(mask, cmap="jet", norm=norm)

        ax[3].axis("off")
        ax[3].title.set_text("GT Mask on Image")
        ax[3].imshow(image)
        ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

        if "prediction" in sample:
            ax[4].title.set_text("Predicted Mask")
            ax[4].imshow(prediction, cmap="jet", norm=norm)

        cmap = plt.get_cmap("jet")
        legend_data = [[i, cmap(norm(i)), str(i)] for i in range(self.num_classes)]
        handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
        labels = [n for k, c, n in legend_data]
        ax[0].legend(handles, labels, loc="center")
        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, no_data_replace=0, no_label_replace=-1, use_metadata=False)

Constructor

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • bands (list[str], default: BAND_SETS['all'] ) –

    Bands that should be output by the dataset. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the corresponding data module, should not include normalization. Defaults to None, which applies ToTensorV2().

  • no_data_replace (float | None, default: 0 ) –

    Replace nan values in input images with this value. If None, does no replacement. Defaults to 0.

  • no_label_replace (int | None, default: -1 ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to -1.

  • use_metadata (bool, default: False ) –

    whether to return metadata info (time and location).

Source code in terratorch/datasets/fire_scars.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    no_data_replace: float | None = 0,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (str): Path to the data root directory.
        bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the corresponding data module,
            should not include normalization. Defaults to None, which applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value.
            If None, does no replacement. Defaults to 0.
        no_label_replace (int | None): Replace nan values in label with this value.
            If none, does no replacement. Defaults to -1.
        use_metadata (bool): whether to return metadata info (time and location).
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {self.splits}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
    self.data_root = Path(data_root)

    input_dir = self.data_root / split_name
    self.image_files = sorted(glob.glob(os.path.join(input_dir, "*_merged.tif")))
    self.segmentation_mask_files = sorted(glob.glob(os.path.join(input_dir, "*.mask.tif")))

    self.use_metadata = use_metadata
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace

    # If no transform is given, apply only to transform to torch tensor
    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    a sample returned by :meth:__getitem__

  • suptitle (str | None, default: None ) –

    optional string to use as a suptitle

Returns:
  • Figure

    a matplotlib Figure with the rendered sample

Source code in terratorch/datasets/fire_scars.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        suptitle: optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample
    """
    num_images = 4

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    # RGB -> channels-last
    image = sample["image"][rgb_indices, ...].permute(1, 2, 0).numpy()
    mask = sample["mask"].numpy()

    image = clip_image_percentile(image)

    if "prediction" in sample:
        prediction = sample["prediction"]
        num_images += 1
    else:
        prediction = None

    fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")

    ax[0].axis("off")

    norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
    ax[1].axis("off")
    ax[1].title.set_text("Image")
    ax[1].imshow(image)

    ax[2].axis("off")
    ax[2].title.set_text("Ground Truth Mask")
    ax[2].imshow(mask, cmap="jet", norm=norm)

    ax[3].axis("off")
    ax[3].title.set_text("GT Mask on Image")
    ax[3].imshow(image)
    ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

    if "prediction" in sample:
        ax[4].title.set_text("Predicted Mask")
        ax[4].imshow(prediction, cmap="jet", norm=norm)

    cmap = plt.get_cmap("jet")
    legend_data = [[i, cmap(norm(i)), str(i)] for i in range(self.num_classes)]
    handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
    labels = [n for k, c, n in legend_data]
    ax[0].legend(handles, labels, loc="center")
    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig
FireScarsSegmentationMask

Bases: RasterDataset

RasterDataset implementation for fire scars segmentation mask. Can be easily merged with input images using the & operator.

Source code in terratorch/datasets/fire_scars.py
class FireScarsSegmentationMask(RasterDataset):
    """RasterDataset implementation for fire scars segmentation mask.
    Can be easily merged with input images using the & operator.
    """

    filename_glob = "subsetted*.mask.tif"
    filename_regex = r"subsetted_512x512_HLS\..30\..{6}\.(?P<date>[0-9]*)\.v1.4.mask.tif"
    date_format = "%Y%j"
    is_image = False
    separate_files = False

terratorch.datasets.biomassters

BioMasstersNonGeo

Bases: BioMassters

BioMassters Dataset for Aboveground Biomass prediction.

Dataset intended for Aboveground Biomass (AGB) prediction over Finnish forests based on Sentinel 1 and 2 data with corresponding target AGB mask values generated by Light Detection and Ranging (LiDAR).

Dataset Format:

  • .tif files for Sentinel 1 and 2 data
  • .tif file for pixel wise AGB target mask
  • .csv files for metadata regarding features and targets

Dataset Features:

  • 13,000 target AGB masks of size (256x256px)
  • 12 months of data per target mask
  • Sentinel 1 and Sentinel 2 data for each location
  • Sentinel 1 available for every month
  • Sentinel 2 available for almost every month (not available for every month due to ESA acquisition halt over the region during particular periods)

If you use this dataset in your research, please cite the following paper:

  • https://nascetti-a.github.io/BioMasster/

.. versionadded:: 0.5

Source code in terratorch/datasets/biomassters.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
class BioMasstersNonGeo(BioMassters):
    """BioMassters Dataset for Aboveground Biomass prediction.

    Dataset intended for Aboveground Biomass (AGB) prediction
    over Finnish forests based on Sentinel 1 and 2 data with
    corresponding target AGB mask values generated by Light Detection
    and Ranging (LiDAR).

    Dataset Format:

    * .tif files for Sentinel 1 and 2 data
    * .tif file for pixel wise AGB target mask
    * .csv files for metadata regarding features and targets

    Dataset Features:

    * 13,000 target AGB masks of size (256x256px)
    * 12 months of data per target mask
    * Sentinel 1 and Sentinel 2 data for each location
    * Sentinel 1 available for every month
    * Sentinel 2 available for almost every month
      (not available for every month due to ESA acquisition halt over the region
      during particular periods)

    If you use this dataset in your research, please cite the following paper:

    * https://nascetti-a.github.io/BioMasster/

    .. versionadded:: 0.5
    """

    S1_BAND_NAMES = ["VV_Asc", "VH_Asc", "VV_Desc", "VH_Desc", "RVI_Asc", "RVI_Desc"]
    S2_BAND_NAMES = [
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
        "CLOUD_PROBABILITY",
    ]

    all_band_names = {
        "S1": S1_BAND_NAMES,
        "S2": S2_BAND_NAMES,
    }

    rgb_bands = {
        "S1": [],
        "S2": ["RED", "GREEN", "BLUE"],
    }

    valid_splits = ("train", "test")
    valid_sensors = ("S1", "S2")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    default_metadata_filename = "The_BioMassters_-_features_metadata.csv.csv"

    def __init__(
        self,
        root = "data",
        split: str = "train",
        bands: dict[str, Sequence[str]] | Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        mask_mean: float | None = 63.4584,
        mask_std: float | None = 72.21242,
        sensors: Sequence[str] = ["S1", "S2"],
        as_time_series: bool = False,
        metadata_filename: str = default_metadata_filename,
        max_cloud_percentage: float | None = None,
        max_red_mean: float | None = None,
        include_corrupt: bool = True,
        subset: float = 1,
        seed: int = 42,
        use_four_frames: bool = False
    ) -> None:
        """Initialize a new instance of BioMassters dataset.

        If ``as_time_series=False`` (the default), each time step becomes its own
        sample with the target being shared across multiple samples.

        Args:
            root: root directory where dataset can be found
            split: train or test split
            sensors: which sensors to consider for the sample, Sentinel 1 and/or
                Sentinel 2 ('S1', 'S2')
            as_time_series: whether or not to return all available
                time-steps or just a single one for a given target location
            metadata_filename: metadata file to be used
            max_cloud_percentage: maximum allowed cloud percentage for images
            max_red_mean: maximum allowed red_mean value for images
            include_corrupt: whether to include images marked as corrupted

        Raises:
            AssertionError: if ``split`` or ``sensors`` is invalid
            DatasetNotFoundError: If dataset is not found.
        """
        self.root = root
        self.sensors = sensors
        self.bands = bands
        assert (
            split in self.valid_splits
        ), f"Please choose one of the valid splits: {self.valid_splits}."
        self.split = split

        assert set(sensors).issubset(
            set(self.valid_sensors)
        ), f"Please choose a subset of valid sensors: {self.valid_sensors}."

        if len(self.sensors) == 1:
            sens = self.sensors[0]
            self.band_indices = [
                self.all_band_names[sens].index(band) for band in self.bands[sens]
            ]
        else:
            self.band_indices = {
                sens: [self.all_band_names[sens].index(band) for band in self.bands[sens]]
                for sens in self.sensors
            }

        self.mask_mean = mask_mean
        self.mask_std = mask_std
        self.as_time_series = as_time_series
        self.metadata_filename = metadata_filename
        self.max_cloud_percentage = max_cloud_percentage
        self.max_red_mean = max_red_mean
        self.include_corrupt = include_corrupt
        self.subset = subset
        self.seed = seed
        self.use_four_frames = use_four_frames

        self._verify()

        # open metadata csv files
        self.df = pd.read_csv(os.path.join(self.root, self.metadata_filename))

        # Filter sensors
        self.df = self.df[self.df["satellite"].isin(self.sensors)]

        # Filter split
        self.df = self.df[self.df["split"] == self.split]

        # Optional filtering
        self._filter_and_select_data()

        # Optional subsampling
        self._random_subsample()

        # generate numerical month from filename since first month is September
        # and has numerical index of 0
        self.df["num_month"] = (
            self.df["filename"]
            .str.split("_", expand=True)[2]
            .str.split(".", expand=True)[0]
            .astype(int)
        )

        # Set dataframe index depending on the task for easier indexing
        if self.as_time_series:
            self.df["num_index"] = self.df.groupby(["chip_id"]).ngroup()
        else:
            filter_df = (
                self.df.groupby(["chip_id", "month"])["satellite"].count().reset_index()
            )
            filter_df = filter_df[
                filter_df["satellite"] == len(self.sensors)
            ].drop("satellite", axis=1)
            # Guarantee that each sample has corresponding number of images available
            self.df = self.df.merge(filter_df, on=["chip_id", "month"], how="inner")

            self.df["num_index"] = self.df.groupby(["chip_id", "month"]).ngroup()

        # Adjust transforms based on the number of sensors
        if len(self.sensors) == 1:
            self.transform = transform if transform else default_transform
        elif transform is None:
            self.transform = MultimodalToTensor(self.sensors)
        else:
            transform = {
                s: transform[s] if s in transform else default_transform
                for s in self.sensors
            }
            self.transform = MultimodalTransforms(transform, shared=False)

        if self.use_four_frames:
            self._select_4_frames()

    def __len__(self) -> int:
        return len(self.df["num_index"].unique())

    def _load_input(self, filenames: list[Path]) -> Tensor:
        """Load the input imagery at the index.

        Args:
            filenames: list of filenames corresponding to input

        Returns:
            input image
        """
        filepaths = [
            os.path.join(self.root, f"{self.split}_features", f) for f in filenames
        ]
        arr_list = [rasterio.open(fp).read() for fp in filepaths]

        if self.as_time_series:
            arr = np.stack(arr_list, axis=0) # (T, C, H, W)
        else:
            arr = np.concatenate(arr_list, axis=0)
        return arr.astype(np.int32)

    def _load_target(self, filename: Path) -> Tensor:
        """Load the target mask at the index.

        Args:
            filename: filename of target to index

        Returns:
            target mask
        """
        with rasterio.open(os.path.join(self.root, f"{self.split}_agbm", filename), "r") as src:
            arr: np.typing.NDArray[np.float64] = src.read()

        return arr

    def _compute_rvi(self, img: np.ndarray, linear: np.ndarray, sens: str) -> np.ndarray:
        """Compute the RVI indices for S1 data."""
        rvi_channels = []
        if self.as_time_series:
            if "RVI_Asc" in self.bands[sens]:
                try:
                    vv_asc_index = self.all_band_names["S1"].index("VV_Asc")
                    vh_asc_index = self.all_band_names["S1"].index("VH_Asc")
                except ValueError as e:
                    msg = f"RVI_Asc needs band: {e}"
                    raise ValueError(msg) from e

                VV = linear[:, vv_asc_index, :, :]
                VH = linear[:, vh_asc_index, :, :]
                rvi_asc = 4 * VH / (VV + VH + 1e-6)
                rvi_asc = np.expand_dims(rvi_asc, axis=1)
                rvi_channels.append(rvi_asc)
            if "RVI_Desc" in self.bands[sens]:
                try:
                    vv_desc_index = self.all_band_names["S1"].index("VV_Desc")
                    vh_desc_index = self.all_band_names["S1"].index("VH_Desc")
                except ValueError as e:
                    msg = f"RVI_Desc needs band: {e}"
                    raise ValueError(msg) from e

                VV_desc = linear[:, vv_desc_index, :, :]
                VH_desc = linear[:, vh_desc_index, :, :]
                rvi_desc = 4 * VH_desc / (VV_desc + VH_desc + 1e-6)
                rvi_desc = np.expand_dims(rvi_desc, axis=1)
                rvi_channels.append(rvi_desc)
            if rvi_channels:
                rvi_concat = np.concatenate(rvi_channels, axis=1)
                img = np.concatenate([img, rvi_concat], axis=1)
        else:
            if "RVI_Asc" in self.bands[sens]:
                if linear.shape[0] < 2:
                    msg = f"Not enough bands to calculate RVI_Asc. Available bands: {linear.shape[0]}"
                    raise ValueError(msg)
                VV = linear[0]
                VH = linear[1]
                rvi_asc = 4 * VH / (VV + VH + 1e-6)
                rvi_asc = np.expand_dims(rvi_asc, axis=0)
                rvi_channels.append(rvi_asc)
            if "RVI_Desc" in self.bands[sens]:
                if linear.shape[0] < 4:
                    msg = f"Not enough bands to calculate RVI_Desc. Available bands: {linear.shape[0]}"
                    raise ValueError(msg)
                VV_desc = linear[2]
                VH_desc = linear[3]
                rvi_desc = 4 * VH_desc / (VV_desc + VH_desc + 1e-6)
                rvi_desc = np.expand_dims(rvi_desc, axis=0) 
                rvi_channels.append(rvi_desc)
            if rvi_channels:
                rvi_concat = np.concatenate(rvi_channels, axis=0)
                img = np.concatenate([linear, rvi_concat], axis=0)
        return img

    def _select_4_frames(self):
        """Filter the dataset to select only 4 frames per sample."""

        if "cloud_percentage" in self.df.columns:
            self.df = self.df.sort_values(by=["chip_id", "cloud_percentage"])
        else:
            self.df = self.df.sort_values(by=["chip_id", "num_month"])

        self.df = (
            self.df.groupby("chip_id")
            .head(4)  # Select the first 4 frames per chip
            .reset_index(drop=True)
        )

    def _process_sensor_images(self, sens: str, sens_filepaths: list[str]) -> np.ndarray:
        """Process images for a given sensor."""
        img = self._load_input(sens_filepaths)
        if sens == "S1":
            img = img.astype(np.float32)
            linear = 10 ** (img / 10)
            img = self._compute_rvi(img, linear, sens)
        if self.as_time_series:
            img = img.transpose(0, 2, 3, 1)  # (T, H, W, C)
        else:
            img = img.transpose(1, 2, 0)  # (H, W, C)
        if len(self.sensors) == 1:
            img = img[..., self.band_indices]
        else:
            img = img[..., self.band_indices[sens]]
        return img

    def __getitem__(self, index: int) -> dict:
        sample_df = self.df[self.df["num_index"] == index].copy()
        # Sort by satellite and month
        sample_df.sort_values(
            by=["satellite", "num_month"], inplace=True, ascending=True
        )

        filepaths = sample_df["filename"].tolist()
        output = {}

        if len(self.sensors) == 1:
            sens = self.sensors[0]
            sens_filepaths = [fp for fp in filepaths if sens in fp]
            img = self._process_sensor_images(sens, sens_filepaths)
            output["image"] = img.astype(np.float32)
        else:
            for sens in self.sensors:
                sens_filepaths = [fp for fp in filepaths if sens in fp]
                img = self._process_sensor_images(sens, sens_filepaths)
                output[sens] = img.astype(np.float32)

        # Load target
        target_filename = sample_df["corresponding_agbm"].unique()[0]
        target = np.array(self._load_target(Path(target_filename)))
        target = target.transpose(1, 2, 0)
        output["mask"] = target
        if self.transform:
            if len(self.sensors) == 1:
                output = self.transform(**output)
            else:
                output = self.transform(output)
        output["mask"] = output["mask"].squeeze().float()
        return output

    def _filter_and_select_data(self):
        if (
            self.max_cloud_percentage is not None
            and "cloud_percentage" in self.df.columns
        ):
            self.df = self.df[self.df["cloud_percentage"] <= self.max_cloud_percentage]

        if self.max_red_mean is not None and "red_mean" in self.df.columns:
            self.df = self.df[self.df["red_mean"] <= self.max_red_mean]

        if not self.include_corrupt and "corrupt_values" in self.df.columns:
            self.df = self.df[self.df["corrupt_values"] is False]

    def _random_subsample(self):
        if self.split == "train" and self.subset < 1.0:
            num_samples = int(len(self.df["num_index"].unique()) * self.subset)
            if self.seed is not None:
                random.seed(self.seed)
            selected_indices = random.sample(
                list(self.df["num_index"].unique()), num_samples
            )
            self.df = self.df[self.df["num_index"].isin(selected_indices)]
            self.df.reset_index(drop=True, inplace=True)

    def plot(
        self,
        sample: dict[str, Tensor],
        show_titles: bool = True,
        suptitle: str | None = None,
    ) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            show_titles: flag indicating whether to show titles above each panel
            suptitle: optional suptitle to use for figure

        Returns:
            a matplotlib Figure with the rendered sample
        """
        # Determine if the sample contains multiple sensors or a single sensor
        if isinstance(sample["image"], dict):
            ncols = len(self.sensors) + 1
        else:
            ncols = 2  # One for the image and one for the mask

        showing_predictions = "prediction" in sample
        if showing_predictions:
            ncols += 1

        fig, axs = plt.subplots(1, ncols=ncols, figsize=(5 * ncols, 10))

        if isinstance(sample["image"], dict):
            # Multiple sensors case
            for idx, sens in enumerate(self.sensors):
                img = sample["image"][sens].numpy()
                if self.as_time_series:
                    # Plot last time step
                    img = img[:, -1, ...]
                if sens == "S2":
                    img = img[[2, 1, 0], ...].transpose(1, 2, 0)
                    img = percentile_normalization(img)
                else:
                    co_polarization = img[0]  # transmit == receive
                    cross_polarization = img[1]  # transmit != receive
                    ratio = co_polarization / (cross_polarization + 1e-6)

                    co_polarization = np.clip(co_polarization / 0.3, 0, 1)
                    cross_polarization = np.clip(cross_polarization / 0.05, 0, 1)
                    ratio = np.clip(ratio / 25, 0, 1)

                    img = np.stack(
                        (co_polarization, cross_polarization, ratio), axis=0
                    )
                    img = img.transpose(1, 2, 0)  # Convert to (H, W, 3)

                axs[idx].imshow(img)
                axs[idx].axis("off")
                if show_titles:
                    axs[idx].set_title(sens)
            mask_idx = len(self.sensors)
        else:
            # Single sensor case
            sens = self.sensors[0]
            img = sample["image"].numpy()
            if self.as_time_series:
                # Plot last time step
                img = img[:, -1, ...]
            if sens == "S2":
                img = img[[2, 1, 0], ...].transpose(1, 2, 0)
                img = percentile_normalization(img)
            else:
                co_polarization = img[0]  # transmit == receive
                cross_polarization = img[1]  # transmit != receive
                ratio = co_polarization / (cross_polarization + 1e-6)

                co_polarization = np.clip(co_polarization / 0.3, 0, 1)
                cross_polarization = np.clip(cross_polarization / 0.05, 0, 1)
                ratio = np.clip(ratio / 25, 0, 1)

                img = np.stack(
                    (co_polarization, cross_polarization, ratio), axis=0
                )
                img = img.transpose(1, 2, 0)  # Convert to (H, W, 3)

            axs[0].imshow(img)
            axs[0].axis("off")
            if show_titles:
                axs[0].set_title(sens)
            mask_idx = 1

        # Plot target mask
        if "mask" in sample:
            target = sample["mask"].squeeze()
            target_im = axs[mask_idx].imshow(target, cmap="YlGn")
            plt.colorbar(target_im, ax=axs[mask_idx], fraction=0.046, pad=0.04)
            axs[mask_idx].axis("off")
            if show_titles:
                axs[mask_idx].set_title("Target")

        # Plot prediction if available
        if showing_predictions:
            pred_idx = mask_idx + 1
            prediction = sample["prediction"].squeeze()
            pred_im = axs[pred_idx].imshow(prediction, cmap="YlGn")
            plt.colorbar(pred_im, ax=axs[pred_idx], fraction=0.046, pad=0.04)
            axs[pred_idx].axis("off")
            if show_titles:
                axs[pred_idx].set_title("Prediction")

        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(root='data', split='train', bands=BAND_SETS['all'], transform=None, mask_mean=63.4584, mask_std=72.21242, sensors=['S1', 'S2'], as_time_series=False, metadata_filename=default_metadata_filename, max_cloud_percentage=None, max_red_mean=None, include_corrupt=True, subset=1, seed=42, use_four_frames=False)

Initialize a new instance of BioMassters dataset.

If as_time_series=False (the default), each time step becomes its own sample with the target being shared across multiple samples.

Parameters:
  • root

    root directory where dataset can be found

  • split (str, default: 'train' ) –

    train or test split

  • sensors (Sequence[str], default: ['S1', 'S2'] ) –

    which sensors to consider for the sample, Sentinel 1 and/or Sentinel 2 ('S1', 'S2')

  • as_time_series (bool, default: False ) –

    whether or not to return all available time-steps or just a single one for a given target location

  • metadata_filename (str, default: default_metadata_filename ) –

    metadata file to be used

  • max_cloud_percentage (float | None, default: None ) –

    maximum allowed cloud percentage for images

  • max_red_mean (float | None, default: None ) –

    maximum allowed red_mean value for images

  • include_corrupt (bool, default: True ) –

    whether to include images marked as corrupted

Raises:
  • AssertionError

    if split or sensors is invalid

  • DatasetNotFoundError

    If dataset is not found.

Source code in terratorch/datasets/biomassters.py
def __init__(
    self,
    root = "data",
    split: str = "train",
    bands: dict[str, Sequence[str]] | Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    mask_mean: float | None = 63.4584,
    mask_std: float | None = 72.21242,
    sensors: Sequence[str] = ["S1", "S2"],
    as_time_series: bool = False,
    metadata_filename: str = default_metadata_filename,
    max_cloud_percentage: float | None = None,
    max_red_mean: float | None = None,
    include_corrupt: bool = True,
    subset: float = 1,
    seed: int = 42,
    use_four_frames: bool = False
) -> None:
    """Initialize a new instance of BioMassters dataset.

    If ``as_time_series=False`` (the default), each time step becomes its own
    sample with the target being shared across multiple samples.

    Args:
        root: root directory where dataset can be found
        split: train or test split
        sensors: which sensors to consider for the sample, Sentinel 1 and/or
            Sentinel 2 ('S1', 'S2')
        as_time_series: whether or not to return all available
            time-steps or just a single one for a given target location
        metadata_filename: metadata file to be used
        max_cloud_percentage: maximum allowed cloud percentage for images
        max_red_mean: maximum allowed red_mean value for images
        include_corrupt: whether to include images marked as corrupted

    Raises:
        AssertionError: if ``split`` or ``sensors`` is invalid
        DatasetNotFoundError: If dataset is not found.
    """
    self.root = root
    self.sensors = sensors
    self.bands = bands
    assert (
        split in self.valid_splits
    ), f"Please choose one of the valid splits: {self.valid_splits}."
    self.split = split

    assert set(sensors).issubset(
        set(self.valid_sensors)
    ), f"Please choose a subset of valid sensors: {self.valid_sensors}."

    if len(self.sensors) == 1:
        sens = self.sensors[0]
        self.band_indices = [
            self.all_band_names[sens].index(band) for band in self.bands[sens]
        ]
    else:
        self.band_indices = {
            sens: [self.all_band_names[sens].index(band) for band in self.bands[sens]]
            for sens in self.sensors
        }

    self.mask_mean = mask_mean
    self.mask_std = mask_std
    self.as_time_series = as_time_series
    self.metadata_filename = metadata_filename
    self.max_cloud_percentage = max_cloud_percentage
    self.max_red_mean = max_red_mean
    self.include_corrupt = include_corrupt
    self.subset = subset
    self.seed = seed
    self.use_four_frames = use_four_frames

    self._verify()

    # open metadata csv files
    self.df = pd.read_csv(os.path.join(self.root, self.metadata_filename))

    # Filter sensors
    self.df = self.df[self.df["satellite"].isin(self.sensors)]

    # Filter split
    self.df = self.df[self.df["split"] == self.split]

    # Optional filtering
    self._filter_and_select_data()

    # Optional subsampling
    self._random_subsample()

    # generate numerical month from filename since first month is September
    # and has numerical index of 0
    self.df["num_month"] = (
        self.df["filename"]
        .str.split("_", expand=True)[2]
        .str.split(".", expand=True)[0]
        .astype(int)
    )

    # Set dataframe index depending on the task for easier indexing
    if self.as_time_series:
        self.df["num_index"] = self.df.groupby(["chip_id"]).ngroup()
    else:
        filter_df = (
            self.df.groupby(["chip_id", "month"])["satellite"].count().reset_index()
        )
        filter_df = filter_df[
            filter_df["satellite"] == len(self.sensors)
        ].drop("satellite", axis=1)
        # Guarantee that each sample has corresponding number of images available
        self.df = self.df.merge(filter_df, on=["chip_id", "month"], how="inner")

        self.df["num_index"] = self.df.groupby(["chip_id", "month"]).ngroup()

    # Adjust transforms based on the number of sensors
    if len(self.sensors) == 1:
        self.transform = transform if transform else default_transform
    elif transform is None:
        self.transform = MultimodalToTensor(self.sensors)
    else:
        transform = {
            s: transform[s] if s in transform else default_transform
            for s in self.sensors
        }
        self.transform = MultimodalTransforms(transform, shared=False)

    if self.use_four_frames:
        self._select_4_frames()
plot(sample, show_titles=True, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    a sample returned by :meth:__getitem__

  • show_titles (bool, default: True ) –

    flag indicating whether to show titles above each panel

  • suptitle (str | None, default: None ) –

    optional suptitle to use for figure

Returns:
  • Figure

    a matplotlib Figure with the rendered sample

Source code in terratorch/datasets/biomassters.py
def plot(
    self,
    sample: dict[str, Tensor],
    show_titles: bool = True,
    suptitle: str | None = None,
) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        show_titles: flag indicating whether to show titles above each panel
        suptitle: optional suptitle to use for figure

    Returns:
        a matplotlib Figure with the rendered sample
    """
    # Determine if the sample contains multiple sensors or a single sensor
    if isinstance(sample["image"], dict):
        ncols = len(self.sensors) + 1
    else:
        ncols = 2  # One for the image and one for the mask

    showing_predictions = "prediction" in sample
    if showing_predictions:
        ncols += 1

    fig, axs = plt.subplots(1, ncols=ncols, figsize=(5 * ncols, 10))

    if isinstance(sample["image"], dict):
        # Multiple sensors case
        for idx, sens in enumerate(self.sensors):
            img = sample["image"][sens].numpy()
            if self.as_time_series:
                # Plot last time step
                img = img[:, -1, ...]
            if sens == "S2":
                img = img[[2, 1, 0], ...].transpose(1, 2, 0)
                img = percentile_normalization(img)
            else:
                co_polarization = img[0]  # transmit == receive
                cross_polarization = img[1]  # transmit != receive
                ratio = co_polarization / (cross_polarization + 1e-6)

                co_polarization = np.clip(co_polarization / 0.3, 0, 1)
                cross_polarization = np.clip(cross_polarization / 0.05, 0, 1)
                ratio = np.clip(ratio / 25, 0, 1)

                img = np.stack(
                    (co_polarization, cross_polarization, ratio), axis=0
                )
                img = img.transpose(1, 2, 0)  # Convert to (H, W, 3)

            axs[idx].imshow(img)
            axs[idx].axis("off")
            if show_titles:
                axs[idx].set_title(sens)
        mask_idx = len(self.sensors)
    else:
        # Single sensor case
        sens = self.sensors[0]
        img = sample["image"].numpy()
        if self.as_time_series:
            # Plot last time step
            img = img[:, -1, ...]
        if sens == "S2":
            img = img[[2, 1, 0], ...].transpose(1, 2, 0)
            img = percentile_normalization(img)
        else:
            co_polarization = img[0]  # transmit == receive
            cross_polarization = img[1]  # transmit != receive
            ratio = co_polarization / (cross_polarization + 1e-6)

            co_polarization = np.clip(co_polarization / 0.3, 0, 1)
            cross_polarization = np.clip(cross_polarization / 0.05, 0, 1)
            ratio = np.clip(ratio / 25, 0, 1)

            img = np.stack(
                (co_polarization, cross_polarization, ratio), axis=0
            )
            img = img.transpose(1, 2, 0)  # Convert to (H, W, 3)

        axs[0].imshow(img)
        axs[0].axis("off")
        if show_titles:
            axs[0].set_title(sens)
        mask_idx = 1

    # Plot target mask
    if "mask" in sample:
        target = sample["mask"].squeeze()
        target_im = axs[mask_idx].imshow(target, cmap="YlGn")
        plt.colorbar(target_im, ax=axs[mask_idx], fraction=0.046, pad=0.04)
        axs[mask_idx].axis("off")
        if show_titles:
            axs[mask_idx].set_title("Target")

    # Plot prediction if available
    if showing_predictions:
        pred_idx = mask_idx + 1
        prediction = sample["prediction"].squeeze()
        pred_im = axs[pred_idx].imshow(prediction, cmap="YlGn")
        plt.colorbar(pred_im, ax=axs[pred_idx], fraction=0.046, pad=0.04)
        axs[pred_idx].axis("off")
        if show_titles:
            axs[pred_idx].set_title("Prediction")

    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.burn_intensity

BurnIntensityNonGeo

Bases: NonGeoDataset

Dataset implementation for Burn Intensity classification.

Source code in terratorch/datasets/burn_intensity.py
class BurnIntensityNonGeo(NonGeoDataset):
    """Dataset implementation for Burn Intensity classification."""

    all_band_names = (
        "BLUE", "GREEN", "RED", "NIR", "SWIR_1", "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    class_names = (
        "No burn",
        "Unburned to Very Low",
        "Low Severity",
        "Moderate Severity",
        "High Severity"
    )

    CSV_FILES = {
        "limited": "BS_files_with_less_than_25_percent_zeros.csv",
        "full": "BS_files_raw.csv",
    }

    num_classes = 5
    splits = {"train": "train", "val": "val"}
    time_steps = ["pre", "during", "post"]

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        use_full_data: bool = True,
        no_data_replace: float | None = 0.0001,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
    ) -> None:
        """Initialize the BurnIntensity dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train' or 'val'.
            bands (Sequence[str]): Bands to output. Defaults to all bands.
            transform (Optional[A.Compose]): Albumentations transform to be applied.
            use_metadata (bool): Whether to return metadata info (location).
            use_full_data (bool): Wheter to use full data or data with less than 25 percent zeros.
            no_data_replace (Optional[float]): Value to replace NaNs in images.
            no_label_replace (Optional[int]): Value to replace NaNs in labels.
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])

        self.data_root = Path(data_root)

        # Read the CSV file to get the list of cases to include
        csv_file_key = "full" if use_full_data else "limited"
        csv_path = self.data_root / self.CSV_FILES[csv_file_key]
        df = pd.read_csv(csv_path)
        casenames = df["Case_Name"].tolist()

        split_file = self.data_root / f"{split}.txt"
        with open(split_file) as f:
            split_images = [line.strip() for line in f.readlines()]

        split_images = [img for img in split_images if self._extract_casename(img) in casenames]

        # Build the samples list
        self.samples = []
        for image_filename in split_images:
            image_files = []
            for time_step in self.time_steps:
                image_file = self.data_root / time_step / image_filename
                image_files.append(str(image_file))
            mask_filename = image_filename.replace("HLS_", "BS_")
            mask_file = self.data_root / "pre" / mask_filename
            self.samples.append({
                "image_files": image_files,
                "mask_file": str(mask_file),
                "casename": self._extract_casename(image_filename),
            })

        self.use_metadata = use_metadata
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace

        self.transform = transform if transform else default_transform

    def _extract_basename(self, filepath: str) -> str:
        """Extract the base filename without extension."""
        return os.path.splitext(os.path.basename(filepath))[0]

    def _extract_casename(self, filename: str) -> str:
        """Extract the casename from the filename."""
        basename = self._extract_basename(filename)
        # Remove 'HLS_' or 'BS_' prefix
        casename = basename.replace("HLS_", "").replace("BS_", "")
        return casename

    def __len__(self) -> int:
        return len(self.samples)

    def _get_coords(self, image: DataArray) -> torch.Tensor:
        pixel_scale = image.rio.resolution()
        width, height = image.rio.width, image.rio.height

        left, bottom, right, top = image.rio.bounds()
        tie_point_x, tie_point_y = left, top

        center_col = width / 2
        center_row = height / 2

        center_lon = tie_point_x + (center_col * pixel_scale[0])
        center_lat = tie_point_y - (center_row * pixel_scale[1])

        lat_lon = np.asarray([center_lat, center_lon])
        return torch.tensor(lat_lon, dtype=torch.float32)

    def __getitem__(self, index: int) -> dict[str, Any]:
        sample = self.samples[index]
        image_files = sample["image_files"]
        mask_file = sample["mask_file"]

        images = []
        for idx, image_file in enumerate(image_files):
            image = self._load_file(Path(image_file), nan_replace=self.no_data_replace)
            if idx == 0 and self.use_metadata:
                location_coords = self._get_coords(image)
            image = image.to_numpy()
            image = np.moveaxis(image, 0, -1)
            image = image[..., self.band_indices]
            images.append(image)

        images = np.stack(images, axis=0)  # (T, H, W, C)

        output = {
            "image": images.astype(np.float32),
            "mask": self._load_file(Path(mask_file), nan_replace=self.no_label_replace).to_numpy()[0]
        }

        if self.transform:
            output = self.transform(**output)

        output["mask"] = output["mask"].long()
        if self.use_metadata:
            output["location_coords"] = location_coords

        return output

    def _load_file(self, path: Path, nan_replace: float | int | None = None) -> DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data


    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Any:
        """Plot a sample from the dataset.

        Args:
            sample: A sample returned by `__getitem__`.
            suptitle: Optional string to use as a suptitle.

        Returns:
            A matplotlib Figure with the rendered sample.
        """
        num_images = len(self.time_steps) + 2
        if "prediction" in sample:
            num_images += 1

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        images = sample["image"]  # (C, T, H, W)
        mask = sample["mask"].numpy()
        num_classes = len(np.unique(mask))

        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 5, 5))

        for i in range(len(self.time_steps)):
            image = images[:, i, :, :]  # (C, H, W)
            image = np.transpose(image, (1, 2, 0))  # (H, W, C)
            rgb_image = image[..., rgb_indices]
            rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
            rgb_image = np.clip(rgb_image, 0, 1)
            ax[i].imshow(rgb_image)
            ax[i].axis("off")
            ax[i].set_title(f"{self.time_steps[i].capitalize()} Image")

        cmap = plt.get_cmap("jet", num_classes)
        norm = Normalize(vmin=0, vmax=num_classes - 1)

        mask_ax_index = len(self.time_steps)
        ax[mask_ax_index].imshow(mask, cmap=cmap, norm=norm)
        ax[mask_ax_index].axis("off")
        ax[mask_ax_index].set_title("Ground Truth Mask")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            pred_ax_index = mask_ax_index + 1
            ax[pred_ax_index].imshow(prediction, cmap=cmap, norm=norm)
            ax[pred_ax_index].axis("off")
            ax[pred_ax_index].set_title("Predicted Mask")

        legend_ax_index = -1
        class_names = sample.get("class_names", self.class_names)
        positions = np.linspace(0, 1, num_classes) if num_classes > 1 else [0.5]

        legend_handles = [
            mpatches.Patch(color=cmap(pos), label=class_names[i])
            for i, pos in enumerate(positions)
        ]
        ax[legend_ax_index].legend(handles=legend_handles, loc="center")
        ax[legend_ax_index].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        plt.tight_layout()
        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, use_full_data=True, no_data_replace=0.0001, no_label_replace=-1, use_metadata=False)

Initialize the BurnIntensity dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train' or 'val'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to output. Defaults to all bands.

  • transform (Optional[Compose], default: None ) –

    Albumentations transform to be applied.

  • use_metadata (bool, default: False ) –

    Whether to return metadata info (location).

  • use_full_data (bool, default: True ) –

    Wheter to use full data or data with less than 25 percent zeros.

  • no_data_replace (Optional[float], default: 0.0001 ) –

    Value to replace NaNs in images.

  • no_label_replace (Optional[int], default: -1 ) –

    Value to replace NaNs in labels.

Source code in terratorch/datasets/burn_intensity.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    use_full_data: bool = True,
    no_data_replace: float | None = 0.0001,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,
) -> None:
    """Initialize the BurnIntensity dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train' or 'val'.
        bands (Sequence[str]): Bands to output. Defaults to all bands.
        transform (Optional[A.Compose]): Albumentations transform to be applied.
        use_metadata (bool): Whether to return metadata info (location).
        use_full_data (bool): Wheter to use full data or data with less than 25 percent zeros.
        no_data_replace (Optional[float]): Value to replace NaNs in images.
        no_label_replace (Optional[int]): Value to replace NaNs in labels.
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])

    self.data_root = Path(data_root)

    # Read the CSV file to get the list of cases to include
    csv_file_key = "full" if use_full_data else "limited"
    csv_path = self.data_root / self.CSV_FILES[csv_file_key]
    df = pd.read_csv(csv_path)
    casenames = df["Case_Name"].tolist()

    split_file = self.data_root / f"{split}.txt"
    with open(split_file) as f:
        split_images = [line.strip() for line in f.readlines()]

    split_images = [img for img in split_images if self._extract_casename(img) in casenames]

    # Build the samples list
    self.samples = []
    for image_filename in split_images:
        image_files = []
        for time_step in self.time_steps:
            image_file = self.data_root / time_step / image_filename
            image_files.append(str(image_file))
        mask_filename = image_filename.replace("HLS_", "BS_")
        mask_file = self.data_root / "pre" / mask_filename
        self.samples.append({
            "image_files": image_files,
            "mask_file": str(mask_file),
            "casename": self._extract_casename(image_filename),
        })

    self.use_metadata = use_metadata
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by __getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Any

    A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/burn_intensity.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Any:
    """Plot a sample from the dataset.

    Args:
        sample: A sample returned by `__getitem__`.
        suptitle: Optional string to use as a suptitle.

    Returns:
        A matplotlib Figure with the rendered sample.
    """
    num_images = len(self.time_steps) + 2
    if "prediction" in sample:
        num_images += 1

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    images = sample["image"]  # (C, T, H, W)
    mask = sample["mask"].numpy()
    num_classes = len(np.unique(mask))

    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 5, 5))

    for i in range(len(self.time_steps)):
        image = images[:, i, :, :]  # (C, H, W)
        image = np.transpose(image, (1, 2, 0))  # (H, W, C)
        rgb_image = image[..., rgb_indices]
        rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
        rgb_image = np.clip(rgb_image, 0, 1)
        ax[i].imshow(rgb_image)
        ax[i].axis("off")
        ax[i].set_title(f"{self.time_steps[i].capitalize()} Image")

    cmap = plt.get_cmap("jet", num_classes)
    norm = Normalize(vmin=0, vmax=num_classes - 1)

    mask_ax_index = len(self.time_steps)
    ax[mask_ax_index].imshow(mask, cmap=cmap, norm=norm)
    ax[mask_ax_index].axis("off")
    ax[mask_ax_index].set_title("Ground Truth Mask")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        pred_ax_index = mask_ax_index + 1
        ax[pred_ax_index].imshow(prediction, cmap=cmap, norm=norm)
        ax[pred_ax_index].axis("off")
        ax[pred_ax_index].set_title("Predicted Mask")

    legend_ax_index = -1
    class_names = sample.get("class_names", self.class_names)
    positions = np.linspace(0, 1, num_classes) if num_classes > 1 else [0.5]

    legend_handles = [
        mpatches.Patch(color=cmap(pos), label=class_names[i])
        for i, pos in enumerate(positions)
    ]
    ax[legend_ax_index].legend(handles=legend_handles, loc="center")
    ax[legend_ax_index].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    plt.tight_layout()
    return fig

terratorch.datasets.carbonflux

CarbonFluxNonGeo

Bases: NonGeoDataset

Dataset for Carbon Flux regression from HLS images and MERRA data.

Source code in terratorch/datasets/carbonflux.py
class CarbonFluxNonGeo(NonGeoDataset):
    """Dataset for Carbon Flux regression from HLS images and MERRA data."""

    all_band_names = (
        "BLUE", "GREEN", "RED", "NIR", "SWIR_1", "SWIR_2",
    )

    rgb_bands = (
        "RED", "GREEN", "BLUE",
    )

    merra_var_names = (
        "T2MIN", "T2MAX", "T2MEAN", "TSMDEWMEAN", "GWETROOT",
        "LHLAND", "SHLAND", "SWLAND", "PARDFLAND", "PRECTOTLAND"
    )

    splits = {"train": "train", "test": "test"}

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    metadata_file = "data_train_hls_37sites_v0_1.csv"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        gpp_mean: float | None = None,
        gpp_std: float | None = None,
        no_data_replace: float | None = 0.0001,
        use_metadata: bool = False,
        modalities: Sequence[str] = ("image", "merra_vars")
    ) -> None:
        """Initialize the CarbonFluxNonGeo dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): 'train' or 'test'.
            bands (Sequence[str]): Bands to use. Defaults to all bands.
            transform (Optional[A.Compose]): Albumentations transform to be applied.
            use_metadata (bool): Whether to return metadata (coordinates and date).
            merra_means (Sequence[float]): Means for MERRA data normalization.
            merra_stds (Sequence[float]): Standard deviations for MERRA data normalization.
            gpp_mean (float): Mean for GPP normalization.
            gpp_std (float): Standard deviation for GPP normalization.
            no_data_replace (Optional[float]): Value to replace NO_DATA values in images.
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)

        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(band) for band in bands]

        self.data_root = Path(data_root)

        # Load the CSV file with metadata
        csv_file = self.data_root / self.metadata_file
        df = pd.read_csv(csv_file)

        # Get list of image filenames in the split directory
        image_dir = self.data_root / self.split
        image_files = [f.name for f in image_dir.glob("*.tiff")]

        df["Chip"] = df["Chip"].str.replace(".tif$", ".tiff", regex=True)
        # Filter the DataFrame to include only rows with 'Chip' in image_files
        df = df[df["Chip"].isin(image_files)]

        # Build the samples list
        self.samples = []
        for _, row in df.iterrows():
            image_filename = row["Chip"]
            image_path = image_dir / image_filename
            # MERRA vectors
            merra_vars = row[list(self.merra_var_names)].values.astype(np.float32)
            # GPP target
            gpp = row["GPP"]

            image_path = image_dir / row["Chip"]
            merra_vars = row[list(self.merra_var_names)].values.astype(np.float32)
            gpp = row["GPP"]
            self.samples.append({
                "image_path": str(image_path),
                "merra_vars": merra_vars,
                "gpp": gpp,
            })

        if gpp_mean is None or gpp_std is None:
            msg = "Mean and standard deviation for GPP must be provided."
            raise ValueError(msg)
        self.gpp_mean = gpp_mean
        self.gpp_std = gpp_std

        self.use_metadata = use_metadata
        self.modalities = modalities
        self.no_data_replace = no_data_replace

        if transform is None:
            self.transform = MultimodalToTensor(self.modalities)
        else:
            transform = {m: transform[m] if m in transform else default_transform
                for m in self.modalities}
            self.transform = MultimodalTransforms(transform, shared=False)

    def __len__(self) -> int:
        return len(self.samples)

    def _load_file(self, path: str, nan_replace: float | int | None = None):
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data

    def _get_coords(self, image) -> torch.Tensor:
        """Extract the center coordinates from the image geospatial metadata."""
        pixel_scale = image.rio.resolution()
        width, height = image.rio.width, image.rio.height

        left, bottom, right, top = image.rio.bounds()
        tie_point_x, tie_point_y = left, top

        center_col = width / 2
        center_row = height / 2

        center_lon = tie_point_x + (center_col * pixel_scale[0])
        center_lat = tie_point_y - (center_row * pixel_scale[1])

        src_crs = image.rio.crs
        dst_crs = "EPSG:4326"

        transformer = pyproj.Transformer.from_crs(src_crs, dst_crs, always_xy=True)
        lon, lat = transformer.transform(center_lon, center_lat)

        coords = np.array([lat, lon], dtype=np.float32)
        return torch.from_numpy(coords)

    def _get_date(self, filename: str) -> torch.Tensor:
        """Extract the date from the filename."""
        base_filename = os.path.basename(filename)
        pattern = r"HLS\..{3}\.[A-Z0-9]{6}\.(?P<date>\d{7}T\d{6})\..*\.tiff$"
        match = re.match(pattern, base_filename)
        if not match:
            msg = f"Filename {filename} does not match expected pattern."
            raise ValueError(msg)

        date_str = match.group("date")
        year = int(date_str[:4])
        julian_day = int(date_str[4:7])

        date_tensor = torch.tensor([year, julian_day], dtype=torch.int32)
        return date_tensor

    def __getitem__(self, idx: int) -> dict[str, Any]:
        sample = self.samples[idx]
        image_path = sample["image_path"]

        image = self._load_file(image_path, nan_replace=self.no_data_replace)

        if self.use_metadata:
            location_coords = self._get_coords(image)
            temporal_coords = self._get_date(os.path.basename(image_path))

        image = image.to_numpy()  # (C, H, W)
        image = image[self.band_indices, ...]
        image = np.moveaxis(image, 0, -1) # (H, W, C)

        merra_vars = np.array(sample["merra_vars"])
        target = np.array(sample["gpp"])
        target_norm = (target - self.gpp_mean) / self.gpp_std
        target_norm = torch.tensor(target_norm, dtype=torch.float32)
        output = {
            "image": image.astype(np.float32),
            "merra_vars": merra_vars,
        }

        if self.transform:
            output = self.transform(output)

        output = {
            "image": {m: output[m] for m in self.modalities if m in output},
            "mask": target_norm
        }
        if self.use_metadata:
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        return output

    def plot(self, sample: dict[str, Any], suptitle: str | None = None) -> Any:
        """Plot a sample from the dataset.

        Args:
            sample: A sample returned by `__getitem__`.
            suptitle: Optional title for the figure.

        Returns:
            A matplotlib figure with the rendered sample.
        """
        image = sample["image"].numpy()

        image = np.transpose(image, (1, 2, 0))  # (H, W, C)

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        rgb_image = image[..., rgb_indices]

        rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
        rgb_image = np.clip(rgb_image, 0, 1)

        fig, ax = plt.subplots(1, 1, figsize=(6, 6))
        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title("Image")

        if suptitle:
            plt.suptitle(suptitle)

        plt.tight_layout()
        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, gpp_mean=None, gpp_std=None, no_data_replace=0.0001, use_metadata=False, modalities=('image', 'merra_vars'))

Initialize the CarbonFluxNonGeo dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    'train' or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to use. Defaults to all bands.

  • transform (Optional[Compose], default: None ) –

    Albumentations transform to be applied.

  • use_metadata (bool, default: False ) –

    Whether to return metadata (coordinates and date).

  • merra_means (Sequence[float]) –

    Means for MERRA data normalization.

  • merra_stds (Sequence[float]) –

    Standard deviations for MERRA data normalization.

  • gpp_mean (float, default: None ) –

    Mean for GPP normalization.

  • gpp_std (float, default: None ) –

    Standard deviation for GPP normalization.

  • no_data_replace (Optional[float], default: 0.0001 ) –

    Value to replace NO_DATA values in images.

Source code in terratorch/datasets/carbonflux.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    gpp_mean: float | None = None,
    gpp_std: float | None = None,
    no_data_replace: float | None = 0.0001,
    use_metadata: bool = False,
    modalities: Sequence[str] = ("image", "merra_vars")
) -> None:
    """Initialize the CarbonFluxNonGeo dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): 'train' or 'test'.
        bands (Sequence[str]): Bands to use. Defaults to all bands.
        transform (Optional[A.Compose]): Albumentations transform to be applied.
        use_metadata (bool): Whether to return metadata (coordinates and date).
        merra_means (Sequence[float]): Means for MERRA data normalization.
        merra_stds (Sequence[float]): Standard deviations for MERRA data normalization.
        gpp_mean (float): Mean for GPP normalization.
        gpp_std (float): Standard deviation for GPP normalization.
        no_data_replace (Optional[float]): Value to replace NO_DATA values in images.
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)

    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(band) for band in bands]

    self.data_root = Path(data_root)

    # Load the CSV file with metadata
    csv_file = self.data_root / self.metadata_file
    df = pd.read_csv(csv_file)

    # Get list of image filenames in the split directory
    image_dir = self.data_root / self.split
    image_files = [f.name for f in image_dir.glob("*.tiff")]

    df["Chip"] = df["Chip"].str.replace(".tif$", ".tiff", regex=True)
    # Filter the DataFrame to include only rows with 'Chip' in image_files
    df = df[df["Chip"].isin(image_files)]

    # Build the samples list
    self.samples = []
    for _, row in df.iterrows():
        image_filename = row["Chip"]
        image_path = image_dir / image_filename
        # MERRA vectors
        merra_vars = row[list(self.merra_var_names)].values.astype(np.float32)
        # GPP target
        gpp = row["GPP"]

        image_path = image_dir / row["Chip"]
        merra_vars = row[list(self.merra_var_names)].values.astype(np.float32)
        gpp = row["GPP"]
        self.samples.append({
            "image_path": str(image_path),
            "merra_vars": merra_vars,
            "gpp": gpp,
        })

    if gpp_mean is None or gpp_std is None:
        msg = "Mean and standard deviation for GPP must be provided."
        raise ValueError(msg)
    self.gpp_mean = gpp_mean
    self.gpp_std = gpp_std

    self.use_metadata = use_metadata
    self.modalities = modalities
    self.no_data_replace = no_data_replace

    if transform is None:
        self.transform = MultimodalToTensor(self.modalities)
    else:
        transform = {m: transform[m] if m in transform else default_transform
            for m in self.modalities}
        self.transform = MultimodalTransforms(transform, shared=False)
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Any]) –

    A sample returned by __getitem__.

  • suptitle (str | None, default: None ) –

    Optional title for the figure.

Returns:
  • Any

    A matplotlib figure with the rendered sample.

Source code in terratorch/datasets/carbonflux.py
def plot(self, sample: dict[str, Any], suptitle: str | None = None) -> Any:
    """Plot a sample from the dataset.

    Args:
        sample: A sample returned by `__getitem__`.
        suptitle: Optional title for the figure.

    Returns:
        A matplotlib figure with the rendered sample.
    """
    image = sample["image"].numpy()

    image = np.transpose(image, (1, 2, 0))  # (H, W, C)

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    rgb_image = image[..., rgb_indices]

    rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
    rgb_image = np.clip(rgb_image, 0, 1)

    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title("Image")

    if suptitle:
        plt.suptitle(suptitle)

    plt.tight_layout()
    return fig

terratorch.datasets.forestnet

ForestNetNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for ForestNet.

Source code in terratorch/datasets/forestnet.py
class ForestNetNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for ForestNet."""


    all_band_names = (
        "RED", "GREEN", "BLUE", "NIR", "SWIR_1", "SWIR_2"
    )

    rgb_bands = (
        "RED", "GREEN", "BLUE",
    )

    splits = ("train", "test", "val")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    default_label_map = {  # noqa: RUF012
        "Plantation": 0,
        "Smallholder agriculture": 1,
        "Grassland shrubland": 2,
        "Other": 3,
    }

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        label_map: dict[str, int] = default_label_map,
        transform: A.Compose | None = None,
        fraction: float = 1.0,
        bands: Sequence[str] = BAND_SETS["all"],
        use_metadata: bool = False,
    ) -> None:
        """
        Initialize the ForestNetNonGeo dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            label_map (Dict[str, int]): Mapping from label names to integer labels.
            transform: Transformations to be applied to the images.
            fraction (float): Fraction of the dataset to use. Defaults to 1.0 (use all data).
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits)}."
            raise ValueError(msg)
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(b) for b in bands]

        self.use_metadata = use_metadata

        self.data_root = Path(data_root)
        self.label_map = label_map

        # Load the CSV file corresponding to the split
        csv_file = self.data_root / f"{split}_filtered.csv"
        original_df = pd.read_csv(csv_file)

        # Apply stratified sampling if fraction < 1.0
        if fraction < 1.0:
            sss = StratifiedShuffleSplit(n_splits=1, test_size=1 - fraction, random_state=47)
            stratified_indices, _ = next(sss.split(original_df, original_df["merged_label"]))
            self.dataset = original_df.iloc[stratified_indices].reset_index(drop=True)
        else:
            self.dataset = original_df

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.dataset)

    def _get_coords(self, event_path: Path) -> torch.Tensor:
        auxiliary_path = event_path / "auxiliary"
        osm_json_path = auxiliary_path / "osm.json"

        with open(osm_json_path) as f:
            osm_data = json.load(f)
            lat = float(osm_data["closest_city"]["lat"])
            lon = float(osm_data["closest_city"]["lon"])
            lat_lon = np.asarray([lat, lon])

        return torch.tensor(lat_lon, dtype=torch.float32)

    def _get_dates(self, image_files: list) -> list:
        dates = []
        pattern = re.compile(r"(\d{4})_(\d{2})_(\d{2})_cloud_\d+\.(png|npy)")
        for img_path in image_files:
            match = pattern.search(img_path)
            year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
            date_obj = datetime.datetime(year, month, day)  # noqa: DTZ001
            julian_day = date_obj.timetuple().tm_yday
            date_tensor = torch.tensor([year, julian_day], dtype=torch.int32)
            dates.append(date_tensor)
        return torch.stack(dates, dim=0)

    def __getitem__(self, index: int):
        path = self.data_root / self.dataset["example_path"][index]
        label = self.map_label(index)

        visible_images, infrared_images, temporal_coords = self._load_images(path)

        visible_images = np.stack(visible_images, axis=0)
        infrared_images = np.stack(infrared_images, axis=0)
        merged_images = np.concatenate([visible_images, infrared_images], axis=-1)
        merged_images = merged_images[..., self.band_indices] # (T, H, W, 2C)
        output = {
            "image": merged_images.astype(np.float32)
        }

        if self.transform:
            output = self.transform(**output)

        if self.use_metadata:
            location_coords = self._get_coords(path)
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        output["label"] = label

        return output

    def _load_images(self, path: str):
        """Load visible and infrared images from the given event path"""
        visible_image_files = glob.glob(os.path.join(path, "images/visible/*_cloud_*.png"))
        infra_image_files = glob.glob(os.path.join(path, "images/infrared/*_cloud_*.npy"))

        selected_visible_images = self.select_images(visible_image_files)
        selected_infra_images = self.select_images(infra_image_files)

        dates = None
        if self.use_metadata:
            dates = self._get_dates(selected_visible_images)

        vis_images = [np.array(Image.open(img)) for img in selected_visible_images] # (T, H, W, C)
        inf_images = [np.load(img, allow_pickle=True) for img in selected_infra_images] # (T, H, W, C)
        return vis_images, inf_images, dates

    def least_cloudy_image(self, image_files):
        pattern = re.compile(r"(\d{4})_\d{2}_\d{2}_cloud_(\d+)\.(png|npy)")
        lowest_cloud_images = defaultdict(lambda: {"path": None, "cloud_value": float("inf")})

        for path in image_files:
            match = pattern.search(path)
            if match:
                year, cloud_value = match.group(1), int(match.group(2))
                if cloud_value < lowest_cloud_images[year]["cloud_value"]:
                    lowest_cloud_images[year] = {"path": path, "cloud_value": cloud_value}

        return [info["path"] for info in lowest_cloud_images.values()]

    def match_timesteps(self, image_files, selected_images):
        if len(selected_images) < 3:
            extra_imgs = [img for img in image_files if img not in selected_images]
            selected_images += extra_imgs[:3 - len(selected_images)]

        while len(selected_images) < 3:
            selected_images.append(selected_images[-1])
        return selected_images[:3]

    def select_images(self, image_files):
        selected = self.least_cloudy_image(image_files)
        return self.match_timesteps(image_files, selected)

    def map_label(self, index: int) -> torch.Tensor:
        """Map the label name to an integer label."""
        label_name = self.dataset["merged_label"][index]
        label = self.label_map[label_name]
        return label

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None):

        num_images = sample["image"].shape[1] + 1

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        fig, ax = plt.subplots(1, num_images, figsize=(15, 5))

        for i in range(sample["image"].shape[1]):
            image = sample["image"][:, i, :, :]
            if torch.is_tensor(image):
                image = image.permute(1, 2, 0).numpy()
            rgb_image = image[..., rgb_indices]
            rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
            rgb_image = np.clip(rgb_image, 0, 1)
            ax[i].imshow(rgb_image)
            ax[i].axis("off")
            ax[i].set_title(f"Timestep {i + 1}")

        legend_handles = [Rectangle((0, 0), 1, 1, color="blue")]
        legend_label = [self.label_map.get(sample["label"], "Unknown Label")]
        ax[-1].legend(legend_handles, legend_label, loc="center")
        ax[-1].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        plt.tight_layout()
        return fig
__init__(data_root, split='train', label_map=default_label_map, transform=None, fraction=1.0, bands=BAND_SETS['all'], use_metadata=False)

Initialize the ForestNetNonGeo dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • label_map (Dict[str, int], default: default_label_map ) –

    Mapping from label names to integer labels.

  • transform (Compose | None, default: None ) –

    Transformations to be applied to the images.

  • fraction (float, default: 1.0 ) –

    Fraction of the dataset to use. Defaults to 1.0 (use all data).

Source code in terratorch/datasets/forestnet.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    label_map: dict[str, int] = default_label_map,
    transform: A.Compose | None = None,
    fraction: float = 1.0,
    bands: Sequence[str] = BAND_SETS["all"],
    use_metadata: bool = False,
) -> None:
    """
    Initialize the ForestNetNonGeo dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        label_map (Dict[str, int]): Mapping from label names to integer labels.
        transform: Transformations to be applied to the images.
        fraction (float): Fraction of the dataset to use. Defaults to 1.0 (use all data).
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits)}."
        raise ValueError(msg)
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(b) for b in bands]

    self.use_metadata = use_metadata

    self.data_root = Path(data_root)
    self.label_map = label_map

    # Load the CSV file corresponding to the split
    csv_file = self.data_root / f"{split}_filtered.csv"
    original_df = pd.read_csv(csv_file)

    # Apply stratified sampling if fraction < 1.0
    if fraction < 1.0:
        sss = StratifiedShuffleSplit(n_splits=1, test_size=1 - fraction, random_state=47)
        stratified_indices, _ = next(sss.split(original_df, original_df["merged_label"]))
        self.dataset = original_df.iloc[stratified_indices].reset_index(drop=True)
    else:
        self.dataset = original_df

    self.transform = transform if transform else default_transform
map_label(index)

Map the label name to an integer label.

Source code in terratorch/datasets/forestnet.py
def map_label(self, index: int) -> torch.Tensor:
    """Map the label name to an integer label."""
    label_name = self.dataset["merged_label"][index]
    label = self.label_map[label_name]
    return label

Datamodules

terratorch.datamodules.fire_scars

FireScarsDataModule

Bases: GeoDataModule

Geo Fire Scars data module implementation that merges input data with ground truth segmentation masks.

Source code in terratorch/datamodules/fire_scars.py
class FireScarsDataModule(GeoDataModule):
    """Geo Fire Scars data module implementation that merges input data with ground truth segmentation masks."""

    def __init__(self, data_root: str, **kwargs: Any) -> None:
        super().__init__(FireScarsSegmentationMask, 4, 224, 100, 0, **kwargs)
        means = list(MEANS.values())
        stds = list(STDS.values())
        self.train_aug = AugmentationSequential(K.RandomCrop(224, 224), K.Normalize(means, stds))
        self.aug = AugmentationSequential(K.Normalize(means, stds))
        self.data_root = data_root

    def setup(self, stage: str) -> None:
        self.images = FireScarsHLS(
            os.path.join(self.data_root, "training/")
        )
        self.labels = FireScarsSegmentationMask(
            os.path.join(self.data_root, "training/")
        )
        self.dataset = self.images & self.labels
        self.train_aug = AugmentationSequential(K.RandomCrop(224, 224), K.normalize())

        self.images_test = FireScarsHLS(
            os.path.join(self.data_root, "validation/")
        )
        self.labels_test = FireScarsSegmentationMask(
            os.path.join(self.data_root, "validation/")
        )
        self.val_dataset = self.images_test & self.labels_test

        if stage in ["fit"]:
            self.train_batch_sampler = RandomBatchGeoSampler(self.dataset, self.patch_size, self.batch_size, None)
        if stage in ["fit", "validate"]:
            self.val_sampler = GridGeoSampler(self.val_dataset, self.patch_size, self.patch_size)
        if stage in ["test"]:
            self.test_sampler = GridGeoSampler(self.val_dataset, self.patch_size, self.patch_size)
FireScarsNonGeoDataModule

Bases: NonGeoDataModule

NonGeo datamodule implementation for Fire Scars

Source code in terratorch/datamodules/fire_scars.py
class FireScarsNonGeoDataModule(NonGeoDataModule):
    """NonGeo datamodule implementation for Fire Scars"""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = FireScarsNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        drop_last: bool = True,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(FireScarsNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        means = [MEANS[b] for b in bands]
        stds = [STDS[b] for b in bands]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=["image"])
        self.drop_last = drop_last
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )

terratorch.datamodules.biomassters

BioMasstersNonGeoDataModule

Bases: NonGeoDataModule

NonGeo datamodule implementation for BioMassters.

Source code in terratorch/datamodules/biomassters.py
class BioMasstersNonGeoDataModule(NonGeoDataModule):
    """NonGeo datamodule implementation for BioMassters."""

    default_metadata_filename = "The_BioMassters_-_features_metadata.csv.csv"

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: dict[str, Sequence[str]] | Sequence[str] = BioMasstersNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        drop_last: bool = True,
        sensors: Sequence[str] = ["S1", "S2"],
        as_time_series: bool = False,
        metadata_filename: str = default_metadata_filename,
        max_cloud_percentage: float | None = None,
        max_red_mean: float | None = None,
        include_corrupt: bool = True,
        subset: float = 1,
        seed: int = 42,
        use_four_frames: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(BioMasstersNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root
        self.sensors = sensors
        if isinstance(bands, dict):
            self.bands = bands
        else:
            sens = sensors[0]
            self.bands = {sens: bands}

        self.means = {}
        self.stds = {}
        for sensor in self.sensors:
            self.means[sensor] = [MEANS[sensor][band] for band in self.bands[sensor]]
            self.stds[sensor] = [STDS[sensor][band] for band in self.bands[sensor]]

        self.mask_mean = MEANS["AGBM"]
        self.mask_std = STDS["AGBM"]
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        if len(sensors) == 1:
            self.aug = Normalize(self.means[sensors[0]], self.stds[sensors[0]]) if aug is None else aug
        else:
            MultimodalNormalize(self.means, self.stds) if aug is None else aug
        self.drop_last = drop_last
        self.as_time_series = as_time_series
        self.metadata_filename = metadata_filename
        self.max_cloud_percentage = max_cloud_percentage
        self.max_red_mean = max_red_mean
        self.include_corrupt = include_corrupt
        self.subset = subset
        self.seed = seed
        self.use_four_frames = use_four_frames

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                mask_mean=self.mask_mean,
                mask_std=self.mask_std,
                sensors=self.sensors,
                as_time_series=self.as_time_series,
                metadata_filename=self.metadata_filename,
                max_cloud_percentage=self.max_cloud_percentage,
                max_red_mean=self.max_red_mean,
                include_corrupt=self.include_corrupt,
                subset=self.subset,
                seed=self.seed,
                use_four_frames=self.use_four_frames,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="test",
                root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                mask_mean=self.mask_mean,
                mask_std=self.mask_std,
                sensors=self.sensors,
                as_time_series=self.as_time_series,
                metadata_filename=self.metadata_filename,
                max_cloud_percentage=self.max_cloud_percentage,
                max_red_mean=self.max_red_mean,
                include_corrupt=self.include_corrupt,
                subset=self.subset,
                seed=self.seed,
                use_four_frames=self.use_four_frames,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",
                root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                mask_mean=self.mask_mean,
                mask_std=self.mask_std,
                sensors=self.sensors,
                as_time_series=self.as_time_series,
                metadata_filename=self.metadata_filename,
                max_cloud_percentage=self.max_cloud_percentage,
                max_red_mean=self.max_red_mean,
                include_corrupt=self.include_corrupt,
                subset=self.subset,
                seed=self.seed,
                use_four_frames=self.use_four_frames,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",
                root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                mask_mean=self.mask_mean,
                mask_std=self.mask_std,
                sensors=self.sensors,
                as_time_series=self.as_time_series,
                metadata_filename=self.metadata_filename,
                max_cloud_percentage=self.max_cloud_percentage,
                max_red_mean=self.max_red_mean,
                include_corrupt=self.include_corrupt,
                subset=self.subset,
                seed=self.seed,
                use_four_frames=self.use_four_frames,
            )

    def _dataloader_factory(self, split: str):
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split =="train" and self.drop_last,
        )

terratorch.datamodules.burn_intensity

BurnIntensityNonGeoDataModule

Bases: NonGeoDataModule

NonGeo datamodule implementation for BurnIntensity.

Source code in terratorch/datamodules/burn_intensity.py
class BurnIntensityNonGeoDataModule(NonGeoDataModule):
    """NonGeo datamodule implementation for BurnIntensity."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = BurnIntensityNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        use_full_data: bool = True,
        no_data_replace: float | None = 0.0001,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(BurnIntensityNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        means = [MEANS[b] for b in bands]
        stds = [STDS[b] for b in bands]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = NormalizeWithTimesteps(means, stds)
        self.use_full_data = use_full_data
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                use_full_data=self.use_full_data,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                use_full_data=self.use_full_data,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                use_full_data=self.use_full_data,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                use_full_data=self.use_full_data,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )

terratorch.datamodules.carbonflux

CarbonFluxNonGeoDataModule

Bases: NonGeoDataModule

NonGeo datamodule implementation for Landslide4Sense.

Source code in terratorch/datamodules/carbonflux.py
class CarbonFluxNonGeoDataModule(NonGeoDataModule):
    """NonGeo datamodule implementation for Landslide4Sense."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = CarbonFluxNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        no_data_replace: float | None = 0.0001,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(CarbonFluxNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        means = {
            m: ([MEANS[m][band] for band in bands] if m == "image" else MEANS[m])
            for m in MEANS.keys()
        }
        stds = {
            m: ([STDS[m][band] for band in bands] if m == "image" else STDS[m])
            for m in STDS.keys()
        }
        self.mask_means = MEANS["mask"]
        self.mask_std = STDS["mask"]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = MultimodalNormalize(means, stds) if aug is None else aug
        self.no_data_replace = no_data_replace
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                gpp_mean=self.mask_means,
                gpp_std=self.mask_std,
                no_data_replace=self.no_data_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                gpp_mean=self.mask_means,
                gpp_std=self.mask_std,
                no_data_replace=self.no_data_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                gpp_mean=self.mask_means,
                gpp_std=self.mask_std,
                no_data_replace=self.no_data_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                gpp_mean=self.mask_means,
                gpp_std=self.mask_std,
                no_data_replace=self.no_data_replace,
                use_metadata=self.use_metadata,
            )

terratorch.datamodules.forestnet

ForestNetNonGeoDataModule

Bases: NonGeoDataModule

NonGeo datamodule implementation for Landslide4Sense.

Source code in terratorch/datamodules/forestnet.py
class ForestNetNonGeoDataModule(NonGeoDataModule):
    """NonGeo datamodule implementation for Landslide4Sense."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        label_map: dict[str, int] = ForestNetNonGeo.default_label_map,
        bands: Sequence[str] = ForestNetNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        fraction: float = 1.0,
        aug: AugmentationSequential = None,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(ForestNetNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        self.means = [MEANS[b] for b in bands]
        self.stds = [STDS[b] for b in bands]
        self.label_map = label_map
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = Normalize(self.means, self.stds) if aug is None else aug
        self.fraction = fraction
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                label_map=self.label_map,
                transform=self.train_transform,
                bands=self.bands,
                fraction=self.fraction,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                label_map=self.label_map,
                transform=self.val_transform,
                bands=self.bands,
                fraction=self.fraction,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                label_map=self.label_map,
                transform=self.test_transform,
                bands=self.bands,
                fraction=self.fraction,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                label_map=self.label_map,
                transform=self.predict_transform,
                bands=self.bands,
                fraction=self.fraction,
                use_metadata=self.use_metadata,
            )