PG2-Grasp โ Text-Grounded Robot Grasp Prediction
Fine-tuned from google/paligemma2-10b-mix-224 to predict robot grasp centers from natural language prompts and RGB images.
This repo contains bf16 safetensors weights (fine-tuned, not base model).
Usage
BF16 (full precision)
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor
from PIL import Image, ImageOps
import torch
model_id = "nagaaato/pg2grasp"
processor = AutoProcessor.from_pretrained(model_id)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
).to("cuda").eval()
# Prepare image (square-pad to max(w, h))
image = Image.open("scene.png").convert("RGB")
w, h = image.size
max_dim = max(w, h)
pad_left = (max_dim - w) // 2
pad_top = (max_dim - h) // 2
padded = ImageOps.pad(image, (max_dim, max_dim), color=0)
# Prompt format: "<image>Pick the <object>."
inputs = processor(
images=padded,
text="<image>Pick the apple.",
return_tensors="pt",
).to("cuda")
with torch.no_grad():
outputs = model(**inputs)
# Extract <loc> logits (last token, 1024-way grid classification)
loc_logits = outputs.logits[0, -1, 256000:257024]
pred_token = loc_logits.argmax().item()
# Decode to pixel coordinates (32ร32 grid)
LOC_GRID = 32
row = pred_token // LOC_GRID
col = pred_token % LOC_GRID
cx_px = (col + 0.5) / LOC_GRID * max_dim - pad_left
cy_px = (row + 0.5) / LOC_GRID * max_dim - pad_top
print(f"Grasp center: ({cx_px:.0f}, {cy_px:.0f}) px in original image")
VRAM: ~20 GB (bf16, batch_size=1)
INT8 Quantization
Requires bitsandbytes and accelerate.
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from PIL import Image, ImageOps
import torch
model_id = "nagaaato/pg2grasp"
processor = AutoProcessor.from_pretrained(model_id)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
device_map="cuda",
).eval()
image = Image.open("scene.png").convert("RGB")
w, h = image.size
max_dim = max(w, h)
pad_left = (max_dim - w) // 2
pad_top = (max_dim - h) // 2
padded = ImageOps.pad(image, (max_dim, max_dim), color=0)
inputs = processor(
images=padded,
text="<image>Pick the apple.",
return_tensors="pt",
).to("cuda")
with torch.no_grad():
outputs = model(**inputs)
loc_logits = outputs.logits[0, -1, 256000:257024]
pred_token = loc_logits.argmax().item()
# Same decoding as BF16
LOC_GRID = 32
row = pred_token // LOC_GRID
col = pred_token % LOC_GRID
cx_px = (col + 0.5) / LOC_GRID * max_dim - pad_left
cy_px = (row + 0.5) / LOC_GRID * max_dim - pad_top
VRAM: ~10-11 GB (INT8 quantized)
INT4 NF4 Quantization
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
import torch
model_id = "nagaaato/pg2grasp"
processor = AutoProcessor.from_pretrained(model_id)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
),
device_map="cuda",
).eval()
# Same usage as above
VRAM: ~5-6 GB (INT4 NF4 quantized)
Model Details
- Base: PaliGemma2-10B (9.66B params)
- SigLIP-SO400M vision encoder (robot-adapted from piRECAP05)
- Gemma2-9B language model (42 decoder layers)
- Fine-tuning: selective unfreeze on robotics grasp datasets
- Training: FSDP on 8รA100-80GB
- Differential LR: base layers frozen, top layers + vision backbone trainable
- Output: single token (1024-way softmax over 32ร32 grid)
- Input resolution: 224ร224 (square-padded with black borders)
Eval Results (validation set)
| Metric | Score |
|---|---|
| canonical_median | 45.1 px |
| best_grasp_median | 14.1 px |
| W100 (within 100px) | 80.0% |
| grounding margin | 9.04 |
| grounding top1 | 89.8% |
Notes
- The repository contains bf16 weights only. Quantization happens at load time via BitsAndBytesConfig โ weights remain bf16 on disk.
- Prompt format is strict:
"<image>Pick the <object>."with period. - Image padding is required: square-pad with black, compute offset to decode back to original space.
- LOC tokens: indices 256000โ257023 (1024 loc tokens for 32ร32 grid).
- Downloads last month
- 130
Model tree for nagaaato/pg2grasp
Base model
google/paligemma2-10b-mix-224