You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

PG2-Grasp โ€” Text-Grounded Robot Grasp Prediction

Fine-tuned from google/paligemma2-10b-mix-224 to predict robot grasp centers from natural language prompts and RGB images.

This repo contains bf16 safetensors weights (fine-tuned, not base model).

Usage

BF16 (full precision)

from transformers import PaliGemmaForConditionalGeneration, AutoProcessor
from PIL import Image, ImageOps
import torch

model_id = "nagaaato/pg2grasp"
processor = AutoProcessor.from_pretrained(model_id)
model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
).to("cuda").eval()

# Prepare image (square-pad to max(w, h))
image = Image.open("scene.png").convert("RGB")
w, h = image.size
max_dim = max(w, h)
pad_left = (max_dim - w) // 2
pad_top = (max_dim - h) // 2
padded = ImageOps.pad(image, (max_dim, max_dim), color=0)

# Prompt format: "<image>Pick the <object>."
inputs = processor(
    images=padded,
    text="<image>Pick the apple.",
    return_tensors="pt",
).to("cuda")

with torch.no_grad():
    outputs = model(**inputs)
    # Extract <loc> logits (last token, 1024-way grid classification)
    loc_logits = outputs.logits[0, -1, 256000:257024]
    pred_token = loc_logits.argmax().item()

# Decode to pixel coordinates (32ร—32 grid)
LOC_GRID = 32
row = pred_token // LOC_GRID
col = pred_token % LOC_GRID
cx_px = (col + 0.5) / LOC_GRID * max_dim - pad_left
cy_px = (row + 0.5) / LOC_GRID * max_dim - pad_top
print(f"Grasp center: ({cx_px:.0f}, {cy_px:.0f}) px in original image")

VRAM: ~20 GB (bf16, batch_size=1)

INT8 Quantization

Requires bitsandbytes and accelerate.

from transformers import PaliGemmaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from PIL import Image, ImageOps
import torch

model_id = "nagaaato/pg2grasp"
processor = AutoProcessor.from_pretrained(model_id)
model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    quantization_config=BitsAndBytesConfig(load_in_8bit=True),
    device_map="cuda",
).eval()

image = Image.open("scene.png").convert("RGB")
w, h = image.size
max_dim = max(w, h)
pad_left = (max_dim - w) // 2
pad_top = (max_dim - h) // 2
padded = ImageOps.pad(image, (max_dim, max_dim), color=0)

inputs = processor(
    images=padded,
    text="<image>Pick the apple.",
    return_tensors="pt",
).to("cuda")

with torch.no_grad():
    outputs = model(**inputs)
    loc_logits = outputs.logits[0, -1, 256000:257024]
    pred_token = loc_logits.argmax().item()

# Same decoding as BF16
LOC_GRID = 32
row = pred_token // LOC_GRID
col = pred_token % LOC_GRID
cx_px = (col + 0.5) / LOC_GRID * max_dim - pad_left
cy_px = (row + 0.5) / LOC_GRID * max_dim - pad_top

VRAM: ~10-11 GB (INT8 quantized)

INT4 NF4 Quantization

from transformers import PaliGemmaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
import torch

model_id = "nagaaato/pg2grasp"
processor = AutoProcessor.from_pretrained(model_id)
model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    ),
    device_map="cuda",
).eval()

# Same usage as above

VRAM: ~5-6 GB (INT4 NF4 quantized)

Model Details

  • Base: PaliGemma2-10B (9.66B params)
    • SigLIP-SO400M vision encoder (robot-adapted from piRECAP05)
    • Gemma2-9B language model (42 decoder layers)
  • Fine-tuning: selective unfreeze on robotics grasp datasets
    • Training: FSDP on 8ร—A100-80GB
    • Differential LR: base layers frozen, top layers + vision backbone trainable
  • Output: single token (1024-way softmax over 32ร—32 grid)
  • Input resolution: 224ร—224 (square-padded with black borders)

Eval Results (validation set)

Metric Score
canonical_median 45.1 px
best_grasp_median 14.1 px
W100 (within 100px) 80.0%
grounding margin 9.04
grounding top1 89.8%

Notes

  • The repository contains bf16 weights only. Quantization happens at load time via BitsAndBytesConfig โ€” weights remain bf16 on disk.
  • Prompt format is strict: "<image>Pick the <object>." with period.
  • Image padding is required: square-pad with black, compute offset to decode back to original space.
  • LOC tokens: indices 256000โ€“257023 (1024 loc tokens for 32ร—32 grid).
Downloads last month
130
Safetensors
Model size
10B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for nagaaato/pg2grasp

Finetuned
(1)
this model