Skip to content

Specific Datamodules#

terratorch.datamodules.torchgeo_data_module #

Ugly proxy objects so parsing config file works with transforms.

These are necessary since, for LightningCLI to instantiate arguments as objects from the config, they must have type annotations

In TorchGeo, transforms is passed in **kwargs, so it has no type annotations! To get around that, we create these wrappers that have transforms type annotated. They create the transforms and forward all method and attribute calls to the original TorchGeo datamodule.

Additionally, TorchGeo datasets pass the data to the transforms callable as a dict, and as a tensor.

Albumentations expects this data not as a dict but as different key-value arguments, and as numpy. We handle that conversion here.

TorchGeoDataModule #

Bases: GeoDataModule

Proxy object for using Geo data modules defined by TorchGeo.

Allows for transforms to be defined and passed using config files. The only reason this class exists is so that we can annotate the transforms argument with a type. This is required for lightningcli and config files. As such, all getattr and setattr will be redirected to the underlying class.

Source code in terratorch/datamodules/torchgeo_data_module.py
class TorchGeoDataModule(GeoDataModule):
    """Proxy object for using Geo data modules defined by TorchGeo.

    Allows for transforms to be defined and passed using config files.
    The only reason this class exists is so that we can annotate the transforms argument with a type.
    This is required for lightningcli and config files.
    As such, all getattr and setattr will be redirected to the underlying class.
    """

    def __init__(
        self,
        cls: type[GeoDataModule],
        batch_size: int | None = None,
        num_workers: int = 0,
        transforms: None | list[BasicTransform] = None,
        **kwargs: Any,
    ):
        """Constructor

        Args:
            cls (type[GeoDataModule]): TorchGeo DataModule class to be instantiated
            batch_size (int | None, optional): batch_size. Defaults to None.
            num_workers (int, optional): num_workers. Defaults to 0.
            transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
                Should enc with ToTensorV2. Defaults to None.
            **kwargs (Any): Arguments passed to instantiate `cls`.
        """
        if batch_size is not None:
            kwargs["batch_size"] = batch_size
        if transforms is not None:
            transforms_as_callable = albumentations_to_callable_with_dict(transforms)
            kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
        # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
        self._proxy = cls(num_workers=num_workers, **kwargs)
        super().__init__(self._proxy.dataset_class)  # dummy arg

    @property
    def collate_fn(self):
        return self._proxy.collate_fn

    @collate_fn.setter
    def collate_fn(self, value):
        self._proxy.collate_fn = value

    @property
    def patch_size(self):
        return self._proxy.patch_size

    @property
    def length(self):
        return self._proxy.length

    def setup(self, stage: str):
        return self._proxy.setup(stage)

    def train_dataloader(self):
        return self._proxy.train_dataloader()

    def val_dataloader(self):
        return self._proxy.val_dataloader()

    def test_dataloader(self):
        return self._proxy.test_dataloader()

    def predict_dataloader(self):
        return self._proxy.predict_dataloader()

    def transfer_batch_to_device(self, batch, device, dataloader_idx):
        return self._proxy.predict_dataloader(batch, device, dataloader_idx)
__init__(cls, batch_size=None, num_workers=0, transforms=None, **kwargs) #

Constructor

Parameters:

Name Type Description Default
cls type[GeoDataModule]

TorchGeo DataModule class to be instantiated

required
batch_size int | None

batch_size. Defaults to None.

None
num_workers int

num_workers. Defaults to 0.

0
transforms None | list[BasicTransform]

List of Albumentations Transforms. Should enc with ToTensorV2. Defaults to None.

None
**kwargs Any

Arguments passed to instantiate cls.

{}
Source code in terratorch/datamodules/torchgeo_data_module.py
def __init__(
    self,
    cls: type[GeoDataModule],
    batch_size: int | None = None,
    num_workers: int = 0,
    transforms: None | list[BasicTransform] = None,
    **kwargs: Any,
):
    """Constructor

    Args:
        cls (type[GeoDataModule]): TorchGeo DataModule class to be instantiated
        batch_size (int | None, optional): batch_size. Defaults to None.
        num_workers (int, optional): num_workers. Defaults to 0.
        transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
            Should enc with ToTensorV2. Defaults to None.
        **kwargs (Any): Arguments passed to instantiate `cls`.
    """
    if batch_size is not None:
        kwargs["batch_size"] = batch_size
    if transforms is not None:
        transforms_as_callable = albumentations_to_callable_with_dict(transforms)
        kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
    # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
    self._proxy = cls(num_workers=num_workers, **kwargs)
    super().__init__(self._proxy.dataset_class)  # dummy arg

TorchNonGeoDataModule #

Bases: NonGeoDataModule

Proxy object for using NonGeo data modules defined by TorchGeo.

Allows for transforms to be defined and passed using config files. The only reason this class exists is so that we can annotate the transforms argument with a type. This is required for lightningcli and config files. As such, all getattr and setattr will be redirected to the underlying class.

Source code in terratorch/datamodules/torchgeo_data_module.py
class TorchNonGeoDataModule(NonGeoDataModule):
    """Proxy object for using NonGeo data modules defined by TorchGeo.

    Allows for transforms to be defined and passed using config files.
    The only reason this class exists is so that we can annotate the transforms argument with a type.
    This is required for lightningcli and config files.
    As such, all getattr and setattr will be redirected to the underlying class.
    """

    def __init__(
        self,
        cls: type[NonGeoDataModule],
        batch_size: int | None = None,
        num_workers: int = 0,
        transforms: None | list[BasicTransform] = None,
        **kwargs: Any,
    ):
        """Constructor

        Args:
            cls (type[NonGeoDataModule]): TorchGeo DataModule class to be instantiated
            batch_size (int | None, optional): batch_size. Defaults to None.
            num_workers (int, optional): num_workers. Defaults to 0.
            transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
                Should enc with ToTensorV2. Defaults to None.
            **kwargs (Any): Arguments passed to instantiate `cls`.
        """
        if batch_size is not None:
            kwargs["batch_size"] = batch_size
        if transforms is not None:
            transforms_as_callable = albumentations_to_callable_with_dict(transforms)
            kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
        # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
        self._proxy = cls(num_workers=num_workers, **kwargs)
        super().__init__(self._proxy.dataset_class)  # dummy arg

    @property
    def collate_fn(self):
        return self._proxy.collate_fn

    @collate_fn.setter
    def collate_fn(self, value):
        self._proxy.collate_fn = value

    def setup(self, stage: str):
        return self._proxy.setup(stage)

    def train_dataloader(self):
        return self._proxy.train_dataloader()

    def val_dataloader(self):
        return self._proxy.val_dataloader()

    def test_dataloader(self):
        return self._proxy.test_dataloader()

    def predict_dataloader(self):
        return self._proxy.predict_dataloader()
__init__(cls, batch_size=None, num_workers=0, transforms=None, **kwargs) #

Constructor

Parameters:

Name Type Description Default
cls type[NonGeoDataModule]

TorchGeo DataModule class to be instantiated

required
batch_size int | None

batch_size. Defaults to None.

None
num_workers int

num_workers. Defaults to 0.

0
transforms None | list[BasicTransform]

List of Albumentations Transforms. Should enc with ToTensorV2. Defaults to None.

None
**kwargs Any

Arguments passed to instantiate cls.

{}
Source code in terratorch/datamodules/torchgeo_data_module.py
def __init__(
    self,
    cls: type[NonGeoDataModule],
    batch_size: int | None = None,
    num_workers: int = 0,
    transforms: None | list[BasicTransform] = None,
    **kwargs: Any,
):
    """Constructor

    Args:
        cls (type[NonGeoDataModule]): TorchGeo DataModule class to be instantiated
        batch_size (int | None, optional): batch_size. Defaults to None.
        num_workers (int, optional): num_workers. Defaults to 0.
        transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
            Should enc with ToTensorV2. Defaults to None.
        **kwargs (Any): Arguments passed to instantiate `cls`.
    """
    if batch_size is not None:
        kwargs["batch_size"] = batch_size
    if transforms is not None:
        transforms_as_callable = albumentations_to_callable_with_dict(transforms)
        kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
    # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
    self._proxy = cls(num_workers=num_workers, **kwargs)
    super().__init__(self._proxy.dataset_class)  # dummy arg

terratorch.datamodules.biomassters #

BioMasstersNonGeoDataModule #

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for BioMassters datamodule.

Source code in terratorch/datamodules/biomassters.py
class BioMasstersNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for BioMassters datamodule."""

    default_metadata_filename = "The_BioMassters_-_features_metadata.csv.csv"

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: dict[str, Sequence[str]] | Sequence[str] = BioMasstersNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        drop_last: bool = True,
        sensors: Sequence[str] = ["S1", "S2"],
        as_time_series: bool = False,
        metadata_filename: str = default_metadata_filename,
        max_cloud_percentage: float | None = None,
        max_red_mean: float | None = None,
        include_corrupt: bool = True,
        subset: float = 1,
        seed: int = 42,
        use_four_frames: bool = False,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the DataModule for the non-geospatial BioMassters datamodule.

        Args:
            data_root (str): Root directory containing the dataset.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (dict[str, Sequence[str]] | Sequence[str], optional): Band configuration; either a dict mapping sensors to bands or a list for the first sensor.
                Defaults to BioMasstersNonGeo.all_band_names
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            aug (AugmentationSequential, optional): Augmentation or normalization to apply. Defaults to normalization if not provided.
            drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
            sensors (Sequence[str], optional): List of sensors to use (e.g., ["S1", "S2"]). Defaults to ["S1", "S2"].
            as_time_series (bool, optional): Whether to treat data as a time series. Defaults to False.
            metadata_filename (str, optional): Metadata filename. Defaults to "The_BioMassters_-_features_metadata.csv.csv".
            max_cloud_percentage (float | None, optional): Maximum allowed cloud percentage. Defaults to None.
            max_red_mean (float | None, optional): Maximum allowed red band mean. Defaults to None.
            include_corrupt (bool, optional): Whether to include corrupt data. Defaults to True.
            subset (float, optional): Fraction of the dataset to use. Defaults to 1.
            seed (int, optional): Random seed for reproducibility. Defaults to 42.
            use_four_frames (bool, optional): Whether to use a four frames configuration. Defaults to False.
            **kwargs: Additional keyword arguments.

        Returns:
            None.
        """
        super().__init__(BioMasstersNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root
        self.sensors = sensors
        if isinstance(bands, dict):
            self.bands = bands
        else:
            sens = sensors[0]
            self.bands = {sens: bands}

        self.means = {}
        self.stds = {}
        for sensor in self.sensors:
            self.means[sensor] = [MEANS[sensor][band] for band in self.bands[sensor]]
            self.stds[sensor] = [STDS[sensor][band] for band in self.bands[sensor]]

        self.mask_mean = MEANS["AGBM"]
        self.mask_std = STDS["AGBM"]
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        if len(sensors) == 1:
            self.aug = Normalize(self.means[sensors[0]], self.stds[sensors[0]]) if aug is None else aug
        else:
            MultimodalNormalize(self.means, self.stds) if aug is None else aug
        self.drop_last = drop_last
        self.as_time_series = as_time_series
        self.metadata_filename = metadata_filename
        self.max_cloud_percentage = max_cloud_percentage
        self.max_red_mean = max_red_mean
        self.include_corrupt = include_corrupt
        self.subset = subset
        self.seed = seed
        self.use_four_frames = use_four_frames

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                mask_mean=self.mask_mean,
                mask_std=self.mask_std,
                sensors=self.sensors,
                as_time_series=self.as_time_series,
                metadata_filename=self.metadata_filename,
                max_cloud_percentage=self.max_cloud_percentage,
                max_red_mean=self.max_red_mean,
                include_corrupt=self.include_corrupt,
                subset=self.subset,
                seed=self.seed,
                use_four_frames=self.use_four_frames,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="test",
                root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                mask_mean=self.mask_mean,
                mask_std=self.mask_std,
                sensors=self.sensors,
                as_time_series=self.as_time_series,
                metadata_filename=self.metadata_filename,
                max_cloud_percentage=self.max_cloud_percentage,
                max_red_mean=self.max_red_mean,
                include_corrupt=self.include_corrupt,
                subset=self.subset,
                seed=self.seed,
                use_four_frames=self.use_four_frames,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",
                root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                mask_mean=self.mask_mean,
                mask_std=self.mask_std,
                sensors=self.sensors,
                as_time_series=self.as_time_series,
                metadata_filename=self.metadata_filename,
                max_cloud_percentage=self.max_cloud_percentage,
                max_red_mean=self.max_red_mean,
                include_corrupt=self.include_corrupt,
                subset=self.subset,
                seed=self.seed,
                use_four_frames=self.use_four_frames,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",
                root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                mask_mean=self.mask_mean,
                mask_std=self.mask_std,
                sensors=self.sensors,
                as_time_series=self.as_time_series,
                metadata_filename=self.metadata_filename,
                max_cloud_percentage=self.max_cloud_percentage,
                max_red_mean=self.max_red_mean,
                include_corrupt=self.include_corrupt,
                subset=self.subset,
                seed=self.seed,
                use_four_frames=self.use_four_frames,
            )

    def _dataloader_factory(self, split: str):
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split =="train" and self.drop_last,
        )
__init__(data_root, batch_size=4, num_workers=0, bands=BioMasstersNonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, aug=None, drop_last=True, sensors=['S1', 'S2'], as_time_series=False, metadata_filename=default_metadata_filename, max_cloud_percentage=None, max_red_mean=None, include_corrupt=True, subset=1, seed=42, use_four_frames=False, **kwargs) #

Initializes the DataModule for the non-geospatial BioMassters datamodule.

Parameters:

Name Type Description Default
data_root str

Root directory containing the dataset.

required
batch_size int

Batch size for DataLoaders. Defaults to 4.

4
num_workers int

Number of workers for data loading. Defaults to 0.

0
bands dict[str, Sequence[str]] | Sequence[str]

Band configuration; either a dict mapping sensors to bands or a list for the first sensor. Defaults to BioMasstersNonGeo.all_band_names

all_band_names
train_transform Compose | None | list[BasicTransform]

Transformations for training data.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation data.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing data.

None
predict_transform Compose | None | list[BasicTransform]

Transformations for prediction data.

None
aug AugmentationSequential

Augmentation or normalization to apply. Defaults to normalization if not provided.

None
drop_last bool

Whether to drop the last incomplete batch. Defaults to True.

True
sensors Sequence[str]

List of sensors to use (e.g., ["S1", "S2"]). Defaults to ["S1", "S2"].

['S1', 'S2']
as_time_series bool

Whether to treat data as a time series. Defaults to False.

False
metadata_filename str

Metadata filename. Defaults to "The_BioMassters_-_features_metadata.csv.csv".

default_metadata_filename
max_cloud_percentage float | None

Maximum allowed cloud percentage. Defaults to None.

None
max_red_mean float | None

Maximum allowed red band mean. Defaults to None.

None
include_corrupt bool

Whether to include corrupt data. Defaults to True.

True
subset float

Fraction of the dataset to use. Defaults to 1.

1
seed int

Random seed for reproducibility. Defaults to 42.

42
use_four_frames bool

Whether to use a four frames configuration. Defaults to False.

False
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
None

None.

Source code in terratorch/datamodules/biomassters.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: dict[str, Sequence[str]] | Sequence[str] = BioMasstersNonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    drop_last: bool = True,
    sensors: Sequence[str] = ["S1", "S2"],
    as_time_series: bool = False,
    metadata_filename: str = default_metadata_filename,
    max_cloud_percentage: float | None = None,
    max_red_mean: float | None = None,
    include_corrupt: bool = True,
    subset: float = 1,
    seed: int = 42,
    use_four_frames: bool = False,
    **kwargs: Any,
) -> None:
    """
    Initializes the DataModule for the non-geospatial BioMassters datamodule.

    Args:
        data_root (str): Root directory containing the dataset.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (dict[str, Sequence[str]] | Sequence[str], optional): Band configuration; either a dict mapping sensors to bands or a list for the first sensor.
            Defaults to BioMasstersNonGeo.all_band_names
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        aug (AugmentationSequential, optional): Augmentation or normalization to apply. Defaults to normalization if not provided.
        drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
        sensors (Sequence[str], optional): List of sensors to use (e.g., ["S1", "S2"]). Defaults to ["S1", "S2"].
        as_time_series (bool, optional): Whether to treat data as a time series. Defaults to False.
        metadata_filename (str, optional): Metadata filename. Defaults to "The_BioMassters_-_features_metadata.csv.csv".
        max_cloud_percentage (float | None, optional): Maximum allowed cloud percentage. Defaults to None.
        max_red_mean (float | None, optional): Maximum allowed red band mean. Defaults to None.
        include_corrupt (bool, optional): Whether to include corrupt data. Defaults to True.
        subset (float, optional): Fraction of the dataset to use. Defaults to 1.
        seed (int, optional): Random seed for reproducibility. Defaults to 42.
        use_four_frames (bool, optional): Whether to use a four frames configuration. Defaults to False.
        **kwargs: Additional keyword arguments.

    Returns:
        None.
    """
    super().__init__(BioMasstersNonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root
    self.sensors = sensors
    if isinstance(bands, dict):
        self.bands = bands
    else:
        sens = sensors[0]
        self.bands = {sens: bands}

    self.means = {}
    self.stds = {}
    for sensor in self.sensors:
        self.means[sensor] = [MEANS[sensor][band] for band in self.bands[sensor]]
        self.stds[sensor] = [STDS[sensor][band] for band in self.bands[sensor]]

    self.mask_mean = MEANS["AGBM"]
    self.mask_std = STDS["AGBM"]
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    if len(sensors) == 1:
        self.aug = Normalize(self.means[sensors[0]], self.stds[sensors[0]]) if aug is None else aug
    else:
        MultimodalNormalize(self.means, self.stds) if aug is None else aug
    self.drop_last = drop_last
    self.as_time_series = as_time_series
    self.metadata_filename = metadata_filename
    self.max_cloud_percentage = max_cloud_percentage
    self.max_red_mean = max_red_mean
    self.include_corrupt = include_corrupt
    self.subset = subset
    self.seed = seed
    self.use_four_frames = use_four_frames
setup(stage) #

Set up datasets.

Parameters:

Name Type Description Default
stage str

Either fit, validate, test, or predict.

required
Source code in terratorch/datamodules/biomassters.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            mask_mean=self.mask_mean,
            mask_std=self.mask_std,
            sensors=self.sensors,
            as_time_series=self.as_time_series,
            metadata_filename=self.metadata_filename,
            max_cloud_percentage=self.max_cloud_percentage,
            max_red_mean=self.max_red_mean,
            include_corrupt=self.include_corrupt,
            subset=self.subset,
            seed=self.seed,
            use_four_frames=self.use_four_frames,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="test",
            root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            mask_mean=self.mask_mean,
            mask_std=self.mask_std,
            sensors=self.sensors,
            as_time_series=self.as_time_series,
            metadata_filename=self.metadata_filename,
            max_cloud_percentage=self.max_cloud_percentage,
            max_red_mean=self.max_red_mean,
            include_corrupt=self.include_corrupt,
            subset=self.subset,
            seed=self.seed,
            use_four_frames=self.use_four_frames,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="test",
            root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            mask_mean=self.mask_mean,
            mask_std=self.mask_std,
            sensors=self.sensors,
            as_time_series=self.as_time_series,
            metadata_filename=self.metadata_filename,
            max_cloud_percentage=self.max_cloud_percentage,
            max_red_mean=self.max_red_mean,
            include_corrupt=self.include_corrupt,
            subset=self.subset,
            seed=self.seed,
            use_four_frames=self.use_four_frames,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="test",
            root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            mask_mean=self.mask_mean,
            mask_std=self.mask_std,
            sensors=self.sensors,
            as_time_series=self.as_time_series,
            metadata_filename=self.metadata_filename,
            max_cloud_percentage=self.max_cloud_percentage,
            max_red_mean=self.max_red_mean,
            include_corrupt=self.include_corrupt,
            subset=self.subset,
            seed=self.seed,
            use_four_frames=self.use_four_frames,
        )

terratorch.datamodules.burn_intensity #

BurnIntensityNonGeoDataModule #

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for BurnIntensity datamodule.

Source code in terratorch/datamodules/burn_intensity.py
class BurnIntensityNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for BurnIntensity datamodule."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = BurnIntensityNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        use_full_data: bool = True,
        no_data_replace: float | None = 0.0001,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the DataModule for the BurnIntensity non-geospatial datamodule.

        Args:
            data_root (str): Root directory of the dataset.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (Sequence[str], optional): List of bands to use. Defaults to BurnIntensityNonGeo.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction.
            use_full_data (bool, optional): Whether to use the full dataset or data with less than 25 percent zeros. Defaults to True.
            no_data_replace (float | None, optional): Value to replace missing data. Defaults to 0.0001.
            no_label_replace (int | None, optional): Value to replace missing labels. Defaults to -1.
            use_metadata (bool): Whether to return metadata info (time and location).
            **kwargs: Additional keyword arguments.
        """
        super().__init__(BurnIntensityNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        means = [MEANS[b] for b in bands]
        stds = [STDS[b] for b in bands]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = NormalizeWithTimesteps(means, stds)
        self.use_full_data = use_full_data
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                use_full_data=self.use_full_data,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                use_full_data=self.use_full_data,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                use_full_data=self.use_full_data,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                use_full_data=self.use_full_data,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
__init__(data_root, batch_size=4, num_workers=0, bands=BurnIntensityNonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, use_full_data=True, no_data_replace=0.0001, no_label_replace=-1, use_metadata=False, **kwargs) #

Initializes the DataModule for the BurnIntensity non-geospatial datamodule.

Parameters:

Name Type Description Default
data_root str

Root directory of the dataset.

required
batch_size int

Batch size for DataLoaders. Defaults to 4.

4
num_workers int

Number of workers for data loading. Defaults to 0.

0
bands Sequence[str]

List of bands to use. Defaults to BurnIntensityNonGeo.all_band_names.

all_band_names
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
predict_transform Compose | None | list[BasicTransform]

Transformations for prediction.

None
use_full_data bool

Whether to use the full dataset or data with less than 25 percent zeros. Defaults to True.

True
no_data_replace float | None

Value to replace missing data. Defaults to 0.0001.

0.0001
no_label_replace int | None

Value to replace missing labels. Defaults to -1.

-1
use_metadata bool

Whether to return metadata info (time and location).

False
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/burn_intensity.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: Sequence[str] = BurnIntensityNonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    use_full_data: bool = True,
    no_data_replace: float | None = 0.0001,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,
    **kwargs: Any,
) -> None:
    """
    Initializes the DataModule for the BurnIntensity non-geospatial datamodule.

    Args:
        data_root (str): Root directory of the dataset.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (Sequence[str], optional): List of bands to use. Defaults to BurnIntensityNonGeo.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction.
        use_full_data (bool, optional): Whether to use the full dataset or data with less than 25 percent zeros. Defaults to True.
        no_data_replace (float | None, optional): Value to replace missing data. Defaults to 0.0001.
        no_label_replace (int | None, optional): Value to replace missing labels. Defaults to -1.
        use_metadata (bool): Whether to return metadata info (time and location).
        **kwargs: Additional keyword arguments.
    """
    super().__init__(BurnIntensityNonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    means = [MEANS[b] for b in bands]
    stds = [STDS[b] for b in bands]
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = NormalizeWithTimesteps(means, stds)
    self.use_full_data = use_full_data
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.use_metadata = use_metadata
setup(stage) #

Set up datasets.

Parameters:

Name Type Description Default
stage str

Either fit, validate, test, or predict.

required
Source code in terratorch/datamodules/burn_intensity.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            use_full_data=self.use_full_data,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            use_full_data=self.use_full_data,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            use_full_data=self.use_full_data,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            use_full_data=self.use_full_data,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )

terratorch.datamodules.carbonflux #

CarbonFluxNonGeoDataModule #

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Carbon FLux dataset.

Source code in terratorch/datamodules/carbonflux.py
class CarbonFluxNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Carbon FLux dataset."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = CarbonFluxNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        no_data_replace: float | None = 0.0001,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the CarbonFluxNonGeoDataModule.

        Args:
            data_root (str): Root directory of the dataset.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (Sequence[str], optional): List of bands to use. Defaults to CarbonFluxNonGeo.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            aug (AugmentationSequential, optional): Augmentation sequence; if None, applies multimodal normalization.
            no_data_replace (float | None, optional): Value to replace missing data. Defaults to 0.0001.
            use_metadata (bool): Whether to return metadata info.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(CarbonFluxNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        means = {
            m: ([MEANS[m][band] for band in bands] if m == "image" else MEANS[m])
            for m in MEANS.keys()
        }
        stds = {
            m: ([STDS[m][band] for band in bands] if m == "image" else STDS[m])
            for m in STDS.keys()
        }
        self.mask_means = MEANS["mask"]
        self.mask_std = STDS["mask"]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = MultimodalNormalize(means, stds) if aug is None else aug
        self.no_data_replace = no_data_replace
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                gpp_mean=self.mask_means,
                gpp_std=self.mask_std,
                no_data_replace=self.no_data_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                gpp_mean=self.mask_means,
                gpp_std=self.mask_std,
                no_data_replace=self.no_data_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                gpp_mean=self.mask_means,
                gpp_std=self.mask_std,
                no_data_replace=self.no_data_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                gpp_mean=self.mask_means,
                gpp_std=self.mask_std,
                no_data_replace=self.no_data_replace,
                use_metadata=self.use_metadata,
            )
__init__(data_root, batch_size=4, num_workers=0, bands=CarbonFluxNonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, aug=None, no_data_replace=0.0001, use_metadata=False, **kwargs) #

Initializes the CarbonFluxNonGeoDataModule.

Parameters:

Name Type Description Default
data_root str

Root directory of the dataset.

required
batch_size int

Batch size for DataLoaders. Defaults to 4.

4
num_workers int

Number of workers for data loading. Defaults to 0.

0
bands Sequence[str]

List of bands to use. Defaults to CarbonFluxNonGeo.all_band_names.

all_band_names
train_transform Compose | None | list[BasicTransform]

Transformations for training data.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation data.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing data.

None
predict_transform Compose | None | list[BasicTransform]

Transformations for prediction data.

None
aug AugmentationSequential

Augmentation sequence; if None, applies multimodal normalization.

None
no_data_replace float | None

Value to replace missing data. Defaults to 0.0001.

0.0001
use_metadata bool

Whether to return metadata info.

False
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/carbonflux.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: Sequence[str] = CarbonFluxNonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    no_data_replace: float | None = 0.0001,
    use_metadata: bool = False,
    **kwargs: Any,
) -> None:
    """
    Initializes the CarbonFluxNonGeoDataModule.

    Args:
        data_root (str): Root directory of the dataset.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (Sequence[str], optional): List of bands to use. Defaults to CarbonFluxNonGeo.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        aug (AugmentationSequential, optional): Augmentation sequence; if None, applies multimodal normalization.
        no_data_replace (float | None, optional): Value to replace missing data. Defaults to 0.0001.
        use_metadata (bool): Whether to return metadata info.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(CarbonFluxNonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    means = {
        m: ([MEANS[m][band] for band in bands] if m == "image" else MEANS[m])
        for m in MEANS.keys()
    }
    stds = {
        m: ([STDS[m][band] for band in bands] if m == "image" else STDS[m])
        for m in STDS.keys()
    }
    self.mask_means = MEANS["mask"]
    self.mask_std = STDS["mask"]
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = MultimodalNormalize(means, stds) if aug is None else aug
    self.no_data_replace = no_data_replace
    self.use_metadata = use_metadata
setup(stage) #

Set up datasets.

Parameters:

Name Type Description Default
stage str

Either fit, validate, test, or predict.

required
Source code in terratorch/datamodules/carbonflux.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            gpp_mean=self.mask_means,
            gpp_std=self.mask_std,
            no_data_replace=self.no_data_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            gpp_mean=self.mask_means,
            gpp_std=self.mask_std,
            no_data_replace=self.no_data_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            gpp_mean=self.mask_means,
            gpp_std=self.mask_std,
            no_data_replace=self.no_data_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            gpp_mean=self.mask_means,
            gpp_std=self.mask_std,
            no_data_replace=self.no_data_replace,
            use_metadata=self.use_metadata,
        )

terratorch.datamodules.forestnet #

ForestNetNonGeoDataModule #

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Landslide4Sense dataset.

Source code in terratorch/datamodules/forestnet.py
class ForestNetNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Landslide4Sense dataset."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        label_map: dict[str, int] = ForestNetNonGeo.default_label_map,
        bands: Sequence[str] = ForestNetNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        fraction: float = 1.0,
        aug: AugmentationSequential = None,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the ForestNetNonGeoDataModule.

        Args:
            data_root (str): Directory containing the dataset.
            batch_size (int, optional): Batch size for data loaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            label_map (dict[str, int], optional): Mapping of labels to integers. Defaults to ForestNetNonGeo.default_label_map.
            bands (Sequence[str], optional): List of band names to use. Defaults to ForestNetNonGeo.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction.
            fraction (float, optional): Fraction of data to use. Defaults to 1.0.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline; if None, uses Normalize.
            use_metadata (bool): Whether to return metadata info.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(ForestNetNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        self.means = [MEANS[b] for b in bands]
        self.stds = [STDS[b] for b in bands]
        self.label_map = label_map
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = Normalize(self.means, self.stds) if aug is None else aug
        self.fraction = fraction
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                label_map=self.label_map,
                transform=self.train_transform,
                bands=self.bands,
                fraction=self.fraction,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                label_map=self.label_map,
                transform=self.val_transform,
                bands=self.bands,
                fraction=self.fraction,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                label_map=self.label_map,
                transform=self.test_transform,
                bands=self.bands,
                fraction=self.fraction,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                label_map=self.label_map,
                transform=self.predict_transform,
                bands=self.bands,
                fraction=self.fraction,
                use_metadata=self.use_metadata,
            )
__init__(data_root, batch_size=4, num_workers=0, label_map=ForestNetNonGeo.default_label_map, bands=ForestNetNonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, fraction=1.0, aug=None, use_metadata=False, **kwargs) #

Initializes the ForestNetNonGeoDataModule.

Parameters:

Name Type Description Default
data_root str

Directory containing the dataset.

required
batch_size int

Batch size for data loaders. Defaults to 4.

4
num_workers int

Number of workers for data loading. Defaults to 0.

0
label_map dict[str, int]

Mapping of labels to integers. Defaults to ForestNetNonGeo.default_label_map.

default_label_map
bands Sequence[str]

List of band names to use. Defaults to ForestNetNonGeo.all_band_names.

all_band_names
train_transform Compose | None | list[BasicTransform]

Transformations for training data.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation data.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing data.

None
predict_transform Compose | None | list[BasicTransform]

Transformations for prediction.

None
fraction float

Fraction of data to use. Defaults to 1.0.

1.0
aug AugmentationSequential

Augmentation/normalization pipeline; if None, uses Normalize.

None
use_metadata bool

Whether to return metadata info.

False
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/forestnet.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    label_map: dict[str, int] = ForestNetNonGeo.default_label_map,
    bands: Sequence[str] = ForestNetNonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    fraction: float = 1.0,
    aug: AugmentationSequential = None,
    use_metadata: bool = False,
    **kwargs: Any,
) -> None:
    """
    Initializes the ForestNetNonGeoDataModule.

    Args:
        data_root (str): Directory containing the dataset.
        batch_size (int, optional): Batch size for data loaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        label_map (dict[str, int], optional): Mapping of labels to integers. Defaults to ForestNetNonGeo.default_label_map.
        bands (Sequence[str], optional): List of band names to use. Defaults to ForestNetNonGeo.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction.
        fraction (float, optional): Fraction of data to use. Defaults to 1.0.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline; if None, uses Normalize.
        use_metadata (bool): Whether to return metadata info.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(ForestNetNonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    self.means = [MEANS[b] for b in bands]
    self.stds = [STDS[b] for b in bands]
    self.label_map = label_map
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = Normalize(self.means, self.stds) if aug is None else aug
    self.fraction = fraction
    self.use_metadata = use_metadata
setup(stage) #

Set up datasets.

Parameters:

Name Type Description Default
stage str

Either fit, validate, test, or predict.

required
Source code in terratorch/datamodules/forestnet.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            label_map=self.label_map,
            transform=self.train_transform,
            bands=self.bands,
            fraction=self.fraction,
            use_metadata=self.use_metadata,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            label_map=self.label_map,
            transform=self.val_transform,
            bands=self.bands,
            fraction=self.fraction,
            use_metadata=self.use_metadata,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            label_map=self.label_map,
            transform=self.test_transform,
            bands=self.bands,
            fraction=self.fraction,
            use_metadata=self.use_metadata,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            label_map=self.label_map,
            transform=self.predict_transform,
            bands=self.bands,
            fraction=self.fraction,
            use_metadata=self.use_metadata,
        )

terratorch.datamodules.fire_scars #

FireScarsDataModule #

Bases: GeoDataModule

Geo Fire Scars data module implementation that merges input data with ground truth segmentation masks.

Source code in terratorch/datamodules/fire_scars.py
class FireScarsDataModule(GeoDataModule):
    """Geo Fire Scars data module implementation that merges input data with ground truth segmentation masks."""

    def __init__(self, data_root: str, **kwargs: Any) -> None:
        super().__init__(FireScarsSegmentationMask, 4, 224, 100, 0, **kwargs)
        means = list(MEANS.values())
        stds = list(STDS.values())
        self.train_aug = AugmentationSequential(K.RandomCrop(224, 224), K.Normalize(means, stds), data_keys=None)
        self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=None)
        self.data_root = data_root

    def setup(self, stage: str) -> None:
        self.images = FireScarsHLS(
            os.path.join(self.data_root, "training/")
        )
        self.labels = FireScarsSegmentationMask(
            os.path.join(self.data_root, "training/")
        )
        self.dataset = self.images & self.labels
        self.train_aug = AugmentationSequential(K.RandomCrop(224, 224), K.normalize(), data_keys=None)

        self.images_test = FireScarsHLS(
            os.path.join(self.data_root, "validation/")
        )
        self.labels_test = FireScarsSegmentationMask(
            os.path.join(self.data_root, "validation/")
        )
        self.val_dataset = self.images_test & self.labels_test

        if stage in ["fit"]:
            self.train_batch_sampler = RandomBatchGeoSampler(self.dataset, self.patch_size, self.batch_size, None)
        if stage in ["fit", "validate"]:
            self.val_sampler = GridGeoSampler(self.val_dataset, self.patch_size, self.patch_size)
        if stage in ["test"]:
            self.test_sampler = GridGeoSampler(self.val_dataset, self.patch_size, self.patch_size)

FireScarsNonGeoDataModule #

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Fire Scars dataset.

Source code in terratorch/datamodules/fire_scars.py
class FireScarsNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Fire Scars dataset."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = FireScarsNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        drop_last: bool = True,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the FireScarsNonGeoDataModule.

        Args:
            data_root (str): Root directory of the dataset.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (Sequence[str], optional): List of band names. Defaults to FireScarsNonGeo.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction.
            drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
            no_data_replace (float | None, optional): Replacement value for missing data. Defaults to 0.
            no_label_replace (int | None, optional): Replacement value for missing labels. Defaults to -1.
            use_metadata (bool): Whether to return metadata info.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(FireScarsNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        means = [MEANS[b] for b in bands]
        stds = [STDS[b] for b in bands]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=None)
        self.drop_last = drop_last
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )
__init__(data_root, batch_size=4, num_workers=0, bands=FireScarsNonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, drop_last=True, no_data_replace=0, no_label_replace=-1, use_metadata=False, **kwargs) #

Initializes the FireScarsNonGeoDataModule.

Parameters:

Name Type Description Default
data_root str

Root directory of the dataset.

required
batch_size int

Batch size for DataLoaders. Defaults to 4.

4
num_workers int

Number of workers for data loading. Defaults to 0.

0
bands Sequence[str]

List of band names. Defaults to FireScarsNonGeo.all_band_names.

all_band_names
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
predict_transform Compose | None | list[BasicTransform]

Transformations for prediction.

None
drop_last bool

Whether to drop the last incomplete batch. Defaults to True.

True
no_data_replace float | None

Replacement value for missing data. Defaults to 0.

0
no_label_replace int | None

Replacement value for missing labels. Defaults to -1.

-1
use_metadata bool

Whether to return metadata info.

False
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/fire_scars.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: Sequence[str] = FireScarsNonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    drop_last: bool = True,
    no_data_replace: float | None = 0,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,
    **kwargs: Any,
) -> None:
    """
    Initializes the FireScarsNonGeoDataModule.

    Args:
        data_root (str): Root directory of the dataset.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (Sequence[str], optional): List of band names. Defaults to FireScarsNonGeo.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction.
        drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
        no_data_replace (float | None, optional): Replacement value for missing data. Defaults to 0.
        no_label_replace (int | None, optional): Replacement value for missing labels. Defaults to -1.
        use_metadata (bool): Whether to return metadata info.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(FireScarsNonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    means = [MEANS[b] for b in bands]
    stds = [STDS[b] for b in bands]
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=None)
    self.drop_last = drop_last
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.use_metadata = use_metadata
setup(stage) #

Set up datasets.

Parameters:

Name Type Description Default
stage str

Either fit, validate, test, or predict.

required
Source code in terratorch/datamodules/fire_scars.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )

terratorch.datamodules.landslide4sense #

Landslide4SenseNonGeoDataModule #

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Landslide4Sense dataset.

Source code in terratorch/datamodules/landslide4sense.py
class Landslide4SenseNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Landslide4Sense dataset."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = Landslide4SenseNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the Landslide4SenseNonGeoDataModule.

        Args:
            data_root (str): Root directory of the dataset.
            batch_size (int, optional): Batch size for data loaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (Sequence[str], optional): List of band names to use. Defaults to Landslide4SenseNonGeo.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            aug (AugmentationSequential, optional): Augmentation pipeline; if None, applies normalization using computed means and stds.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(Landslide4SenseNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        self.means = [MEANS[b] for b in bands]
        self.stds = [STDS[b] for b in bands]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = (
            AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=None) if aug is None else aug
        )

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands
            )
__init__(data_root, batch_size=4, num_workers=0, bands=Landslide4SenseNonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, aug=None, **kwargs) #

Initializes the Landslide4SenseNonGeoDataModule.

Parameters:

Name Type Description Default
data_root str

Root directory of the dataset.

required
batch_size int

Batch size for data loaders. Defaults to 4.

4
num_workers int

Number of workers for data loading. Defaults to 0.

0
bands Sequence[str]

List of band names to use. Defaults to Landslide4SenseNonGeo.all_band_names.

all_band_names
train_transform Compose | None | list[BasicTransform]

Transformations for training data.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation data.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing data.

None
predict_transform Compose | None | list[BasicTransform]

Transformations for prediction data.

None
aug AugmentationSequential

Augmentation pipeline; if None, applies normalization using computed means and stds.

None
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/landslide4sense.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: Sequence[str] = Landslide4SenseNonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    **kwargs: Any,
) -> None:
    """
    Initializes the Landslide4SenseNonGeoDataModule.

    Args:
        data_root (str): Root directory of the dataset.
        batch_size (int, optional): Batch size for data loaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (Sequence[str], optional): List of band names to use. Defaults to Landslide4SenseNonGeo.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        aug (AugmentationSequential, optional): Augmentation pipeline; if None, applies normalization using computed means and stds.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(Landslide4SenseNonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    self.means = [MEANS[b] for b in bands]
    self.stds = [STDS[b] for b in bands]
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = (
        AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=None) if aug is None else aug
    )
setup(stage) #

Set up datasets.

Parameters:

Name Type Description Default
stage str

Either fit, validate, test, or predict.

required
Source code in terratorch/datamodules/landslide4sense.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands
        )

terratorch.datamodules.m_eurosat #

MEuroSATNonGeoDataModule #

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-EuroSAT dataset.

Source code in terratorch/datamodules/m_eurosat.py
class MEuroSATNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-EuroSAT dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MEuroSATNonGeoDataModule for the MEuroSATNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MEuroSATNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs) #

Initializes the MEuroSATNonGeoDataModule for the MEuroSATNonGeo dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
bands Sequence[str] | None

List of bands to use. Defaults to None.

None
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
aug AugmentationSequential

Augmentation/normalization pipeline. Defaults to None.

None
partition str

Partition size. Defaults to "default".

'default'
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/m_eurosat.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MEuroSATNonGeoDataModule for the MEuroSATNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MEuroSATNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.m_bigearthnet #

MBigEarthNonGeoDataModule #

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-BigEarthNet dataset.

Source code in terratorch/datamodules/m_bigearthnet.py
class MBigEarthNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-BigEarthNet dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MBigEarthNonGeoDataModule for the M-BigEarthNet dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MBigEarthNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs) #

Initializes the MBigEarthNonGeoDataModule for the M-BigEarthNet dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
bands Sequence[str] | None

List of bands to use. Defaults to None.

None
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
aug AugmentationSequential

Augmentation/normalization pipeline. Defaults to None.

None
partition str

Partition size. Defaults to "default".

'default'
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/m_bigearthnet.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MBigEarthNonGeoDataModule for the M-BigEarthNet dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MBigEarthNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.m_brick_kiln #

MBrickKilnNonGeoDataModule #

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-BrickKiln dataset.

Source code in terratorch/datamodules/m_brick_kiln.py
class MBrickKilnNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-BrickKiln dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MBrickKilnNonGeoDataModule for the M-BrickKilnNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MBrickKilnNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs) #

Initializes the MBrickKilnNonGeoDataModule for the M-BrickKilnNonGeo dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
bands Sequence[str] | None

List of bands to use. Defaults to None.

None
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
aug AugmentationSequential

Augmentation/normalization pipeline. Defaults to None.

None
partition str

Partition size. Defaults to "default".

'default'
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/m_brick_kiln.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MBrickKilnNonGeoDataModule for the M-BrickKilnNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MBrickKilnNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.m_forestnet #

MForestNetNonGeoDataModule #

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-ForestNet dataset.

Source code in terratorch/datamodules/m_forestnet.py
class MForestNetNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-ForestNet dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        use_metadata: bool = False,  # noqa: FBT002, FBT001
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MForestNetNonGeoDataModule for the MForestNetNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            use_metadata (bool): Whether to return metadata info.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MForestNetNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            use_metadata=use_metadata,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', use_metadata=False, **kwargs) #

Initializes the MForestNetNonGeoDataModule for the MForestNetNonGeo dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
bands Sequence[str] | None

List of bands to use. Defaults to None.

None
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
aug AugmentationSequential

Augmentation/normalization pipeline. Defaults to None.

None
partition str

Partition size. Defaults to "default".

'default'
use_metadata bool

Whether to return metadata info.

False
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/m_forestnet.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    use_metadata: bool = False,  # noqa: FBT002, FBT001
    **kwargs: Any,
) -> None:
    """
    Initializes the MForestNetNonGeoDataModule for the MForestNetNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        use_metadata (bool): Whether to return metadata info.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MForestNetNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        use_metadata=use_metadata,
        **kwargs,
    )

terratorch.datamodules.m_so2sat #

MSo2SatNonGeoDataModule #

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-So2Sat dataset.

Source code in terratorch/datamodules/m_so2sat.py
class MSo2SatNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-So2Sat dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MSo2SatNonGeoDataModule for the MSo2SatNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MSo2SatNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs) #

Initializes the MSo2SatNonGeoDataModule for the MSo2SatNonGeo dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
bands Sequence[str] | None

List of bands to use. Defaults to None.

None
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
aug AugmentationSequential

Augmentation/normalization pipeline. Defaults to None.

None
partition str

Partition size. Defaults to "default".

'default'
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/m_so2sat.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MSo2SatNonGeoDataModule for the MSo2SatNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MSo2SatNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.m_pv4ger #

MPv4gerNonGeoDataModule #

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-Pv4ger dataset.

Source code in terratorch/datamodules/m_pv4ger.py
class MPv4gerNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-Pv4ger dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        use_metadata: bool = False,  # noqa: FBT002, FBT001
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MPv4gerNonGeoDataModule for the MPv4gerNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            use_metadata (bool): Whether to return metadata info.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MPv4gerNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            use_metadata=use_metadata,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', use_metadata=False, **kwargs) #

Initializes the MPv4gerNonGeoDataModule for the MPv4gerNonGeo dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
bands Sequence[str] | None

List of bands to use. Defaults to None.

None
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
aug AugmentationSequential

Augmentation/normalization pipeline. Defaults to None.

None
partition str

Partition size. Defaults to "default".

'default'
use_metadata bool

Whether to return metadata info.

False
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/m_pv4ger.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    use_metadata: bool = False,  # noqa: FBT002, FBT001
    **kwargs: Any,
) -> None:
    """
    Initializes the MPv4gerNonGeoDataModule for the MPv4gerNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        use_metadata (bool): Whether to return metadata info.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MPv4gerNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        use_metadata=use_metadata,
        **kwargs,
    )

terratorch.datamodules.m_cashew_plantation #

MBeninSmallHolderCashewsNonGeoDataModule #

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-Cashew Plantation dataset.

Source code in terratorch/datamodules/m_cashew_plantation.py
class MBeninSmallHolderCashewsNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-Cashew Plantation dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        use_metadata: bool = False,  # noqa: FBT002, FBT001
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MBeninSmallHolderCashewsNonGeoDataModule for the M-BeninSmallHolderCashewsNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            use_metadata (bool): Whether to return metadata info.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MBeninSmallHolderCashewsNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            use_metadata=use_metadata,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', use_metadata=False, **kwargs) #

Initializes the MBeninSmallHolderCashewsNonGeoDataModule for the M-BeninSmallHolderCashewsNonGeo dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
bands Sequence[str] | None

List of bands to use. Defaults to None.

None
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
aug AugmentationSequential

Augmentation/normalization pipeline. Defaults to None.

None
partition str

Partition size. Defaults to "default".

'default'
use_metadata bool

Whether to return metadata info.

False
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/m_cashew_plantation.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    use_metadata: bool = False,  # noqa: FBT002, FBT001
    **kwargs: Any,
) -> None:
    """
    Initializes the MBeninSmallHolderCashewsNonGeoDataModule for the M-BeninSmallHolderCashewsNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        use_metadata (bool): Whether to return metadata info.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MBeninSmallHolderCashewsNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        use_metadata=use_metadata,
        **kwargs,
    )

terratorch.datamodules.m_nz_cattle #

MNzCattleNonGeoDataModule #

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-NZCattle dataset.

Source code in terratorch/datamodules/m_nz_cattle.py
class MNzCattleNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-NZCattle dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        use_metadata: bool = False,  # noqa: FBT002, FBT001
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MNzCattleNonGeoDataModule for the MNzCattleNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            use_metadata (bool): Whether to return metadata info.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MNzCattleNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            use_metadata=use_metadata,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', use_metadata=False, **kwargs) #

Initializes the MNzCattleNonGeoDataModule for the MNzCattleNonGeo dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
bands Sequence[str] | None

List of bands to use. Defaults to None.

None
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
aug AugmentationSequential

Augmentation/normalization pipeline. Defaults to None.

None
partition str

Partition size. Defaults to "default".

'default'
use_metadata bool

Whether to return metadata info.

False
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/m_nz_cattle.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    use_metadata: bool = False,  # noqa: FBT002, FBT001
    **kwargs: Any,
) -> None:
    """
    Initializes the MNzCattleNonGeoDataModule for the MNzCattleNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        use_metadata (bool): Whether to return metadata info.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MNzCattleNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        use_metadata=use_metadata,
        **kwargs,
    )

terratorch.datamodules.m_chesapeake_landcover #

MChesapeakeLandcoverNonGeoDataModule #

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-ChesapeakeLandcover dataset.

Source code in terratorch/datamodules/m_chesapeake_landcover.py
class MChesapeakeLandcoverNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-ChesapeakeLandcover dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MChesapeakeLandcoverNonGeoDataModule for the M-BigEarthNet dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MChesapeakeLandcoverNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs) #

Initializes the MChesapeakeLandcoverNonGeoDataModule for the M-BigEarthNet dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
bands Sequence[str] | None

List of bands to use. Defaults to None.

None
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
aug AugmentationSequential

Augmentation/normalization pipeline. Defaults to None.

None
partition str

Partition size. Defaults to "default".

'default'
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/m_chesapeake_landcover.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MChesapeakeLandcoverNonGeoDataModule for the M-BigEarthNet dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MChesapeakeLandcoverNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.m_pv4ger_seg #

MPv4gerSegNonGeoDataModule #

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-Pv4gerSeg dataset.

Source code in terratorch/datamodules/m_pv4ger_seg.py
class MPv4gerSegNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-Pv4gerSeg dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        use_metadata: bool = False,  # noqa: FBT002, FBT001
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MPv4gerNonGeoDataModule for the MPv4gerSegNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            use_metadata (bool): Whether to return metadata info.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MPv4gerSegNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            use_metadata=use_metadata,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', use_metadata=False, **kwargs) #

Initializes the MPv4gerNonGeoDataModule for the MPv4gerSegNonGeo dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
bands Sequence[str] | None

List of bands to use. Defaults to None.

None
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
aug AugmentationSequential

Augmentation/normalization pipeline. Defaults to None.

None
partition str

Partition size. Defaults to "default".

'default'
use_metadata bool

Whether to return metadata info.

False
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/m_pv4ger_seg.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    use_metadata: bool = False,  # noqa: FBT002, FBT001
    **kwargs: Any,
) -> None:
    """
    Initializes the MPv4gerNonGeoDataModule for the MPv4gerSegNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        use_metadata (bool): Whether to return metadata info.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MPv4gerSegNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        use_metadata=use_metadata,
        **kwargs,
    )

terratorch.datamodules.m_SA_crop_type #

MSACropTypeNonGeoDataModule #

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-SA-CropType dataset.

Source code in terratorch/datamodules/m_SA_crop_type.py
class MSACropTypeNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-SA-CropType dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MSACropTypeNonGeoDataModule for the MSACropTypeNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MSACropTypeNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs) #

Initializes the MSACropTypeNonGeoDataModule for the MSACropTypeNonGeo dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
bands Sequence[str] | None

List of bands to use. Defaults to None.

None
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
aug AugmentationSequential

Augmentation/normalization pipeline. Defaults to None.

None
partition str

Partition size. Defaults to "default".

'default'
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/m_SA_crop_type.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MSACropTypeNonGeoDataModule for the MSACropTypeNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MSACropTypeNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.m_neontree #

MNeonTreeNonGeoDataModule #

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-NeonTree dataset.

Source code in terratorch/datamodules/m_neontree.py
class MNeonTreeNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-NeonTree dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MNeonTreeNonGeoDataModule for the MNeonTreeNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MNeonTreeNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs) #

Initializes the MNeonTreeNonGeoDataModule for the MNeonTreeNonGeo dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
bands Sequence[str] | None

List of bands to use. Defaults to None.

None
train_transform Compose | None | list[BasicTransform]

Transformations for training.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing.

None
aug AugmentationSequential

Augmentation/normalization pipeline. Defaults to None.

None
partition str

Partition size. Defaults to "default".

'default'
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/m_neontree.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MNeonTreeNonGeoDataModule for the MNeonTreeNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MNeonTreeNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.multi_temporal_crop_classification #

MultiTemporalCropClassificationDataModule #

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for multi-temporal crop classification.

Source code in terratorch/datamodules/multi_temporal_crop_classification.py
class MultiTemporalCropClassificationDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for multi-temporal crop classification."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = MultiTemporalCropClassification.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        drop_last: bool = True,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        expand_temporal_dimension: bool = True,
        reduce_zero_label: bool = True,
        use_metadata: bool = False,
        metadata_file_name: str = "chips_df.csv",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MultiTemporalCropClassificationDataModule for multi-temporal crop classification.

        Args:
            data_root (str): Directory containing the dataset.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (Sequence[str], optional): List of bands to use. Defaults to MultiTemporalCropClassification.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            drop_last (bool, optional): Whether to drop the last incomplete batch during training. Defaults to True.
            no_data_replace (float | None, optional): Replacement value for missing data. Defaults to 0.
            no_label_replace (int | None, optional): Replacement value for missing labels. Defaults to -1.
            expand_temporal_dimension (bool, optional): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to True.
            reduce_zero_label (bool, optional): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to True.
            use_metadata (bool): Whether to return metadata info (time and location).
            **kwargs: Additional keyword arguments.
        """
        super().__init__(MultiTemporalCropClassification, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        self.means = [MEANS[b] for b in bands]
        self.stds = [STDS[b] for b in bands]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = Normalize(self.means, self.stds)
        self.drop_last = drop_last
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.expand_temporal_dimension = expand_temporal_dimension
        self.reduce_zero_label = reduce_zero_label
        self.use_metadata = use_metadata
        self.metadata_file_name = metadata_file_name

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension = self.expand_temporal_dimension,
                reduce_zero_label = self.reduce_zero_label,
                use_metadata=self.use_metadata,
                metadata_file_name=self.metadata_file_name,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension = self.expand_temporal_dimension,
                reduce_zero_label = self.reduce_zero_label,
                use_metadata=self.use_metadata,
                metadata_file_name=self.metadata_file_name,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension = self.expand_temporal_dimension,
                reduce_zero_label = self.reduce_zero_label,
                use_metadata=self.use_metadata,
                metadata_file_name=self.metadata_file_name,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension = self.expand_temporal_dimension,
                reduce_zero_label = self.reduce_zero_label,
                use_metadata=self.use_metadata,
                metadata_file_name=self.metadata_file_name,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )
__init__(data_root, batch_size=4, num_workers=0, bands=MultiTemporalCropClassification.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, drop_last=True, no_data_replace=0, no_label_replace=-1, expand_temporal_dimension=True, reduce_zero_label=True, use_metadata=False, metadata_file_name='chips_df.csv', **kwargs) #

Initializes the MultiTemporalCropClassificationDataModule for multi-temporal crop classification.

Parameters:

Name Type Description Default
data_root str

Directory containing the dataset.

required
batch_size int

Batch size for DataLoaders. Defaults to 4.

4
num_workers int

Number of workers for data loading. Defaults to 0.

0
bands Sequence[str]

List of bands to use. Defaults to MultiTemporalCropClassification.all_band_names.

all_band_names
train_transform Compose | None | list[BasicTransform]

Transformations for training data.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation data.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing data.

None
predict_transform Compose | None | list[BasicTransform]

Transformations for prediction data.

None
drop_last bool

Whether to drop the last incomplete batch during training. Defaults to True.

True
no_data_replace float | None

Replacement value for missing data. Defaults to 0.

0
no_label_replace int | None

Replacement value for missing labels. Defaults to -1.

-1
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to True.

True
reduce_zero_label bool

Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to True.

True
use_metadata bool

Whether to return metadata info (time and location).

False
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/multi_temporal_crop_classification.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: Sequence[str] = MultiTemporalCropClassification.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    drop_last: bool = True,
    no_data_replace: float | None = 0,
    no_label_replace: int | None = -1,
    expand_temporal_dimension: bool = True,
    reduce_zero_label: bool = True,
    use_metadata: bool = False,
    metadata_file_name: str = "chips_df.csv",
    **kwargs: Any,
) -> None:
    """
    Initializes the MultiTemporalCropClassificationDataModule for multi-temporal crop classification.

    Args:
        data_root (str): Directory containing the dataset.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (Sequence[str], optional): List of bands to use. Defaults to MultiTemporalCropClassification.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        drop_last (bool, optional): Whether to drop the last incomplete batch during training. Defaults to True.
        no_data_replace (float | None, optional): Replacement value for missing data. Defaults to 0.
        no_label_replace (int | None, optional): Replacement value for missing labels. Defaults to -1.
        expand_temporal_dimension (bool, optional): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to True.
        reduce_zero_label (bool, optional): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to True.
        use_metadata (bool): Whether to return metadata info (time and location).
        **kwargs: Additional keyword arguments.
    """
    super().__init__(MultiTemporalCropClassification, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    self.means = [MEANS[b] for b in bands]
    self.stds = [STDS[b] for b in bands]
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = Normalize(self.means, self.stds)
    self.drop_last = drop_last
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.expand_temporal_dimension = expand_temporal_dimension
    self.reduce_zero_label = reduce_zero_label
    self.use_metadata = use_metadata
    self.metadata_file_name = metadata_file_name
setup(stage) #

Set up datasets.

Parameters:

Name Type Description Default
stage str

Either fit, validate, test, or predict.

required
Source code in terratorch/datamodules/multi_temporal_crop_classification.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            expand_temporal_dimension = self.expand_temporal_dimension,
            reduce_zero_label = self.reduce_zero_label,
            use_metadata=self.use_metadata,
            metadata_file_name=self.metadata_file_name,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            expand_temporal_dimension = self.expand_temporal_dimension,
            reduce_zero_label = self.reduce_zero_label,
            use_metadata=self.use_metadata,
            metadata_file_name=self.metadata_file_name,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            expand_temporal_dimension = self.expand_temporal_dimension,
            reduce_zero_label = self.reduce_zero_label,
            use_metadata=self.use_metadata,
            metadata_file_name=self.metadata_file_name,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            expand_temporal_dimension = self.expand_temporal_dimension,
            reduce_zero_label = self.reduce_zero_label,
            use_metadata=self.use_metadata,
            metadata_file_name=self.metadata_file_name,
        )

terratorch.datamodules.open_sentinel_map #

OpenSentinelMapDataModule #

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Open Sentinel Map.

Source code in terratorch/datamodules/open_sentinel_map.py
class OpenSentinelMapDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Open Sentinel Map."""

    def __init__(
        self,
        bands: list[str] | None = None,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        spatial_interpolate_and_stack_temporally: bool = True,  # noqa: FBT001, FBT002
        pad_image: int | None = None,
        truncate_image: int | None = None,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the OpenSentinelMapDataModule for the Open Sentinel Map dataset.

        Args:
            bands (list[str] | None, optional): List of bands to use. Defaults to None.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            spatial_interpolate_and_stack_temporally (bool, optional): If True, the bands are interpolated and concatenated over time.
                Default is True.
            pad_image (int | None, optional): Number of timesteps to pad the time dimension of the image.
                If None, no padding is applied.
            truncate_image (int | None, optional):  Number of timesteps to truncate the time dimension of the image.
                If None, no truncation is performed.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(
            OpenSentinelMap,
            batch_size=batch_size,
            num_workers=num_workers,
            **kwargs,
        )
        self.bands = bands
        self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally
        self.pad_image = pad_image
        self.truncate_image = truncate_image
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.data_root = data_root
        self.kwargs = kwargs

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = OpenSentinelMap(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
                pad_image = self.pad_image,
                truncate_image = self.truncate_image,
                **self.kwargs,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = OpenSentinelMap(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
                pad_image = self.pad_image,
                truncate_image = self.truncate_image,
                **self.kwargs,
            )
        if stage in ["test"]:
            self.test_dataset = OpenSentinelMap(
                split="test",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
                pad_image = self.pad_image,
                truncate_image = self.truncate_image,
                **self.kwargs,
            )
        if stage in ["predict"]:
            self.predict_dataset = OpenSentinelMap(
                split="test",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
                pad_image = self.pad_image,
                truncate_image = self.truncate_image,
                **self.kwargs,
            )
__init__(bands=None, batch_size=8, num_workers=0, data_root='./', train_transform=None, val_transform=None, test_transform=None, predict_transform=None, spatial_interpolate_and_stack_temporally=True, pad_image=None, truncate_image=None, **kwargs) #

Initializes the OpenSentinelMapDataModule for the Open Sentinel Map dataset.

Parameters:

Name Type Description Default
bands list[str] | None

List of bands to use. Defaults to None.

None
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
train_transform Compose | None | list[BasicTransform]

Transformations for training data.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation data.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing data.

None
predict_transform Compose | None | list[BasicTransform]

Transformations for prediction data.

None
spatial_interpolate_and_stack_temporally bool

If True, the bands are interpolated and concatenated over time. Default is True.

True
pad_image int | None

Number of timesteps to pad the time dimension of the image. If None, no padding is applied.

None
truncate_image int | None

Number of timesteps to truncate the time dimension of the image. If None, no truncation is performed.

None
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/open_sentinel_map.py
def __init__(
    self,
    bands: list[str] | None = None,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    spatial_interpolate_and_stack_temporally: bool = True,  # noqa: FBT001, FBT002
    pad_image: int | None = None,
    truncate_image: int | None = None,
    **kwargs: Any,
) -> None:
    """
    Initializes the OpenSentinelMapDataModule for the Open Sentinel Map dataset.

    Args:
        bands (list[str] | None, optional): List of bands to use. Defaults to None.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        spatial_interpolate_and_stack_temporally (bool, optional): If True, the bands are interpolated and concatenated over time.
            Default is True.
        pad_image (int | None, optional): Number of timesteps to pad the time dimension of the image.
            If None, no padding is applied.
        truncate_image (int | None, optional):  Number of timesteps to truncate the time dimension of the image.
            If None, no truncation is performed.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(
        OpenSentinelMap,
        batch_size=batch_size,
        num_workers=num_workers,
        **kwargs,
    )
    self.bands = bands
    self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally
    self.pad_image = pad_image
    self.truncate_image = truncate_image
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.data_root = data_root
    self.kwargs = kwargs
setup(stage) #

Set up datasets.

Parameters:

Name Type Description Default
stage str

Either fit, validate, test, or predict.

required
Source code in terratorch/datamodules/open_sentinel_map.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = OpenSentinelMap(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
            pad_image = self.pad_image,
            truncate_image = self.truncate_image,
            **self.kwargs,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = OpenSentinelMap(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
            pad_image = self.pad_image,
            truncate_image = self.truncate_image,
            **self.kwargs,
        )
    if stage in ["test"]:
        self.test_dataset = OpenSentinelMap(
            split="test",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
            pad_image = self.pad_image,
            truncate_image = self.truncate_image,
            **self.kwargs,
        )
    if stage in ["predict"]:
        self.predict_dataset = OpenSentinelMap(
            split="test",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
            pad_image = self.pad_image,
            truncate_image = self.truncate_image,
            **self.kwargs,
        )

terratorch.datamodules.openearthmap #

OpenEarthMapNonGeoDataModule #

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Open Earth Map.

Source code in terratorch/datamodules/openearthmap.py
class OpenEarthMapNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Open Earth Map."""

    def __init__(
        self, 
        batch_size: int = 8, 
        num_workers: int = 0, 
        data_root: str = "./",
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        **kwargs: Any
    ) -> None:
        """
        Initializes the OpenEarthMapNonGeoDataModule for the Open Earth Map dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for test data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            aug (AugmentationSequential, optional): Augmentation pipeline; if None, defaults to normalization using computed means and stds.
            **kwargs: Additional keyword arguments. Can include 'bands' (list[str]) to specify the bands; defaults to OpenEarthMapNonGeo.all_band_names if not provided.
        """
        super().__init__(OpenEarthMapNonGeo, batch_size, num_workers, **kwargs)

        bands = kwargs.get("bands", OpenEarthMapNonGeo.all_band_names)
        self.means = torch.tensor([MEANS[b] for b in bands])
        self.stds = torch.tensor([STDS[b] for b in bands])
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.data_root = data_root
        self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=None) if aug is None else aug

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(  
                split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val", data_root=self.data_root, transform=self.val_transform, **self.kwargs
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",data_root=self.data_root, transform=self.predict_transform, **self.kwargs
            )
__init__(batch_size=8, num_workers=0, data_root='./', train_transform=None, val_transform=None, test_transform=None, predict_transform=None, aug=None, **kwargs) #

Initializes the OpenEarthMapNonGeoDataModule for the Open Earth Map dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
train_transform Compose | None | list[BasicTransform]

Transformations for training data.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation data.

None
test_transform Compose | None | list[BasicTransform]

Transformations for test data.

None
predict_transform Compose | None | list[BasicTransform]

Transformations for prediction data.

None
aug AugmentationSequential

Augmentation pipeline; if None, defaults to normalization using computed means and stds.

None
**kwargs Any

Additional keyword arguments. Can include 'bands' (list[str]) to specify the bands; defaults to OpenEarthMapNonGeo.all_band_names if not provided.

{}
Source code in terratorch/datamodules/openearthmap.py
def __init__(
    self, 
    batch_size: int = 8, 
    num_workers: int = 0, 
    data_root: str = "./",
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    **kwargs: Any
) -> None:
    """
    Initializes the OpenEarthMapNonGeoDataModule for the Open Earth Map dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for test data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        aug (AugmentationSequential, optional): Augmentation pipeline; if None, defaults to normalization using computed means and stds.
        **kwargs: Additional keyword arguments. Can include 'bands' (list[str]) to specify the bands; defaults to OpenEarthMapNonGeo.all_band_names if not provided.
    """
    super().__init__(OpenEarthMapNonGeo, batch_size, num_workers, **kwargs)

    bands = kwargs.get("bands", OpenEarthMapNonGeo.all_band_names)
    self.means = torch.tensor([MEANS[b] for b in bands])
    self.stds = torch.tensor([STDS[b] for b in bands])
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.data_root = data_root
    self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=None) if aug is None else aug
setup(stage) #

Set up datasets.

Parameters:

Name Type Description Default
stage str

Either fit, validate, test, or predict.

required
Source code in terratorch/datamodules/openearthmap.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(  
            split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val", data_root=self.data_root, transform=self.val_transform, **self.kwargs
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="test",data_root=self.data_root, transform=self.predict_transform, **self.kwargs
        )

terratorch.datamodules.pastis #

PASTISDataModule #

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for PASTIS.

Source code in terratorch/datamodules/pastis.py
class PASTISDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for PASTIS."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        truncate_image: int | None = None,
        pad_image: int | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the PASTISDataModule for the PASTIS dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Directory containing the dataset. Defaults to "./".
            truncate_image (int, optional): Truncate the time dimension of the image to 
                a specified number of timesteps. If None, no truncation is performed.
            pad_image (int, optional): Pad the time dimension of the image to a specified 
                number of timesteps. If None, no padding is applied.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(
            PASTIS,
            batch_size=batch_size,
            num_workers=num_workers,
            **kwargs,
        )
        self.truncate_image = truncate_image
        self.pad_image = pad_image
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.data_root = data_root
        self.kwargs = kwargs

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = PASTIS(
                folds=[1, 2, 3],
                data_root=self.data_root,
                transform=self.train_transform,
                truncate_image=self.truncate_image,
                pad_image=self.pad_image,
                **self.kwargs,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = PASTIS(
                folds=[4],
                data_root=self.data_root,
                transform=self.val_transform,
                truncate_image=self.truncate_image,
                pad_image=self.pad_image,
                **self.kwargs,
            )
        if stage in ["test"]:
            self.test_dataset = PASTIS(
                folds=[5],
                data_root=self.data_root,
                transform=self.test_transform,
                truncate_image=self.truncate_image,
                pad_image=self.pad_image,
                **self.kwargs,
            )
        if stage in ["predict"]:
            self.predict_dataset = PASTIS(
                folds=[5],
                data_root=self.data_root,
                transform=self.predict_transform,
                truncate_image=self.truncate_image,
                pad_image=self.pad_image,
                **self.kwargs,
            )
__init__(batch_size=8, num_workers=0, data_root='./', truncate_image=None, pad_image=None, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, **kwargs) #

Initializes the PASTISDataModule for the PASTIS dataset.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Directory containing the dataset. Defaults to "./".

'./'
truncate_image int

Truncate the time dimension of the image to a specified number of timesteps. If None, no truncation is performed.

None
pad_image int

Pad the time dimension of the image to a specified number of timesteps. If None, no padding is applied.

None
train_transform Compose | None | list[BasicTransform]

Transformations for training data.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation data.

None
test_transform Compose | None | list[BasicTransform]

Transformations for testing data.

None
predict_transform Compose | None | list[BasicTransform]

Transformations for prediction data.

None
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/pastis.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    truncate_image: int | None = None,
    pad_image: int | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    **kwargs: Any,
) -> None:
    """
    Initializes the PASTISDataModule for the PASTIS dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Directory containing the dataset. Defaults to "./".
        truncate_image (int, optional): Truncate the time dimension of the image to 
            a specified number of timesteps. If None, no truncation is performed.
        pad_image (int, optional): Pad the time dimension of the image to a specified 
            number of timesteps. If None, no padding is applied.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(
        PASTIS,
        batch_size=batch_size,
        num_workers=num_workers,
        **kwargs,
    )
    self.truncate_image = truncate_image
    self.pad_image = pad_image
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.data_root = data_root
    self.kwargs = kwargs
setup(stage) #

Set up datasets.

Parameters:

Name Type Description Default
stage str

Either fit, validate, test, or predict.

required
Source code in terratorch/datamodules/pastis.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = PASTIS(
            folds=[1, 2, 3],
            data_root=self.data_root,
            transform=self.train_transform,
            truncate_image=self.truncate_image,
            pad_image=self.pad_image,
            **self.kwargs,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = PASTIS(
            folds=[4],
            data_root=self.data_root,
            transform=self.val_transform,
            truncate_image=self.truncate_image,
            pad_image=self.pad_image,
            **self.kwargs,
        )
    if stage in ["test"]:
        self.test_dataset = PASTIS(
            folds=[5],
            data_root=self.data_root,
            transform=self.test_transform,
            truncate_image=self.truncate_image,
            pad_image=self.pad_image,
            **self.kwargs,
        )
    if stage in ["predict"]:
        self.predict_dataset = PASTIS(
            folds=[5],
            data_root=self.data_root,
            transform=self.predict_transform,
            truncate_image=self.truncate_image,
            pad_image=self.pad_image,
            **self.kwargs,
        )

terratorch.datamodules.sen1floods11 #

Sen1Floods11NonGeoDataModule #

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Fire Scars.

Source code in terratorch/datamodules/sen1floods11.py
class Sen1Floods11NonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Fire Scars."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = Sen1Floods11NonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        drop_last: bool = True,
        constant_scale: float = 0.0001,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the Sen1Floods11NonGeoDataModule.

        Args:
            data_root (str): Root directory of the dataset.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (Sequence[str], optional): List of bands to use. Defaults to Sen1Floods11NonGeo.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for test data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
            constant_scale (float, optional): Scale constant applied to the dataset. Defaults to 0.0001.
            no_data_replace (float | None, optional): Replacement value for missing data. Defaults to 0.
            no_label_replace (int | None, optional): Replacement value for missing labels. Defaults to -1.
            use_metadata (bool): Whether to return metadata info (time and location).
            **kwargs: Additional keyword arguments.
        """
        super().__init__(Sen1Floods11NonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        means = [MEANS[b] for b in bands]
        stds = [STDS[b] for b in bands]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=None)
        self.drop_last = drop_last
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                constant_scale=self.constant_scale,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                constant_scale=self.constant_scale,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                constant_scale=self.constant_scale,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                constant_scale=self.constant_scale,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )
__init__(data_root, batch_size=4, num_workers=0, bands=Sen1Floods11NonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, drop_last=True, constant_scale=0.0001, no_data_replace=0, no_label_replace=-1, use_metadata=False, **kwargs) #

Initializes the Sen1Floods11NonGeoDataModule.

Parameters:

Name Type Description Default
data_root str

Root directory of the dataset.

required
batch_size int

Batch size for DataLoaders. Defaults to 4.

4
num_workers int

Number of workers for data loading. Defaults to 0.

0
bands Sequence[str]

List of bands to use. Defaults to Sen1Floods11NonGeo.all_band_names.

all_band_names
train_transform Compose | None | list[BasicTransform]

Transformations for training data.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation data.

None
test_transform Compose | None | list[BasicTransform]

Transformations for test data.

None
predict_transform Compose | None | list[BasicTransform]

Transformations for prediction data.

None
drop_last bool

Whether to drop the last incomplete batch. Defaults to True.

True
constant_scale float

Scale constant applied to the dataset. Defaults to 0.0001.

0.0001
no_data_replace float | None

Replacement value for missing data. Defaults to 0.

0
no_label_replace int | None

Replacement value for missing labels. Defaults to -1.

-1
use_metadata bool

Whether to return metadata info (time and location).

False
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/sen1floods11.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: Sequence[str] = Sen1Floods11NonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    drop_last: bool = True,
    constant_scale: float = 0.0001,
    no_data_replace: float | None = 0,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,
    **kwargs: Any,
) -> None:
    """
    Initializes the Sen1Floods11NonGeoDataModule.

    Args:
        data_root (str): Root directory of the dataset.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (Sequence[str], optional): List of bands to use. Defaults to Sen1Floods11NonGeo.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for test data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
        constant_scale (float, optional): Scale constant applied to the dataset. Defaults to 0.0001.
        no_data_replace (float | None, optional): Replacement value for missing data. Defaults to 0.
        no_label_replace (int | None, optional): Replacement value for missing labels. Defaults to -1.
        use_metadata (bool): Whether to return metadata info (time and location).
        **kwargs: Additional keyword arguments.
    """
    super().__init__(Sen1Floods11NonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    means = [MEANS[b] for b in bands]
    stds = [STDS[b] for b in bands]
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=None)
    self.drop_last = drop_last
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.use_metadata = use_metadata
setup(stage) #

Set up datasets.

Parameters:

Name Type Description Default
stage str

Either fit, validate, test, or predict.

required
Source code in terratorch/datamodules/sen1floods11.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            constant_scale=self.constant_scale,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            constant_scale=self.constant_scale,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            constant_scale=self.constant_scale,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            constant_scale=self.constant_scale,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )

terratorch.datamodules.sen4agrinet #

Sen4AgriNetDataModule #

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Sen4AgriNet.

Source code in terratorch/datamodules/sen4agrinet.py
class Sen4AgriNetDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Sen4AgriNet."""

    def __init__(
        self,
        bands: list[str] | None = None,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        seed: int = 42,
        scenario: str = "random",
        requires_norm: bool = True,
        binary_labels: bool = False,
        linear_encoder: dict = None,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the Sen4AgriNetDataModule for the Sen4AgriNet dataset.

        Args:
            bands (list[str] | None, optional): List of bands to use. Defaults to None.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for test data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            seed (int, optional): Random seed for reproducibility. Defaults to 42.
            scenario (str): Defines the splitting scenario to use. Options are:
                - 'random': Random split of the data.
                - 'spatial': Split by geographical regions (Catalonia and France).
                - 'spatio-temporal': Split by region and year (France 2019 and Catalonia 2020).
            requires_norm (bool, optional): Whether normalization is required. Defaults to True.
            binary_labels (bool, optional): Whether to use binary labels. Defaults to False.
            linear_encoder (dict, optional): Mapping for label encoding. Defaults to None.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(
            Sen4AgriNet,
            batch_size=batch_size,
            num_workers=num_workers,
            **kwargs,
        )
        self.bands = bands
        self.seed = seed
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.data_root = data_root
        self.scenario = scenario
        self.requires_norm = requires_norm
        self.binary_labels = binary_labels
        self.linear_encoder = linear_encoder
        self.kwargs = kwargs

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = Sen4AgriNet(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                seed=self.seed,
                scenario=self.scenario,
                requires_norm=self.requires_norm,
                binary_labels=self.binary_labels,
                linear_encoder=self.linear_encoder,
                **self.kwargs,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = Sen4AgriNet(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                seed=self.seed,
                scenario=self.scenario,
                requires_norm=self.requires_norm,
                binary_labels=self.binary_labels,
                linear_encoder=self.linear_encoder,
                **self.kwargs,
            )
        if stage in ["test"]:
            self.test_dataset = Sen4AgriNet(
                split="test",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                seed=self.seed,
                scenario=self.scenario,
                requires_norm=self.requires_norm,
                binary_labels=self.binary_labels,
                linear_encoder=self.linear_encoder,
                **self.kwargs,
            )
        if stage in ["predict"]:
            self.predict_dataset = Sen4AgriNet(
                split="test",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                seed=self.seed,
                scenario=self.scenario,
                requires_norm=self.requires_norm,
                binary_labels=self.binary_labels,
                linear_encoder=self.linear_encoder,
                **self.kwargs,
            )
__init__(bands=None, batch_size=8, num_workers=0, data_root='./', train_transform=None, val_transform=None, test_transform=None, predict_transform=None, seed=42, scenario='random', requires_norm=True, binary_labels=False, linear_encoder=None, **kwargs) #

Initializes the Sen4AgriNetDataModule for the Sen4AgriNet dataset.

Parameters:

Name Type Description Default
bands list[str] | None

List of bands to use. Defaults to None.

None
batch_size int

Batch size for DataLoaders. Defaults to 8.

8
num_workers int

Number of workers for data loading. Defaults to 0.

0
data_root str

Root directory of the dataset. Defaults to "./".

'./'
train_transform Compose | None | list[BasicTransform]

Transformations for training data.

None
val_transform Compose | None | list[BasicTransform]

Transformations for validation data.

None
test_transform Compose | None | list[BasicTransform]

Transformations for test data.

None
predict_transform Compose | None | list[BasicTransform]

Transformations for prediction data.

None
seed int

Random seed for reproducibility. Defaults to 42.

42
scenario str

Defines the splitting scenario to use. Options are: - 'random': Random split of the data. - 'spatial': Split by geographical regions (Catalonia and France). - 'spatio-temporal': Split by region and year (France 2019 and Catalonia 2020).

'random'
requires_norm bool

Whether normalization is required. Defaults to True.

True
binary_labels bool

Whether to use binary labels. Defaults to False.

False
linear_encoder dict

Mapping for label encoding. Defaults to None.

None
**kwargs Any

Additional keyword arguments.

{}
Source code in terratorch/datamodules/sen4agrinet.py
def __init__(
    self,
    bands: list[str] | None = None,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    seed: int = 42,
    scenario: str = "random",
    requires_norm: bool = True,
    binary_labels: bool = False,
    linear_encoder: dict = None,
    **kwargs: Any,
) -> None:
    """
    Initializes the Sen4AgriNetDataModule for the Sen4AgriNet dataset.

    Args:
        bands (list[str] | None, optional): List of bands to use. Defaults to None.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for test data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        seed (int, optional): Random seed for reproducibility. Defaults to 42.
        scenario (str): Defines the splitting scenario to use. Options are:
            - 'random': Random split of the data.
            - 'spatial': Split by geographical regions (Catalonia and France).
            - 'spatio-temporal': Split by region and year (France 2019 and Catalonia 2020).
        requires_norm (bool, optional): Whether normalization is required. Defaults to True.
        binary_labels (bool, optional): Whether to use binary labels. Defaults to False.
        linear_encoder (dict, optional): Mapping for label encoding. Defaults to None.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(
        Sen4AgriNet,
        batch_size=batch_size,
        num_workers=num_workers,
        **kwargs,
    )
    self.bands = bands
    self.seed = seed
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.data_root = data_root
    self.scenario = scenario
    self.requires_norm = requires_norm
    self.binary_labels = binary_labels
    self.linear_encoder = linear_encoder
    self.kwargs = kwargs
setup(stage) #

Set up datasets.

Parameters:

Name Type Description Default
stage str

Either fit, validate, test, or predict.

required
Source code in terratorch/datamodules/sen4agrinet.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = Sen4AgriNet(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            seed=self.seed,
            scenario=self.scenario,
            requires_norm=self.requires_norm,
            binary_labels=self.binary_labels,
            linear_encoder=self.linear_encoder,
            **self.kwargs,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = Sen4AgriNet(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            seed=self.seed,
            scenario=self.scenario,
            requires_norm=self.requires_norm,
            binary_labels=self.binary_labels,
            linear_encoder=self.linear_encoder,
            **self.kwargs,
        )
    if stage in ["test"]:
        self.test_dataset = Sen4AgriNet(
            split="test",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            seed=self.seed,
            scenario=self.scenario,
            requires_norm=self.requires_norm,
            binary_labels=self.binary_labels,
            linear_encoder=self.linear_encoder,
            **self.kwargs,
        )
    if stage in ["predict"]:
        self.predict_dataset = Sen4AgriNet(
            split="test",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            seed=self.seed,
            scenario=self.scenario,
            requires_norm=self.requires_norm,
            binary_labels=self.binary_labels,
            linear_encoder=self.linear_encoder,
            **self.kwargs,
        )

terratorch.datamodules.sen4map #

Sen4MapLucasDataModule #

Bases: LightningDataModule

NonGeo LightningDataModule implementation for Sen4map.

Source code in terratorch/datamodules/sen4map.py
class Sen4MapLucasDataModule(pl.LightningDataModule):
    """NonGeo LightningDataModule implementation for Sen4map."""

    def __init__(
            self, 
            batch_size,
            num_workers,
            prefetch_factor = 0,
            # dataset_bands:list[HLSBands|int] = None,
            # input_bands:list[HLSBands|int] = None,
            train_hdf5_path = None,
            train_hdf5_keys_path = None,
            test_hdf5_path = None,
            test_hdf5_keys_path = None,
            val_hdf5_path = None,
            val_hdf5_keys_path = None,
            **kwargs
            ):
        """
        Initializes the Sen4MapLucasDataModule for handling Sen4Map monthly composites.

        Args:
            batch_size (int): Batch size for DataLoaders.
            num_workers (int): Number of worker processes for data loading.
            prefetch_factor (int, optional): Number of samples to prefetch per worker. Defaults to 0.
            train_hdf5_path (str, optional): Path to the training HDF5 file.
            train_hdf5_keys_path (str, optional): Path to the training HDF5 keys file.
            test_hdf5_path (str, optional): Path to the testing HDF5 file.
            test_hdf5_keys_path (str, optional): Path to the testing HDF5 keys file.
            val_hdf5_path (str, optional): Path to the validation HDF5 file.
            val_hdf5_keys_path (str, optional): Path to the validation HDF5 keys file.
            train_hdf5_keys_save_path (str, optional): (from kwargs) Path to save generated train keys.
            test_hdf5_keys_save_path (str, optional): (from kwargs) Path to save generated test keys.
            val_hdf5_keys_save_path (str, optional): (from kwargs) Path to save generated validation keys.
            shuffle (bool, optional): Global shuffle flag.
            train_shuffle (bool, optional): Shuffle flag for training data; defaults to global shuffle if unset.
            val_shuffle (bool, optional): Shuffle flag for validation data.
            test_shuffle (bool, optional): Shuffle flag for test data.
            train_data_fraction (float, optional): Fraction of training data to use. Defaults to 1.0.
            val_data_fraction (float, optional): Fraction of validation data to use. Defaults to 1.0.
            test_data_fraction (float, optional): Fraction of test data to use. Defaults to 1.0.
            all_hdf5_data_path (str, optional): General HDF5 data path for all splits. If provided, overrides specific paths.
            resize (bool, optional): Whether to resize images. Defaults to False.
            resize_to (int or tuple, optional): Target size for resizing images.
            resize_interpolation (str, optional): Interpolation mode for resizing ('bilinear', 'bicubic', etc.).
            resize_antialiasing (bool, optional): Whether to apply antialiasing during resizing. Defaults to True.
            **kwargs: Additional keyword arguments.
        """
        self.prepare_data_per_node = False
        self._log_hyperparams = None
        self.allow_zero_length_dataloader_with_multiple_devices = False

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.prefetch_factor = prefetch_factor

        self.train_hdf5_path = train_hdf5_path
        self.test_hdf5_path = test_hdf5_path
        self.val_hdf5_path = val_hdf5_path

        self.train_hdf5_keys_path = train_hdf5_keys_path
        self.test_hdf5_keys_path = test_hdf5_keys_path
        self.val_hdf5_keys_path = val_hdf5_keys_path

        if train_hdf5_path and not train_hdf5_keys_path: print(f"Train dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")
        if test_hdf5_path and not test_hdf5_keys_path: print(f"Test dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")
        if val_hdf5_path and not val_hdf5_keys_path: print(f"Val dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")

        self.train_hdf5_keys_save_path = kwargs.pop("train_hdf5_keys_save_path", None)
        self.test_hdf5_keys_save_path = kwargs.pop("test_hdf5_keys_save_path", None)
        self.val_hdf5_keys_save_path = kwargs.pop("val_hdf5_keys_save_path", None)

        self.shuffle = kwargs.pop("shuffle", None)
        self.train_shuffle = kwargs.pop("train_shuffle", None) or self.shuffle
        self.val_shuffle = kwargs.pop("val_shuffle", None)
        self.test_shuffle = kwargs.pop("test_shuffle", None)

        self.train_data_fraction = kwargs.pop("train_data_fraction", 1.0)
        self.val_data_fraction = kwargs.pop("val_data_fraction", 1.0)
        self.test_data_fraction = kwargs.pop("test_data_fraction", 1.0)

        if self.train_data_fraction != 1.0  and  not train_hdf5_keys_path: raise ValueError(f"train_data_fraction provided as non-unity but train_hdf5_keys_path is unset.")
        if self.val_data_fraction != 1.0  and  not val_hdf5_keys_path: raise ValueError(f"val_data_fraction provided as non-unity but val_hdf5_keys_path is unset.")
        if self.test_data_fraction != 1.0  and  not test_hdf5_keys_path: raise ValueError(f"test_data_fraction provided as non-unity but test_hdf5_keys_path is unset.")

        all_hdf5_data_path = kwargs.pop("all_hdf5_data_path", None)
        if all_hdf5_data_path is not None:
            print(f"all_hdf5_data_path provided, will be interpreted as the general data path for all splits.\nKeys in provided train_hdf5_keys_path assumed to encompass all keys for entire data. Validation and Test keys will be subtracted from Train keys.")
            if self.train_hdf5_path: raise ValueError(f"Both general all_hdf5_data_path provided and a specific train_hdf5_path, remove the train_hdf5_path")
            if self.val_hdf5_path: raise ValueError(f"Both general all_hdf5_data_path provided and a specific val_hdf5_path, remove the val_hdf5_path")
            if self.test_hdf5_path: raise ValueError(f"Both general all_hdf5_data_path provided and a specific test_hdf5_path, remove the test_hdf5_path")
            self.train_hdf5_path = all_hdf5_data_path
            self.val_hdf5_path = all_hdf5_data_path
            self.test_hdf5_path = all_hdf5_data_path
            self.reduce_train_keys = True
        else:
            self.reduce_train_keys = False

        self.resize = kwargs.pop("resize", False)
        self.resize_to = kwargs.pop("resize_to", None)
        if self.resize and self.resize_to is None:
            raise ValueError(f"Config provided resize as True, but resize_to parameter not given")
        self.resize_interpolation = kwargs.pop("resize_interpolation", None)
        if self.resize and self.resize_interpolation is None:
            print(f"Config provided resize as True, but resize_interpolation mode not given. Will assume default bilinear")
            self.resize_interpolation = "bilinear"
        interpolation_dict = {
            "bilinear": InterpolationMode.BILINEAR,
            "bicubic": InterpolationMode.BICUBIC,
            "nearest": InterpolationMode.NEAREST,
            "nearest_exact": InterpolationMode.NEAREST_EXACT
        }
        if self.resize:
            if self.resize_interpolation not in interpolation_dict.keys():
                raise ValueError(f"resize_interpolation provided as {self.resize_interpolation}, but valid options are: {interpolation_dict.keys()}")
            self.resize_interpolation = interpolation_dict[self.resize_interpolation]
        self.resize_antialiasing = kwargs.pop("resize_antialiasing", True)

        self.kwargs = kwargs

    def _load_hdf5_keys_from_path(self, path, fraction=1.0):
        if path is None: return None
        with open(path, "rb") as f:
            keys = pickle.load(f)
            return keys[:int(fraction*len(keys))]

    def setup(self, stage: str):
        """Set up datasets.

        Args:
            stage: Either fit, test.
        """
        if stage == "fit":
            train_keys = self._load_hdf5_keys_from_path(self.train_hdf5_keys_path, fraction=self.train_data_fraction)
            val_keys = self._load_hdf5_keys_from_path(self.val_hdf5_keys_path, fraction=self.val_data_fraction)
            if self.reduce_train_keys:
                test_keys = self._load_hdf5_keys_from_path(self.test_hdf5_keys_path, fraction=self.test_data_fraction)
                train_keys = list(set(train_keys) - set(val_keys) - set(test_keys))
            train_file = h5py.File(self.train_hdf5_path, 'r')
            self.lucasS2_train = Sen4MapDatasetMonthlyComposites(
                train_file, 
                h5data_keys = train_keys, 
                resize = self.resize,
                resize_to = self.resize_to,
                resize_interpolation = self.resize_interpolation,
                resize_antialiasing = self.resize_antialiasing,
                save_keys_path = self.train_hdf5_keys_save_path,
                **self.kwargs
            )
            val_file = h5py.File(self.val_hdf5_path, 'r')
            self.lucasS2_val = Sen4MapDatasetMonthlyComposites(
                val_file, 
                h5data_keys=val_keys, 
                resize = self.resize,
                resize_to = self.resize_to,
                resize_interpolation = self.resize_interpolation,
                resize_antialiasing = self.resize_antialiasing,
                save_keys_path = self.val_hdf5_keys_save_path,
                **self.kwargs
            )
        if stage == "test":
            test_file = h5py.File(self.test_hdf5_path, 'r')
            test_keys = self._load_hdf5_keys_from_path(self.test_hdf5_keys_path, fraction=self.test_data_fraction)
            self.lucasS2_test = Sen4MapDatasetMonthlyComposites(
                test_file, 
                h5data_keys=test_keys, 
                resize = self.resize,
                resize_to = self.resize_to,
                resize_interpolation = self.resize_interpolation,
                resize_antialiasing = self.resize_antialiasing,
                save_keys_path = self.test_hdf5_keys_save_path,
                **self.kwargs
            )

    def train_dataloader(self):
        return DataLoader(self.lucasS2_train, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, shuffle=self.train_shuffle)

    def val_dataloader(self):
        return DataLoader(self.lucasS2_val, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, shuffle=self.val_shuffle)

    def test_dataloader(self):
        return DataLoader(self.lucasS2_test, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, shuffle=self.test_shuffle)
__init__(batch_size, num_workers, prefetch_factor=0, train_hdf5_path=None, train_hdf5_keys_path=None, test_hdf5_path=None, test_hdf5_keys_path=None, val_hdf5_path=None, val_hdf5_keys_path=None, **kwargs) #

Initializes the Sen4MapLucasDataModule for handling Sen4Map monthly composites.

Parameters:

Name Type Description Default
batch_size int

Batch size for DataLoaders.

required
num_workers int

Number of worker processes for data loading.

required
prefetch_factor int

Number of samples to prefetch per worker. Defaults to 0.

0
train_hdf5_path str

Path to the training HDF5 file.

None
train_hdf5_keys_path str

Path to the training HDF5 keys file.

None
test_hdf5_path str

Path to the testing HDF5 file.

None
test_hdf5_keys_path str

Path to the testing HDF5 keys file.

None
val_hdf5_path str

Path to the validation HDF5 file.

None
val_hdf5_keys_path str

Path to the validation HDF5 keys file.

None
train_hdf5_keys_save_path str

(from kwargs) Path to save generated train keys.

required
test_hdf5_keys_save_path str

(from kwargs) Path to save generated test keys.

required
val_hdf5_keys_save_path str

(from kwargs) Path to save generated validation keys.

required
shuffle bool

Global shuffle flag.

required
train_shuffle bool

Shuffle flag for training data; defaults to global shuffle if unset.

required
val_shuffle bool

Shuffle flag for validation data.

required
test_shuffle bool

Shuffle flag for test data.

required
train_data_fraction float

Fraction of training data to use. Defaults to 1.0.

required
val_data_fraction float

Fraction of validation data to use. Defaults to 1.0.

required
test_data_fraction float

Fraction of test data to use. Defaults to 1.0.

required
all_hdf5_data_path str

General HDF5 data path for all splits. If provided, overrides specific paths.

required
resize bool

Whether to resize images. Defaults to False.

required
resize_to int or tuple

Target size for resizing images.

required
resize_interpolation str

Interpolation mode for resizing ('bilinear', 'bicubic', etc.).

required
resize_antialiasing bool

Whether to apply antialiasing during resizing. Defaults to True.

required
**kwargs

Additional keyword arguments.

{}
Source code in terratorch/datamodules/sen4map.py
def __init__(
        self, 
        batch_size,
        num_workers,
        prefetch_factor = 0,
        # dataset_bands:list[HLSBands|int] = None,
        # input_bands:list[HLSBands|int] = None,
        train_hdf5_path = None,
        train_hdf5_keys_path = None,
        test_hdf5_path = None,
        test_hdf5_keys_path = None,
        val_hdf5_path = None,
        val_hdf5_keys_path = None,
        **kwargs
        ):
    """
    Initializes the Sen4MapLucasDataModule for handling Sen4Map monthly composites.

    Args:
        batch_size (int): Batch size for DataLoaders.
        num_workers (int): Number of worker processes for data loading.
        prefetch_factor (int, optional): Number of samples to prefetch per worker. Defaults to 0.
        train_hdf5_path (str, optional): Path to the training HDF5 file.
        train_hdf5_keys_path (str, optional): Path to the training HDF5 keys file.
        test_hdf5_path (str, optional): Path to the testing HDF5 file.
        test_hdf5_keys_path (str, optional): Path to the testing HDF5 keys file.
        val_hdf5_path (str, optional): Path to the validation HDF5 file.
        val_hdf5_keys_path (str, optional): Path to the validation HDF5 keys file.
        train_hdf5_keys_save_path (str, optional): (from kwargs) Path to save generated train keys.
        test_hdf5_keys_save_path (str, optional): (from kwargs) Path to save generated test keys.
        val_hdf5_keys_save_path (str, optional): (from kwargs) Path to save generated validation keys.
        shuffle (bool, optional): Global shuffle flag.
        train_shuffle (bool, optional): Shuffle flag for training data; defaults to global shuffle if unset.
        val_shuffle (bool, optional): Shuffle flag for validation data.
        test_shuffle (bool, optional): Shuffle flag for test data.
        train_data_fraction (float, optional): Fraction of training data to use. Defaults to 1.0.
        val_data_fraction (float, optional): Fraction of validation data to use. Defaults to 1.0.
        test_data_fraction (float, optional): Fraction of test data to use. Defaults to 1.0.
        all_hdf5_data_path (str, optional): General HDF5 data path for all splits. If provided, overrides specific paths.
        resize (bool, optional): Whether to resize images. Defaults to False.
        resize_to (int or tuple, optional): Target size for resizing images.
        resize_interpolation (str, optional): Interpolation mode for resizing ('bilinear', 'bicubic', etc.).
        resize_antialiasing (bool, optional): Whether to apply antialiasing during resizing. Defaults to True.
        **kwargs: Additional keyword arguments.
    """
    self.prepare_data_per_node = False
    self._log_hyperparams = None
    self.allow_zero_length_dataloader_with_multiple_devices = False

    self.batch_size = batch_size
    self.num_workers = num_workers
    self.prefetch_factor = prefetch_factor

    self.train_hdf5_path = train_hdf5_path
    self.test_hdf5_path = test_hdf5_path
    self.val_hdf5_path = val_hdf5_path

    self.train_hdf5_keys_path = train_hdf5_keys_path
    self.test_hdf5_keys_path = test_hdf5_keys_path
    self.val_hdf5_keys_path = val_hdf5_keys_path

    if train_hdf5_path and not train_hdf5_keys_path: print(f"Train dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")
    if test_hdf5_path and not test_hdf5_keys_path: print(f"Test dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")
    if val_hdf5_path and not val_hdf5_keys_path: print(f"Val dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")

    self.train_hdf5_keys_save_path = kwargs.pop("train_hdf5_keys_save_path", None)
    self.test_hdf5_keys_save_path = kwargs.pop("test_hdf5_keys_save_path", None)
    self.val_hdf5_keys_save_path = kwargs.pop("val_hdf5_keys_save_path", None)

    self.shuffle = kwargs.pop("shuffle", None)
    self.train_shuffle = kwargs.pop("train_shuffle", None) or self.shuffle
    self.val_shuffle = kwargs.pop("val_shuffle", None)
    self.test_shuffle = kwargs.pop("test_shuffle", None)

    self.train_data_fraction = kwargs.pop("train_data_fraction", 1.0)
    self.val_data_fraction = kwargs.pop("val_data_fraction", 1.0)
    self.test_data_fraction = kwargs.pop("test_data_fraction", 1.0)

    if self.train_data_fraction != 1.0  and  not train_hdf5_keys_path: raise ValueError(f"train_data_fraction provided as non-unity but train_hdf5_keys_path is unset.")
    if self.val_data_fraction != 1.0  and  not val_hdf5_keys_path: raise ValueError(f"val_data_fraction provided as non-unity but val_hdf5_keys_path is unset.")
    if self.test_data_fraction != 1.0  and  not test_hdf5_keys_path: raise ValueError(f"test_data_fraction provided as non-unity but test_hdf5_keys_path is unset.")

    all_hdf5_data_path = kwargs.pop("all_hdf5_data_path", None)
    if all_hdf5_data_path is not None:
        print(f"all_hdf5_data_path provided, will be interpreted as the general data path for all splits.\nKeys in provided train_hdf5_keys_path assumed to encompass all keys for entire data. Validation and Test keys will be subtracted from Train keys.")
        if self.train_hdf5_path: raise ValueError(f"Both general all_hdf5_data_path provided and a specific train_hdf5_path, remove the train_hdf5_path")
        if self.val_hdf5_path: raise ValueError(f"Both general all_hdf5_data_path provided and a specific val_hdf5_path, remove the val_hdf5_path")
        if self.test_hdf5_path: raise ValueError(f"Both general all_hdf5_data_path provided and a specific test_hdf5_path, remove the test_hdf5_path")
        self.train_hdf5_path = all_hdf5_data_path
        self.val_hdf5_path = all_hdf5_data_path
        self.test_hdf5_path = all_hdf5_data_path
        self.reduce_train_keys = True
    else:
        self.reduce_train_keys = False

    self.resize = kwargs.pop("resize", False)
    self.resize_to = kwargs.pop("resize_to", None)
    if self.resize and self.resize_to is None:
        raise ValueError(f"Config provided resize as True, but resize_to parameter not given")
    self.resize_interpolation = kwargs.pop("resize_interpolation", None)
    if self.resize and self.resize_interpolation is None:
        print(f"Config provided resize as True, but resize_interpolation mode not given. Will assume default bilinear")
        self.resize_interpolation = "bilinear"
    interpolation_dict = {
        "bilinear": InterpolationMode.BILINEAR,
        "bicubic": InterpolationMode.BICUBIC,
        "nearest": InterpolationMode.NEAREST,
        "nearest_exact": InterpolationMode.NEAREST_EXACT
    }
    if self.resize:
        if self.resize_interpolation not in interpolation_dict.keys():
            raise ValueError(f"resize_interpolation provided as {self.resize_interpolation}, but valid options are: {interpolation_dict.keys()}")
        self.resize_interpolation = interpolation_dict[self.resize_interpolation]
    self.resize_antialiasing = kwargs.pop("resize_antialiasing", True)

    self.kwargs = kwargs
setup(stage) #

Set up datasets.

Parameters:

Name Type Description Default
stage str

Either fit, test.

required
Source code in terratorch/datamodules/sen4map.py
def setup(self, stage: str):
    """Set up datasets.

    Args:
        stage: Either fit, test.
    """
    if stage == "fit":
        train_keys = self._load_hdf5_keys_from_path(self.train_hdf5_keys_path, fraction=self.train_data_fraction)
        val_keys = self._load_hdf5_keys_from_path(self.val_hdf5_keys_path, fraction=self.val_data_fraction)
        if self.reduce_train_keys:
            test_keys = self._load_hdf5_keys_from_path(self.test_hdf5_keys_path, fraction=self.test_data_fraction)
            train_keys = list(set(train_keys) - set(val_keys) - set(test_keys))
        train_file = h5py.File(self.train_hdf5_path, 'r')
        self.lucasS2_train = Sen4MapDatasetMonthlyComposites(
            train_file, 
            h5data_keys = train_keys, 
            resize = self.resize,
            resize_to = self.resize_to,
            resize_interpolation = self.resize_interpolation,
            resize_antialiasing = self.resize_antialiasing,
            save_keys_path = self.train_hdf5_keys_save_path,
            **self.kwargs
        )
        val_file = h5py.File(self.val_hdf5_path, 'r')
        self.lucasS2_val = Sen4MapDatasetMonthlyComposites(
            val_file, 
            h5data_keys=val_keys, 
            resize = self.resize,
            resize_to = self.resize_to,
            resize_interpolation = self.resize_interpolation,
            resize_antialiasing = self.resize_antialiasing,
            save_keys_path = self.val_hdf5_keys_save_path,
            **self.kwargs
        )
    if stage == "test":
        test_file = h5py.File(self.test_hdf5_path, 'r')
        test_keys = self._load_hdf5_keys_from_path(self.test_hdf5_keys_path, fraction=self.test_data_fraction)
        self.lucasS2_test = Sen4MapDatasetMonthlyComposites(
            test_file, 
            h5data_keys=test_keys, 
            resize = self.resize,
            resize_to = self.resize_to,
            resize_interpolation = self.resize_interpolation,
            resize_antialiasing = self.resize_antialiasing,
            save_keys_path = self.test_hdf5_keys_save_path,
            **self.kwargs
        )