datamodule=MultiTemporalCropClassificationDataModule(batch_size=8,num_workers=2,data_root=dataset_path,train_transform=[terratorch.datasets.transforms.FlattenTemporalIntoChannels(),# Required for temporal dataalbumentations.D4(),# Random flips and rotationalbumentations.pytorch.transforms.ToTensorV2(),terratorch.datasets.transforms.UnflattenTemporalFromChannels(n_timesteps=3),],val_transform=None,# Using ToTensor() by defaulttest_transform=None,expand_temporal_dimension=True,use_metadata=False,# The crop dataset has metadata for location and timereduce_zero_label=True,)# Setup train and val datasetsdatamodule.setup("predict")
checkpoint_callback=pl.callbacks.ModelCheckpoint(dirpath="output/multicrop/checkpoints/",mode="max",monitor="val/Multiclass_Jaccard_Index",# Variable to monitorfilename="best-{epoch:02d}",)trainer=pl.Trainer(accelerator="auto",strategy="auto",devices=1,# Lightning multi-gpu often fails in notebooksprecision='bf16-mixed',# Speed up trainingnum_nodes=1,logger=True,# Uses TensorBoard by defaultmax_epochs=1,# For demoslog_every_n_steps=5,enable_checkpointing=True,callbacks=[checkpoint_callback,pl.callbacks.RichProgressBar()],default_root_dir="output/multicrop",)
model=terratorch.tasks.SemanticSegmentationTask(model_factory="EncoderDecoderFactory",model_args={# Backbone"backbone":"prithvi_eo_v2_300","backbone_pretrained":True,"backbone_num_frames":3,"backbone_bands":["BLUE","GREEN","RED","NIR_NARROW","SWIR_1","SWIR_2"],"backbone_coords_encoding":[],# use ["time", "location"] for time and location metadata# Necks "necks":[{"name":"SelectIndices","indices":[5,11,17,23]},{"name":"ReshapeTokensToImage","effective_time_dim":3},{"name":"LearnedInterpolateToPyramidal"},],# Decoder"decoder":"UNetDecoder","decoder_channels":[512,256,128,64],# Head"head_dropout":0.1,"num_classes":13,},loss="ce",lr=1e-4,optimizer="AdamW",ignore_index=-1,freeze_backbone=True,freeze_decoder=False,plot_on_val=True,)
Predicting for some samples in the prediction dataset.#
preds=trainer.predict(model,datamodule=datamodule,ckpt_path=best_ckpt_100_epoch_path)# get data data_loader=trainer.predict_dataloadersbatch=next(iter(data_loader))BATCH_SIZE=8foriinrange(BATCH_SIZE):sample={key:batch[key][i]forkeyinbatch}sample["prediction"]=preds[0][0][i].cpu().numpy()datamodule.predict_dataset.plot(sample)