metadata
language: en
license: apache-2.0
tags:
- earth-observation
- geospatial
- flood-detection
- prithvi
- pytorch
- semantic-segmentation
- sentinel-2
datasets:
- sen1floods11
metrics:
- iou
- accuracy
- f1
model-index:
- name: Prithvi-2.0 300M - Flood Detection (Sen1Floods11)
results:
- task:
type: semantic-segmentation
name: Semantic Segmentation
dataset:
name: Sen1Floods11
type: sen1floods11
split: test
metrics:
- type: iou
value: 0.7196
name: Flood IoU
- type: f1
value: 0.837
name: Flood F1
- type: accuracy
value: 0.9633
name: Overall Accuracy
Prithvi-2.0 300M - Fine-tuned for Flood Detection (Sen1Floods11)
This model is a fine-tuned version of the Prithvi-2.0 300M foundation model, specialized for binary flood detection using Sentinel-2 optical imagery.
Model Description
- Developed by: Tushar Thokdar
- Model Type: Semantic Segmentation
- Backbone: Prithvi-2.0 300M (ViT-Base)
- Segmentation Head: UPerNet
- Input Resolution: 224x224
- Input Bands: 6 (Red, Green, Blue, Narrow NIR, SWIR 1, SWIR 2)
- Fine-tuned on: Sen1Floods11 dataset
Performance Metrics (Official Test Split)
Metrics derived from fine-tuning for 80 epochs on the official Sen1Floods11 test split.
Model Metrics
| Metric | Fine-Tuned (80 Epochs) | Baseline | Gain (%) |
|---|---|---|---|
| Flood IoU | 0.7196 | 0.1339 | +437.3% |
| Flood F1 | 0.8370 | 0.2362 | +254.3% |
| Flood Precision | 0.8902 | 0.1638 | +443.4% |
| Flood Recall | 0.7897 | 0.4234 | +86.5% |
| Mean IoU | 0.8396 | 0.3953 | +112.3% |
| Overall Accuracy | 0.9633 | 0.6741 | +42.9% |
Per-Class Metrics (Fine-Tuned)
- No Flood IoU: 0.9595
- No Flood F1: 0.9793
- Flood IoU: 0.7196
- Flood F1: 0.8370
Training Configuration
- Epochs: 80
- Batch Size: 16
- Learning Rate: 5e-5
- Loss: Dice (0.5) + Focal (0.5)
- Data Splits: Official (252 Train / 89 Val / 90 Test)
Inference Performance
- Throughput: 20.66 samples/sec (NVIDIA T4)
- Inference Time (Avg): 0.048s per sample
Usage Instructions
To use this model with the godel-train library:
import torch
from godel_train.models.factory import ModelFactory
# 1. Initialize model with appropriate config
model = ModelFactory.segmentation(
backbone="prithvi_eo_v2_300m",
num_classes=2,
checkpoint_path="pytorch_model.bin" # Local or downloaded file
)
model.eval()
# 2. Prepare sample input (Batch, 6 Bands, 224, 224)
# Bands: Red, Green, Blue, Narrow NIR, SWIR 1, SWIR 2
sample_input = torch.randn(1, 6, 224, 224)
# 3. Run Inference
with torch.no_grad():
prediction = model(sample_input)
# prediction shape: [1, 2, 224, 224] (Logits)
mask = torch.argmax(prediction, dim=1)
# mask shape: [1, 224, 224] (0: No Flood, 1: Flood)
Data and Credits
- Dataset: Sen1Floods11
- Fine-tuning: Performed by Tushar Thokdar
- Foundation Model: IBM/NASA Prithvi-2.0