Skip to content

Specific Datasets#

terratorch.datasets.biomassters #

BioMasstersNonGeo #

Bases: BioMassters

BioMassters Dataset for Aboveground Biomass prediction.

Dataset intended for Aboveground Biomass (AGB) prediction over Finnish forests based on Sentinel 1 and 2 data with corresponding target AGB mask values generated by Light Detection and Ranging (LiDAR).

Dataset Format:

  • .tif files for Sentinel 1 and 2 data
  • .tif file for pixel wise AGB target mask
  • .csv files for metadata regarding features and targets

Dataset Features:

  • 13,000 target AGB masks of size (256x256px)
  • 12 months of data per target mask
  • Sentinel 1 and Sentinel 2 data for each location
  • Sentinel 1 available for every month
  • Sentinel 2 available for almost every month (not available for every month due to ESA acquisition halt over the region during particular periods)

If you use this dataset in your research, please cite the following paper:

  • https://nascetti-a.github.io/BioMasster/

.. versionadded:: 0.5

Source code in terratorch/datasets/biomassters.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
class BioMasstersNonGeo(BioMassters):
    """[BioMassters Dataset](https://huggingface.co/datasets/ibm-nasa-geospatial/BioMassters) for Aboveground Biomass prediction.

    Dataset intended for Aboveground Biomass (AGB) prediction
    over Finnish forests based on Sentinel 1 and 2 data with
    corresponding target AGB mask values generated by Light Detection
    and Ranging (LiDAR).

    Dataset Format:

    * .tif files for Sentinel 1 and 2 data
    * .tif file for pixel wise AGB target mask
    * .csv files for metadata regarding features and targets

    Dataset Features:

    * 13,000 target AGB masks of size (256x256px)
    * 12 months of data per target mask
    * Sentinel 1 and Sentinel 2 data for each location
    * Sentinel 1 available for every month
    * Sentinel 2 available for almost every month
      (not available for every month due to ESA acquisition halt over the region
      during particular periods)

    If you use this dataset in your research, please cite the following paper:

    * https://nascetti-a.github.io/BioMasster/

    .. versionadded:: 0.5
    """

    S1_BAND_NAMES = ["VV_Asc", "VH_Asc", "VV_Desc", "VH_Desc", "RVI_Asc", "RVI_Desc"]
    S2_BAND_NAMES = [
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
        "CLOUD_PROBABILITY",
    ]

    all_band_names = {
        "S1": S1_BAND_NAMES,
        "S2": S2_BAND_NAMES,
    }

    rgb_bands = {
        "S1": [],
        "S2": ["RED", "GREEN", "BLUE"],
    }

    valid_splits = ("train", "test")
    valid_sensors = ("S1", "S2")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    default_metadata_filename = "The_BioMassters_-_features_metadata.csv.csv"

    def __init__(
        self,
        root = "data",
        split: str = "train",
        bands: dict[str, Sequence[str]] | Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        mask_mean: float | None = 63.4584,
        mask_std: float | None = 72.21242,
        sensors: Sequence[str] = ["S1", "S2"],
        as_time_series: bool = False,
        metadata_filename: str = default_metadata_filename,
        max_cloud_percentage: float | None = None,
        max_red_mean: float | None = None,
        include_corrupt: bool = True,
        subset: float = 1,
        seed: int = 42,
        use_four_frames: bool = False
    ) -> None:
        """Initialize a new instance of BioMassters dataset.

        If ``as_time_series=False`` (the default), each time step becomes its own
        sample with the target being shared across multiple samples.

        Args:
            root: root directory where dataset can be found
            split: train or test split
            sensors: which sensors to consider for the sample, Sentinel 1 and/or
                Sentinel 2 ('S1', 'S2')
            as_time_series: whether or not to return all available
                time-steps or just a single one for a given target location
            metadata_filename: metadata file to be used
            max_cloud_percentage: maximum allowed cloud percentage for images
            max_red_mean: maximum allowed red_mean value for images
            include_corrupt: whether to include images marked as corrupted

        Raises:
            AssertionError: if ``split`` or ``sensors`` is invalid
            DatasetNotFoundError: If dataset is not found.
        """
        self.root = root
        self.sensors = sensors
        self.bands = bands
        assert (
            split in self.valid_splits
        ), f"Please choose one of the valid splits: {self.valid_splits}."
        self.split = split

        assert set(sensors).issubset(
            set(self.valid_sensors)
        ), f"Please choose a subset of valid sensors: {self.valid_sensors}."

        if len(self.sensors) == 1:
            sens = self.sensors[0]
            self.band_indices = [
                self.all_band_names[sens].index(band) for band in self.bands[sens]
            ]
        else:
            self.band_indices = {
                sens: [self.all_band_names[sens].index(band) for band in self.bands[sens]]
                for sens in self.sensors
            }

        self.mask_mean = mask_mean
        self.mask_std = mask_std
        self.as_time_series = as_time_series
        self.metadata_filename = metadata_filename
        self.max_cloud_percentage = max_cloud_percentage
        self.max_red_mean = max_red_mean
        self.include_corrupt = include_corrupt
        self.subset = subset
        self.seed = seed
        self.use_four_frames = use_four_frames

        self._verify()

        # open metadata csv files
        self.df = pd.read_csv(os.path.join(self.root, self.metadata_filename))

        # Filter sensors
        self.df = self.df[self.df["satellite"].isin(self.sensors)]

        # Filter split
        self.df = self.df[self.df["split"] == self.split]

        # Optional filtering
        self._filter_and_select_data()

        # Optional subsampling
        self._random_subsample()

        # generate numerical month from filename since first month is September
        # and has numerical index of 0
        self.df["num_month"] = (
            self.df["filename"]
            .str.split("_", expand=True)[2]
            .str.split(".", expand=True)[0]
            .astype(int)
        )

        # Set dataframe index depending on the task for easier indexing
        if self.as_time_series:
            self.df["num_index"] = self.df.groupby(["chip_id"]).ngroup()
        else:
            filter_df = (
                self.df.groupby(["chip_id", "month"])["satellite"].count().reset_index()
            )
            filter_df = filter_df[
                filter_df["satellite"] == len(self.sensors)
            ].drop("satellite", axis=1)
            # Guarantee that each sample has corresponding number of images available
            self.df = self.df.merge(filter_df, on=["chip_id", "month"], how="inner")

            self.df["num_index"] = self.df.groupby(["chip_id", "month"]).ngroup()

        # Adjust transforms based on the number of sensors
        if len(self.sensors) == 1:
            self.transform = transform if transform else default_transform
        elif transform is None:
            self.transform = MultimodalToTensor(self.sensors)
        else:
            transform = {
                s: transform[s] if s in transform else default_transform
                for s in self.sensors
            }
            self.transform = MultimodalTransforms(transform, shared=False)

        if self.use_four_frames:
            self._select_4_frames()

    def __len__(self) -> int:
        return len(self.df["num_index"].unique())

    def _load_input(self, filenames: list[Path]) -> Tensor:
        """Load the input imagery at the index.

        Args:
            filenames: list of filenames corresponding to input

        Returns:
            input image
        """
        filepaths = [
            os.path.join(self.root, f"{self.split}_features", f) for f in filenames
        ]
        arr_list = [rasterio.open(fp).read() for fp in filepaths]

        if self.as_time_series:
            arr = np.stack(arr_list, axis=0) # (T, C, H, W)
        else:
            arr = np.concatenate(arr_list, axis=0)
        return arr.astype(np.int32)

    def _load_target(self, filename: Path) -> Tensor:
        """Load the target mask at the index.

        Args:
            filename: filename of target to index

        Returns:
            target mask
        """
        with rasterio.open(os.path.join(self.root, f"{self.split}_agbm", filename), "r") as src:
            arr: np.typing.NDArray[np.float64] = src.read()

        return arr

    def _compute_rvi(self, img: np.ndarray, linear: np.ndarray, sens: str) -> np.ndarray:
        """Compute the RVI indices for S1 data."""
        rvi_channels = []
        if self.as_time_series:
            if "RVI_Asc" in self.bands[sens]:
                try:
                    vv_asc_index = self.all_band_names["S1"].index("VV_Asc")
                    vh_asc_index = self.all_band_names["S1"].index("VH_Asc")
                except ValueError as e:
                    msg = f"RVI_Asc needs band: {e}"
                    raise ValueError(msg) from e

                VV = linear[:, vv_asc_index, :, :]
                VH = linear[:, vh_asc_index, :, :]
                rvi_asc = 4 * VH / (VV + VH + 1e-6)
                rvi_asc = np.expand_dims(rvi_asc, axis=1)
                rvi_channels.append(rvi_asc)
            if "RVI_Desc" in self.bands[sens]:
                try:
                    vv_desc_index = self.all_band_names["S1"].index("VV_Desc")
                    vh_desc_index = self.all_band_names["S1"].index("VH_Desc")
                except ValueError as e:
                    msg = f"RVI_Desc needs band: {e}"
                    raise ValueError(msg) from e

                VV_desc = linear[:, vv_desc_index, :, :]
                VH_desc = linear[:, vh_desc_index, :, :]
                rvi_desc = 4 * VH_desc / (VV_desc + VH_desc + 1e-6)
                rvi_desc = np.expand_dims(rvi_desc, axis=1)
                rvi_channels.append(rvi_desc)
            if rvi_channels:
                rvi_concat = np.concatenate(rvi_channels, axis=1)
                img = np.concatenate([img, rvi_concat], axis=1)
        else:
            if "RVI_Asc" in self.bands[sens]:
                if linear.shape[0] < 2:
                    msg = f"Not enough bands to calculate RVI_Asc. Available bands: {linear.shape[0]}"
                    raise ValueError(msg)
                VV = linear[0]
                VH = linear[1]
                rvi_asc = 4 * VH / (VV + VH + 1e-6)
                rvi_asc = np.expand_dims(rvi_asc, axis=0)
                rvi_channels.append(rvi_asc)
            if "RVI_Desc" in self.bands[sens]:
                if linear.shape[0] < 4:
                    msg = f"Not enough bands to calculate RVI_Desc. Available bands: {linear.shape[0]}"
                    raise ValueError(msg)
                VV_desc = linear[2]
                VH_desc = linear[3]
                rvi_desc = 4 * VH_desc / (VV_desc + VH_desc + 1e-6)
                rvi_desc = np.expand_dims(rvi_desc, axis=0) 
                rvi_channels.append(rvi_desc)
            if rvi_channels:
                rvi_concat = np.concatenate(rvi_channels, axis=0)
                img = np.concatenate([linear, rvi_concat], axis=0)
        return img

    def _select_4_frames(self):
        """Filter the dataset to select only 4 frames per sample."""

        if "cloud_percentage" in self.df.columns:
            self.df = self.df.sort_values(by=["chip_id", "cloud_percentage"])
        else:
            self.df = self.df.sort_values(by=["chip_id", "num_month"])

        self.df = (
            self.df.groupby("chip_id")
            .head(4)  # Select the first 4 frames per chip
            .reset_index(drop=True)
        )

    def _process_sensor_images(self, sens: str, sens_filepaths: list[str]) -> np.ndarray:
        """Process images for a given sensor."""
        img = self._load_input(sens_filepaths)
        if sens == "S1":
            img = img.astype(np.float32)
            linear = 10 ** (img / 10)
            img = self._compute_rvi(img, linear, sens)
        if self.as_time_series:
            img = img.transpose(0, 2, 3, 1)  # (T, H, W, C)
        else:
            img = img.transpose(1, 2, 0)  # (H, W, C)
        if len(self.sensors) == 1:
            img = img[..., self.band_indices]
        else:
            img = img[..., self.band_indices[sens]]
        return img

    def __getitem__(self, index: int) -> dict:
        sample_df = self.df[self.df["num_index"] == index].copy()
        # Sort by satellite and month
        sample_df.sort_values(
            by=["satellite", "num_month"], inplace=True, ascending=True
        )

        filepaths = sample_df["filename"].tolist()
        output = {}

        if len(self.sensors) == 1:
            sens = self.sensors[0]
            sens_filepaths = [fp for fp in filepaths if sens in fp]
            img = self._process_sensor_images(sens, sens_filepaths)
            output["image"] = img.astype(np.float32)
        else:
            for sens in self.sensors:
                sens_filepaths = [fp for fp in filepaths if sens in fp]
                img = self._process_sensor_images(sens, sens_filepaths)
                output[sens] = img.astype(np.float32)

        # Load target
        target_filename = sample_df["corresponding_agbm"].unique()[0]
        target = np.array(self._load_target(Path(target_filename)))
        target = target.transpose(1, 2, 0)
        output["mask"] = target
        if self.transform:
            if len(self.sensors) == 1:
                output = self.transform(**output)
            else:
                output = self.transform(output)
        output["mask"] = output["mask"].squeeze().float()
        return output

    def _filter_and_select_data(self):
        if (
            self.max_cloud_percentage is not None
            and "cloud_percentage" in self.df.columns
        ):
            self.df = self.df[self.df["cloud_percentage"] <= self.max_cloud_percentage]

        if self.max_red_mean is not None and "red_mean" in self.df.columns:
            self.df = self.df[self.df["red_mean"] <= self.max_red_mean]

        if not self.include_corrupt and "corrupt_values" in self.df.columns:
            self.df = self.df[self.df["corrupt_values"] is False]

    def _random_subsample(self):
        if self.split == "train" and self.subset < 1.0:
            num_samples = int(len(self.df["num_index"].unique()) * self.subset)
            if self.seed is not None:
                random.seed(self.seed)
            selected_indices = random.sample(
                list(self.df["num_index"].unique()), num_samples
            )
            self.df = self.df[self.df["num_index"].isin(selected_indices)]
            self.df.reset_index(drop=True, inplace=True)

    def plot(
        self,
        sample: dict[str, Tensor],
        show_titles: bool = True,
        suptitle: str | None = None,
    ) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            show_titles: flag indicating whether to show titles above each panel
            suptitle: optional suptitle to use for figure

        Returns:
            a matplotlib Figure with the rendered sample
        """
        # Determine if the sample contains multiple sensors or a single sensor
        if isinstance(sample["image"], dict):
            ncols = len(self.sensors) + 1
        else:
            ncols = 2  # One for the image and one for the mask

        showing_predictions = "prediction" in sample
        if showing_predictions:
            ncols += 1

        fig, axs = plt.subplots(1, ncols=ncols, figsize=(5 * ncols, 10))

        if isinstance(sample["image"], dict):
            # Multiple sensors case
            for idx, sens in enumerate(self.sensors):
                img = sample["image"][sens].numpy()
                if self.as_time_series:
                    # Plot last time step
                    img = img[:, -1, ...]
                if sens == "S2":
                    img = img[[2, 1, 0], ...].transpose(1, 2, 0)
                    img = percentile_normalization(img)
                else:
                    co_polarization = img[0]  # transmit == receive
                    cross_polarization = img[1]  # transmit != receive
                    ratio = co_polarization / (cross_polarization + 1e-6)

                    co_polarization = np.clip(co_polarization / 0.3, 0, 1)
                    cross_polarization = np.clip(cross_polarization / 0.05, 0, 1)
                    ratio = np.clip(ratio / 25, 0, 1)

                    img = np.stack(
                        (co_polarization, cross_polarization, ratio), axis=0
                    )
                    img = img.transpose(1, 2, 0)  # Convert to (H, W, 3)

                axs[idx].imshow(img)
                axs[idx].axis("off")
                if show_titles:
                    axs[idx].set_title(sens)
            mask_idx = len(self.sensors)
        else:
            # Single sensor case
            sens = self.sensors[0]
            img = sample["image"].numpy()
            if self.as_time_series:
                # Plot last time step
                img = img[:, -1, ...]
            if sens == "S2":
                img = img[[2, 1, 0], ...].transpose(1, 2, 0)
                img = percentile_normalization(img)
            else:
                co_polarization = img[0]  # transmit == receive
                cross_polarization = img[1]  # transmit != receive
                ratio = co_polarization / (cross_polarization + 1e-6)

                co_polarization = np.clip(co_polarization / 0.3, 0, 1)
                cross_polarization = np.clip(cross_polarization / 0.05, 0, 1)
                ratio = np.clip(ratio / 25, 0, 1)

                img = np.stack(
                    (co_polarization, cross_polarization, ratio), axis=0
                )
                img = img.transpose(1, 2, 0)  # Convert to (H, W, 3)

            axs[0].imshow(img)
            axs[0].axis("off")
            if show_titles:
                axs[0].set_title(sens)
            mask_idx = 1

        # Plot target mask
        if "mask" in sample:
            target = sample["mask"].squeeze()
            target_im = axs[mask_idx].imshow(target, cmap="YlGn")
            plt.colorbar(target_im, ax=axs[mask_idx], fraction=0.046, pad=0.04)
            axs[mask_idx].axis("off")
            if show_titles:
                axs[mask_idx].set_title("Target")

        # Plot prediction if available
        if showing_predictions:
            pred_idx = mask_idx + 1
            prediction = sample["prediction"].squeeze()
            pred_im = axs[pred_idx].imshow(prediction, cmap="YlGn")
            plt.colorbar(pred_im, ax=axs[pred_idx], fraction=0.046, pad=0.04)
            axs[pred_idx].axis("off")
            if show_titles:
                axs[pred_idx].set_title("Prediction")

        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(root='data', split='train', bands=BAND_SETS['all'], transform=None, mask_mean=63.4584, mask_std=72.21242, sensors=['S1', 'S2'], as_time_series=False, metadata_filename=default_metadata_filename, max_cloud_percentage=None, max_red_mean=None, include_corrupt=True, subset=1, seed=42, use_four_frames=False) #

Initialize a new instance of BioMassters dataset.

If as_time_series=False (the default), each time step becomes its own sample with the target being shared across multiple samples.

Parameters:

Name Type Description Default
root

root directory where dataset can be found

'data'
split str

train or test split

'train'
sensors Sequence[str]

which sensors to consider for the sample, Sentinel 1 and/or Sentinel 2 ('S1', 'S2')

['S1', 'S2']
as_time_series bool

whether or not to return all available time-steps or just a single one for a given target location

False
metadata_filename str

metadata file to be used

default_metadata_filename
max_cloud_percentage float | None

maximum allowed cloud percentage for images

None
max_red_mean float | None

maximum allowed red_mean value for images

None
include_corrupt bool

whether to include images marked as corrupted

True

Raises:

Type Description
AssertionError

if split or sensors is invalid

DatasetNotFoundError

If dataset is not found.

Source code in terratorch/datasets/biomassters.py
def __init__(
    self,
    root = "data",
    split: str = "train",
    bands: dict[str, Sequence[str]] | Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    mask_mean: float | None = 63.4584,
    mask_std: float | None = 72.21242,
    sensors: Sequence[str] = ["S1", "S2"],
    as_time_series: bool = False,
    metadata_filename: str = default_metadata_filename,
    max_cloud_percentage: float | None = None,
    max_red_mean: float | None = None,
    include_corrupt: bool = True,
    subset: float = 1,
    seed: int = 42,
    use_four_frames: bool = False
) -> None:
    """Initialize a new instance of BioMassters dataset.

    If ``as_time_series=False`` (the default), each time step becomes its own
    sample with the target being shared across multiple samples.

    Args:
        root: root directory where dataset can be found
        split: train or test split
        sensors: which sensors to consider for the sample, Sentinel 1 and/or
            Sentinel 2 ('S1', 'S2')
        as_time_series: whether or not to return all available
            time-steps or just a single one for a given target location
        metadata_filename: metadata file to be used
        max_cloud_percentage: maximum allowed cloud percentage for images
        max_red_mean: maximum allowed red_mean value for images
        include_corrupt: whether to include images marked as corrupted

    Raises:
        AssertionError: if ``split`` or ``sensors`` is invalid
        DatasetNotFoundError: If dataset is not found.
    """
    self.root = root
    self.sensors = sensors
    self.bands = bands
    assert (
        split in self.valid_splits
    ), f"Please choose one of the valid splits: {self.valid_splits}."
    self.split = split

    assert set(sensors).issubset(
        set(self.valid_sensors)
    ), f"Please choose a subset of valid sensors: {self.valid_sensors}."

    if len(self.sensors) == 1:
        sens = self.sensors[0]
        self.band_indices = [
            self.all_band_names[sens].index(band) for band in self.bands[sens]
        ]
    else:
        self.band_indices = {
            sens: [self.all_band_names[sens].index(band) for band in self.bands[sens]]
            for sens in self.sensors
        }

    self.mask_mean = mask_mean
    self.mask_std = mask_std
    self.as_time_series = as_time_series
    self.metadata_filename = metadata_filename
    self.max_cloud_percentage = max_cloud_percentage
    self.max_red_mean = max_red_mean
    self.include_corrupt = include_corrupt
    self.subset = subset
    self.seed = seed
    self.use_four_frames = use_four_frames

    self._verify()

    # open metadata csv files
    self.df = pd.read_csv(os.path.join(self.root, self.metadata_filename))

    # Filter sensors
    self.df = self.df[self.df["satellite"].isin(self.sensors)]

    # Filter split
    self.df = self.df[self.df["split"] == self.split]

    # Optional filtering
    self._filter_and_select_data()

    # Optional subsampling
    self._random_subsample()

    # generate numerical month from filename since first month is September
    # and has numerical index of 0
    self.df["num_month"] = (
        self.df["filename"]
        .str.split("_", expand=True)[2]
        .str.split(".", expand=True)[0]
        .astype(int)
    )

    # Set dataframe index depending on the task for easier indexing
    if self.as_time_series:
        self.df["num_index"] = self.df.groupby(["chip_id"]).ngroup()
    else:
        filter_df = (
            self.df.groupby(["chip_id", "month"])["satellite"].count().reset_index()
        )
        filter_df = filter_df[
            filter_df["satellite"] == len(self.sensors)
        ].drop("satellite", axis=1)
        # Guarantee that each sample has corresponding number of images available
        self.df = self.df.merge(filter_df, on=["chip_id", "month"], how="inner")

        self.df["num_index"] = self.df.groupby(["chip_id", "month"]).ngroup()

    # Adjust transforms based on the number of sensors
    if len(self.sensors) == 1:
        self.transform = transform if transform else default_transform
    elif transform is None:
        self.transform = MultimodalToTensor(self.sensors)
    else:
        transform = {
            s: transform[s] if s in transform else default_transform
            for s in self.sensors
        }
        self.transform = MultimodalTransforms(transform, shared=False)

    if self.use_four_frames:
        self._select_4_frames()
plot(sample, show_titles=True, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

a sample returned by :meth:__getitem__

required
show_titles bool

flag indicating whether to show titles above each panel

True
suptitle str | None

optional suptitle to use for figure

None

Returns:

Type Description
Figure

a matplotlib Figure with the rendered sample

Source code in terratorch/datasets/biomassters.py
def plot(
    self,
    sample: dict[str, Tensor],
    show_titles: bool = True,
    suptitle: str | None = None,
) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        show_titles: flag indicating whether to show titles above each panel
        suptitle: optional suptitle to use for figure

    Returns:
        a matplotlib Figure with the rendered sample
    """
    # Determine if the sample contains multiple sensors or a single sensor
    if isinstance(sample["image"], dict):
        ncols = len(self.sensors) + 1
    else:
        ncols = 2  # One for the image and one for the mask

    showing_predictions = "prediction" in sample
    if showing_predictions:
        ncols += 1

    fig, axs = plt.subplots(1, ncols=ncols, figsize=(5 * ncols, 10))

    if isinstance(sample["image"], dict):
        # Multiple sensors case
        for idx, sens in enumerate(self.sensors):
            img = sample["image"][sens].numpy()
            if self.as_time_series:
                # Plot last time step
                img = img[:, -1, ...]
            if sens == "S2":
                img = img[[2, 1, 0], ...].transpose(1, 2, 0)
                img = percentile_normalization(img)
            else:
                co_polarization = img[0]  # transmit == receive
                cross_polarization = img[1]  # transmit != receive
                ratio = co_polarization / (cross_polarization + 1e-6)

                co_polarization = np.clip(co_polarization / 0.3, 0, 1)
                cross_polarization = np.clip(cross_polarization / 0.05, 0, 1)
                ratio = np.clip(ratio / 25, 0, 1)

                img = np.stack(
                    (co_polarization, cross_polarization, ratio), axis=0
                )
                img = img.transpose(1, 2, 0)  # Convert to (H, W, 3)

            axs[idx].imshow(img)
            axs[idx].axis("off")
            if show_titles:
                axs[idx].set_title(sens)
        mask_idx = len(self.sensors)
    else:
        # Single sensor case
        sens = self.sensors[0]
        img = sample["image"].numpy()
        if self.as_time_series:
            # Plot last time step
            img = img[:, -1, ...]
        if sens == "S2":
            img = img[[2, 1, 0], ...].transpose(1, 2, 0)
            img = percentile_normalization(img)
        else:
            co_polarization = img[0]  # transmit == receive
            cross_polarization = img[1]  # transmit != receive
            ratio = co_polarization / (cross_polarization + 1e-6)

            co_polarization = np.clip(co_polarization / 0.3, 0, 1)
            cross_polarization = np.clip(cross_polarization / 0.05, 0, 1)
            ratio = np.clip(ratio / 25, 0, 1)

            img = np.stack(
                (co_polarization, cross_polarization, ratio), axis=0
            )
            img = img.transpose(1, 2, 0)  # Convert to (H, W, 3)

        axs[0].imshow(img)
        axs[0].axis("off")
        if show_titles:
            axs[0].set_title(sens)
        mask_idx = 1

    # Plot target mask
    if "mask" in sample:
        target = sample["mask"].squeeze()
        target_im = axs[mask_idx].imshow(target, cmap="YlGn")
        plt.colorbar(target_im, ax=axs[mask_idx], fraction=0.046, pad=0.04)
        axs[mask_idx].axis("off")
        if show_titles:
            axs[mask_idx].set_title("Target")

    # Plot prediction if available
    if showing_predictions:
        pred_idx = mask_idx + 1
        prediction = sample["prediction"].squeeze()
        pred_im = axs[pred_idx].imshow(prediction, cmap="YlGn")
        plt.colorbar(pred_im, ax=axs[pred_idx], fraction=0.046, pad=0.04)
        axs[pred_idx].axis("off")
        if show_titles:
            axs[pred_idx].set_title("Prediction")

    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.burn_intensity #

BurnIntensityNonGeo #

Bases: NonGeoDataset

Dataset implementation for Burn Intensity classification.

Source code in terratorch/datasets/burn_intensity.py
class BurnIntensityNonGeo(NonGeoDataset):
    """Dataset implementation for [Burn Intensity classification](https://huggingface.co/datasets/ibm-nasa-geospatial/burn_intensity)."""

    all_band_names = (
        "BLUE", "GREEN", "RED", "NIR", "SWIR_1", "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    class_names = (
        "No burn",
        "Unburned to Very Low",
        "Low Severity",
        "Moderate Severity",
        "High Severity"
    )

    CSV_FILES = {
        "limited": "BS_files_with_less_than_25_percent_zeros.csv",
        "full": "BS_files_raw.csv",
    }

    num_classes = 5
    splits = {"train": "train", "val": "val"}
    time_steps = ["pre", "during", "post"]

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        use_full_data: bool = True,
        no_data_replace: float | None = 0.0001,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
    ) -> None:
        """Initialize the BurnIntensity dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train' or 'val'.
            bands (Sequence[str]): Bands to output. Defaults to all bands.
            transform (Optional[A.Compose]): Albumentations transform to be applied.
            use_metadata (bool): Whether to return metadata info (location).
            use_full_data (bool): Wheter to use full data or data with less than 25 percent zeros.
            no_data_replace (Optional[float]): Value to replace NaNs in images.
            no_label_replace (Optional[int]): Value to replace NaNs in labels.
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])

        self.data_root = Path(data_root)

        # Read the CSV file to get the list of cases to include
        csv_file_key = "full" if use_full_data else "limited"
        csv_path = self.data_root / self.CSV_FILES[csv_file_key]
        df = pd.read_csv(csv_path)
        casenames = df["Case_Name"].tolist()

        split_file = self.data_root / f"{split}.txt"
        with open(split_file) as f:
            split_images = [line.strip() for line in f.readlines()]

        split_images = [img for img in split_images if self._extract_casename(img) in casenames]

        # Build the samples list
        self.samples = []
        for image_filename in split_images:
            image_files = []
            for time_step in self.time_steps:
                image_file = self.data_root / time_step / image_filename
                image_files.append(str(image_file))
            mask_filename = image_filename.replace("HLS_", "BS_")
            mask_file = self.data_root / "pre" / mask_filename
            self.samples.append({
                "image_files": image_files,
                "mask_file": str(mask_file),
                "casename": self._extract_casename(image_filename),
            })

        self.use_metadata = use_metadata
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace

        self.transform = transform if transform else default_transform

    def _extract_basename(self, filepath: str) -> str:
        """Extract the base filename without extension."""
        return os.path.splitext(os.path.basename(filepath))[0]

    def _extract_casename(self, filename: str) -> str:
        """Extract the casename from the filename."""
        basename = self._extract_basename(filename)
        # Remove 'HLS_' or 'BS_' prefix
        casename = basename.replace("HLS_", "").replace("BS_", "")
        return casename

    def __len__(self) -> int:
        return len(self.samples)

    def _get_coords(self, image: DataArray) -> torch.Tensor:
        pixel_scale = image.rio.resolution()
        width, height = image.rio.width, image.rio.height

        left, bottom, right, top = image.rio.bounds()
        tie_point_x, tie_point_y = left, top

        center_col = width / 2
        center_row = height / 2

        center_lon = tie_point_x + (center_col * pixel_scale[0])
        center_lat = tie_point_y - (center_row * pixel_scale[1])

        lat_lon = np.asarray([center_lat, center_lon])
        return torch.tensor(lat_lon, dtype=torch.float32)

    def __getitem__(self, index: int) -> dict[str, Any]:
        sample = self.samples[index]
        image_files = sample["image_files"]
        mask_file = sample["mask_file"]

        images = []
        for idx, image_file in enumerate(image_files):
            image = self._load_file(Path(image_file), nan_replace=self.no_data_replace)
            if idx == 0 and self.use_metadata:
                location_coords = self._get_coords(image)
            image = image.to_numpy()
            image = np.moveaxis(image, 0, -1)
            image = image[..., self.band_indices]
            images.append(image)

        images = np.stack(images, axis=0)  # (T, H, W, C)

        output = {
            "image": images.astype(np.float32),
            "mask": self._load_file(Path(mask_file), nan_replace=self.no_label_replace).to_numpy()[0]
        }

        if self.transform:
            output = self.transform(**output)

        output["mask"] = output["mask"].long()
        if self.use_metadata:
            output["location_coords"] = location_coords

        return output

    def _load_file(self, path: Path, nan_replace: float | int | None = None) -> DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data


    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Any:
        """Plot a sample from the dataset.

        Args:
            sample: A sample returned by `__getitem__`.
            suptitle: Optional string to use as a suptitle.

        Returns:
            A matplotlib Figure with the rendered sample.
        """
        num_images = len(self.time_steps) + 2
        if "prediction" in sample:
            num_images += 1

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        images = sample["image"]  # (C, T, H, W)
        mask = sample["mask"].numpy()
        num_classes = len(np.unique(mask))

        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 5, 5))

        for i in range(len(self.time_steps)):
            image = images[:, i, :, :]  # (C, H, W)
            image = np.transpose(image, (1, 2, 0))  # (H, W, C)
            rgb_image = image[..., rgb_indices]
            rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
            rgb_image = np.clip(rgb_image, 0, 1)
            ax[i].imshow(rgb_image)
            ax[i].axis("off")
            ax[i].set_title(f"{self.time_steps[i].capitalize()} Image")

        cmap = plt.get_cmap("jet", num_classes)
        norm = Normalize(vmin=0, vmax=num_classes - 1)

        mask_ax_index = len(self.time_steps)
        ax[mask_ax_index].imshow(mask, cmap=cmap, norm=norm)
        ax[mask_ax_index].axis("off")
        ax[mask_ax_index].set_title("Ground Truth Mask")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            pred_ax_index = mask_ax_index + 1
            ax[pred_ax_index].imshow(prediction, cmap=cmap, norm=norm)
            ax[pred_ax_index].axis("off")
            ax[pred_ax_index].set_title("Predicted Mask")

        legend_ax_index = -1
        class_names = sample.get("class_names", self.class_names)
        positions = np.linspace(0, 1, num_classes) if num_classes > 1 else [0.5]

        legend_handles = [
            mpatches.Patch(color=cmap(pos), label=class_names[i])
            for i, pos in enumerate(positions)
        ]
        ax[legend_ax_index].legend(handles=legend_handles, loc="center")
        ax[legend_ax_index].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        plt.tight_layout()
        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, use_full_data=True, no_data_replace=0.0001, no_label_replace=-1, use_metadata=False) #

Initialize the BurnIntensity dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train' or 'val'.

'train'
bands Sequence[str]

Bands to output. Defaults to all bands.

BAND_SETS['all']
transform Optional[Compose]

Albumentations transform to be applied.

None
use_metadata bool

Whether to return metadata info (location).

False
use_full_data bool

Wheter to use full data or data with less than 25 percent zeros.

True
no_data_replace Optional[float]

Value to replace NaNs in images.

0.0001
no_label_replace Optional[int]

Value to replace NaNs in labels.

-1
Source code in terratorch/datasets/burn_intensity.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    use_full_data: bool = True,
    no_data_replace: float | None = 0.0001,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,
) -> None:
    """Initialize the BurnIntensity dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train' or 'val'.
        bands (Sequence[str]): Bands to output. Defaults to all bands.
        transform (Optional[A.Compose]): Albumentations transform to be applied.
        use_metadata (bool): Whether to return metadata info (location).
        use_full_data (bool): Wheter to use full data or data with less than 25 percent zeros.
        no_data_replace (Optional[float]): Value to replace NaNs in images.
        no_label_replace (Optional[int]): Value to replace NaNs in labels.
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])

    self.data_root = Path(data_root)

    # Read the CSV file to get the list of cases to include
    csv_file_key = "full" if use_full_data else "limited"
    csv_path = self.data_root / self.CSV_FILES[csv_file_key]
    df = pd.read_csv(csv_path)
    casenames = df["Case_Name"].tolist()

    split_file = self.data_root / f"{split}.txt"
    with open(split_file) as f:
        split_images = [line.strip() for line in f.readlines()]

    split_images = [img for img in split_images if self._extract_casename(img) in casenames]

    # Build the samples list
    self.samples = []
    for image_filename in split_images:
        image_files = []
        for time_step in self.time_steps:
            image_file = self.data_root / time_step / image_filename
            image_files.append(str(image_file))
        mask_filename = image_filename.replace("HLS_", "BS_")
        mask_file = self.data_root / "pre" / mask_filename
        self.samples.append({
            "image_files": image_files,
            "mask_file": str(mask_file),
            "casename": self._extract_casename(image_filename),
        })

    self.use_metadata = use_metadata
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

A sample returned by __getitem__.

required
suptitle str | None

Optional string to use as a suptitle.

None

Returns:

Type Description
Any

A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/burn_intensity.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Any:
    """Plot a sample from the dataset.

    Args:
        sample: A sample returned by `__getitem__`.
        suptitle: Optional string to use as a suptitle.

    Returns:
        A matplotlib Figure with the rendered sample.
    """
    num_images = len(self.time_steps) + 2
    if "prediction" in sample:
        num_images += 1

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    images = sample["image"]  # (C, T, H, W)
    mask = sample["mask"].numpy()
    num_classes = len(np.unique(mask))

    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 5, 5))

    for i in range(len(self.time_steps)):
        image = images[:, i, :, :]  # (C, H, W)
        image = np.transpose(image, (1, 2, 0))  # (H, W, C)
        rgb_image = image[..., rgb_indices]
        rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
        rgb_image = np.clip(rgb_image, 0, 1)
        ax[i].imshow(rgb_image)
        ax[i].axis("off")
        ax[i].set_title(f"{self.time_steps[i].capitalize()} Image")

    cmap = plt.get_cmap("jet", num_classes)
    norm = Normalize(vmin=0, vmax=num_classes - 1)

    mask_ax_index = len(self.time_steps)
    ax[mask_ax_index].imshow(mask, cmap=cmap, norm=norm)
    ax[mask_ax_index].axis("off")
    ax[mask_ax_index].set_title("Ground Truth Mask")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        pred_ax_index = mask_ax_index + 1
        ax[pred_ax_index].imshow(prediction, cmap=cmap, norm=norm)
        ax[pred_ax_index].axis("off")
        ax[pred_ax_index].set_title("Predicted Mask")

    legend_ax_index = -1
    class_names = sample.get("class_names", self.class_names)
    positions = np.linspace(0, 1, num_classes) if num_classes > 1 else [0.5]

    legend_handles = [
        mpatches.Patch(color=cmap(pos), label=class_names[i])
        for i, pos in enumerate(positions)
    ]
    ax[legend_ax_index].legend(handles=legend_handles, loc="center")
    ax[legend_ax_index].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    plt.tight_layout()
    return fig

terratorch.datasets.carbonflux #

CarbonFluxNonGeo #

Bases: NonGeoDataset

Dataset for Carbon Flux regression from HLS images and MERRA data.

Source code in terratorch/datasets/carbonflux.py
class CarbonFluxNonGeo(NonGeoDataset):
    """Dataset for [Carbon Flux](https://huggingface.co/datasets/ibm-nasa-geospatial/hls_merra2_gppFlux) regression from HLS images and MERRA data."""

    all_band_names = (
        "BLUE", "GREEN", "RED", "NIR", "SWIR_1", "SWIR_2",
    )

    rgb_bands = (
        "RED", "GREEN", "BLUE",
    )

    merra_var_names = (
        "T2MIN", "T2MAX", "T2MEAN", "TSMDEWMEAN", "GWETROOT",
        "LHLAND", "SHLAND", "SWLAND", "PARDFLAND", "PRECTOTLAND"
    )

    splits = {"train": "train", "test": "test"}

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    metadata_file = "data_train_hls_37sites_v0_1.csv"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        gpp_mean: float | None = None,
        gpp_std: float | None = None,
        no_data_replace: float | None = 0.0001,
        use_metadata: bool = False,
        modalities: Sequence[str] = ("image", "merra_vars")
    ) -> None:
        """Initialize the CarbonFluxNonGeo dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): 'train' or 'test'.
            bands (Sequence[str]): Bands to use. Defaults to all bands.
            transform (Optional[A.Compose]): Albumentations transform to be applied.
            use_metadata (bool): Whether to return metadata (coordinates and date).
            merra_means (Sequence[float]): Means for MERRA data normalization.
            merra_stds (Sequence[float]): Standard deviations for MERRA data normalization.
            gpp_mean (float): Mean for GPP normalization.
            gpp_std (float): Standard deviation for GPP normalization.
            no_data_replace (Optional[float]): Value to replace NO_DATA values in images.
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)

        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(band) for band in bands]

        self.data_root = Path(data_root)

        # Load the CSV file with metadata
        csv_file = self.data_root / self.metadata_file
        df = pd.read_csv(csv_file)

        # Get list of image filenames in the split directory
        image_dir = self.data_root / self.split
        image_files = [f.name for f in image_dir.glob("*.tiff")]

        df["Chip"] = df["Chip"].str.replace(".tif$", ".tiff", regex=True)
        # Filter the DataFrame to include only rows with 'Chip' in image_files
        df = df[df["Chip"].isin(image_files)]

        # Build the samples list
        self.samples = []
        for _, row in df.iterrows():
            image_filename = row["Chip"]
            image_path = image_dir / image_filename
            # MERRA vectors
            merra_vars = row[list(self.merra_var_names)].values.astype(np.float32)
            # GPP target
            gpp = row["GPP"]

            image_path = image_dir / row["Chip"]
            merra_vars = row[list(self.merra_var_names)].values.astype(np.float32)
            gpp = row["GPP"]
            self.samples.append({
                "image_path": str(image_path),
                "merra_vars": merra_vars,
                "gpp": gpp,
            })

        if gpp_mean is None or gpp_std is None:
            msg = "Mean and standard deviation for GPP must be provided."
            raise ValueError(msg)
        self.gpp_mean = gpp_mean
        self.gpp_std = gpp_std

        self.use_metadata = use_metadata
        self.modalities = modalities
        self.no_data_replace = no_data_replace

        if transform is None:
            self.transform = MultimodalToTensor(self.modalities)
        else:
            transform = {m: transform[m] if m in transform else default_transform
                for m in self.modalities}
            self.transform = MultimodalTransforms(transform, shared=False)

    def __len__(self) -> int:
        return len(self.samples)

    def _load_file(self, path: str, nan_replace: float | int | None = None):
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data

    def _get_coords(self, image) -> torch.Tensor:
        """Extract the center coordinates from the image geospatial metadata."""
        pixel_scale = image.rio.resolution()
        width, height = image.rio.width, image.rio.height

        left, bottom, right, top = image.rio.bounds()
        tie_point_x, tie_point_y = left, top

        center_col = width / 2
        center_row = height / 2

        center_lon = tie_point_x + (center_col * pixel_scale[0])
        center_lat = tie_point_y - (center_row * pixel_scale[1])

        src_crs = image.rio.crs
        dst_crs = "EPSG:4326"

        transformer = pyproj.Transformer.from_crs(src_crs, dst_crs, always_xy=True)
        lon, lat = transformer.transform(center_lon, center_lat)

        coords = np.array([lat, lon], dtype=np.float32)
        return torch.from_numpy(coords)

    def _get_date(self, filename: str) -> torch.Tensor:
        """Extract the date from the filename."""
        base_filename = os.path.basename(filename)
        pattern = r"HLS\..{3}\.[A-Z0-9]{6}\.(?P<date>\d{7}T\d{6})\..*\.tiff$"
        match = re.match(pattern, base_filename)
        if not match:
            msg = f"Filename {filename} does not match expected pattern."
            raise ValueError(msg)

        date_str = match.group("date")
        year = int(date_str[:4])
        julian_day = int(date_str[4:7])

        date_tensor = torch.tensor([year, julian_day], dtype=torch.int32)
        return date_tensor

    def __getitem__(self, idx: int) -> dict[str, Any]:
        sample = self.samples[idx]
        image_path = sample["image_path"]

        image = self._load_file(image_path, nan_replace=self.no_data_replace)

        if self.use_metadata:
            location_coords = self._get_coords(image)
            temporal_coords = self._get_date(os.path.basename(image_path))

        image = image.to_numpy()  # (C, H, W)
        image = image[self.band_indices, ...]
        image = np.moveaxis(image, 0, -1) # (H, W, C)

        merra_vars = np.array(sample["merra_vars"])
        target = np.array(sample["gpp"])
        target_norm = (target - self.gpp_mean) / self.gpp_std
        target_norm = torch.tensor(target_norm, dtype=torch.float32)
        output = {
            "image": image.astype(np.float32),
            "merra_vars": merra_vars,
        }

        if self.transform:
            output = self.transform(output)

        output = {
            "image": {m: output[m] for m in self.modalities if m in output},
            "mask": target_norm
        }
        if self.use_metadata:
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        return output

    def plot(self, sample: dict[str, Any], suptitle: str | None = None) -> Any:
        """Plot a sample from the dataset.

        Args:
            sample: A sample returned by `__getitem__`.
            suptitle: Optional title for the figure.

        Returns:
            A matplotlib figure with the rendered sample.
        """
        image = sample["image"].numpy()

        image = np.transpose(image, (1, 2, 0))  # (H, W, C)

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        rgb_image = image[..., rgb_indices]

        rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
        rgb_image = np.clip(rgb_image, 0, 1)

        fig, ax = plt.subplots(1, 1, figsize=(6, 6))
        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title("Image")

        if suptitle:
            plt.suptitle(suptitle)

        plt.tight_layout()
        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, gpp_mean=None, gpp_std=None, no_data_replace=0.0001, use_metadata=False, modalities=('image', 'merra_vars')) #

Initialize the CarbonFluxNonGeo dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

'train' or 'test'.

'train'
bands Sequence[str]

Bands to use. Defaults to all bands.

BAND_SETS['all']
transform Optional[Compose]

Albumentations transform to be applied.

None
use_metadata bool

Whether to return metadata (coordinates and date).

False
merra_means Sequence[float]

Means for MERRA data normalization.

required
merra_stds Sequence[float]

Standard deviations for MERRA data normalization.

required
gpp_mean float

Mean for GPP normalization.

None
gpp_std float

Standard deviation for GPP normalization.

None
no_data_replace Optional[float]

Value to replace NO_DATA values in images.

0.0001
Source code in terratorch/datasets/carbonflux.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    gpp_mean: float | None = None,
    gpp_std: float | None = None,
    no_data_replace: float | None = 0.0001,
    use_metadata: bool = False,
    modalities: Sequence[str] = ("image", "merra_vars")
) -> None:
    """Initialize the CarbonFluxNonGeo dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): 'train' or 'test'.
        bands (Sequence[str]): Bands to use. Defaults to all bands.
        transform (Optional[A.Compose]): Albumentations transform to be applied.
        use_metadata (bool): Whether to return metadata (coordinates and date).
        merra_means (Sequence[float]): Means for MERRA data normalization.
        merra_stds (Sequence[float]): Standard deviations for MERRA data normalization.
        gpp_mean (float): Mean for GPP normalization.
        gpp_std (float): Standard deviation for GPP normalization.
        no_data_replace (Optional[float]): Value to replace NO_DATA values in images.
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)

    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(band) for band in bands]

    self.data_root = Path(data_root)

    # Load the CSV file with metadata
    csv_file = self.data_root / self.metadata_file
    df = pd.read_csv(csv_file)

    # Get list of image filenames in the split directory
    image_dir = self.data_root / self.split
    image_files = [f.name for f in image_dir.glob("*.tiff")]

    df["Chip"] = df["Chip"].str.replace(".tif$", ".tiff", regex=True)
    # Filter the DataFrame to include only rows with 'Chip' in image_files
    df = df[df["Chip"].isin(image_files)]

    # Build the samples list
    self.samples = []
    for _, row in df.iterrows():
        image_filename = row["Chip"]
        image_path = image_dir / image_filename
        # MERRA vectors
        merra_vars = row[list(self.merra_var_names)].values.astype(np.float32)
        # GPP target
        gpp = row["GPP"]

        image_path = image_dir / row["Chip"]
        merra_vars = row[list(self.merra_var_names)].values.astype(np.float32)
        gpp = row["GPP"]
        self.samples.append({
            "image_path": str(image_path),
            "merra_vars": merra_vars,
            "gpp": gpp,
        })

    if gpp_mean is None or gpp_std is None:
        msg = "Mean and standard deviation for GPP must be provided."
        raise ValueError(msg)
    self.gpp_mean = gpp_mean
    self.gpp_std = gpp_std

    self.use_metadata = use_metadata
    self.modalities = modalities
    self.no_data_replace = no_data_replace

    if transform is None:
        self.transform = MultimodalToTensor(self.modalities)
    else:
        transform = {m: transform[m] if m in transform else default_transform
            for m in self.modalities}
        self.transform = MultimodalTransforms(transform, shared=False)
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Any]

A sample returned by __getitem__.

required
suptitle str | None

Optional title for the figure.

None

Returns:

Type Description
Any

A matplotlib figure with the rendered sample.

Source code in terratorch/datasets/carbonflux.py
def plot(self, sample: dict[str, Any], suptitle: str | None = None) -> Any:
    """Plot a sample from the dataset.

    Args:
        sample: A sample returned by `__getitem__`.
        suptitle: Optional title for the figure.

    Returns:
        A matplotlib figure with the rendered sample.
    """
    image = sample["image"].numpy()

    image = np.transpose(image, (1, 2, 0))  # (H, W, C)

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    rgb_image = image[..., rgb_indices]

    rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
    rgb_image = np.clip(rgb_image, 0, 1)

    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title("Image")

    if suptitle:
        plt.suptitle(suptitle)

    plt.tight_layout()
    return fig

terratorch.datasets.forestnet #

ForestNetNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for ForestNet.

Source code in terratorch/datasets/forestnet.py
class ForestNetNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [ForestNet](https://huggingface.co/datasets/ibm-nasa-geospatial/ForestNet)."""

    all_band_names = (
        "RED", "GREEN", "BLUE", "NIR", "SWIR_1", "SWIR_2"
    )

    rgb_bands = (
        "RED", "GREEN", "BLUE",
    )

    splits = ("train", "test", "val")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    default_label_map = {  # noqa: RUF012
        "Plantation": 0,
        "Smallholder agriculture": 1,
        "Grassland shrubland": 2,
        "Other": 3,
    }

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        label_map: dict[str, int] = default_label_map,
        transform: A.Compose | None = None,
        fraction: float = 1.0,
        bands: Sequence[str] = BAND_SETS["all"],
        use_metadata: bool = False,
    ) -> None:
        """
        Initialize the ForestNetNonGeo dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            label_map (Dict[str, int]): Mapping from label names to integer labels.
            transform: Transformations to be applied to the images.
            fraction (float): Fraction of the dataset to use. Defaults to 1.0 (use all data).
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits)}."
            raise ValueError(msg)
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(b) for b in bands]

        self.use_metadata = use_metadata

        self.data_root = Path(data_root)
        self.label_map = label_map

        # Load the CSV file corresponding to the split
        csv_file = self.data_root / f"{split}_filtered.csv"
        original_df = pd.read_csv(csv_file)

        # Apply stratified sampling if fraction < 1.0
        if fraction < 1.0:
            sss = StratifiedShuffleSplit(n_splits=1, test_size=1 - fraction, random_state=47)
            stratified_indices, _ = next(sss.split(original_df, original_df["merged_label"]))
            self.dataset = original_df.iloc[stratified_indices].reset_index(drop=True)
        else:
            self.dataset = original_df

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.dataset)

    def _get_coords(self, event_path: Path) -> torch.Tensor:
        auxiliary_path = event_path / "auxiliary"
        osm_json_path = auxiliary_path / "osm.json"

        with open(osm_json_path) as f:
            osm_data = json.load(f)
            lat = float(osm_data["closest_city"]["lat"])
            lon = float(osm_data["closest_city"]["lon"])
            lat_lon = np.asarray([lat, lon])

        return torch.tensor(lat_lon, dtype=torch.float32)

    def _get_dates(self, image_files: list) -> list:
        dates = []
        pattern = re.compile(r"(\d{4})_(\d{2})_(\d{2})_cloud_\d+\.(png|npy)")
        for img_path in image_files:
            match = pattern.search(img_path)
            year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
            date_obj = datetime.datetime(year, month, day)  # noqa: DTZ001
            julian_day = date_obj.timetuple().tm_yday
            date_tensor = torch.tensor([year, julian_day], dtype=torch.int32)
            dates.append(date_tensor)
        return torch.stack(dates, dim=0)

    def __getitem__(self, index: int):
        path = self.data_root / self.dataset["example_path"][index]
        label = self.map_label(index)

        visible_images, infrared_images, temporal_coords = self._load_images(path)

        visible_images = np.stack(visible_images, axis=0)
        infrared_images = np.stack(infrared_images, axis=0)
        merged_images = np.concatenate([visible_images, infrared_images], axis=-1)
        merged_images = merged_images[..., self.band_indices] # (T, H, W, 2C)
        output = {
            "image": merged_images.astype(np.float32)
        }

        if self.transform:
            output = self.transform(**output)

        if self.use_metadata:
            location_coords = self._get_coords(path)
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        output["label"] = label

        return output

    def _load_images(self, path: str):
        """Load visible and infrared images from the given event path"""
        visible_image_files = glob.glob(os.path.join(path, "images/visible/*_cloud_*.png"))
        infra_image_files = glob.glob(os.path.join(path, "images/infrared/*_cloud_*.npy"))

        selected_visible_images = self.select_images(visible_image_files)
        selected_infra_images = self.select_images(infra_image_files)

        dates = None
        if self.use_metadata:
            dates = self._get_dates(selected_visible_images)

        vis_images = [np.array(Image.open(img)) for img in selected_visible_images] # (T, H, W, C)
        inf_images = [np.load(img, allow_pickle=True) for img in selected_infra_images] # (T, H, W, C)
        return vis_images, inf_images, dates

    def least_cloudy_image(self, image_files):
        pattern = re.compile(r"(\d{4})_\d{2}_\d{2}_cloud_(\d+)\.(png|npy)")
        lowest_cloud_images = defaultdict(lambda: {"path": None, "cloud_value": float("inf")})

        for path in image_files:
            match = pattern.search(path)
            if match:
                year, cloud_value = match.group(1), int(match.group(2))
                if cloud_value < lowest_cloud_images[year]["cloud_value"]:
                    lowest_cloud_images[year] = {"path": path, "cloud_value": cloud_value}

        return [info["path"] for info in lowest_cloud_images.values()]

    def match_timesteps(self, image_files, selected_images):
        if len(selected_images) < 3:
            extra_imgs = [img for img in image_files if img not in selected_images]
            selected_images += extra_imgs[:3 - len(selected_images)]

        while len(selected_images) < 3:
            selected_images.append(selected_images[-1])
        return selected_images[:3]

    def select_images(self, image_files):
        selected = self.least_cloudy_image(image_files)
        return self.match_timesteps(image_files, selected)

    def map_label(self, index: int) -> torch.Tensor:
        """Map the label name to an integer label."""
        label_name = self.dataset["merged_label"][index]
        label = self.label_map[label_name]
        return label

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None):

        num_images = sample["image"].shape[1] + 1

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        fig, ax = plt.subplots(1, num_images, figsize=(15, 5))

        for i in range(sample["image"].shape[1]):
            image = sample["image"][:, i, :, :]
            if torch.is_tensor(image):
                image = image.permute(1, 2, 0).numpy()
            rgb_image = image[..., rgb_indices]
            rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
            rgb_image = np.clip(rgb_image, 0, 1)
            ax[i].imshow(rgb_image)
            ax[i].axis("off")
            ax[i].set_title(f"Timestep {i + 1}")

        legend_handles = [Rectangle((0, 0), 1, 1, color="blue")]
        legend_label = [self.label_map.get(sample["label"], "Unknown Label")]
        ax[-1].legend(legend_handles, legend_label, loc="center")
        ax[-1].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        plt.tight_layout()
        return fig
__init__(data_root, split='train', label_map=default_label_map, transform=None, fraction=1.0, bands=BAND_SETS['all'], use_metadata=False) #

Initialize the ForestNetNonGeo dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'val', or 'test'.

'train'
label_map Dict[str, int]

Mapping from label names to integer labels.

default_label_map
transform Compose | None

Transformations to be applied to the images.

None
fraction float

Fraction of the dataset to use. Defaults to 1.0 (use all data).

1.0
Source code in terratorch/datasets/forestnet.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    label_map: dict[str, int] = default_label_map,
    transform: A.Compose | None = None,
    fraction: float = 1.0,
    bands: Sequence[str] = BAND_SETS["all"],
    use_metadata: bool = False,
) -> None:
    """
    Initialize the ForestNetNonGeo dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        label_map (Dict[str, int]): Mapping from label names to integer labels.
        transform: Transformations to be applied to the images.
        fraction (float): Fraction of the dataset to use. Defaults to 1.0 (use all data).
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits)}."
        raise ValueError(msg)
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(b) for b in bands]

    self.use_metadata = use_metadata

    self.data_root = Path(data_root)
    self.label_map = label_map

    # Load the CSV file corresponding to the split
    csv_file = self.data_root / f"{split}_filtered.csv"
    original_df = pd.read_csv(csv_file)

    # Apply stratified sampling if fraction < 1.0
    if fraction < 1.0:
        sss = StratifiedShuffleSplit(n_splits=1, test_size=1 - fraction, random_state=47)
        stratified_indices, _ = next(sss.split(original_df, original_df["merged_label"]))
        self.dataset = original_df.iloc[stratified_indices].reset_index(drop=True)
    else:
        self.dataset = original_df

    self.transform = transform if transform else default_transform
map_label(index) #

Map the label name to an integer label.

Source code in terratorch/datasets/forestnet.py
def map_label(self, index: int) -> torch.Tensor:
    """Map the label name to an integer label."""
    label_name = self.dataset["merged_label"][index]
    label = self.label_map[label_name]
    return label

terratorch.datasets.fire_scars #

FireScarsHLS #

Bases: RasterDataset

RasterDataset implementation for fire scars input images.

Source code in terratorch/datasets/fire_scars.py
class FireScarsHLS(RasterDataset):
    """RasterDataset implementation for fire scars input images."""

    filename_glob = "subsetted*_merged.tif"
    filename_regex = r"subsetted_512x512_HLS\..30\..{6}\.(?P<date>[0-9]*)\.v1.4_merged.tif"
    date_format = "%Y%j"
    is_image = True
    separate_files = False
    all_bands = dataclasses.field(default_factory=["B02", "B03", "B04", "B8A", "B11", "B12"])
    rgb_bands = dataclasses.field(default_factory=["B04", "B03", "B02"])

FireScarsNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for fire scars.

Source code in terratorch/datasets/fire_scars.py
class FireScarsNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [fire scars](https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars)."""
    all_band_names = (
        "BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    num_classes = 2
    splits = {"train": "training", "val": "validation"}   # Only train and val splits available

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (str): Path to the data root directory.
            bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the corresponding data module,
                should not include normalization. Defaults to None, which applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value.
                If None, does no replacement. Defaults to 0.
            no_label_replace (int | None): Replace nan values in label with this value.
                If none, does no replacement. Defaults to -1.
            use_metadata (bool): whether to return metadata info (time and location).
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {self.splits}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
        self.data_root = Path(data_root)

        input_dir = self.data_root / split_name
        self.image_files = sorted(glob.glob(os.path.join(input_dir, "*_merged.tif")))
        self.segmentation_mask_files = sorted(glob.glob(os.path.join(input_dir, "*.mask.tif")))

        self.use_metadata = use_metadata
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace

        # If no transform is given, apply only to transform to torch tensor
        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def _get_date(self, index: int) -> torch.Tensor:
        file_name = self.image_files[index]
        base_filename = os.path.basename(file_name)

        filename_regex = r"subsetted_512x512_HLS\.S30\.T[0-9A-Z]{5}\.(?P<date>[0-9]+)\.v1\.4_merged\.tif"
        match = re.match(filename_regex, base_filename)
        date_str = match.group("date")
        year = int(date_str[:4])
        julian_day = int(date_str[4:])

        return torch.tensor([[year, julian_day]], dtype=torch.float32)

    def _get_coords(self, image: DataArray) -> torch.Tensor:
        px = image.x.shape[0] // 2
        py = image.y.shape[0] // 2

        # get center point to reproject to lat/lon
        point = image.isel(band=0, x=slice(px, px + 1), y=slice(py, py + 1))
        point = point.rio.reproject("epsg:4326")

        lat_lon = np.asarray([point.y[0], point.x[0]])

        return torch.tensor(lat_lon, dtype=torch.float32)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace)

        location_coords, temporal_coords = None, None
        if self.use_metadata:
            location_coords = self._get_coords(image)
            temporal_coords = self._get_date(index)

        # to channels last
        image = image.to_numpy()
        image = np.moveaxis(image, 0, -1)

        # filter bands
        image = image[..., self.band_indices]

        output = {
            "image": image.astype(np.float32),
            "mask": self._load_file(
                self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[0],
        }
        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        if self.use_metadata:
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        return output

    def _load_file(self, path: Path, nan_replace: int | float | None = None) -> DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            suptitle: optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample
        """
        num_images = 4

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        # RGB -> channels-last
        image = sample["image"][rgb_indices, ...].permute(1, 2, 0).numpy()
        mask = sample["mask"].numpy()

        image = clip_image_percentile(image)

        if "prediction" in sample:
            prediction = sample["prediction"]
            num_images += 1
        else:
            prediction = None

        fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")

        ax[0].axis("off")

        norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
        ax[1].axis("off")
        ax[1].title.set_text("Image")
        ax[1].imshow(image)

        ax[2].axis("off")
        ax[2].title.set_text("Ground Truth Mask")
        ax[2].imshow(mask, cmap="jet", norm=norm)

        ax[3].axis("off")
        ax[3].title.set_text("GT Mask on Image")
        ax[3].imshow(image)
        ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

        if "prediction" in sample:
            ax[4].title.set_text("Predicted Mask")
            ax[4].imshow(prediction, cmap="jet", norm=norm)

        cmap = plt.get_cmap("jet")
        legend_data = [[i, cmap(norm(i)), str(i)] for i in range(self.num_classes)]
        handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
        labels = [n for k, c, n in legend_data]
        ax[0].legend(handles, labels, loc="center")
        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, no_data_replace=0, no_label_replace=-1, use_metadata=False) #

Constructor

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
bands list[str]

Bands that should be output by the dataset. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Should end with ToTensorV2(). If used through the corresponding data module, should not include normalization. Defaults to None, which applies ToTensorV2().

None
no_data_replace float | None

Replace nan values in input images with this value. If None, does no replacement. Defaults to 0.

0
no_label_replace int | None

Replace nan values in label with this value. If none, does no replacement. Defaults to -1.

-1
use_metadata bool

whether to return metadata info (time and location).

False
Source code in terratorch/datasets/fire_scars.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    no_data_replace: float | None = 0,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (str): Path to the data root directory.
        bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the corresponding data module,
            should not include normalization. Defaults to None, which applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value.
            If None, does no replacement. Defaults to 0.
        no_label_replace (int | None): Replace nan values in label with this value.
            If none, does no replacement. Defaults to -1.
        use_metadata (bool): whether to return metadata info (time and location).
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {self.splits}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
    self.data_root = Path(data_root)

    input_dir = self.data_root / split_name
    self.image_files = sorted(glob.glob(os.path.join(input_dir, "*_merged.tif")))
    self.segmentation_mask_files = sorted(glob.glob(os.path.join(input_dir, "*.mask.tif")))

    self.use_metadata = use_metadata
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace

    # If no transform is given, apply only to transform to torch tensor
    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

a sample returned by :meth:__getitem__

required
suptitle str | None

optional string to use as a suptitle

None

Returns:

Type Description
Figure

a matplotlib Figure with the rendered sample

Source code in terratorch/datasets/fire_scars.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        suptitle: optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample
    """
    num_images = 4

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    # RGB -> channels-last
    image = sample["image"][rgb_indices, ...].permute(1, 2, 0).numpy()
    mask = sample["mask"].numpy()

    image = clip_image_percentile(image)

    if "prediction" in sample:
        prediction = sample["prediction"]
        num_images += 1
    else:
        prediction = None

    fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")

    ax[0].axis("off")

    norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
    ax[1].axis("off")
    ax[1].title.set_text("Image")
    ax[1].imshow(image)

    ax[2].axis("off")
    ax[2].title.set_text("Ground Truth Mask")
    ax[2].imshow(mask, cmap="jet", norm=norm)

    ax[3].axis("off")
    ax[3].title.set_text("GT Mask on Image")
    ax[3].imshow(image)
    ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

    if "prediction" in sample:
        ax[4].title.set_text("Predicted Mask")
        ax[4].imshow(prediction, cmap="jet", norm=norm)

    cmap = plt.get_cmap("jet")
    legend_data = [[i, cmap(norm(i)), str(i)] for i in range(self.num_classes)]
    handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
    labels = [n for k, c, n in legend_data]
    ax[0].legend(handles, labels, loc="center")
    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig

FireScarsSegmentationMask #

Bases: RasterDataset

RasterDataset implementation for fire scars segmentation mask. Can be easily merged with input images using the & operator.

Source code in terratorch/datasets/fire_scars.py
class FireScarsSegmentationMask(RasterDataset):
    """RasterDataset implementation for fire scars segmentation mask.
    Can be easily merged with input images using the & operator.
    """

    filename_glob = "subsetted*.mask.tif"
    filename_regex = r"subsetted_512x512_HLS\..30\..{6}\.(?P<date>[0-9]*)\.v1.4.mask.tif"
    date_format = "%Y%j"
    is_image = False
    separate_files = False

terratorch.datasets.landslide4sense #

Landslide4SenseNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for Landslide4Sense.

Source code in terratorch/datasets/landslide4sense.py
class Landslide4SenseNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [Landslide4Sense](https://huggingface.co/datasets/ibm-nasa-geospatial/Landslide4sense)."""
    all_band_names = (
        "COASTAL AEROSOL",
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "WATER_VAPOR",
        "CIRRUS",
        "SWIR_1",
        "SWIR_2",
        "SLOPE",
        "DEM",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")
    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "validation", "test": "test"}


    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
    ) -> None:
        """Initialize the Landslide4Sense dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'validation', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(b) for b in bands]

        self.data_directory = Path(data_root)

        images_dir = self.data_directory / "images" / split_name
        annotations_dir = self.data_directory / "annotations" / split_name

        self.image_files = sorted(images_dir.glob("image_*.h5"))
        self.mask_files = sorted(annotations_dir.glob("mask_*.h5"))

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        image_file = self.image_files[index]
        mask_file = self.mask_files[index]

        with h5py.File(image_file, "r") as h5file:
            image = np.array(h5file["img"])[..., self.band_indices]

        with h5py.File(mask_file, "r") as h5file:
            mask = np.array(h5file["mask"])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        mask = sample["mask"].numpy()
        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]

        rgb_image = (rgb_image - rgb_image.min(axis=(0, 1))) * (1 / rgb_image.max(axis=(0, 1)))
        rgb_image = np.clip(rgb_image, 0, 1)

        num_classes = len(np.unique(mask))
        cmap = colormaps["jet"]
        norm = Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"]
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if sample.get("class_names"):
            class_names = sample["class_names"]
            legend_handles = [
                mpatches.Patch(color=cmap(i), label=class_names[i]) for i in range(num_classes)
            ]
            ax[0].legend(handles=legend_handles, bbox_to_anchor=(1.05, 1), loc="upper left")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None) #

Initialize the Landslide4Sense dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'validation', or 'test'.

'train'
bands Sequence[str]

Bands to be used. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Defaults to None, which applies default_transform().

None
Source code in terratorch/datasets/landslide4sense.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
) -> None:
    """Initialize the Landslide4Sense dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'validation', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(b) for b in bands]

    self.data_directory = Path(data_root)

    images_dir = self.data_directory / "images" / split_name
    annotations_dir = self.data_directory / "annotations" / split_name

    self.image_files = sorted(images_dir.glob("image_*.h5"))
    self.mask_files = sorted(annotations_dir.glob("mask_*.h5"))

    self.transform = transform if transform else default_transform

terratorch.datasets.m_eurosat #

MEuroSATNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for M-EuroSAT.

Source code in terratorch/datasets/m_eurosat.py
class MEuroSATNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-EuroSAT](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "COASTAL_AEROSOL",
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "NIR_NARROW",
        "WATER_VAPOR",
        "CIRRUS",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-eurosat"
    partition_file_template = "{partition}_partition.json"
    label_map_file = "label_map.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(b) for b in bands]\

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        label_map_path = self.data_directory / self.label_map_file
        with open(label_map_path) as file:
            self.label_map = json.load(file)

        self.id_to_class = {img_id: cls for cls, ids in self.label_map.items() for img_id in ids}

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)

        label_class = self.id_to_class[image_id]
        label_index = list(self.label_map.keys()).index(label_class)

        output = {"image": image.astype(np.float32)}

        if self.transform:
            output = self.transform(**output)

        output["label"] = label_index

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        label_index = sample["label"]

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        class_names = list(self.label_map.keys())
        class_name = class_names[label_index]

        fig, ax = plt.subplots(figsize=(6, 6))
        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title(f"Class: {class_name}")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default') #

Initialize the dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'val', or 'test'.

'train'
bands Sequence[str]

Bands to be used. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Defaults to None, which applies default_transform().

None
partition str

Partition name for the dataset splits. Defaults to 'default'.

'default'
Source code in terratorch/datasets/m_eurosat.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(b) for b in bands]\

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    label_map_path = self.data_directory / self.label_map_file
    with open(label_map_path) as file:
        self.label_map = json.load(file)

    self.id_to_class = {img_id: cls for cls, ids in self.label_map.items() for img_id in ids}

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

A sample returned by :meth:__getitem__.

required
suptitle str | None

Optional string to use as a suptitle.

None

Returns:

Type Description
Figure

matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_eurosat.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    label_index = sample["label"]

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    class_names = list(self.label_map.keys())
    class_name = class_names[label_index]

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title(f"Class: {class_name}")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_bigearthnet #

MBigEarthNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for M-BigEarthNet.

Source code in terratorch/datasets/m_bigearthnet.py
class MBigEarthNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-BigEarthNet](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "COASTAL_AEROSOL",
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "NIR_NARROW",
        "WATER_VAPOR",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-bigearthnet"
    label_map_file = "label_stats.json"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        label_map_path = self.data_directory / self.label_map_file
        with open(label_map_path) as file:
            self.label_map = json.load(file)

        self.num_classes = len(next(iter(self.label_map.values())))

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found in partition file."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)

        labels_vector = self.label_map[image_id]
        labels_tensor = torch.tensor(labels_vector, dtype=torch.float)

        output = {"image": image}

        if self.transform:
            output = self.transform(**output)

        output["label"] = labels_tensor
        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        label = sample["label"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()  # Convert to (H, W, C)

        rgb_image = image[:, :, rgb_indices]

        rgb_image = clip_image(rgb_image)

        active_labels = [i for i, lbl in enumerate(label) if lbl == 1]

        fig, ax = plt.subplots(figsize=(6, 6))

        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title(f"Active Labels: {active_labels}")

        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default') #

Initialize the dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'val', or 'test'.

'train'
bands Sequence[str]

Bands to be used. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Defaults to None, which applies default_transform().

None
partition str

Partition name for the dataset splits. Defaults to 'default'.

'default'
Source code in terratorch/datasets/m_bigearthnet.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    label_map_path = self.data_directory / self.label_map_file
    with open(label_map_path) as file:
        self.label_map = json.load(file)

    self.num_classes = len(next(iter(self.label_map.values())))

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found in partition file."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

A sample returned by :meth:__getitem__.

required
suptitle str | None

Optional string to use as a suptitle.

None

Returns:

Type Description
Figure

matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_bigearthnet.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    label = sample["label"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()  # Convert to (H, W, C)

    rgb_image = image[:, :, rgb_indices]

    rgb_image = clip_image(rgb_image)

    active_labels = [i for i, lbl in enumerate(label) if lbl == 1]

    fig, ax = plt.subplots(figsize=(6, 6))

    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title(f"Active Labels: {active_labels}")

    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_brick_kiln #

MBrickKilnNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for M-BrickKiln.

Source code in terratorch/datasets/m_brick_kiln.py
class MBrickKilnNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-BrickKiln](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "COASTAL_AEROSOL",
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "NIR_NARROW",
        "WATER_VAPOR",
        "CIRRUS",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-brick-kiln"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found in partition file."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))
            class_index = attr_dict["label"]

        output = {"image": image.astype(np.float32)}

        if self.transform:
            output = self.transform(**output)

        output["label"] = class_index

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        label = sample["label"]

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()  # Convert to (H, W, C)

        rgb_image = image[:, :, rgb_indices]

        rgb_image = clip_image(rgb_image)

        fig, ax = plt.subplots(figsize=(6, 6))

        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title(f"Class: {label}")

        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default') #

Initialize the dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'val', or 'test'.

'train'
bands Sequence[str]

Bands to be used. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Defaults to None, which applies default_transform().

None
partition str

Partition name for the dataset splits. Defaults to 'default'.

'default'
Source code in terratorch/datasets/m_brick_kiln.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found in partition file."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

A sample returned by :meth:__getitem__.

required
suptitle str | None

Optional string to use as a suptitle.

None

Returns:

Type Description
Figure

matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_brick_kiln.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    label = sample["label"]

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()  # Convert to (H, W, C)

    rgb_image = image[:, :, rgb_indices]

    rgb_image = clip_image(rgb_image)

    fig, ax = plt.subplots(figsize=(6, 6))

    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title(f"Class: {label}")

    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_forestnet #

MForestNetNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for M-ForestNet.

Source code in terratorch/datasets/m_forestnet.py
class MForestNetNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-ForestNet](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "BLUE",
        "GREEN",
        "RED",
        "NIR",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-forestnet"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
        use_metadata: bool = False,
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
            use_metadata (bool): Whether to return metadata info (time and location).
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.use_metadata = use_metadata

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))  # noqa: S301
            class_index = attr_dict["label"]

        output = {"image": image.astype(np.float32)}

        if self.transform:
            output = self.transform(**output)

        output["label"] = class_index

        if self.use_metadata:
            temporal_coords = self._get_date(image_id)
            location_coords = self._get_coords(image_id)

            output["temporal_coords"] = temporal_coords
            output["location_coords"] = location_coords

        return output

    def _get_coords(self, image_id: str) -> torch.Tensor:
        """Extract spatial coordinates from the image ID.

        Args:
            image_id (str): The ID of the image.

        Returns:
            torch.Tensor: Tensor containing latitude and longitude.
        """
        lat_str, lon_str, _ = image_id.split("_", 2)
        latitude = float(lat_str)
        longitude = float(lon_str)
        return torch.tensor([latitude, longitude], dtype=torch.float32)

    def _get_date(self, image_id: str) -> torch.Tensor:
        _, _, date_str = image_id.split("_", 2)
        date = pd.to_datetime(date_str, format="%Y_%m_%d")

        return torch.tensor([[date.year, date.dayofyear - 1]], dtype=torch.float32)

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        label = sample["label"]

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()  # (H, W, C)

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        fig, ax = plt.subplots(figsize=(6, 6))
        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title(f"Class: {label}")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default', use_metadata=False) #

Initialize the dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'val', or 'test'.

'train'
bands Sequence[str]

Bands to be used. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Defaults to None, which applies default_transform().

None
partition str

Partition name for the dataset splits. Defaults to 'default'.

'default'
use_metadata bool

Whether to return metadata info (time and location).

False
Source code in terratorch/datasets/m_forestnet.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
    use_metadata: bool = False,
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
        use_metadata (bool): Whether to return metadata info (time and location).
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.use_metadata = use_metadata

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

A sample returned by :meth:__getitem__.

required
suptitle str | None

Optional string to use as a suptitle.

None

Returns:

Type Description
Figure

matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_forestnet.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    label = sample["label"]

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()  # (H, W, C)

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title(f"Class: {label}")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_so2sat #

MSo2SatNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for M-So2Sat.

Source code in terratorch/datasets/m_so2sat.py
class MSo2SatNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-So2Sat](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "VH_REAL",
        "BLUE",
        "VH_IMAGINARY",
        "GREEN",
        "VV_REAL",
        "RED",
        "VV_IMAGINARY",
        "VH_LEE_FILTERED",
        "RED_EDGE_1",
        "VV_LEE_FILTERED",
        "RED_EDGE_2",
        "VH_LEE_FILTERED_REAL",
        "RED_EDGE_3",
        "NIR_BROAD",
        "VV_LEE_FILTERED_IMAGINARY",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-so2sat"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(b) for b in bands]

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))
            class_index = attr_dict["label"]

        output = {"image": image.astype(np.float32)}

        if self.transform:
            output = self.transform(**output)

        output["label"] = class_index

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        label_index = sample["label"]

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        class_name = str(label_index)

        fig, ax = plt.subplots(figsize=(6, 6))
        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title(f"Class: {class_name}")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default') #

Initialize the dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'val', or 'test'.

'train'
bands Sequence[str]

Bands to be used. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Defaults to None, which applies default_transform().

None
partition str

Partition name for the dataset splits. Defaults to 'default'.

'default'
Source code in terratorch/datasets/m_so2sat.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(b) for b in bands]

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

A sample returned by :meth:__getitem__.

required
suptitle str | None

Optional string to use as a suptitle.

None

Returns:

Type Description
Figure

matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_so2sat.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    label_index = sample["label"]

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    class_name = str(label_index)

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title(f"Class: {class_name}")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_pv4ger #

MPv4gerNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for M-PV4GER.

Source code in terratorch/datasets/m_pv4ger.py
class MPv4gerNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-PV4GER](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = ("BLUE", "GREEN", "RED")

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-pv4ger"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
        use_metadata: bool = False,
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
            use_metadata (bool): Whether to return metadata info (location coordinates).
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.use_metadata = use_metadata

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))  # noqa: S301
            class_index = attr_dict["label"]

        output = {"image": image.astype(np.float32)}

        if self.transform:
            output = self.transform(**output)

        output["label"] = class_index

        if self.use_metadata:
            output["location_coords"] = self._get_coords(image_id)

        return output

    def _get_coords(self, image_id: str) -> torch.Tensor:
        """Extract spatial coordinates from the image ID."""
        lat_str, lon_str = image_id.split(",")
        latitude = float(lat_str)
        longitude = float(lon_str)
        return torch.tensor([latitude, longitude], dtype=torch.float32)

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        label = sample["label"]

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        fig, ax = plt.subplots(figsize=(6, 6))
        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title(f"Class: {label}")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default', use_metadata=False) #

Initialize the dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'val', or 'test'.

'train'
bands Sequence[str]

Bands to be used. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Defaults to None, which applies default_transform().

None
partition str

Partition name for the dataset splits. Defaults to 'default'.

'default'
use_metadata bool

Whether to return metadata info (location coordinates).

False
Source code in terratorch/datasets/m_pv4ger.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
    use_metadata: bool = False,
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
        use_metadata (bool): Whether to return metadata info (location coordinates).
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.use_metadata = use_metadata

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

A sample returned by :meth:__getitem__.

required
suptitle str | None

Optional string to use as a suptitle.

None

Returns:

Type Description
Figure

matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_pv4ger.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    label = sample["label"]

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title(f"Class: {label}")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_cashew_plantation #

MBeninSmallHolderCashewsNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for M-BeninSmallHolderCashews.

Source code in terratorch/datasets/m_cashew_plantation.py
class MBeninSmallHolderCashewsNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-BeninSmallHolderCashews](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "COASTAL_AEROSOL",
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "NIR_NARROW",
        "WATER_VAPOR",
        "SWIR_1",
        "SWIR_2",
        "CLOUD_PROBABILITY",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-cashew-plant"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
        use_metadata: bool = False,
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
            use_metadata (bool): Whether to return metadata info (time).
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.use_metadata = use_metadata

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found in partition file."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def _get_date(self, keys) -> torch.Tensor:
        date_pattern = re.compile(r"\d{4}-\d{2}-\d{2}")

        date_str = None
        for key in keys:
            match = date_pattern.search(key)
            if match:
                date_str = match.group()
                break

        date = torch.zeros((1, 2), dtype=torch.float32)
        if date_str:
            date = pd.to_datetime(date_str, format="%Y-%m-%d")
            date = torch.tensor([[date.year, date.dayofyear - 1]], dtype=torch.float32)

        return date

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            temporal_coords = self._get_date(h5file)
            mask = np.array(h5file["label"])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()
        if self.use_metadata:
            output["temporal_coords"] = temporal_coords

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        mask = sample["mask"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()  # (H, W, C)

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        num_classes = len(np.unique(mask))
        cmap = plt.get_cmap("jet")
        norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default', use_metadata=False) #

Initialize the dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'val', or 'test'.

'train'
bands Sequence[str]

Bands to be used. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Defaults to None, which applies default_transform().

None
partition str

Partition name for the dataset splits. Defaults to 'default'.

'default'
use_metadata bool

Whether to return metadata info (time).

False
Source code in terratorch/datasets/m_cashew_plantation.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
    use_metadata: bool = False,
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
        use_metadata (bool): Whether to return metadata info (time).
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.use_metadata = use_metadata

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found in partition file."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

A sample returned by :meth:__getitem__.

required
suptitle str | None

Optional string to use as a suptitle.

None

Returns:

Type Description
Figure

matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_cashew_plantation.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    mask = sample["mask"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()  # (H, W, C)

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    num_classes = len(np.unique(mask))
    cmap = plt.get_cmap("jet")
    norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

    num_images = 4 if "prediction" in sample else 3
    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

    ax[0].imshow(rgb_image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    ax[1].imshow(mask, cmap=cmap, norm=norm)
    ax[1].set_title("Ground Truth Mask")
    ax[1].axis("off")

    ax[2].imshow(rgb_image)
    ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
    ax[2].set_title("GT Mask on Image")
    ax[2].axis("off")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        ax[3].imshow(prediction, cmap=cmap, norm=norm)
        ax[3].set_title("Predicted Mask")
        ax[3].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_nz_cattle #

MNzCattleNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for M-NZ-Cattle.

Source code in terratorch/datasets/m_nz_cattle.py
class MNzCattleNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-NZ-Cattle](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = ("BLUE", "GREEN", "RED")

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-nz-cattle"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
        use_metadata: bool = False,
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
            use_metadata (bool): Whether to return metadata info (time and location).
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(b) for b in bands]

        self.use_metadata = use_metadata

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        file_name = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())

            data_keys = [key for key in keys if "label" not in key]
            label_keys = [key for key in keys if "label" in key]

            temporal_coords = self._get_date(data_keys[0])

            bands = [np.array(h5file[key]) for key in data_keys]
            image = np.stack(bands, axis=-1)

            mask = np.array(h5file[label_keys[0]])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        if self.use_metadata:
            location_coords = self._get_coords(file_name)
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        return output

    def _get_coords(self, file_name: str) -> torch.Tensor:
        """Extract spatial coordinates from the file name."""
        match = re.search(r"_(\-?\d+\.\d+),(\-?\d+\.\d+)", file_name)
        if match:
            longitude, latitude = map(float, match.groups())

        return torch.tensor([latitude, longitude], dtype=torch.float32)

    def _get_date(self, band_name: str) -> torch.Tensor:
        date_str = band_name.split("_")[-1]
        date = pd.to_datetime(date_str, format="%Y-%m-%d")

        return torch.tensor([[date.year, date.dayofyear - 1]], dtype=torch.float32)

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        mask = sample["mask"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        num_classes = len(np.unique(mask))
        cmap = plt.get_cmap("jet")
        norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default', use_metadata=False) #

Initialize the dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'val', or 'test'.

'train'
bands Sequence[str]

Bands to be used. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Defaults to None, which applies default_transform().

None
partition str

Partition name for the dataset splits. Defaults to 'default'.

'default'
use_metadata bool

Whether to return metadata info (time and location).

False
Source code in terratorch/datasets/m_nz_cattle.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
    use_metadata: bool = False,
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
        use_metadata (bool): Whether to return metadata info (time and location).
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(b) for b in bands]

    self.use_metadata = use_metadata

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

A sample returned by :meth:__getitem__.

required
suptitle str | None

Optional string to use as a suptitle.

None

Returns:

Type Description
Figure

matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_nz_cattle.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    mask = sample["mask"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    num_classes = len(np.unique(mask))
    cmap = plt.get_cmap("jet")
    norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

    num_images = 4 if "prediction" in sample else 3
    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

    ax[0].imshow(rgb_image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    ax[1].imshow(mask, cmap=cmap, norm=norm)
    ax[1].set_title("Ground Truth Mask")
    ax[1].axis("off")

    ax[2].imshow(rgb_image)
    ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
    ax[2].set_title("GT Mask on Image")
    ax[2].axis("off")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        ax[3].imshow(prediction, cmap=cmap, norm=norm)
        ax[3].set_title("Predicted Mask")
        ax[3].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_chesapeake_landcover #

MChesapeakeLandcoverNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for M-ChesapeakeLandcover.

Source code in terratorch/datasets/m_chesapeake_landcover.py
class MChesapeakeLandcoverNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-ChesapeakeLandcover](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = ("BLUE", "GREEN", "NIR", "RED")

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-chesapeake"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found in partition file."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            mask = np.array(h5file["label"])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        mask = sample["mask"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()  # (H, W, C)

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        num_classes = len(np.unique(mask))
        cmap = plt.get_cmap("jet")
        norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default') #

Initialize the dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'val', or 'test'.

'train'
bands Sequence[str]

Bands to be used. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Defaults to None, which applies default_transform().

None
partition str

Partition name for the dataset splits. Defaults to 'default'.

'default'
Source code in terratorch/datasets/m_chesapeake_landcover.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found in partition file."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

A sample returned by :meth:__getitem__.

required
suptitle str | None

Optional string to use as a suptitle.

None

Returns:

Type Description
Figure

matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_chesapeake_landcover.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    mask = sample["mask"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()  # (H, W, C)

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    num_classes = len(np.unique(mask))
    cmap = plt.get_cmap("jet")
    norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

    num_images = 4 if "prediction" in sample else 3
    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

    ax[0].imshow(rgb_image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    ax[1].imshow(mask, cmap=cmap, norm=norm)
    ax[1].set_title("Ground Truth Mask")
    ax[1].axis("off")

    ax[2].imshow(rgb_image)
    ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
    ax[2].set_title("GT Mask on Image")
    ax[2].axis("off")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        ax[3].imshow(prediction, cmap=cmap, norm=norm)
        ax[3].set_title("Predicted Mask")
        ax[3].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_pv4ger_seg #

MPv4gerSegNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for M-PV4GER-SEG.

Source code in terratorch/datasets/m_pv4ger_seg.py
class MPv4gerSegNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-PV4GER-SEG](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = ("BLUE", "GREEN", "RED")

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-pv4ger-seg"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
        use_metadata: bool = False,
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
            use_metadata (bool): Whether to return metadata info (location coordinates).
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.use_metadata = use_metadata

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            mask = np.array(h5file["label"])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        if self.use_metadata:
            output["location_coords"] = self._get_coords(image_id)

        return output

    def _get_coords(self, image_id: str) -> torch.Tensor:
        """Extract spatial coordinates from the image ID."""
        lat_str, lon_str = image_id.split(",")
        latitude = float(lat_str)
        longitude = float(lon_str)
        return torch.tensor([latitude, longitude], dtype=torch.float32)


    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        mask = sample["mask"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        num_classes = len(np.unique(mask))
        cmap = plt.get_cmap("jet")
        norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default', use_metadata=False) #

Initialize the dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'val', or 'test'.

'train'
bands Sequence[str]

Bands to be used. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Defaults to None, which applies default_transform().

None
partition str

Partition name for the dataset splits. Defaults to 'default'.

'default'
use_metadata bool

Whether to return metadata info (location coordinates).

False
Source code in terratorch/datasets/m_pv4ger_seg.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
    use_metadata: bool = False,
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
        use_metadata (bool): Whether to return metadata info (location coordinates).
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.use_metadata = use_metadata

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

A sample returned by :meth:__getitem__.

required
suptitle str | None

Optional string to use as a suptitle.

None

Returns:

Type Description
Figure

matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_pv4ger_seg.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    mask = sample["mask"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    num_classes = len(np.unique(mask))
    cmap = plt.get_cmap("jet")
    norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

    num_images = 4 if "prediction" in sample else 3
    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

    ax[0].imshow(rgb_image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    ax[1].imshow(mask, cmap=cmap, norm=norm)
    ax[1].set_title("Ground Truth Mask")
    ax[1].axis("off")

    ax[2].imshow(rgb_image)
    ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
    ax[2].set_title("GT Mask on Image")
    ax[2].axis("off")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        ax[3].imshow(prediction, cmap=cmap, norm=norm)
        ax[3].set_title("Predicted Mask")
        ax[3].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_SA_crop_type #

MSACropTypeNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for M-SA-Crop-Type.

Source code in terratorch/datasets/m_SA_crop_type.py
class MSACropTypeNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-SA-Crop-Type](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "COASTAL_AEROSOL",
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "NIR_NARROW",
        "WATER_VAPOR",
        "SWIR_1",
        "SWIR_2",
        "CLOUD_PROBABILITY",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-SA-crop-type"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(b) for b in bands]

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            mask = np.array(h5file["label"])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        mask = sample["mask"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        num_classes = len(np.unique(mask))
        cmap = plt.get_cmap("jet")
        norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default') #

Initialize the dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'val', or 'test'.

'train'
bands Sequence[str]

Bands to be used. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Defaults to None, which applies default_transform().

None
partition str

Partition name for the dataset splits. Defaults to 'default'.

'default'
Source code in terratorch/datasets/m_SA_crop_type.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(b) for b in bands]

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

A sample returned by :meth:__getitem__.

required
suptitle str | None

Optional string to use as a suptitle.

None

Returns:

Type Description
Figure

matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_SA_crop_type.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    mask = sample["mask"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    num_classes = len(np.unique(mask))
    cmap = plt.get_cmap("jet")
    norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

    num_images = 4 if "prediction" in sample else 3
    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

    ax[0].imshow(rgb_image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    ax[1].imshow(mask, cmap=cmap, norm=norm)
    ax[1].set_title("Ground Truth Mask")
    ax[1].axis("off")

    ax[2].imshow(rgb_image)
    ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
    ax[2].set_title("GT Mask on Image")
    ax[2].axis("off")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        ax[3].imshow(prediction, cmap=cmap, norm=norm)
        ax[3].set_title("Predicted Mask")
        ax[3].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_neontree #

MNeonTreeNonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for M-NeonTree.

Source code in terratorch/datasets/m_neontree.py
class MNeonTreeNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-NeonTree](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = ("BLUE", "CANOPY_HEIGHT_MODEL", "GREEN", "NEON", "RED")

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-NeonTree"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = rgb_bands,
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to RGB bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            mask = np.array(h5file["label"])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)
        image = sample["image"]
        mask = sample["mask"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()  # (H, W, C)

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        num_classes = len(np.unique(mask))
        cmap = plt.get_cmap("jet")
        norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=rgb_bands, transform=None, partition='default') #

Initialize the dataset.

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

One of 'train', 'val', or 'test'.

'train'
bands Sequence[str]

Bands to be used. Defaults to RGB bands.

rgb_bands
transform Compose | None

Albumentations transform to be applied. Defaults to None, which applies default_transform().

None
partition str

Partition name for the dataset splits. Defaults to 'default'.

'default'
Source code in terratorch/datasets/m_neontree.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = rgb_bands,
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to RGB bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

A sample returned by :meth:__getitem__.

required
suptitle str | None

Optional string to use as a suptitle.

None

Returns:

Type Description
Figure

matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_neontree.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)
    image = sample["image"]
    mask = sample["mask"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()  # (H, W, C)

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    num_classes = len(np.unique(mask))
    cmap = plt.get_cmap("jet")
    norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

    num_images = 4 if "prediction" in sample else 3
    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

    ax[0].imshow(rgb_image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    ax[1].imshow(mask, cmap=cmap, norm=norm)
    ax[1].set_title("Ground Truth Mask")
    ax[1].axis("off")

    ax[2].imshow(rgb_image)
    ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
    ax[2].set_title("GT Mask on Image")
    ax[2].axis("off")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        ax[3].imshow(prediction, cmap=cmap, norm=norm)
        ax[3].set_title("Predicted Mask")
        ax[3].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.multi_temporal_crop_classification #

MultiTemporalCropClassification #

Bases: NonGeoDataset

NonGeo dataset implementation for multi-temporal crop classification.

Source code in terratorch/datasets/multi_temporal_crop_classification.py
class MultiTemporalCropClassification(NonGeoDataset):
    """NonGeo dataset implementation for [multi-temporal crop classification](https://huggingface.co/datasets/ibm-nasa-geospatial/multi-temporal-crop-classification)."""

    all_band_names = (
        "BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    )

    class_names = (
        "Natural Vegetation",
        "Forest",
        "Corn",
        "Soybeans",
        "Wetlands",
        "Developed / Barren",
        "Open Water",
        "Winter Wheat",
        "Alfalfa",
        "Fallow / Idle Cropland",
        "Cotton",
        "Sorghum",
        "Other",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    num_classes = 13
    time_steps = 3
    splits = {"train": "training", "val": "validation"}  # Only train and val splits available
    col_name = "chip_id"
    date_columns = ["first_img_date", "middle_img_date", "last_img_date"]

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = True,
        reduce_zero_label: bool = True,
        use_metadata: bool = False,
        metadata_file_name: str = "chips_df.csv",
    ) -> None:
        """Constructor

        Args:
            data_root (str): Path to the data root directory.
            split (str): one of 'train' or 'val'.
            bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the corresponding data module,
                should not include normalization. Defaults to None, which applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value.
                If None, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value.
                If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to True.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to True.
            use_metadata (bool): whether to return metadata info (time and location).
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {self.splits}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
        self.data_root = Path(data_root)

        data_dir = self.data_root / f"{split_name}_chips"
        self.image_files = sorted(glob.glob(os.path.join(data_dir, "*_merged.tif")))
        self.segmentation_mask_files = sorted(glob.glob(os.path.join(data_dir, "*.mask.tif")))
        split_file = self.data_root / f"{split_name}_data.txt"

        with open(split_file) as f:
            split = f.readlines()
        valid_files = {rf"{substring.strip()}" for substring in split}
        self.image_files = filter_valid_files(
            self.image_files,
            valid_files=valid_files,
            ignore_extensions=True,
            allow_substring=True,
        )
        self.segmentation_mask_files = filter_valid_files(
            self.segmentation_mask_files,
            valid_files=valid_files,
            ignore_extensions=True,
            allow_substring=True,
        )

        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.reduce_zero_label = reduce_zero_label
        self.expand_temporal_dimension = expand_temporal_dimension
        self.use_metadata = use_metadata
        self.metadata = None
        self.metadata_file_name = metadata_file_name
        if self.use_metadata:
            metadata_file = self.data_root / self.metadata_file_name
            self.metadata = pd.read_csv(metadata_file)
            self._build_image_metadata_mapping()

        # If no transform is given, apply only to transform to torch tensor
        self.transform = transform if transform else default_transform

    def _build_image_metadata_mapping(self):
        """Build a mapping from image filenames to metadata indices."""
        self.image_to_metadata_index = dict()

        for idx, image_file in enumerate(self.image_files):
            image_filename = Path(image_file).name
            image_id = image_filename.replace("_merged.tif", "").replace(".tif", "")
            metadata_indices = self.metadata.index[self.metadata[self.col_name] == image_id].tolist()
            self.image_to_metadata_index[idx] = metadata_indices[0]

    def __len__(self) -> int:
        return len(self.image_files)

    def _get_date(self, row: pd.Series) -> torch.Tensor:
        """Extract and format temporal coordinates (T, date) from metadata."""
        temporal_coords = []
        for col in self.date_columns:
            date_str = row[col]
            date = pd.to_datetime(date_str)
            temporal_coords.append([date.year, date.dayofyear - 1])

        return torch.tensor(temporal_coords, dtype=torch.float32)

    def _get_coords(self, image: DataArray) -> torch.Tensor:
        px = image.x.shape[0] // 2
        py = image.y.shape[0] // 2

        # get center point to reproject to lat/lon
        point = image.isel(band=0, x=slice(px, px + 1), y=slice(py, py + 1))
        point = point.rio.reproject("epsg:4326")

        lat_lon = np.asarray([point.y[0], point.x[0]])

        return torch.tensor(lat_lon, dtype=torch.float32)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace)

        location_coords, temporal_coords = None, None
        if self.use_metadata:
            location_coords = self._get_coords(image)
            metadata_idx = self.image_to_metadata_index.get(index, None)
            if metadata_idx is not None:
                row = self.metadata.iloc[metadata_idx]
                temporal_coords = self._get_date(row)

        # to channels last
        image = image.to_numpy()
        if self.expand_temporal_dimension:
            image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.bands))
        image = np.moveaxis(image, 0, -1)

        # filter bands
        image = image[..., self.band_indices]

        output = {
            "image": image.astype(np.float32),
            "mask": self._load_file(
                self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[0],
        }

        if self.reduce_zero_label:
            output["mask"] -= 1
        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        if self.use_metadata:
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        return output

    def _load_file(self, path: Path, nan_replace: int | float | None = None) -> DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            suptitle: optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample
        """
        num_images = self.time_steps + 2

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        images = sample["image"]
        images = images[rgb_indices, ...]  # Shape: (T, 3, H, W)

        processed_images = []
        for t in range(self.time_steps):
            img = images[t]
            img = img.permute(1, 2, 0)
            img = img.numpy()
            img = clip_image(img)
            processed_images.append(img)

        mask = sample["mask"].numpy()
        if "prediction" in sample:
            num_images += 1
        fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")
        ax[0].axis("off")

        norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
        for i, img in enumerate(processed_images):
            ax[i + 1].axis("off")
            ax[i + 1].title.set_text(f"T{i}")
            ax[i + 1].imshow(img)

        ax[self.time_steps + 1].axis("off")
        ax[self.time_steps + 1].title.set_text("Ground Truth Mask")
        ax[self.time_steps + 1].imshow(mask, cmap="jet", norm=norm)

        if "prediction" in sample:
            prediction = sample["prediction"]
            ax[self.time_steps + 2].axis("off")
            ax[self.time_steps + 2].title.set_text("Predicted Mask")
            ax[self.time_steps + 2].imshow(prediction, cmap="jet", norm=norm)

        cmap = plt.get_cmap("jet")
        legend_data = [[i, cmap(norm(i)), self.class_names[i]] for i in range(self.num_classes)]
        handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
        labels = [n for k, c, n in legend_data]
        ax[0].legend(handles, labels, loc="center")

        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=True, reduce_zero_label=True, use_metadata=False, metadata_file_name='chips_df.csv') #

Constructor

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

one of 'train' or 'val'.

'train'
bands list[str]

Bands that should be output by the dataset. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Should end with ToTensorV2(). If used through the corresponding data module, should not include normalization. Defaults to None, which applies ToTensorV2().

None
no_data_replace float | None

Replace nan values in input images with this value. If None, does no replacement. Defaults to None.

None
no_label_replace int | None

Replace nan values in label with this value. If none, does no replacement. Defaults to None.

None
expand_temporal_dimension bool

Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to True.

True
reduce_zero_label bool

Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to True.

True
use_metadata bool

whether to return metadata info (time and location).

False
Source code in terratorch/datasets/multi_temporal_crop_classification.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = True,
    reduce_zero_label: bool = True,
    use_metadata: bool = False,
    metadata_file_name: str = "chips_df.csv",
) -> None:
    """Constructor

    Args:
        data_root (str): Path to the data root directory.
        split (str): one of 'train' or 'val'.
        bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the corresponding data module,
            should not include normalization. Defaults to None, which applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value.
            If None, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value.
            If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to True.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to True.
        use_metadata (bool): whether to return metadata info (time and location).
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {self.splits}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
    self.data_root = Path(data_root)

    data_dir = self.data_root / f"{split_name}_chips"
    self.image_files = sorted(glob.glob(os.path.join(data_dir, "*_merged.tif")))
    self.segmentation_mask_files = sorted(glob.glob(os.path.join(data_dir, "*.mask.tif")))
    split_file = self.data_root / f"{split_name}_data.txt"

    with open(split_file) as f:
        split = f.readlines()
    valid_files = {rf"{substring.strip()}" for substring in split}
    self.image_files = filter_valid_files(
        self.image_files,
        valid_files=valid_files,
        ignore_extensions=True,
        allow_substring=True,
    )
    self.segmentation_mask_files = filter_valid_files(
        self.segmentation_mask_files,
        valid_files=valid_files,
        ignore_extensions=True,
        allow_substring=True,
    )

    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.reduce_zero_label = reduce_zero_label
    self.expand_temporal_dimension = expand_temporal_dimension
    self.use_metadata = use_metadata
    self.metadata = None
    self.metadata_file_name = metadata_file_name
    if self.use_metadata:
        metadata_file = self.data_root / self.metadata_file_name
        self.metadata = pd.read_csv(metadata_file)
        self._build_image_metadata_mapping()

    # If no transform is given, apply only to transform to torch tensor
    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

a sample returned by :meth:__getitem__

required
suptitle str | None

optional string to use as a suptitle

None

Returns:

Type Description
Figure

a matplotlib Figure with the rendered sample

Source code in terratorch/datasets/multi_temporal_crop_classification.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        suptitle: optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample
    """
    num_images = self.time_steps + 2

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    images = sample["image"]
    images = images[rgb_indices, ...]  # Shape: (T, 3, H, W)

    processed_images = []
    for t in range(self.time_steps):
        img = images[t]
        img = img.permute(1, 2, 0)
        img = img.numpy()
        img = clip_image(img)
        processed_images.append(img)

    mask = sample["mask"].numpy()
    if "prediction" in sample:
        num_images += 1
    fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")
    ax[0].axis("off")

    norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
    for i, img in enumerate(processed_images):
        ax[i + 1].axis("off")
        ax[i + 1].title.set_text(f"T{i}")
        ax[i + 1].imshow(img)

    ax[self.time_steps + 1].axis("off")
    ax[self.time_steps + 1].title.set_text("Ground Truth Mask")
    ax[self.time_steps + 1].imshow(mask, cmap="jet", norm=norm)

    if "prediction" in sample:
        prediction = sample["prediction"]
        ax[self.time_steps + 2].axis("off")
        ax[self.time_steps + 2].title.set_text("Predicted Mask")
        ax[self.time_steps + 2].imshow(prediction, cmap="jet", norm=norm)

    cmap = plt.get_cmap("jet")
    legend_data = [[i, cmap(norm(i)), self.class_names[i]] for i in range(self.num_classes)]
    handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
    labels = [n for k, c, n in legend_data]
    ax[0].legend(handles, labels, loc="center")

    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.open_sentinel_map #

OpenSentinelMap #

Bases: NonGeoDataset

Pytorch Dataset class to load samples from the OpenSentinelMap dataset, supporting multiple bands and temporal sampling strategies.

Source code in terratorch/datasets/open_sentinel_map.py
class OpenSentinelMap(NonGeoDataset):
    """
    Pytorch Dataset class to load samples from the [OpenSentinelMap](https://visionsystemsinc.github.io/open-sentinel-map/) dataset, supporting
    multiple bands and temporal sampling strategies.
    """

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: list[str] | None = None,
        transform: A.Compose | None = None,
        spatial_interpolate_and_stack_temporally: bool = True,  # noqa: FBT001, FBT002
        pad_image: int | None = None,
        truncate_image: int | None = None,
        target: int = 0,
        pick_random_pair: bool = True,  # noqa: FBT002, FBT001
    ) -> None:
        """

        Args:
            data_root (str): Path to the root directory of the dataset.
            split (str): Dataset split to load. Options are 'train', 'val', or 'test'. Defaults to 'train'.
            bands (list of str, optional): List of band names to load. Defaults to ['gsd_10', 'gsd_20', 'gsd_60'].
            transform (albumentations.Compose, optional): Albumentations transformations to apply to the data.
            spatial_interpolate_and_stack_temporally (bool): If True, the bands are interpolated and concatenated over time.
                Default is True.
            pad_image (int, optional): Number of timesteps to pad the time dimension of the image.
                If None, no padding is applied.
            truncate_image (int, optional): Number of timesteps to truncate the time dimension of the image.
                If None, no truncation is performed.
            target (int): Specifies which target class to use from the mask. Default is 0.
            pick_random_pair (bool): If True, selects two random images from the temporal sequence. Default is True.
        """
        split = "test"
        if bands is None:
            bands = ["gsd_10", "gsd_20", "gsd_60"]

        allowed_bands = {"gsd_10", "gsd_20", "gsd_60"}
        for band in bands:
            if band not in allowed_bands:
                msg = f"Band '{band}' is not recognized. Available values are: {', '.join(allowed_bands)}"
                raise ValueError(msg)

        if split not in ["train", "val", "test"]:
            msg = f"Split '{split}' not recognized. Use 'train', 'val', or 'test'."
            raise ValueError(msg)

        self.data_root = Path(data_root)
        split_mapping = {"train": "training", "val": "validation", "test": "testing"}
        split = split_mapping[split]
        self.imagery_root = self.data_root / "osm_sentinel_imagery"
        self.label_root = self.data_root / "osm_label_images_v10"
        self.auxiliary_data = pd.read_csv(self.data_root / "spatial_cell_info.csv")
        self.auxiliary_data = self.auxiliary_data[self.auxiliary_data["split"] == split]
        self.bands = bands
        self.transform = transform if transform else lambda **batch: to_tensor(batch)
        self.label_mappings = self._load_label_mappings()
        self.split_data = self.auxiliary_data[self.auxiliary_data["split"] == split]
        self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally
        self.pad_image = pad_image
        self.truncate_image = truncate_image
        self.target = target
        self.pick_random_pair = pick_random_pair

        self.image_files = []
        self.label_files = []

        for _, row in self.split_data.iterrows():
            mgrs_tile = row["MGRS_tile"]
            spatial_cell = str(row["cell_id"])

            label_file = self.label_root / mgrs_tile / f"{spatial_cell}.png"

            if label_file.exists():
                self.image_files.append((mgrs_tile, spatial_cell))
                self.label_files.append(label_file)

    def _load_label_mappings(self):
        with open(self.data_root / "osm_categories.json") as f:
            return json.load(f)

    def _extract_date_from_filename(self, filename: str) -> str:
        match = re.search(r"(\d{8})", filename)
        if match:
            return match.group(1)
        else:
            msg = f"Date not found in filename {filename}"
            raise ValueError(msg)

    def __len__(self) -> int:
        return len(self.image_files)

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None, show_axes: bool | None = False) -> Figure:
        if "gsd_10" not in self.bands:
            return None

        num_images = len([key for key in sample if key.startswith("image")])
        images = []

        for i in range(1, num_images + 1):
            image_dict = sample[f"image{i}"]
            image = image_dict["gsd_10"]
            if isinstance(image, Tensor):
                image = image.numpy()

            image = image.take(range(3), axis=2)
            image = image.squeeze()
            image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
            image = np.clip(image, 0, 1)
            images.append(image)

        label_mask = sample["mask"]
        if isinstance(label_mask, Tensor):
            label_mask = label_mask.numpy()

        return self._plot_sample(images, label_mask, suptitle=suptitle, show_axes=show_axes)

    def _plot_sample(
        self,
        images: list[np.ndarray],
        label: np.ndarray,
        suptitle: str | None = None,
        show_axes: bool = False,
    ) -> Figure:
        num_images = len(images)
        fig, ax = plt.subplots(1, num_images + 1, figsize=(15, 5))
        axes_visibility = "on" if show_axes else "off"

        for i, image in enumerate(images):
            ax[i].imshow(image)
            ax[i].set_title(f"Image {i + 1}")
            ax[i].axis(axes_visibility)

        ax[-1].imshow(label, cmap="gray")
        ax[-1].set_title("Ground Truth Mask")
        ax[-1].axis(axes_visibility)

        if suptitle:
            plt.suptitle(suptitle)

        return fig

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        mgrs_tile, spatial_cell = self.image_files[index]
        spatial_cell_path = self.imagery_root / mgrs_tile / spatial_cell

        npz_files = list(spatial_cell_path.glob("*.npz"))
        npz_files.sort(key=lambda x: self._extract_date_from_filename(x.stem))

        if self.pick_random_pair:
            npz_files = random.sample(npz_files, 2)
            npz_files.sort(key=lambda x: self._extract_date_from_filename(x.stem))

        output = {}

        if self.spatial_interpolate_and_stack_temporally:
            images_over_time = []
            for _, npz_file in enumerate(npz_files):
                data = np.load(npz_file)
                interpolated_bands = []
                for band in self.bands:
                    band_frame = data[band]
                    band_frame = torch.from_numpy(band_frame).float()
                    band_frame = band_frame.permute(2, 0, 1)
                    interpolated = F.interpolate(
                        band_frame.unsqueeze(0), size=MAX_TEMPORAL_IMAGE_SIZE, mode="bilinear", align_corners=False
                    ).squeeze(0)
                    interpolated_bands.append(interpolated)
                concatenated_bands = torch.cat(interpolated_bands, dim=0)
                images_over_time.append(concatenated_bands)

            images = torch.stack(images_over_time, dim=0).numpy()
            if self.truncate_image:
                images = images[-self.truncate_image :]
            if self.pad_image:
                images = pad_numpy(images, self.pad_image)

            output["image"] = images.transpose(0, 2, 3, 1)
        else:
            image_dict = {band: [] for band in self.bands}
            for _, npz_file in enumerate(npz_files):
                data = np.load(npz_file)
                for band in self.bands:
                    band_frames = data[band]
                    band_frames = band_frames.astype(np.float32)
                    band_frames = np.transpose(band_frames, (2, 0, 1))
                    image_dict[band].append(band_frames)

            final_image_dict = {}
            for band in self.bands:
                band_images = image_dict[band]
                if self.truncate_image:
                    band_images = band_images[-self.truncate_image :]
                if self.pad_image:
                    band_images = [pad_numpy(img, self.pad_image) for img in band_images]
                band_images = np.stack(band_images, axis=0)
                final_image_dict[band] = band_images

            output["image"] = final_image_dict

        label_file = self.label_files[index]
        mask = np.array(Image.open(label_file)).astype(int)

        # Map 'unlabel' (254) and 'none' (255) to unused classes 15 and 16 for processing
        mask[mask == 254] = 15  # noqa: PLR2004
        mask[mask == 255] = 16  # noqa: PLR2004
        output["mask"] = mask[:, :, self.target]

        if self.transform:
            output = self.transform(**output)

        return output
__init__(data_root, split='train', bands=None, transform=None, spatial_interpolate_and_stack_temporally=True, pad_image=None, truncate_image=None, target=0, pick_random_pair=True) #

Parameters:

Name Type Description Default
data_root str

Path to the root directory of the dataset.

required
split str

Dataset split to load. Options are 'train', 'val', or 'test'. Defaults to 'train'.

'train'
bands list of str

List of band names to load. Defaults to ['gsd_10', 'gsd_20', 'gsd_60'].

None
transform Compose

Albumentations transformations to apply to the data.

None
spatial_interpolate_and_stack_temporally bool

If True, the bands are interpolated and concatenated over time. Default is True.

True
pad_image int

Number of timesteps to pad the time dimension of the image. If None, no padding is applied.

None
truncate_image int

Number of timesteps to truncate the time dimension of the image. If None, no truncation is performed.

None
target int

Specifies which target class to use from the mask. Default is 0.

0
pick_random_pair bool

If True, selects two random images from the temporal sequence. Default is True.

True
Source code in terratorch/datasets/open_sentinel_map.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: list[str] | None = None,
    transform: A.Compose | None = None,
    spatial_interpolate_and_stack_temporally: bool = True,  # noqa: FBT001, FBT002
    pad_image: int | None = None,
    truncate_image: int | None = None,
    target: int = 0,
    pick_random_pair: bool = True,  # noqa: FBT002, FBT001
) -> None:
    """

    Args:
        data_root (str): Path to the root directory of the dataset.
        split (str): Dataset split to load. Options are 'train', 'val', or 'test'. Defaults to 'train'.
        bands (list of str, optional): List of band names to load. Defaults to ['gsd_10', 'gsd_20', 'gsd_60'].
        transform (albumentations.Compose, optional): Albumentations transformations to apply to the data.
        spatial_interpolate_and_stack_temporally (bool): If True, the bands are interpolated and concatenated over time.
            Default is True.
        pad_image (int, optional): Number of timesteps to pad the time dimension of the image.
            If None, no padding is applied.
        truncate_image (int, optional): Number of timesteps to truncate the time dimension of the image.
            If None, no truncation is performed.
        target (int): Specifies which target class to use from the mask. Default is 0.
        pick_random_pair (bool): If True, selects two random images from the temporal sequence. Default is True.
    """
    split = "test"
    if bands is None:
        bands = ["gsd_10", "gsd_20", "gsd_60"]

    allowed_bands = {"gsd_10", "gsd_20", "gsd_60"}
    for band in bands:
        if band not in allowed_bands:
            msg = f"Band '{band}' is not recognized. Available values are: {', '.join(allowed_bands)}"
            raise ValueError(msg)

    if split not in ["train", "val", "test"]:
        msg = f"Split '{split}' not recognized. Use 'train', 'val', or 'test'."
        raise ValueError(msg)

    self.data_root = Path(data_root)
    split_mapping = {"train": "training", "val": "validation", "test": "testing"}
    split = split_mapping[split]
    self.imagery_root = self.data_root / "osm_sentinel_imagery"
    self.label_root = self.data_root / "osm_label_images_v10"
    self.auxiliary_data = pd.read_csv(self.data_root / "spatial_cell_info.csv")
    self.auxiliary_data = self.auxiliary_data[self.auxiliary_data["split"] == split]
    self.bands = bands
    self.transform = transform if transform else lambda **batch: to_tensor(batch)
    self.label_mappings = self._load_label_mappings()
    self.split_data = self.auxiliary_data[self.auxiliary_data["split"] == split]
    self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally
    self.pad_image = pad_image
    self.truncate_image = truncate_image
    self.target = target
    self.pick_random_pair = pick_random_pair

    self.image_files = []
    self.label_files = []

    for _, row in self.split_data.iterrows():
        mgrs_tile = row["MGRS_tile"]
        spatial_cell = str(row["cell_id"])

        label_file = self.label_root / mgrs_tile / f"{spatial_cell}.png"

        if label_file.exists():
            self.image_files.append((mgrs_tile, spatial_cell))
            self.label_files.append(label_file)

terratorch.datasets.openearthmap #

OpenEarthMapNonGeo #

Bases: NonGeoDataset

OpenEarthMapNonGeo Dataset for non-georeferenced imagery.

This dataset class handles non-georeferenced image data from the OpenEarthMap dataset. It supports configurable band sets and transformations, and performs cropping operations to ensure that the images conform to the required input dimensions. The dataset is split into "train", "test", and "val" subsets based on the provided split parameter.

Source code in terratorch/datasets/openearthmap.py
class OpenEarthMapNonGeo(NonGeoDataset):
    """
    [OpenEarthMapNonGeo](https://open-earth-map.org/) Dataset for non-georeferenced imagery.

    This dataset class handles non-georeferenced image data from the OpenEarthMap dataset.
    It supports configurable band sets and transformations, and performs cropping operations
    to ensure that the images conform to the required input dimensions. The dataset is split
    into "train", "test", and "val" subsets based on the provided split parameter.
    """


    all_band_names = ("BLUE","GREEN","RED")

    rgb_bands = ("RED","GREEN","BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    def __init__(self, data_root: str,
                 bands: Sequence[str] = BAND_SETS["all"],
                 transform: A.Compose | None = None,
                 split="train",
                 crop_size: int = 256,
                 random_crop: bool = True) -> None:
        """
        Initialize a new instance of the OpenEarthMapNonGeo dataset.

        Args:
            data_root (str): The root directory containing the dataset files.
            bands (Sequence[str], optional): A list of band names to be used. Default is BAND_SETS["all"].
            transform (A.Compose or None, optional): A transformation pipeline to be applied to the data.
                If None, a default transform converting the data to a tensor is applied.
            split (str, optional): The dataset split to use ("train", "test", or "val"). Default is "train".
            crop_size (int, optional): The size (in pixels) of the crop to apply to images. Must be greater than 0.
                Default is 256.
            random_crop (bool, optional): If True, performs a random crop; otherwise, performs a center crop.
                Default is True.

        Raises:
            Exception: If the provided split is not one of "train", "test", or "val".
            AssertionError: If crop_size is not greater than 0.
        """
        super().__init__()
        if split not in ["train", "test", "val"]:
            msg = "Split must be one of train, test, val."
            raise Exception(msg)

        self.transform = transform if transform else lambda **batch: to_tensor(batch, transpose=False)
        self.split = split
        self.data_root = data_root

        # images in openearthmap are not all 1024x1024 and must be cropped
        self.crop_size = crop_size
        self.random_crop = random_crop

        assert self.crop_size > 0, "Crop size must be greater than 0"

        self.image_files = self._get_file_paths(Path(self.data_root, f"{split}.txt"))

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        image_path, label_path = self.image_files[index]

        with rasterio.open(image_path) as src:
            image = src.read()
        with rasterio.open(label_path) as src:
            mask = src.read()

        # some images in the dataset are not perfect squares
        # cropping to fit to the prepare_features_for_image_model call
        if self.random_crop:
            image, mask = self._random_crop(image, mask)
        else:
            image, mask = self._center_crop(image, mask)

        output =  {
            "image": image.astype(np.float32),
            "mask": mask
        }

        output = self.transform(**output)
        output['mask'] = output['mask'].long()

        return output

    def _parse_file_name(self, file_name: str):
        underscore_pos = file_name.rfind('_')
        folder_name = file_name[:underscore_pos]
        region_path = Path(self.data_root, folder_name)
        image_path = Path(region_path, "images", file_name)
        label_path = Path(region_path, "labels", file_name)
        return image_path, label_path

    def _get_file_paths(self, text_file_path: str):
        with open(text_file_path, 'r') as file:
            lines = file.readlines()
            file_paths = [self._parse_file_name(line.strip()) for line in lines]
        return file_paths

    def __len__(self):
        return len(self.image_files)

    def _random_crop(self, image, mask):
        h, w = image.shape[1:]
        top = np.random.randint(0, h - self.crop_size)
        left = np.random.randint(0, w - self.crop_size)

        image = image[:, top: top + self.crop_size, left: left + self.crop_size]
        mask = mask[:, top: top + self.crop_size, left: left + self.crop_size]

        return image, mask

    def _center_crop(self, image, mask):
        h, w = image.shape[1:]
        top = (h - self.crop_size) // 2
        left = (w - self.crop_size) // 2

        image = image[:, top: top + self.crop_size, left: left + self.crop_size]
        mask = mask[:, top: top + self.crop_size, left: left + self.crop_size]

        return image, mask

    def plot(self, arg, suptitle: str | None = None) -> None:
        pass

    def plot_sample(self, sample, prediction=None, suptitle: str | None = None, class_names=None):
        pass
__init__(data_root, bands=BAND_SETS['all'], transform=None, split='train', crop_size=256, random_crop=True) #

Initialize a new instance of the OpenEarthMapNonGeo dataset.

Parameters:

Name Type Description Default
data_root str

The root directory containing the dataset files.

required
bands Sequence[str]

A list of band names to be used. Default is BAND_SETS["all"].

BAND_SETS['all']
transform Compose or None

A transformation pipeline to be applied to the data. If None, a default transform converting the data to a tensor is applied.

None
split str

The dataset split to use ("train", "test", or "val"). Default is "train".

'train'
crop_size int

The size (in pixels) of the crop to apply to images. Must be greater than 0. Default is 256.

256
random_crop bool

If True, performs a random crop; otherwise, performs a center crop. Default is True.

True

Raises:

Type Description
Exception

If the provided split is not one of "train", "test", or "val".

AssertionError

If crop_size is not greater than 0.

Source code in terratorch/datasets/openearthmap.py
def __init__(self, data_root: str,
             bands: Sequence[str] = BAND_SETS["all"],
             transform: A.Compose | None = None,
             split="train",
             crop_size: int = 256,
             random_crop: bool = True) -> None:
    """
    Initialize a new instance of the OpenEarthMapNonGeo dataset.

    Args:
        data_root (str): The root directory containing the dataset files.
        bands (Sequence[str], optional): A list of band names to be used. Default is BAND_SETS["all"].
        transform (A.Compose or None, optional): A transformation pipeline to be applied to the data.
            If None, a default transform converting the data to a tensor is applied.
        split (str, optional): The dataset split to use ("train", "test", or "val"). Default is "train".
        crop_size (int, optional): The size (in pixels) of the crop to apply to images. Must be greater than 0.
            Default is 256.
        random_crop (bool, optional): If True, performs a random crop; otherwise, performs a center crop.
            Default is True.

    Raises:
        Exception: If the provided split is not one of "train", "test", or "val".
        AssertionError: If crop_size is not greater than 0.
    """
    super().__init__()
    if split not in ["train", "test", "val"]:
        msg = "Split must be one of train, test, val."
        raise Exception(msg)

    self.transform = transform if transform else lambda **batch: to_tensor(batch, transpose=False)
    self.split = split
    self.data_root = data_root

    # images in openearthmap are not all 1024x1024 and must be cropped
    self.crop_size = crop_size
    self.random_crop = random_crop

    assert self.crop_size > 0, "Crop size must be greater than 0"

    self.image_files = self._get_file_paths(Path(self.data_root, f"{split}.txt"))

terratorch.datasets.pastis #

PASTIS #

Bases: NonGeoDataset

" Pytorch Dataset class to load samples from the PASTIS dataset, for semantic and panoptic segmentation.

Source code in terratorch/datasets/pastis.py
class PASTIS(NonGeoDataset):
    """ "
    Pytorch Dataset class to load samples from the [PASTIS](https://github.com/VSainteuf/pastis-benchmark) dataset,
    for semantic and panoptic segmentation.
    """

    def __init__(
        self,
        data_root,
        norm=True,  # noqa: FBT002
        target="semantic",
        folds=None,
        reference_date="2018-09-01",
        date_interval=(-200, 600),
        class_mapping=None,
        transform=None,
        truncate_image=None,
        pad_image=None,
        satellites=["S2"],  # noqa: B006
    ):
        """

        Args:
            data_root (str): Path to the dataset.
            norm (bool): If true, images are standardised using pre-computed
                channel-wise means and standard deviations.
            reference_date (str, Format : 'YYYY-MM-DD'): Defines the reference date
                based on which all observation dates are expressed. Along with the image
                time series and the target tensor, this dataloader yields the sequence
                of observation dates (in terms of number of days since the reference
                date). This sequence of dates is used for instance for the positional
                encoding in attention based approaches.
            target (str): 'semantic' or 'instance'. Defines which type of target is
                returned by the dataloader.
                * If 'semantic' the target tensor is a tensor containing the class of
                each pixel.
                * If 'instance' the target tensor is the concatenation of several
                signals, necessary to train the Parcel-as-Points module:
                    - the centerness heatmap,
                    - the instance ids,
                    - the voronoi partitioning of the patch with regards to the parcels'
                    centers,
                    - the (height, width) size of each parcel,
                    - the semantic label of each parcel,
                    - the semantic label of each pixel.
            folds (list, optional): List of ints specifying which of the 5 official
                folds to load. By default (when None is specified), all folds are loaded.
            class_mapping (dict, optional): A dictionary to define a mapping between the
                default 18 class nomenclature and another class grouping. If not provided,
                the default class mapping is used.
            transform (callable, optional): A transform to apply to the loaded data
                (images, dates, and masks). By default, no transformation is applied.
            truncate_image (int, optional): Truncate the time dimension of the image to
                a specified number of timesteps. If None, no truncation is performed.
            pad_image (int, optional): Pad the time dimension of the image to a specified
                number of timesteps. If None, no padding is applied.
            satellites (list): Defines the satellites to use. If you are using PASTIS-R, you
                have access to Sentinel-2 imagery and Sentinel-1 observations in Ascending
                and Descending orbits, respectively S2, S1A, and S1D. For example, use
                satellites=['S2', 'S1A'] for Sentinel-2 + Sentinel-1 ascending time series,
                or satellites=['S2', 'S1A', 'S1D'] to retrieve all time series. If you are using
                PASTIS, only S2 observations are available.
        """
        if target not in ["semantic", "instance"]:
            msg = f"Target '{target}' not recognized. Use 'semantic', or 'instance'."
            raise ValueError(msg)
        valid_satellites = {"S2", "S1A", "S1D"}
        for sat in satellites:
            if sat not in valid_satellites:
                msg = f"Satellite '{sat}' not recognized. Valid options are {valid_satellites}."
                raise ValueError(msg)

        super().__init__()
        self.data_root = data_root
        self.norm = norm
        self.reference_date = datetime(*map(int, reference_date.split("-")), tzinfo=timezone.utc)
        self.class_mapping = np.vectorize(lambda x: class_mapping[x]) if class_mapping is not None else class_mapping
        self.target = target
        self.satellites = satellites
        self.transform = transform
        self.truncate_image = truncate_image
        self.pad_image = pad_image
        # loads patches metadata
        self.meta_patch = gpd.read_file(os.path.join(data_root, "metadata.geojson"))
        self.meta_patch.index = self.meta_patch["ID_PATCH"].astype(int)
        self.meta_patch.sort_index(inplace=True)
        # stores table for each satalite date
        self.date_tables = {s: None for s in satellites}
        # date interval used in the PASTIS benchmark paper.
        date_interval_begin, date_interval_end = date_interval
        self.date_range = np.array(range(date_interval_begin, date_interval_end))
        for s in satellites:
            # maps patches to its observation dates
            dates = self.meta_patch[f"dates-{s}"]
            date_table = pd.DataFrame(index=self.meta_patch.index, columns=self.date_range, dtype=int)
            for pid, date_seq in dates.items():
                if type(date_seq) is str:
                    date_seq = json.loads(date_seq)  # noqa: PLW2901
                # convert date to days since obersavation format
                d = pd.DataFrame().from_dict(date_seq, orient="index")
                d = d[0].apply(
                    lambda x: (
                        datetime(int(str(x)[:4]), int(str(x)[4:6]), int(str(x)[6:]), tzinfo=timezone.utc)
                        - self.reference_date
                    ).days
                )
                date_table.loc[pid, d.values] = 1
            date_table = date_table.fillna(0)
            self.date_tables[s] = {
                index: np.array(list(d.values())) for index, d in date_table.to_dict(orient="index").items()
            }

        # selects patches correspondig to selected folds
        if folds is not None:
            self.meta_patch = pd.concat([self.meta_patch[self.meta_patch["Fold"] == f] for f in folds])

        self.len = self.meta_patch.shape[0]
        self.id_patches = self.meta_patch.index

        # loads normalization values
        if norm:
            self.norm = {}
            for s in self.satellites:
                with open(os.path.join(data_root, f"NORM_{s}_patch.json")) as file:
                    normvals = json.loads(file.read())
                selected_folds = folds if folds is not None else range(1, 6)
                means = [normvals[f"Fold_{f}"]["mean"] for f in selected_folds]
                stds = [normvals[f"Fold_{f}"]["std"] for f in selected_folds]
                self.norm[s] = np.stack(means).mean(axis=0), np.stack(stds).mean(axis=0)
                self.norm[s] = (
                    self.norm[s][0],
                    self.norm[s][1],
                )
        else:
            self.norm = None

    def __len__(self):
        return self.len

    def get_dates(self, id_patch, sat):
        return self.date_range[np.where(self.date_tables[sat][id_patch] == 1)[0]]

    def __getitem__(self, item):
        id_patch = self.id_patches[item]
        output = {}
        satellites = {}
        for satellite in self.satellites:
            data = np.load(
                os.path.join(
                    self.data_root,
                    f"DATA_{satellite}",
                    f"{satellite}_{id_patch}.npy",
                )
            ).astype(np.float32)

            if self.norm is not None:
                data = data - self.norm[satellite][0][None, :, None, None]
                data = data / self.norm[satellite][1][None, :, None, None]

            if self.truncate_image and data.shape[0] > self.truncate_image:
                data = data[-self.truncate_image :]

            if self.pad_image and data.shape[0] < self.pad_image:
                data = pad_numpy(data, self.pad_image)

            satellites[satellite] = data.astype(np.float32)

        if self.target == "semantic":
            target = np.load(os.path.join(self.data_root, "ANNOTATIONS", f"TARGET_{id_patch}.npy"))
            target = target[0].astype(int)
            if self.class_mapping is not None:
                target = self.class_mapping(target)
        elif self.target == "instance":
            heatmap = np.load(os.path.join(self.data_root, "INSTANCE_ANNOTATIONS", f"HEATMAP_{id_patch}.npy"))
            instance_ids = np.load(os.path.join(self.data_root, "INSTANCE_ANNOTATIONS", f"INSTANCES_{id_patch}.npy"))
            zones_path = os.path.join(self.data_root, "INSTANCE_ANNOTATIONS", f"ZONES_{id_patch}.npy")
            pixel_to_object_mapping = np.load(zones_path)
            pixel_semantic_annotation = np.load(os.path.join(self.data_root, "ANNOTATIONS", f"TARGET_{id_patch}.npy"))

            if self.class_mapping is not None:
                pixel_semantic_annotation = self.class_mapping(pixel_semantic_annotation[0])
            else:
                pixel_semantic_annotation = pixel_semantic_annotation[0]

            size = np.zeros((*instance_ids.shape, 2))
            object_semantic_annotation = np.zeros(instance_ids.shape)
            for instance_id in np.unique(instance_ids):
                if instance_id != 0:
                    h = (instance_ids == instance_id).any(axis=-1).sum()
                    w = (instance_ids == instance_id).any(axis=-2).sum()
                    size[pixel_to_object_mapping == instance_id] = (h, w)
                    semantic_value = pixel_semantic_annotation[instance_ids == instance_id][0]
                    object_semantic_annotation[pixel_to_object_mapping == instance_id] = semantic_value

            target = np.concatenate(
                [
                    heatmap[:, :, None],
                    instance_ids[:, :, None],
                    pixel_to_object_mapping[:, :, None],
                    size,
                    object_semantic_annotation[:, :, None],
                    pixel_semantic_annotation[:, :, None],
                ],
                axis=-1,
            ).astype(np.float32)

        dates = {}
        for satellite in self.satellites:
            date = np.array(self.get_dates(id_patch, satellite))

            if self.truncate_image and len(date) > self.truncate_image:
                date = date[-self.truncate_image :]

            if self.pad_image and len(date) < self.pad_image:
                date = pad_dates_numpy(date, self.pad_image)

            dates[satellite] = torch.from_numpy(date)

        output["image"] = satellites["S2"].transpose(0, 2, 3, 1)
        output["mask"] = target

        if self.transform:
            output = self.transform(**output)

        output.update(satellites)
        output["dates"] = dates

        return output

    def plot(self, sample, suptitle=None, show_axes=False):
        dates = sample["dates"]
        target = sample["target"]

        if "S2" not in sample:
            warnings.warn("No RGB image.", stacklevel=2)
            return None

        image_data = sample["S2"]
        date_data = dates["S2"]

        rgb_images = []
        for i in range(image_data.shape[0]):
            rgb_image = image_data[i, :3, :, :].numpy().transpose(1, 2, 0)

            rgb_min = rgb_image.min(axis=(0, 1), keepdims=True)
            rgb_max = rgb_image.max(axis=(0, 1), keepdims=True)
            denom = rgb_max - rgb_min
            denom[denom == 0] = 1
            rgb_image = (rgb_image - rgb_min) / denom

            rgb_images.append(np.clip(rgb_image, 0, 1))

        return self._plot_sample(rgb_images, date_data, target, suptitle=suptitle, show_axes=show_axes)

    def _plot_sample(
        self,
        images: list[np.ndarray],
        dates: torch.Tensor,
        target: torch.Tensor | None,
        suptitle: str | None = None,
        show_axes: bool | None = False,
    ):
        num_images = len(images)
        cols = 5
        rows = (num_images + cols) // cols

        fig, ax = plt.subplots(rows, cols, figsize=(20, 4 * rows))
        axes_visibility = "on" if show_axes else "off"

        for i, image in enumerate(images):
            ax[i // cols, i % cols].imshow(image)
            ax[i // cols, i % cols].set_title(f"Image {i + 1} - Day {dates[i].item()}")
            ax[i // cols, i % cols].axis(axes_visibility)

        if target is not None:
            if rows * cols > num_images:
                target_ax = ax[(num_images) // cols, (num_images) % cols]
            else:
                fig.add_subplot(rows + 1, 1, 1)
                target_ax = fig.gca()

            target_ax.imshow(target.numpy(), cmap="tab20")
            target_ax.set_title("Target")
            target_ax.axis(axes_visibility)

        for k in range(num_images + 1, rows * cols):
            ax[k // cols, k % cols].axis(axes_visibility)

        if suptitle:
            plt.suptitle(suptitle)

        plt.tight_layout()
        return fig
__init__(data_root, norm=True, target='semantic', folds=None, reference_date='2018-09-01', date_interval=(-200, 600), class_mapping=None, transform=None, truncate_image=None, pad_image=None, satellites=['S2']) #

Parameters:

Name Type Description Default
data_root str

Path to the dataset.

required
norm bool

If true, images are standardised using pre-computed channel-wise means and standard deviations.

True
reference_date (str, Format)

'YYYY-MM-DD'): Defines the reference date based on which all observation dates are expressed. Along with the image time series and the target tensor, this dataloader yields the sequence of observation dates (in terms of number of days since the reference date). This sequence of dates is used for instance for the positional encoding in attention based approaches.

'2018-09-01'
target str

'semantic' or 'instance'. Defines which type of target is returned by the dataloader. * If 'semantic' the target tensor is a tensor containing the class of each pixel. * If 'instance' the target tensor is the concatenation of several signals, necessary to train the Parcel-as-Points module: - the centerness heatmap, - the instance ids, - the voronoi partitioning of the patch with regards to the parcels' centers, - the (height, width) size of each parcel, - the semantic label of each parcel, - the semantic label of each pixel.

'semantic'
folds list

List of ints specifying which of the 5 official folds to load. By default (when None is specified), all folds are loaded.

None
class_mapping dict

A dictionary to define a mapping between the default 18 class nomenclature and another class grouping. If not provided, the default class mapping is used.

None
transform callable

A transform to apply to the loaded data (images, dates, and masks). By default, no transformation is applied.

None
truncate_image int

Truncate the time dimension of the image to a specified number of timesteps. If None, no truncation is performed.

None
pad_image int

Pad the time dimension of the image to a specified number of timesteps. If None, no padding is applied.

None
satellites list

Defines the satellites to use. If you are using PASTIS-R, you have access to Sentinel-2 imagery and Sentinel-1 observations in Ascending and Descending orbits, respectively S2, S1A, and S1D. For example, use satellites=['S2', 'S1A'] for Sentinel-2 + Sentinel-1 ascending time series, or satellites=['S2', 'S1A', 'S1D'] to retrieve all time series. If you are using PASTIS, only S2 observations are available.

['S2']
Source code in terratorch/datasets/pastis.py
def __init__(
    self,
    data_root,
    norm=True,  # noqa: FBT002
    target="semantic",
    folds=None,
    reference_date="2018-09-01",
    date_interval=(-200, 600),
    class_mapping=None,
    transform=None,
    truncate_image=None,
    pad_image=None,
    satellites=["S2"],  # noqa: B006
):
    """

    Args:
        data_root (str): Path to the dataset.
        norm (bool): If true, images are standardised using pre-computed
            channel-wise means and standard deviations.
        reference_date (str, Format : 'YYYY-MM-DD'): Defines the reference date
            based on which all observation dates are expressed. Along with the image
            time series and the target tensor, this dataloader yields the sequence
            of observation dates (in terms of number of days since the reference
            date). This sequence of dates is used for instance for the positional
            encoding in attention based approaches.
        target (str): 'semantic' or 'instance'. Defines which type of target is
            returned by the dataloader.
            * If 'semantic' the target tensor is a tensor containing the class of
            each pixel.
            * If 'instance' the target tensor is the concatenation of several
            signals, necessary to train the Parcel-as-Points module:
                - the centerness heatmap,
                - the instance ids,
                - the voronoi partitioning of the patch with regards to the parcels'
                centers,
                - the (height, width) size of each parcel,
                - the semantic label of each parcel,
                - the semantic label of each pixel.
        folds (list, optional): List of ints specifying which of the 5 official
            folds to load. By default (when None is specified), all folds are loaded.
        class_mapping (dict, optional): A dictionary to define a mapping between the
            default 18 class nomenclature and another class grouping. If not provided,
            the default class mapping is used.
        transform (callable, optional): A transform to apply to the loaded data
            (images, dates, and masks). By default, no transformation is applied.
        truncate_image (int, optional): Truncate the time dimension of the image to
            a specified number of timesteps. If None, no truncation is performed.
        pad_image (int, optional): Pad the time dimension of the image to a specified
            number of timesteps. If None, no padding is applied.
        satellites (list): Defines the satellites to use. If you are using PASTIS-R, you
            have access to Sentinel-2 imagery and Sentinel-1 observations in Ascending
            and Descending orbits, respectively S2, S1A, and S1D. For example, use
            satellites=['S2', 'S1A'] for Sentinel-2 + Sentinel-1 ascending time series,
            or satellites=['S2', 'S1A', 'S1D'] to retrieve all time series. If you are using
            PASTIS, only S2 observations are available.
    """
    if target not in ["semantic", "instance"]:
        msg = f"Target '{target}' not recognized. Use 'semantic', or 'instance'."
        raise ValueError(msg)
    valid_satellites = {"S2", "S1A", "S1D"}
    for sat in satellites:
        if sat not in valid_satellites:
            msg = f"Satellite '{sat}' not recognized. Valid options are {valid_satellites}."
            raise ValueError(msg)

    super().__init__()
    self.data_root = data_root
    self.norm = norm
    self.reference_date = datetime(*map(int, reference_date.split("-")), tzinfo=timezone.utc)
    self.class_mapping = np.vectorize(lambda x: class_mapping[x]) if class_mapping is not None else class_mapping
    self.target = target
    self.satellites = satellites
    self.transform = transform
    self.truncate_image = truncate_image
    self.pad_image = pad_image
    # loads patches metadata
    self.meta_patch = gpd.read_file(os.path.join(data_root, "metadata.geojson"))
    self.meta_patch.index = self.meta_patch["ID_PATCH"].astype(int)
    self.meta_patch.sort_index(inplace=True)
    # stores table for each satalite date
    self.date_tables = {s: None for s in satellites}
    # date interval used in the PASTIS benchmark paper.
    date_interval_begin, date_interval_end = date_interval
    self.date_range = np.array(range(date_interval_begin, date_interval_end))
    for s in satellites:
        # maps patches to its observation dates
        dates = self.meta_patch[f"dates-{s}"]
        date_table = pd.DataFrame(index=self.meta_patch.index, columns=self.date_range, dtype=int)
        for pid, date_seq in dates.items():
            if type(date_seq) is str:
                date_seq = json.loads(date_seq)  # noqa: PLW2901
            # convert date to days since obersavation format
            d = pd.DataFrame().from_dict(date_seq, orient="index")
            d = d[0].apply(
                lambda x: (
                    datetime(int(str(x)[:4]), int(str(x)[4:6]), int(str(x)[6:]), tzinfo=timezone.utc)
                    - self.reference_date
                ).days
            )
            date_table.loc[pid, d.values] = 1
        date_table = date_table.fillna(0)
        self.date_tables[s] = {
            index: np.array(list(d.values())) for index, d in date_table.to_dict(orient="index").items()
        }

    # selects patches correspondig to selected folds
    if folds is not None:
        self.meta_patch = pd.concat([self.meta_patch[self.meta_patch["Fold"] == f] for f in folds])

    self.len = self.meta_patch.shape[0]
    self.id_patches = self.meta_patch.index

    # loads normalization values
    if norm:
        self.norm = {}
        for s in self.satellites:
            with open(os.path.join(data_root, f"NORM_{s}_patch.json")) as file:
                normvals = json.loads(file.read())
            selected_folds = folds if folds is not None else range(1, 6)
            means = [normvals[f"Fold_{f}"]["mean"] for f in selected_folds]
            stds = [normvals[f"Fold_{f}"]["std"] for f in selected_folds]
            self.norm[s] = np.stack(means).mean(axis=0), np.stack(stds).mean(axis=0)
            self.norm[s] = (
                self.norm[s][0],
                self.norm[s][1],
            )
    else:
        self.norm = None

terratorch.datasets.sen1floods11 #

Sen1Floods11NonGeo #

Bases: NonGeoDataset

NonGeo dataset implementation for sen1floods11.

Source code in terratorch/datasets/sen1floods11.py
class Sen1Floods11NonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [sen1floods11](https://github.com/cloudtostreet/Sen1Floods11)."""

    all_band_names = (
            "COASTAL_AEROSOL",
            "BLUE",
            "GREEN",
            "RED",
            "RED_EDGE_1",
            "RED_EDGE_2",
            "RED_EDGE_3",
            "NIR_BROAD",
            "NIR_NARROW",
            "WATER_VAPOR",
            "CIRRUS",
            "SWIR_1",
            "SWIR_2",
    )
    rgb_bands = ("RED", "GREEN", "BLUE")
    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}
    num_classes = 2
    splits = {"train": "train", "val": "valid", "test": "test"}
    data_dir = "v1.1/data/flood_events/HandLabeled/S2Hand"
    label_dir = "v1.1/data/flood_events/HandLabeled/LabelHand"
    split_dir = "v1.1/splits/flood_handlabeled"
    metadata_file = "v1.1/Sen1Floods11_Metadata.geojson"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        constant_scale: float = 0.0001,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,  # noqa: FBT001, FBT002
    ) -> None:
        """Constructor

        Args:
            data_root (str): Path to the data root directory.
            split (str): one of 'train', 'val' or 'test'.
            bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). Defaults to None, which applies ToTensorV2().
            constant_scale (float): Factor to multiply image values by. Defaults to 0.0001.
            no_data_replace (float | None): Replace nan values in input images with this value.
                If None, does no replacement. Defaults to 0.
            no_label_replace (int | None): Replace nan values in label with this value.
                If none, does no replacement. Defaults to -1.
            use_metadata (bool): whether to return metadata info (time and location).
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {self.splits}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
        self.constant_scale = constant_scale
        self.data_root = Path(data_root)

        data_dir = self.data_root / self.data_dir
        label_dir = self.data_root / self.label_dir

        self.image_files = sorted(glob.glob(os.path.join(data_dir, "*_S2Hand.tif")))
        self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_dir, "*_LabelHand.tif")))

        split_file = self.data_root / self.split_dir / f"flood_{split_name}_data.txt"
        with open(split_file) as f:
            split = f.readlines()
        valid_files = {rf"{substring.strip()}" for substring in split}
        self.image_files = filter_valid_files(
            self.image_files,
            valid_files=valid_files,
            ignore_extensions=True,
            allow_substring=True,
        )
        self.segmentation_mask_files = filter_valid_files(
            self.segmentation_mask_files,
            valid_files=valid_files,
            ignore_extensions=True,
            allow_substring=True,
        )

        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.use_metadata = use_metadata
        self.metadata = None
        if self.use_metadata:
            self.metadata = geopandas.read_file(self.data_root / self.metadata_file)

        # If no transform is given, apply only to transform to torch tensor
        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def _get_date(self, index: int) -> torch.Tensor:
        file_name = self.image_files[index]
        location = os.path.basename(file_name).split("_")[0]
        if self.metadata[self.metadata["location"] == location].shape[0] != 1:
            date = pd.to_datetime("13-10-1998", dayfirst=True)
        else:
            date = pd.to_datetime(self.metadata[self.metadata["location"] == location]["s2_date"].item())

        return torch.tensor([[date.year, date.dayofyear - 1]], dtype=torch.float32)  # (n_timesteps, coords)

    def _get_coords(self, image: DataArray) -> torch.Tensor:

        center_lat = image.y[image.y.shape[0] // 2]
        center_lon = image.x[image.x.shape[0] // 2]
        lat_lon = np.asarray([center_lat, center_lon])

        return torch.tensor(lat_lon, dtype=torch.float32)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace)

        location_coords, temporal_coords = None, None
        if self.use_metadata:
            location_coords = self._get_coords(image)
            temporal_coords = self._get_date(index)

        # to channels last
        image = image.to_numpy()
        image = np.moveaxis(image, 0, -1)

        # filter bands
        image = image[..., self.band_indices]

        output = {
            "image": image.astype(np.float32) * self.constant_scale,
            "mask": self._load_file(
                self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[0],
        }
        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        if self.use_metadata:
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        return output

    def _load_file(self, path: Path, nan_replace: int | float | None = None) -> DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            suptitle: optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample
        """
        num_images = 4

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        # RGB -> channels-last
        image = sample["image"][rgb_indices, ...].permute(1, 2, 0).numpy()
        mask = sample["mask"].numpy()

        image = clip_image(image)

        if "prediction" in sample:
            prediction = sample["prediction"]
            num_images += 1
        else:
            prediction = None

        fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")

        ax[0].axis("off")

        norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
        ax[1].axis("off")
        ax[1].title.set_text("Image")
        ax[1].imshow(image)

        ax[2].axis("off")
        ax[2].title.set_text("Ground Truth Mask")
        ax[2].imshow(mask, cmap="jet", norm=norm)

        ax[3].axis("off")
        ax[3].title.set_text("GT Mask on Image")
        ax[3].imshow(image)
        ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

        if "prediction" in sample:
            ax[4].title.set_text("Predicted Mask")
            ax[4].imshow(prediction, cmap="jet", norm=norm)

        cmap = plt.get_cmap("jet")
        legend_data = [[i, cmap(norm(i)), str(i)] for i in range(self.num_classes)]
        handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
        labels = [n for k, c, n in legend_data]
        ax[0].legend(handles, labels, loc="center")
        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, constant_scale=0.0001, no_data_replace=0, no_label_replace=-1, use_metadata=False) #

Constructor

Parameters:

Name Type Description Default
data_root str

Path to the data root directory.

required
split str

one of 'train', 'val' or 'test'.

'train'
bands list[str]

Bands that should be output by the dataset. Defaults to all bands.

BAND_SETS['all']
transform Compose | None

Albumentations transform to be applied. Should end with ToTensorV2(). Defaults to None, which applies ToTensorV2().

None
constant_scale float

Factor to multiply image values by. Defaults to 0.0001.

0.0001
no_data_replace float | None

Replace nan values in input images with this value. If None, does no replacement. Defaults to 0.

0
no_label_replace int | None

Replace nan values in label with this value. If none, does no replacement. Defaults to -1.

-1
use_metadata bool

whether to return metadata info (time and location).

False
Source code in terratorch/datasets/sen1floods11.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    constant_scale: float = 0.0001,
    no_data_replace: float | None = 0,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,  # noqa: FBT001, FBT002
) -> None:
    """Constructor

    Args:
        data_root (str): Path to the data root directory.
        split (str): one of 'train', 'val' or 'test'.
        bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). Defaults to None, which applies ToTensorV2().
        constant_scale (float): Factor to multiply image values by. Defaults to 0.0001.
        no_data_replace (float | None): Replace nan values in input images with this value.
            If None, does no replacement. Defaults to 0.
        no_label_replace (int | None): Replace nan values in label with this value.
            If none, does no replacement. Defaults to -1.
        use_metadata (bool): whether to return metadata info (time and location).
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {self.splits}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
    self.constant_scale = constant_scale
    self.data_root = Path(data_root)

    data_dir = self.data_root / self.data_dir
    label_dir = self.data_root / self.label_dir

    self.image_files = sorted(glob.glob(os.path.join(data_dir, "*_S2Hand.tif")))
    self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_dir, "*_LabelHand.tif")))

    split_file = self.data_root / self.split_dir / f"flood_{split_name}_data.txt"
    with open(split_file) as f:
        split = f.readlines()
    valid_files = {rf"{substring.strip()}" for substring in split}
    self.image_files = filter_valid_files(
        self.image_files,
        valid_files=valid_files,
        ignore_extensions=True,
        allow_substring=True,
    )
    self.segmentation_mask_files = filter_valid_files(
        self.segmentation_mask_files,
        valid_files=valid_files,
        ignore_extensions=True,
        allow_substring=True,
    )

    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.use_metadata = use_metadata
    self.metadata = None
    if self.use_metadata:
        self.metadata = geopandas.read_file(self.data_root / self.metadata_file)

    # If no transform is given, apply only to transform to torch tensor
    self.transform = transform if transform else default_transform
plot(sample, suptitle=None) #

Plot a sample from the dataset.

Parameters:

Name Type Description Default
sample dict[str, Tensor]

a sample returned by :meth:__getitem__

required
suptitle str | None

optional string to use as a suptitle

None

Returns:

Type Description
Figure

a matplotlib Figure with the rendered sample

Source code in terratorch/datasets/sen1floods11.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        suptitle: optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample
    """
    num_images = 4

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    # RGB -> channels-last
    image = sample["image"][rgb_indices, ...].permute(1, 2, 0).numpy()
    mask = sample["mask"].numpy()

    image = clip_image(image)

    if "prediction" in sample:
        prediction = sample["prediction"]
        num_images += 1
    else:
        prediction = None

    fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")

    ax[0].axis("off")

    norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
    ax[1].axis("off")
    ax[1].title.set_text("Image")
    ax[1].imshow(image)

    ax[2].axis("off")
    ax[2].title.set_text("Ground Truth Mask")
    ax[2].imshow(mask, cmap="jet", norm=norm)

    ax[3].axis("off")
    ax[3].title.set_text("GT Mask on Image")
    ax[3].imshow(image)
    ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

    if "prediction" in sample:
        ax[4].title.set_text("Predicted Mask")
        ax[4].imshow(prediction, cmap="jet", norm=norm)

    cmap = plt.get_cmap("jet")
    legend_data = [[i, cmap(norm(i)), str(i)] for i in range(self.num_classes)]
    handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
    labels = [n for k, c, n in legend_data]
    ax[0].legend(handles, labels, loc="center")
    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.sen4agrinet #

Sen4AgriNet #

Bases: NonGeoDataset

Source code in terratorch/datasets/sen4agrinet.py
class Sen4AgriNet(NonGeoDataset):
    def __init__(
        self,
        data_root: str,
        bands: list[str] | None = None,
        scenario: str = "random",
        split: str = "train",
        transform: A.Compose = None,
        truncate_image: int | None = 4,
        pad_image: int | None = 4,
        spatial_interpolate_and_stack_temporally: bool = True,  # noqa: FBT001, FBT002
        seed: int = 42,
    ):
        """
        Pytorch Dataset class to load samples from the [Sen4AgriNet](https://github.com/Orion-AI-Lab/S4A) dataset, supporting
        multiple scenarios for splitting the data.

        Args:
            data_root (str): Root directory of the dataset.
            bands (list of str, optional): List of band names to load. Defaults to all available bands.
            scenario (str): Defines the splitting scenario to use. Options are:
                - 'random': Random split of the data.
                - 'spatial': Split by geographical regions (Catalonia and France).
                - 'spatio-temporal': Split by region and year (France 2019 and Catalonia 2020).
            split (str): Specifies the dataset split. Options are 'train', 'val', or 'test'.
            transform (albumentations.Compose, optional): Albumentations transformations to apply to the data.
            truncate_image (int, optional): Number of timesteps to truncate the time dimension of the image.
                If None, no truncation is applied. Default is 4.
            pad_image (int, optional): Number of timesteps to pad the time dimension of the image.
                If None, no padding is applied. Default is 4.
            spatial_interpolate_and_stack_temporally (bool): Whether to interpolate bands and concatenate them over time
            seed (int): Random seed used for data splitting.
        """
        self.data_root = Path(data_root) / "data"
        self.transform = transform if transform else lambda **batch: to_tensor(batch)
        self.scenario = scenario
        self.seed = seed
        self.truncate_image = truncate_image
        self.pad_image = pad_image
        self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally

        if bands is None:
            bands = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B10", "B11", "B12", "B8A"]
        self.bands = bands

        self.image_files = list(self.data_root.glob("**/*.nc"))

        self.train_files, self.val_files, self.test_files = self.split_data()

        if split == "train":
            self.image_files = self.train_files
        elif split == "val":
            self.image_files = self.val_files
        elif split == "test":
            self.image_files = self.test_files

    def __len__(self):
        return len(self.image_files)

    def split_data(self):
        random.seed(self.seed)

        if self.scenario == "random":
            random.shuffle(self.image_files)
            total_files = len(self.image_files)
            train_split = int(0.6 * total_files)
            val_split = int(0.8 * total_files)

            train_files = self.image_files[:train_split]
            val_files = self.image_files[train_split:val_split]
            test_files = self.image_files[val_split:]

        elif self.scenario == "spatial":
            catalonia_files = [f for f in self.image_files if any(tile in f.stem for tile in CAT_TILES)]
            france_files = [f for f in self.image_files if any(tile in f.stem for tile in FR_TILES)]

            val_split_cat = int(0.2 * len(catalonia_files))
            train_files = catalonia_files[val_split_cat:]
            val_files = catalonia_files[:val_split_cat]
            test_files = france_files

        elif self.scenario == "spatio-temporal":
            france_files = [f for f in self.image_files if any(tile in f.stem for tile in FR_TILES)]
            catalonia_files = [f for f in self.image_files if any(tile in f.stem for tile in CAT_TILES)]

            france_2019_files = [f for f in france_files if "2019" in f.stem]
            catalonia_2020_files = [f for f in catalonia_files if "2020" in f.stem]

            val_split_france_2019 = int(0.2 * len(france_2019_files))
            train_files = france_2019_files[val_split_france_2019:]
            val_files = france_2019_files[:val_split_france_2019]
            test_files = catalonia_2020_files

        return train_files, val_files, test_files

    def __getitem__(self, index: int):
        patch_file = self.image_files[index]

        with h5py.File(patch_file, "r") as patch_data:
            output = {}
            images_over_time = []
            for band in self.bands:
                band_group = patch_data[band]
                band_data = band_group[f"{band}"][:]
                time_vector = band_group["time"][:]

                sorted_indices = np.argsort(time_vector)
                band_data = band_data[sorted_indices].astype(np.float32)

                if self.truncate_image:
                    band_data = band_data[-self.truncate_image :]
                if self.pad_image:
                    band_data = pad_numpy(band_data, self.pad_image)

                if self.spatial_interpolate_and_stack_temporally:
                    band_data = torch.from_numpy(band_data)
                    band_data = band_data.clone().detach()

                    interpolated = F.interpolate(
                        band_data.unsqueeze(0), size=MAX_TEMPORAL_IMAGE_SIZE, mode="bilinear", align_corners=False
                    ).squeeze(0)
                    images_over_time.append(interpolated)
                else:
                    output[band] = band_data

            if self.spatial_interpolate_and_stack_temporally:
                images = torch.stack(images_over_time, dim=0).numpy()
                output["image"] = images

            labels = patch_data["labels"]["labels"][:].astype(int)
            parcels = patch_data["parcels"]["parcels"][:].astype(int)

        output["mask"] = labels

        image_shape = output["image"].shape[-2:]
        mask_shape = output["mask"].shape

        if image_shape != mask_shape:
            diff_h = mask_shape[0] - image_shape[0]
            diff_w = mask_shape[1] - image_shape[1]

            output["image"] = np.pad(
                output["image"],
                [(0, 0), (0, 0), (diff_h // 2, diff_h - diff_h // 2), (diff_w // 2, diff_w - diff_w // 2)],
                mode="constant",
                constant_values=0,
            )

        linear_encoder = {val: i + 1 for i, val in enumerate(sorted(SELECTED_CLASSES))}
        linear_encoder[0] = 0

        output["image"] = output["image"].transpose(0, 2, 3, 1)
        output["mask"] = self.map_mask_to_discrete_classes(output["mask"], linear_encoder)

        if self.transform:
            output = self.transform(**output)

        output["parcels"] = parcels

        return output

    def plot(self, sample, suptitle=None, show_axes=False):
        rgb_bands = ["B04", "B03", "B02"]

        if not all(band in sample for band in rgb_bands):
            warnings.warn("No RGB image.")  # noqa: B028
            return None

        rgb_images = []
        for t in range(sample["B04"].shape[0]):
            rgb_image = torch.stack([sample[band][t] for band in rgb_bands])

            # Normalization
            rgb_min = rgb_image.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values
            rgb_max = rgb_image.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values
            denom = rgb_max - rgb_min
            denom[denom == 0] = 1
            rgb_image = (rgb_image - rgb_min) / denom

            rgb_image = rgb_image.permute(1, 2, 0).numpy()
            rgb_images.append(np.clip(rgb_image, 0, 1))

        dates = torch.arange(sample["B04"].shape[0])

        return self._plot_sample(rgb_images, dates, sample.get("labels"), suptitle=suptitle, show_axes=show_axes)

    def _plot_sample(self, images, dates, labels=None, suptitle=None, show_axes=False):
        num_images = len(images)
        cols = 5
        rows = (num_images + cols - 1) // cols

        fig, ax = plt.subplots(rows, cols, figsize=(20, 4 * rows))
        axes_visibility = "on" if show_axes else "off"

        for i, image in enumerate(images):
            ax[i // cols, i % cols].imshow(image)
            ax[i // cols, i % cols].set_title(f"T{i+1} - Day {dates[i].item()}")
            ax[i // cols, i % cols].axis(axes_visibility)

        if labels is not None:
            if rows * cols > num_images:
                target_ax = ax[(num_images) // cols, (num_images) % cols]
            else:
                fig.add_subplot(rows + 1, 1, 1)
                target_ax = fig.gca()

            target_ax.imshow(labels.numpy(), cmap="tab20")
            target_ax.set_title("Labels")
            target_ax.axis(axes_visibility)

        for k in range(num_images, rows * cols):
            ax[k // cols, k % cols].axis(axes_visibility)

        if suptitle:
            plt.suptitle(suptitle)

        plt.tight_layout()
        plt.show()

    def map_mask_to_discrete_classes(self, mask, encoder):
        map_func = np.vectorize(lambda x: encoder.get(x, 0))
        return map_func(mask)
__init__(data_root, bands=None, scenario='random', split='train', transform=None, truncate_image=4, pad_image=4, spatial_interpolate_and_stack_temporally=True, seed=42) #

Pytorch Dataset class to load samples from the Sen4AgriNet dataset, supporting multiple scenarios for splitting the data.

Parameters:

Name Type Description Default
data_root str

Root directory of the dataset.

required
bands list of str

List of band names to load. Defaults to all available bands.

None
scenario str

Defines the splitting scenario to use. Options are: - 'random': Random split of the data. - 'spatial': Split by geographical regions (Catalonia and France). - 'spatio-temporal': Split by region and year (France 2019 and Catalonia 2020).

'random'
split str

Specifies the dataset split. Options are 'train', 'val', or 'test'.

'train'
transform Compose

Albumentations transformations to apply to the data.

None
truncate_image int

Number of timesteps to truncate the time dimension of the image. If None, no truncation is applied. Default is 4.

4
pad_image int

Number of timesteps to pad the time dimension of the image. If None, no padding is applied. Default is 4.

4
spatial_interpolate_and_stack_temporally bool

Whether to interpolate bands and concatenate them over time

True
seed int

Random seed used for data splitting.

42
Source code in terratorch/datasets/sen4agrinet.py
def __init__(
    self,
    data_root: str,
    bands: list[str] | None = None,
    scenario: str = "random",
    split: str = "train",
    transform: A.Compose = None,
    truncate_image: int | None = 4,
    pad_image: int | None = 4,
    spatial_interpolate_and_stack_temporally: bool = True,  # noqa: FBT001, FBT002
    seed: int = 42,
):
    """
    Pytorch Dataset class to load samples from the [Sen4AgriNet](https://github.com/Orion-AI-Lab/S4A) dataset, supporting
    multiple scenarios for splitting the data.

    Args:
        data_root (str): Root directory of the dataset.
        bands (list of str, optional): List of band names to load. Defaults to all available bands.
        scenario (str): Defines the splitting scenario to use. Options are:
            - 'random': Random split of the data.
            - 'spatial': Split by geographical regions (Catalonia and France).
            - 'spatio-temporal': Split by region and year (France 2019 and Catalonia 2020).
        split (str): Specifies the dataset split. Options are 'train', 'val', or 'test'.
        transform (albumentations.Compose, optional): Albumentations transformations to apply to the data.
        truncate_image (int, optional): Number of timesteps to truncate the time dimension of the image.
            If None, no truncation is applied. Default is 4.
        pad_image (int, optional): Number of timesteps to pad the time dimension of the image.
            If None, no padding is applied. Default is 4.
        spatial_interpolate_and_stack_temporally (bool): Whether to interpolate bands and concatenate them over time
        seed (int): Random seed used for data splitting.
    """
    self.data_root = Path(data_root) / "data"
    self.transform = transform if transform else lambda **batch: to_tensor(batch)
    self.scenario = scenario
    self.seed = seed
    self.truncate_image = truncate_image
    self.pad_image = pad_image
    self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally

    if bands is None:
        bands = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B10", "B11", "B12", "B8A"]
    self.bands = bands

    self.image_files = list(self.data_root.glob("**/*.nc"))

    self.train_files, self.val_files, self.test_files = self.split_data()

    if split == "train":
        self.image_files = self.train_files
    elif split == "val":
        self.image_files = self.val_files
    elif split == "test":
        self.image_files = self.test_files

terratorch.datasets.sen4map #

Sen4MapDatasetMonthlyComposites #

Bases: Dataset

Sen4Map Dataset for Monthly Composites.

Dataset intended for land-cover and crop classification tasks based on monthly composites derived from multi-temporal satellite data stored in HDF5 files.

Dataset Format:

  • HDF5 files containing multi-temporal acquisitions with spectral bands (e.g., B2, B3, …, B12)
  • Composite images computed as the median across available acquisitions for each month.
  • Classification labels provided via HDF5 attributes (e.g., 'lc1') with mappings defined for:
    • Land-cover: using land_cover_classification_map
    • Crops: using crop_classification_map

Dataset Features:

  • Supports two classification tasks: "land-cover" (default) and "crops".
  • Pre-processing options include center cropping, reverse tiling, and resizing.
  • Option to save the keys HDF5 for later filtering.
  • Input channel selection via a mapping between available bands and input bands.
Source code in terratorch/datasets/sen4map.py
class Sen4MapDatasetMonthlyComposites(Dataset):
    """[Sen4Map](https://gitlab.jsc.fz-juelich.de/sdlrs/sen4map-benchmark-dataset) Dataset for Monthly Composites.

    Dataset intended for land-cover and crop classification tasks based on monthly composites
    derived from multi-temporal satellite data stored in HDF5 files.

    Dataset Format:

    * HDF5 files containing multi-temporal acquisitions with spectral bands (e.g., B2, B3, …, B12)
    * Composite images computed as the median across available acquisitions for each month.
    * Classification labels provided via HDF5 attributes (e.g., 'lc1') with mappings defined for:
        - Land-cover: using `land_cover_classification_map`
        - Crops: using `crop_classification_map`

    Dataset Features:

    * Supports two classification tasks: "land-cover" (default) and "crops".
    * Pre-processing options include center cropping, reverse tiling, and resizing.
    * Option to save the keys HDF5 for later filtering.
    * Input channel selection via a mapping between available bands and input bands.


    """
    land_cover_classification_map={'A10':0, 'A11':0, 'A12':0, 'A13':0, 
    'A20':0, 'A21':0, 'A30':0, 
    'A22':1, 'F10':1, 'F20':1, 
    'F30':1, 'F40':1,
    'E10':2, 'E20':2, 'E30':2, 'B50':2, 'B51':2, 'B52':2,
    'B53':2, 'B54':2, 'B55':2,
    'B10':3, 'B11':3, 'B12':3, 'B13':3, 'B14':3, 'B15':3,
    'B16':3, 'B17':3, 'B18':3, 'B19':3, 'B10':3, 'B20':3, 
    'B21':3, 'B22':3, 'B23':3, 'B30':3, 'B31':3, 'B32':3,
    'B33':3, 'B34':3, 'B35':3, 'B30':3, 'B36':3, 'B37':3,
    'B40':3, 'B41':3, 'B42':3, 'B43':3, 'B44':3, 'B45':3,
    'B70':3, 'B71':3, 'B72':3, 'B73':3, 'B74':3, 'B75':3,
    'B76':3, 'B77':3, 'B80':3, 'B81':3, 'B82':3, 'B83':3,
    'B84':3, 
    'BX1':3, 'BX2':3,
    'C10':4, 'C20':5, 'C21':5, 'C22':5,
    'C23':5, 'C30':5, 'C31':5, 'C32':5,
    'C33':5, 
    'CXX1':5, 'CXX2':5, 'CXX3':5, 'CXX4':5, 'CXX5':5,
    'CXX5':5, 'CXX6':5, 'CXX7':5, 'CXX8':5, 'CXX9':5,
    'CXXA':5, 'CXXB':5, 'CXXC':5, 'CXXD':5, 'CXXE':5,
    'D10':6, 'D20':6, 'D10':6,
    'G10':7, 'G11':7, 'G12':7, 'G20':7, 'G21':7, 'G22':7, 'G30':7, 
    'G40':7,
    'G50':7,
    'H10':8, 'H11':8, 'H12':8, 'H11':8,'H20':8, 'H21':8,
    'H22':8, 'H23':8, '': 9}
    #  This dictionary maps the LUCAS classes to crop classes.
    crop_classification_map = {
        "B11":0, "B12":0, "B13":0, "B14":0, "B15":0, "B16":0, "B17":0, "B18":0, "B19":0,  # Cereals
        "B21":1, "B22":1, "B23":1,  # Root Crops
        "B31":2, "B32":2, "B33":2, "B34":2, "B35":2, "B36":2, "B37":2,  # Nonpermanent Industrial Crops
        "B41":3, "B42":3, "B43":3, "B44":3, "B45":3,  # Dry Pulses, Vegetables and Flowers
        "B51":4, "B52":4, "B53":4, "B54":4,  # Fodder Crops
        "F10":5, "F20":5, "F30":5, "F40":5,  # Bareland
        "B71":6, "B72":6, "B73":6, "B74":6, "B75":6, "B76":6, "B77":6, 
        "B81":6, "B82":6, "B83":6, "B84":6, "C10":6, "C21":6, "C22":6, "C23":6, "C31":6, "C32":6, "C33":6, "D10":6, "D20":6,  # Woodland and Shrubland
        "B55":7, "E10":7, "E20":7, "E30":7,  # Grassland
    }

    def __init__(
            self,
            h5py_file_object:h5py.File,
            h5data_keys = None,
            crop_size:None|int = None,
            dataset_bands:list[HLSBands|int]|None = None,
            input_bands:list[HLSBands|int]|None = None,
            resize = False,
            resize_to = [224, 224],
            resize_interpolation = InterpolationMode.BILINEAR,
            resize_antialiasing = True,
            reverse_tile = False,
            reverse_tile_size = 3,
            save_keys_path = None,
            classification_map = "land-cover"
            ):
        """Initialize a new instance of Sen4MapDatasetMonthlyComposites.

        This dataset loads data from an HDF5 file object containing multi-temporal satellite data and computes
        monthly composite images by aggregating acquisitions (via median).

        Args:
            h5py_file_object: An open h5py.File object containing the dataset.
            h5data_keys: Optional list of keys to select a subset of data samples from the HDF5 file.
                If None, all keys are used.
            crop_size: Optional integer specifying the square crop size for the output image.
            dataset_bands: Optional list of bands available in the dataset.
            input_bands: Optional list of bands to be used as input channels.
                Must be provided along with `dataset_bands`.
            resize: Boolean flag indicating whether the image should be resized. Default is False.
            resize_to: Target dimensions [height, width] for resizing. Default is [224, 224].
            resize_interpolation: Interpolation mode used for resizing. Default is InterpolationMode.BILINEAR.
            resize_antialiasing: Boolean flag to apply antialiasing during resizing. Default is True.
            reverse_tile: Boolean flag indicating whether to apply reverse tiling to the image. Default is False.
            reverse_tile_size: Kernel size for the reverse tiling operation. Must be an odd number >= 3. Default is 3.
            save_keys_path: Optional file path to save the list of dataset keys.
            classification_map: String specifying the classification mapping to use ("land-cover" or "crops").
                Default is "land-cover".

        Raises:
            ValueError: If `input_bands` is provided without specifying `dataset_bands`.
            ValueError: If an invalid `classification_map` is provided.
        """
        self.h5data = h5py_file_object
        if h5data_keys is None:
            if classification_map == "crops": print(f"Crop classification task chosen but no keys supplied. Will fail unless dataset hdf5 files have been filtered. Either filter dataset files or create a filtered set of keys.")
            self.h5data_keys = list(self.h5data.keys())
            if save_keys_path is not None:
                with open(save_keys_path, "wb") as file:
                    pickle.dump(self.h5data_keys, file)
        else:
            self.h5data_keys = h5data_keys
        self.crop_size = crop_size
        if input_bands and not dataset_bands:
            raise ValueError(f"input_bands was provided without specifying the dataset_bands")
        # self.dataset_bands = dataset_bands
        # self.input_bands = input_bands
        if input_bands and dataset_bands:
            self.input_channels = [dataset_bands.index(band_ind) for band_ind in input_bands if band_ind in dataset_bands]
        else: self.input_channels = None

        classification_maps = {"land-cover": Sen4MapDatasetMonthlyComposites.land_cover_classification_map,
                               "crops": Sen4MapDatasetMonthlyComposites.crop_classification_map}
        if classification_map not in classification_maps.keys():
            raise ValueError(f"Provided classification_map of: {classification_map}, is not from the list of valid ones: {classification_maps}")
        self.classification_map = classification_maps[classification_map]

        self.resize = resize
        self.resize_to = resize_to
        self.resize_interpolation = resize_interpolation
        self.resize_antialiasing = resize_antialiasing

        self.reverse_tile = reverse_tile
        self.reverse_tile_size = reverse_tile_size

    def __getitem__(self, index):
        # we can call dataset with an index, eg. dataset[0]
        im = self.h5data[self.h5data_keys[index]]
        Image, Label = self.get_data(im)
        Image = self.min_max_normalize(Image, [67.0, 122.0, 93.27, 158.5, 160.77, 174.27, 162.27, 149.0, 84.5, 66.27 ],
                                    [2089.0, 2598.45, 3214.5, 3620.45, 4033.61, 4613.0, 4825.45, 4945.72, 5140.84, 4414.45])

        Image = Image.clip(0,1)
        Label = torch.LongTensor(Label)
        if self.input_channels:
            Image = Image[self.input_channels, ...]

        return {"image":Image, "label":Label}

    def __len__(self):
        return len(self.h5data_keys)

    def get_data(self, im):
        mask = im['SCL'] < 9

        B2= np.where(mask==1, im['B2'], 0)
        B3= np.where(mask==1, im['B3'], 0)
        B4= np.where(mask==1, im['B4'], 0)
        B5= np.where(mask==1, im['B5'], 0)
        B6= np.where(mask==1, im['B6'], 0)
        B7= np.where(mask==1, im['B7'], 0)
        B8= np.where(mask==1, im['B8'], 0)
        B8A= np.where(mask==1, im['B8A'], 0)
        B11= np.where(mask==1, im['B11'], 0)
        B12= np.where(mask==1, im['B12'], 0)
        Image = np.stack((B2,B3,B4,B5,B6,B7,B8,B8A,B11,B12), axis=0, dtype="float32")
        Image = np.moveaxis(Image, [0],[1])
        Image = torch.from_numpy(Image)

        # Composites:
        n1= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201801' in s]
        n2= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201802' in s]
        n3= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201803' in s]
        n4= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201804' in s]
        n5= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201805' in s]
        n6= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201806' in s]
        n7= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201807' in s]
        n8= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201808' in s]
        n9= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201809' in s]
        n10= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201810' in s]
        n11= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201811' in s]
        n12= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201812' in s]


        Jan= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n1 else n1
        Feb= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n2 else n2
        Mar= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n3 else n3
        Apr= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n4 else n4
        May= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n5 else n5
        Jun= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n6 else n6
        Jul= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n7 else n7
        Aug= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n8 else n8
        Sep= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n9 else n9
        Oct= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n10 else n10
        Nov= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n11 else n11
        Dec= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n12 else n12

        month_indices = [Jan, Feb, Mar, Apr, May, Jun, Jul, Aug, Sep, Oct, Nov, Dec]

        month_medians = [torch.stack([Image[month_indices[i][j]] for j in range(len(month_indices[i]))]).median(dim=0).values for i in range(12)]


        Image = torch.stack(month_medians, dim=0)
        Image = torch.moveaxis(Image, 0, 1)

        if self.crop_size: Image = self.crop_center(Image, self.crop_size, self.crop_size)
        if self.reverse_tile:
            Image = self.reverse_tiling_pytorch(Image, kernel_size=self.reverse_tile_size)
        if self.resize:
            Image = resize(Image, size=self.resize_to, interpolation=self.resize_interpolation, antialias=self.resize_antialiasing)

        Label = im.attrs['lc1']
        Label = self.classification_map[Label]
        Label = np.array(Label)
        Label = Label.astype('float32')

        return Image, Label

    def crop_center(self, img_b:torch.Tensor, cropx, cropy) -> torch.Tensor:
        c, t, y, x = img_b.shape
        startx = x//2-(cropx//2)
        starty = y//2-(cropy//2)    
        return img_b[0:c, 0:t, starty:starty+cropy, startx:startx+cropx]


    def reverse_tiling_pytorch(self, img_tensor: torch.Tensor, kernel_size: int=3):
        """
        Upscales an image where every pixel is expanded into `kernel_size`*`kernel_size` pixels.
        Used to test whether the benefit of resizing images to the pre-trained size comes from the bilnearly interpolated pixels,
        or if the same would be realized with no interpolated pixels.
        """
        assert kernel_size % 2 == 1
        assert kernel_size >= 3
        padding = (kernel_size - 1) // 2
        # img_tensor shape: (batch_size, channels, H, W)
        batch_size, channels, H, W = img_tensor.shape
        # Unfold: Extract 3x3 patches with padding of 1 to cover borders
        img_tensor = F.pad(img_tensor, pad=(padding,padding,padding,padding), mode="replicate")
        patches = F.unfold(img_tensor, kernel_size=kernel_size, padding=0)  # Shape: (batch_size, channels*9, H*W)
        # Reshape to organize the 9 values from each 3x3 neighborhood
        patches = patches.view(batch_size, channels, kernel_size*kernel_size, H, W)  # Shape: (batch_size, channels, 9, H, W)
        # Rearrange the patches into (batch_size, channels, 3, 3, H, W)
        patches = patches.view(batch_size, channels, kernel_size, kernel_size, H, W)
        # Permute to have the spatial dimensions first and unfold them
        patches = patches.permute(0, 1, 4, 2, 5, 3)  # Shape: (batch_size, channels, H, 3, W, 3)
        # Reshape to get the final expanded image of shape (batch_size, channels, H*3, W*3)
        expanded_img = patches.reshape(batch_size, channels, H * kernel_size, W * kernel_size)
        return expanded_img

    def min_max_normalize(self, tensor:torch.Tensor, q_low:list[float], q_hi:list[float]) -> torch.Tensor:
        dtype = tensor.dtype
        q_low = torch.as_tensor(q_low, dtype=dtype, device=tensor.device)
        q_hi = torch.as_tensor(q_hi, dtype=dtype, device=tensor.device)
        x = torch.tensor(-12.0)
        y = torch.exp(x)
        tensor.sub_(q_low[:, None, None, None]).div_((q_hi[:, None, None, None].sub_(q_low[:, None, None, None])).add(y))
        return tensor
__init__(h5py_file_object, h5data_keys=None, crop_size=None, dataset_bands=None, input_bands=None, resize=False, resize_to=[224, 224], resize_interpolation=InterpolationMode.BILINEAR, resize_antialiasing=True, reverse_tile=False, reverse_tile_size=3, save_keys_path=None, classification_map='land-cover') #

Initialize a new instance of Sen4MapDatasetMonthlyComposites.

This dataset loads data from an HDF5 file object containing multi-temporal satellite data and computes monthly composite images by aggregating acquisitions (via median).

Parameters:

Name Type Description Default
h5py_file_object File

An open h5py.File object containing the dataset.

required
h5data_keys

Optional list of keys to select a subset of data samples from the HDF5 file. If None, all keys are used.

None
crop_size None | int

Optional integer specifying the square crop size for the output image.

None
dataset_bands list[HLSBands | int] | None

Optional list of bands available in the dataset.

None
input_bands list[HLSBands | int] | None

Optional list of bands to be used as input channels. Must be provided along with dataset_bands.

None
resize

Boolean flag indicating whether the image should be resized. Default is False.

False
resize_to

Target dimensions [height, width] for resizing. Default is [224, 224].

[224, 224]
resize_interpolation

Interpolation mode used for resizing. Default is InterpolationMode.BILINEAR.

BILINEAR
resize_antialiasing

Boolean flag to apply antialiasing during resizing. Default is True.

True
reverse_tile

Boolean flag indicating whether to apply reverse tiling to the image. Default is False.

False
reverse_tile_size

Kernel size for the reverse tiling operation. Must be an odd number >= 3. Default is 3.

3
save_keys_path

Optional file path to save the list of dataset keys.

None
classification_map

String specifying the classification mapping to use ("land-cover" or "crops"). Default is "land-cover".

'land-cover'

Raises:

Type Description
ValueError

If input_bands is provided without specifying dataset_bands.

ValueError

If an invalid classification_map is provided.

Source code in terratorch/datasets/sen4map.py
def __init__(
        self,
        h5py_file_object:h5py.File,
        h5data_keys = None,
        crop_size:None|int = None,
        dataset_bands:list[HLSBands|int]|None = None,
        input_bands:list[HLSBands|int]|None = None,
        resize = False,
        resize_to = [224, 224],
        resize_interpolation = InterpolationMode.BILINEAR,
        resize_antialiasing = True,
        reverse_tile = False,
        reverse_tile_size = 3,
        save_keys_path = None,
        classification_map = "land-cover"
        ):
    """Initialize a new instance of Sen4MapDatasetMonthlyComposites.

    This dataset loads data from an HDF5 file object containing multi-temporal satellite data and computes
    monthly composite images by aggregating acquisitions (via median).

    Args:
        h5py_file_object: An open h5py.File object containing the dataset.
        h5data_keys: Optional list of keys to select a subset of data samples from the HDF5 file.
            If None, all keys are used.
        crop_size: Optional integer specifying the square crop size for the output image.
        dataset_bands: Optional list of bands available in the dataset.
        input_bands: Optional list of bands to be used as input channels.
            Must be provided along with `dataset_bands`.
        resize: Boolean flag indicating whether the image should be resized. Default is False.
        resize_to: Target dimensions [height, width] for resizing. Default is [224, 224].
        resize_interpolation: Interpolation mode used for resizing. Default is InterpolationMode.BILINEAR.
        resize_antialiasing: Boolean flag to apply antialiasing during resizing. Default is True.
        reverse_tile: Boolean flag indicating whether to apply reverse tiling to the image. Default is False.
        reverse_tile_size: Kernel size for the reverse tiling operation. Must be an odd number >= 3. Default is 3.
        save_keys_path: Optional file path to save the list of dataset keys.
        classification_map: String specifying the classification mapping to use ("land-cover" or "crops").
            Default is "land-cover".

    Raises:
        ValueError: If `input_bands` is provided without specifying `dataset_bands`.
        ValueError: If an invalid `classification_map` is provided.
    """
    self.h5data = h5py_file_object
    if h5data_keys is None:
        if classification_map == "crops": print(f"Crop classification task chosen but no keys supplied. Will fail unless dataset hdf5 files have been filtered. Either filter dataset files or create a filtered set of keys.")
        self.h5data_keys = list(self.h5data.keys())
        if save_keys_path is not None:
            with open(save_keys_path, "wb") as file:
                pickle.dump(self.h5data_keys, file)
    else:
        self.h5data_keys = h5data_keys
    self.crop_size = crop_size
    if input_bands and not dataset_bands:
        raise ValueError(f"input_bands was provided without specifying the dataset_bands")
    # self.dataset_bands = dataset_bands
    # self.input_bands = input_bands
    if input_bands and dataset_bands:
        self.input_channels = [dataset_bands.index(band_ind) for band_ind in input_bands if band_ind in dataset_bands]
    else: self.input_channels = None

    classification_maps = {"land-cover": Sen4MapDatasetMonthlyComposites.land_cover_classification_map,
                           "crops": Sen4MapDatasetMonthlyComposites.crop_classification_map}
    if classification_map not in classification_maps.keys():
        raise ValueError(f"Provided classification_map of: {classification_map}, is not from the list of valid ones: {classification_maps}")
    self.classification_map = classification_maps[classification_map]

    self.resize = resize
    self.resize_to = resize_to
    self.resize_interpolation = resize_interpolation
    self.resize_antialiasing = resize_antialiasing

    self.reverse_tile = reverse_tile
    self.reverse_tile_size = reverse_tile_size
reverse_tiling_pytorch(img_tensor, kernel_size=3) #

Upscales an image where every pixel is expanded into kernel_size*kernel_size pixels. Used to test whether the benefit of resizing images to the pre-trained size comes from the bilnearly interpolated pixels, or if the same would be realized with no interpolated pixels.

Source code in terratorch/datasets/sen4map.py
def reverse_tiling_pytorch(self, img_tensor: torch.Tensor, kernel_size: int=3):
    """
    Upscales an image where every pixel is expanded into `kernel_size`*`kernel_size` pixels.
    Used to test whether the benefit of resizing images to the pre-trained size comes from the bilnearly interpolated pixels,
    or if the same would be realized with no interpolated pixels.
    """
    assert kernel_size % 2 == 1
    assert kernel_size >= 3
    padding = (kernel_size - 1) // 2
    # img_tensor shape: (batch_size, channels, H, W)
    batch_size, channels, H, W = img_tensor.shape
    # Unfold: Extract 3x3 patches with padding of 1 to cover borders
    img_tensor = F.pad(img_tensor, pad=(padding,padding,padding,padding), mode="replicate")
    patches = F.unfold(img_tensor, kernel_size=kernel_size, padding=0)  # Shape: (batch_size, channels*9, H*W)
    # Reshape to organize the 9 values from each 3x3 neighborhood
    patches = patches.view(batch_size, channels, kernel_size*kernel_size, H, W)  # Shape: (batch_size, channels, 9, H, W)
    # Rearrange the patches into (batch_size, channels, 3, 3, H, W)
    patches = patches.view(batch_size, channels, kernel_size, kernel_size, H, W)
    # Permute to have the spatial dimensions first and unfold them
    patches = patches.permute(0, 1, 4, 2, 5, 3)  # Shape: (batch_size, channels, H, 3, W, 3)
    # Reshape to get the final expanded image of shape (batch_size, channels, H*3, W*3)
    expanded_img = patches.reshape(batch_size, channels, H * kernel_size, W * kernel_size)
    return expanded_img