| ---
|
| language: en
|
| license: apache-2.0
|
| tags:
|
| - earth-observation
|
| - geospatial
|
| - flood-detection
|
| - prithvi
|
| - pytorch
|
| - semantic-segmentation
|
| - sentinel-2
|
| datasets:
|
| - sen1floods11
|
| metrics:
|
| - iou
|
| - accuracy
|
| - f1
|
| model-index:
|
| - name: Prithvi-2.0 300M - Flood Detection (Sen1Floods11)
|
| results:
|
| - task:
|
| type: semantic-segmentation
|
| name: Semantic Segmentation
|
| dataset:
|
| name: Sen1Floods11
|
| type: sen1floods11
|
| split: test
|
| metrics:
|
| - type: iou
|
| value: 0.7196
|
| name: Flood IoU
|
| - type: f1
|
| value: 0.8370
|
| name: Flood F1
|
| - type: accuracy
|
| value: 0.9633
|
| name: Overall Accuracy
|
| ---
|
|
|
| # Prithvi-2.0 300M - Fine-tuned for Flood Detection (Sen1Floods11)
|
|
|
| This model is a fine-tuned version of the **Prithvi-2.0 300M** foundation model, specialized for binary flood detection using Sentinel-2 optical imagery.
|
|
|
| ## Model Description
|
|
|
| - **Developed by:** Tushar Thokdar
|
| - **Model Type:** Semantic Segmentation
|
| - **Backbone:** Prithvi-2.0 300M (ViT-Base)
|
| - **Segmentation Head:** UPerNet
|
| - **Input Resolution:** 224x224
|
| - **Input Bands:** 6 (Red, Green, Blue, Narrow NIR, SWIR 1, SWIR 2)
|
| - **Fine-tuned on:** Sen1Floods11 dataset
|
|
|
| ## Performance Metrics (Official Test Split)
|
|
|
| Metrics derived from fine-tuning for 80 epochs on the official Sen1Floods11 test split.
|
|
|
| ### Model Metrics
|
|
|
| | Metric | Fine-Tuned (80 Epochs) | Baseline | Gain (%) |
|
| | -------------------- | ---------------------- | -------- | -------- |
|
| | **Flood IoU** | 0.7196 | 0.1339 | +437.3% |
|
| | **Flood F1** | 0.8370 | 0.2362 | +254.3% |
|
| | **Flood Precision** | 0.8902 | 0.1638 | +443.4% |
|
| | **Flood Recall** | 0.7897 | 0.4234 | +86.5% |
|
| | **Mean IoU** | 0.8396 | 0.3953 | +112.3% |
|
| | **Overall Accuracy** | 0.9633 | 0.6741 | +42.9% |
|
|
|
| ### Per-Class Metrics (Fine-Tuned)
|
|
|
| - **No Flood IoU:** 0.9595
|
| - **No Flood F1:** 0.9793
|
| - **Flood IoU:** 0.7196
|
| - **Flood F1:** 0.8370
|
|
|
| ## Training Configuration
|
|
|
| - **Epochs:** 80
|
| - **Batch Size:** 16
|
| - **Learning Rate:** 5e-5
|
| - **Loss:** Dice (0.5) + Focal (0.5)
|
| - **Data Splits:** Official (252 Train / 89 Val / 90 Test)
|
|
|
| ## Inference Performance
|
|
|
| - **Throughput:** 20.66 samples/sec (NVIDIA T4)
|
| - **Inference Time (Avg):** 0.048s per sample
|
|
|
| ## Usage Instructions
|
|
|
| To use this model with the `godel-train` library:
|
|
|
| ```python
|
| import torch
|
| from godel_train.models.factory import ModelFactory
|
|
|
| # 1. Initialize model with appropriate config
|
| model = ModelFactory.segmentation(
|
| backbone="prithvi_eo_v2_300m",
|
| num_classes=2,
|
| checkpoint_path="pytorch_model.bin" # Local or downloaded file
|
| )
|
| model.eval()
|
|
|
| # 2. Prepare sample input (Batch, 6 Bands, 224, 224)
|
| # Bands: Red, Green, Blue, Narrow NIR, SWIR 1, SWIR 2
|
| sample_input = torch.randn(1, 6, 224, 224)
|
|
|
| # 3. Run Inference
|
| with torch.no_grad():
|
| prediction = model(sample_input)
|
| # prediction shape: [1, 2, 224, 224] (Logits)
|
|
|
| mask = torch.argmax(prediction, dim=1)
|
| # mask shape: [1, 224, 224] (0: No Flood, 1: Flood)
|
| ```
|
|
|
| ## Data and Credits
|
|
|
| - **Dataset:** Sen1Floods11
|
| - **Fine-tuning:** Performed by Tushar Thokdar
|
| - **Foundation Model:** IBM/NASA Prithvi-2.0
|
|
|