Sat-JEPA-Diff (Caption-Guided Zero-RGB)
Caption-Guided Zero-RGB Satellite Image Forecasting via Self-Supervised Diffusion
Submitted to IEEE Geoscience and Remote Sensing Letters (GRSL)
Model Description
This is the Caption-Guided Zero-RGB version of Sat-JEPA-Diff. It advances the original model by completely eliminating the 32×32 coarse RGB input, replacing it with a Triple-Captioning system that textually drives temporal changes.
The model predicts future satellite images (t → t+1) through a dual-stream mechanism:
- I-JEPA spatial anchor encodes the current image and predicts future latent embeddings — dictating where changes manifest
- Triple-Captioning system provides hierarchical text conditioning — describing what changes:
- Informative caption (t): dense pixel-level description
- Geometric caption (t): spatial layout and topology
- Semantic caption (t+1): forecasted future state via a learned Caption Forecaster
- Multi-Stream Conditioning Adapter fuses all four signals (3 caption embeddings + IJEPA tokens) via cross-attention
- Frozen Stable Diffusion 3.5 Medium + LoRA generates the final high-fidelity RGB imagery
This Zero-RGB paradigm bypasses physical sensor blindness — the model never sees past pixels during generation, relying entirely on semantic and latent representations.
What Changed from the Original Model
| Original (coarse-RGB) | This model (Zero-RGB) | |
|---|---|---|
| RGB conditioning | 32×32 downsampled input | None |
| Text conditioning | None | Triple-caption (informative + geometric + semantic) |
| Caption forecasting | N/A | Cross-attention forecaster predicts semantic caption at t+1 |
| Adapter | Single-stream | Multi-stream (4-signal fusion + temporal attention) |
| LPIPS ↓ | 0.4449 | 0.3781 (+15% perceptual improvement) |
| GSSIM ↑ | 0.8984 | 0.8974 (maintained) |
Key Results
| Model | L1 ↓ | MSE ↓ | PSNR ↑ | SSIM ↑ | GSSIM ↑ | LPIPS ↓ | FID ↓ |
|---|---|---|---|---|---|---|---|
| Deterministic Baselines | |||||||
| Default | 0.0131 | 0.0008 | 37.52 | 0.9361 | 0.7858 | 0.0708 | 0.6959 |
| PredRNN | 0.0117 | 0.0005 | 38.38 | 0.9476 | 0.7836 | 0.0726 | 9.9720 |
| SimVP v2 | 0.0131 | 0.0006 | 37.63 | 0.9391 | 0.7719 | 0.0928 | 18.7208 |
| Generative Models | |||||||
| SD 3.5 | 0.0175 | 0.0005 | 32.98 | 0.8398 | 0.8711 | 0.4528 | 0.1533 |
| MCVD | 0.0314 | 0.0031 | 31.28 | 0.8637 | 0.7665 | 0.1890 | 0.1956 |
| Sat-JEPA-Diff (coarse RGB) | 0.0158 | 0.0004 | 33.81 | 0.8672 | 0.8984 | 0.4449 | 0.1475 |
| Sat-JEPA-Diff (Caption Guided) | 0.0267 | 0.0025 | 28.95 | 0.7806 | 0.8974 | 0.3781 | 0.2642 |
Despite using zero optical inputs, the caption-guided model achieves 15% better perceptual realism (LPIPS) while maintaining topological integrity (GSSIM).
Architecture Details
| Component | Specification |
|---|---|
| IJEPA Encoder | ViT-Base, patch size 8, input 128×128 |
| IJEPA Predictor | 6-layer transformer, embed dim 384 |
| Projection Head | Linear, 768 → 64 |
| Caption Forecaster | 3-layer cross-attention, hidden dim 1024, gated residual delta |
| Multi-Stream Adapter | 4-signal fusion + temporal attention, ~25M params |
| Diffusion Backbone | SD 3.5 Medium (frozen) + LoRA (rank 16, alpha 32) |
| Text Encoders | CLIP ×2 + T5 (frozen, used offline for precomputation) |
| VAE | 8× spatial compression, 16 latent channels |
How to Use
Note: This is a custom PyTorch checkpoint, not a standard
transformersmodel. You need the source code from the GitHub repository to load and run inference.
1. Setup
git clone https://github.com/VU-AIML/SAT-JEPA-DIFF.git
cd SAT-JEPA-DIFF
git checkout caption-guided
conda create -n satjepa python=3.12
conda activate satjepa
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
pip install diffusers transformers peft accelerate
pip install rasterio matplotlib pyyaml lpips
You need a Hugging Face token with access to Stable Diffusion 3.5 Medium:
export HF_TOKEN=hf_your_token_here
2. Download the Checkpoint
from huggingface_hub import hf_hub_download
checkpoint_path = hf_hub_download(
repo_id="kursatkomurcu/SAT-JEPA-DIFF-Caption-Guided",
filename="s2_future_jepa-best.pth.tar",
)
Or via CLI:
huggingface-cli download kursatkomurcu/SAT-JEPA-DIFF-Caption-Guided \
s2_future_jepa-best.pth.tar --local-dir ./checkpoints
3. Run Inference
cd SAT-JEPA-DIFF/src
python inference.py \
--checkpoint /path/to/s2_future_jepa-best.pth.tar \
--output_dir ./results \
--diffusion_steps 20 \
--noise_strength 0.35
4. Programmatic Usage
import torch
from helper import init_model
from sd_models import load_sd_model, encode_caption_batch
from sd_joint_loss import diffusion_sample
from caption_forecaster import CaptionForecaster
device = torch.device("cuda")
# Load all components from checkpoint
encoder, predictor, sd_state, embed_dim = load_model(
checkpoint_path="path/to/s2_future_jepa-best.pth.tar",
device=device,
)
# Load caption forecaster
caption_forecaster = CaptionForecaster(
ijepa_dim=768, text_dim=4096, hidden_dim=1024, num_layers=3
).to(device)
# Load a Sentinel-2 GeoTIFF
rgb_t = load_and_resize_tif("path/to/sentinel2_image.tif", target_size=128)
# Predict with caption conditioning
rgb_t1_pred = predict_next_frame(
rgb_t=rgb_t,
encoder=encoder,
predictor=predictor,
sd_state=sd_state,
caption_forecaster=caption_forecaster,
device=device,
num_diffusion_steps=20,
noise_strength=0.35,
)
# rgb_t1_pred: (3, 128, 128) tensor in [0, 1]
Checkpoint Contents
The .pth.tar file contains:
| Key | Description |
|---|---|
encoder |
IJEPA ViT encoder state dict |
predictor |
IJEPA predictor state dict |
target_encoder |
EMA target encoder state dict |
proj_head |
Projection head (768 → 64) state dict |
cond_adapter |
Multi-stream caption conditioning adapter state dict |
caption_forecaster |
Caption forecaster (cross-attention + gated delta) state dict |
lora_state_dict |
LoRA weights for SD 3.5 UNet |
prompt_embeds |
Fallback text prompt embeddings |
pooled_prompt_embeds |
Fallback pooled prompt embeddings |
config |
Training configuration dict |
epoch |
Training epoch |
optimizer |
Optimizer state |
Caption Pipeline
The model requires precomputed caption embeddings. The full pipeline is documented in the GitHub repository:
- Generate captions with EarthDial (
generate_captions.py) - Precompute embeddings with frozen CLIP+T5 (
precompute_caption_embeddings.py) - Train or run inference — text encoders are never loaded during training/inference
Training
See the GitHub repository for full training instructions. The model jointly optimizes three losses:
ℒ_total = ℒ_IJEPA + λ_sd · ℒ_diff + λ_cap · ℒ_caption
Training takes approximately 5 days on a single NVIDIA RTX 5090 (24GB).
Dataset
Sentinel-2 RGB imagery (10m GSD) paired with Alpha Earth Foundation Model embeddings and EarthDial captions across 100 global Regions of Interest (2017–2024). Available on Zenodo.
Requirements
- Python 3.12+
- PyTorch 2.0+ with CUDA
- ~24GB GPU VRAM (RTX 3090/4090/5090 or A100)
- Stable Diffusion 3.5 Medium access via HF token
Citation
@article{komurcu2026satjepadiff_caption,
title={Sat-{JEPA}-Diff: Caption-Guided Zero-{RGB} Satellite Image Forecasting via Self-Supervised Diffusion},
author={K{\"o}m{\"u}rc{\"u}, K{\"u}r{\c{s}}at and Petkevicius, Linas},
journal={IEEE Geoscience and Remote Sensing Letters},
year={2026},
note={Submitted}
}
@inproceedings{komurcu2026satjepadiff,
title={Sat-{JEPA}-Diff: Bridging Self-Supervised Learning and Generative Diffusion for Remote Sensing},
author={Kursat Komurcu and Linas Petkevicius},
booktitle={4th ICLR Workshop on Machine Learning for Remote Sensing (Main Track)},
year={2026},
url={https://openreview.net/forum?id=WBHfQLbgZR}
}
Acknowledgments
This project was funded by the European Union (project No S-MIP-23-45) under the agreement with the Research Council of Lithuania (LMTLT). The I-JEPA implementation is based on Meta's I-JEPA. Diffusion backbone: SD 3.5 Medium. Captions: EarthDial.