Prithvi EO Models#
Code examples and more details are available in the Prithvi-EO 2.0 GitHub repo.
Model Versions#
Available model names:
Models with the _tl
suffix support additional time and location metadata inputs. See Metadata Inputs.
These models were pre-trained on the following bands:
BLUE
, GREEN
, RED
, NIR_NARROW
, SWIR_1
, SWIR_2
Usage#
You can build the backbone using the BACKBONE_REGISTRY
.
Optionally, specify a subset or new band names using a list. Unknown bands will have their patch embeddings initialized with random weights.
For multi-temporal task, specify the number of input frames.
from terratorch.registry import BACKBONE_REGISTRY
model = BACKBONE_REGISTRY.build(
"prithvi_eo_v2_300", pretrained=True,
bands=["RED", "GREEN", "BLUE", "NEW"], # Optional, specify bands
num_frames=1, # Optional, number of time steps (default: 1)
)
Fine-tuning#
Use Prithvi EO as a backbone in TerraTorch's EncoderDecoderFactory:
Backbone output: list of tensors with shape [batch, token, embedding]
(includes CLS token).
For hierarchical decoders such as UNet, use the following necks:
model_args:
...
necks:
- name: ReshapeTokensToImage # Reshape 1D tokens to 2D grid
- name: SelectIndices # Select three intermediate layer outputs and the final one
# indices: [2, 5, 8, 11] # 100M model
indices: [5, 11, 17, 23] # 300M model
# indices: [7, 15, 23, 31] # 600M model
- name: LearnedInterpolateToPyramidal # Upscale outputs for hierarchical decoders
...
Full example: burn_scars.yaml
Metadata Inputs#
Metadata is optional and supported only by _tl
models. During pre-training, metadata was dropped in 10% of the samples, so the model is robust to missing metadata.
Specify metadata usage with:
During inference, pass the metadata inputs like so:
output = model(
data_tensor,
temporal_coords=time_data, # Shape: [B, T, 2] — year, day of year (0–364)
location_coords=loc_data, # Shape: [B, 2] — longitude, latitude
)
Metadata example using pandas
and torch
:
date = pd.to_datetime('2024-06-15')
time_data = torch.Tensor([[[date.year, date.dayofyear - 1]]], device=device) # [1, 1, 2]
loc_data = torch.Tensor([[47.309, 8.544]], device=device) # [1, 2]
Warning
Metadata is currently not supported with the generic data modules. You are required to use a custom data module and dataset class, e.g., by modifying one listed in Datamodules.
The TerraTorch task automatically passes all additional values in the batch dict to the model. In your custom dataset class, add the metadata as additional values to the dict:
def __getitem__(idx):
...
# Load metadata from
date: str = '2024-06-15' # Example for a date
lon, lat = 47.309, 8.544 # Example for a location
date = pd.to_datetime(date)
time_data = torch.Tensor([[date.year, date.dayofyear - 1]]) # Shape [T, 2]
loc_data = torch.Tensor([lon, lat]) # Shape [2]
...
sample = {
"image": data,
"mask": mask,
"temporal_coords": time_data,
"location_coords": loc_data
}
return sample