Sat-JEPA-Diff (Caption-Guided Zero-RGB)

Caption-Guided Zero-RGB Satellite Image Forecasting via Self-Supervised Diffusion

Submitted to IEEE Geoscience and Remote Sensing Letters (GRSL)

[![Paper](Not Available Yet) Workshop Paper GitHub HF Dataset Dataset

Model Description

This is the Caption-Guided Zero-RGB version of Sat-JEPA-Diff. It advances the original model by completely eliminating the 32×32 coarse RGB input, replacing it with a Triple-Captioning system that textually drives temporal changes.

The model predicts future satellite images (t → t+1) through a dual-stream mechanism:

  1. I-JEPA spatial anchor encodes the current image and predicts future latent embeddings — dictating where changes manifest
  2. Triple-Captioning system provides hierarchical text conditioning — describing what changes:
    • Informative caption (t): dense pixel-level description
    • Geometric caption (t): spatial layout and topology
    • Semantic caption (t+1): forecasted future state via a learned Caption Forecaster
  3. Multi-Stream Conditioning Adapter fuses all four signals (3 caption embeddings + IJEPA tokens) via cross-attention
  4. Frozen Stable Diffusion 3.5 Medium + LoRA generates the final high-fidelity RGB imagery

This Zero-RGB paradigm bypasses physical sensor blindness — the model never sees past pixels during generation, relying entirely on semantic and latent representations.

What Changed from the Original Model

Original (coarse-RGB) This model (Zero-RGB)
RGB conditioning 32×32 downsampled input None
Text conditioning None Triple-caption (informative + geometric + semantic)
Caption forecasting N/A Cross-attention forecaster predicts semantic caption at t+1
Adapter Single-stream Multi-stream (4-signal fusion + temporal attention)
LPIPS ↓ 0.4449 0.3781 (+15% perceptual improvement)
GSSIM ↑ 0.8984 0.8974 (maintained)

Key Results

Model L1 ↓ MSE ↓ PSNR ↑ SSIM ↑ GSSIM ↑ LPIPS ↓ FID ↓
Deterministic Baselines
Default 0.0131 0.0008 37.52 0.9361 0.7858 0.0708 0.6959
PredRNN 0.0117 0.0005 38.38 0.9476 0.7836 0.0726 9.9720
SimVP v2 0.0131 0.0006 37.63 0.9391 0.7719 0.0928 18.7208
Generative Models
SD 3.5 0.0175 0.0005 32.98 0.8398 0.8711 0.4528 0.1533
MCVD 0.0314 0.0031 31.28 0.8637 0.7665 0.1890 0.1956
Sat-JEPA-Diff (coarse RGB) 0.0158 0.0004 33.81 0.8672 0.8984 0.4449 0.1475
Sat-JEPA-Diff (Caption Guided) 0.0267 0.0025 28.95 0.7806 0.8974 0.3781 0.2642

Despite using zero optical inputs, the caption-guided model achieves 15% better perceptual realism (LPIPS) while maintaining topological integrity (GSSIM).

Architecture Details

Component Specification
IJEPA Encoder ViT-Base, patch size 8, input 128×128
IJEPA Predictor 6-layer transformer, embed dim 384
Projection Head Linear, 768 → 64
Caption Forecaster 3-layer cross-attention, hidden dim 1024, gated residual delta
Multi-Stream Adapter 4-signal fusion + temporal attention, ~25M params
Diffusion Backbone SD 3.5 Medium (frozen) + LoRA (rank 16, alpha 32)
Text Encoders CLIP ×2 + T5 (frozen, used offline for precomputation)
VAE 8× spatial compression, 16 latent channels

How to Use

Note: This is a custom PyTorch checkpoint, not a standard transformers model. You need the source code from the GitHub repository to load and run inference.

1. Setup

git clone https://github.com/VU-AIML/SAT-JEPA-DIFF.git
cd SAT-JEPA-DIFF
git checkout caption-guided

conda create -n satjepa python=3.12
conda activate satjepa

pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
pip install diffusers transformers peft accelerate
pip install rasterio matplotlib pyyaml lpips

You need a Hugging Face token with access to Stable Diffusion 3.5 Medium:

export HF_TOKEN=hf_your_token_here

2. Download the Checkpoint

from huggingface_hub import hf_hub_download

checkpoint_path = hf_hub_download(
    repo_id="kursatkomurcu/SAT-JEPA-DIFF-Caption-Guided",
    filename="s2_future_jepa-best.pth.tar",
)

Or via CLI:

huggingface-cli download kursatkomurcu/SAT-JEPA-DIFF-Caption-Guided \
    s2_future_jepa-best.pth.tar --local-dir ./checkpoints

3. Run Inference

cd SAT-JEPA-DIFF/src

python inference.py \
    --checkpoint /path/to/s2_future_jepa-best.pth.tar \
    --output_dir ./results \
    --diffusion_steps 20 \
    --noise_strength 0.35

4. Programmatic Usage

import torch
from helper import init_model
from sd_models import load_sd_model, encode_caption_batch
from sd_joint_loss import diffusion_sample
from caption_forecaster import CaptionForecaster

device = torch.device("cuda")

# Load all components from checkpoint
encoder, predictor, sd_state, embed_dim = load_model(
    checkpoint_path="path/to/s2_future_jepa-best.pth.tar",
    device=device,
)

# Load caption forecaster
caption_forecaster = CaptionForecaster(
    ijepa_dim=768, text_dim=4096, hidden_dim=1024, num_layers=3
).to(device)

# Load a Sentinel-2 GeoTIFF
rgb_t = load_and_resize_tif("path/to/sentinel2_image.tif", target_size=128)

# Predict with caption conditioning
rgb_t1_pred = predict_next_frame(
    rgb_t=rgb_t,
    encoder=encoder,
    predictor=predictor,
    sd_state=sd_state,
    caption_forecaster=caption_forecaster,
    device=device,
    num_diffusion_steps=20,
    noise_strength=0.35,
)
# rgb_t1_pred: (3, 128, 128) tensor in [0, 1]

Checkpoint Contents

The .pth.tar file contains:

Key Description
encoder IJEPA ViT encoder state dict
predictor IJEPA predictor state dict
target_encoder EMA target encoder state dict
proj_head Projection head (768 → 64) state dict
cond_adapter Multi-stream caption conditioning adapter state dict
caption_forecaster Caption forecaster (cross-attention + gated delta) state dict
lora_state_dict LoRA weights for SD 3.5 UNet
prompt_embeds Fallback text prompt embeddings
pooled_prompt_embeds Fallback pooled prompt embeddings
config Training configuration dict
epoch Training epoch
optimizer Optimizer state

Caption Pipeline

The model requires precomputed caption embeddings. The full pipeline is documented in the GitHub repository:

  1. Generate captions with EarthDial (generate_captions.py)
  2. Precompute embeddings with frozen CLIP+T5 (precompute_caption_embeddings.py)
  3. Train or run inference — text encoders are never loaded during training/inference

Training

See the GitHub repository for full training instructions. The model jointly optimizes three losses:

ℒ_total = ℒ_IJEPA + λ_sd · ℒ_diff + λ_cap · ℒ_caption

Training takes approximately 5 days on a single NVIDIA RTX 5090 (24GB).

Dataset

Sentinel-2 RGB imagery (10m GSD) paired with Alpha Earth Foundation Model embeddings and EarthDial captions across 100 global Regions of Interest (2017–2024). Available on Zenodo.

Requirements

Citation

@article{komurcu2026satjepadiff_caption,
  title={Sat-{JEPA}-Diff: Caption-Guided Zero-{RGB} Satellite Image Forecasting via Self-Supervised Diffusion},
  author={K{\"o}m{\"u}rc{\"u}, K{\"u}r{\c{s}}at and Petkevicius, Linas},
  journal={IEEE Geoscience and Remote Sensing Letters},
  year={2026},
  note={Submitted}
}

@inproceedings{komurcu2026satjepadiff,
  title={Sat-{JEPA}-Diff: Bridging Self-Supervised Learning and Generative Diffusion for Remote Sensing},
  author={Kursat Komurcu and Linas Petkevicius},
  booktitle={4th ICLR Workshop on Machine Learning for Remote Sensing (Main Track)},
  year={2026},
  url={https://openreview.net/forum?id=WBHfQLbgZR}
}

Acknowledgments

This project was funded by the European Union (project No S-MIP-23-45) under the agreement with the Research Council of Lithuania (LMTLT). The I-JEPA implementation is based on Meta's I-JEPA. Diffusion backbone: SD 3.5 Medium. Captions: EarthDial.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support