magboola's picture
deploy model-onnx-exporter-zerogpu
54190c0 verified
"""ONNX export ZeroGPU Space for custom YOLO models.
Paired with `model-validator-zerogpu`. When a user activates a custom model
the backend asks this Space to convert the stored `.pt` weights to a
browser-friendly `.onnx` file that the frontend WebGPU detector can cache
in IndexedDB. The backend handles:
1. Generating a short-lived signed GET URL for the `.pt` (Azure Blob)
2. Generating a short-lived signed PUT URL for the `.onnx` destination
3. Calling this Space with both URLs + the backend callback URL
This Space:
1. Downloads the .pt
2. Loads it with Ultralytics' YOLO and exports to ONNX (simplified,
optionally quantised to int8 for mobile / bandwidth-sensitive clients)
3. Uploads the resulting .onnx via the signed PUT URL
4. POSTs `{ok, checksum, bytes}` back to the backend callback so the
`custom_models.onnx_key` + `onnx_checksum` columns are updated
Mirrors the sam3 / parakeet / model-validator pattern β€” Gradio app with an
`api_*` function exposed as the Space's remote API, run under
`@spaces.GPU` so each export gets a fresh GPU slice.
"""
from __future__ import annotations
try:
import spaces
ZEROGPU_AVAILABLE = True
except ImportError:
ZEROGPU_AVAILABLE = False
class _spaces_stub:
@staticmethod
def GPU(duration=60):
def decorator(func):
return func
return decorator
spaces = _spaces_stub()
import gc
import hashlib
import os
import tempfile
import time
import traceback
from pathlib import Path
from typing import Any
import gradio as gr
import httpx
MAX_MODEL_BYTES = 200 * 1024 * 1024 # 200 MB cap
GPU_DURATION_S = 180 # export may take a while on cold start + large models
DEFAULT_IMG_SIZE = 640
def _download_weights(source_url: str, timeout_s: float = 120.0) -> Path:
tmp = tempfile.NamedTemporaryFile(suffix=".pt", delete=False)
tmp_path = Path(tmp.name)
tmp.close()
with httpx.stream(
"GET", source_url, timeout=timeout_s, follow_redirects=True
) as response:
response.raise_for_status()
with tmp_path.open("wb") as f:
for chunk in response.iter_bytes(chunk_size=1024 * 1024):
f.write(chunk)
if tmp_path.stat().st_size > MAX_MODEL_BYTES:
raise ValueError(
f"Weights exceed {MAX_MODEL_BYTES // (1024 * 1024)}MB limit"
)
return tmp_path
def _upload_onnx(put_url: str, onnx_path: Path, timeout_s: float = 120.0) -> None:
"""PUT the ONNX file to the backend-provided signed URL.
Azure Blob's PUT API expects `x-ms-blob-type: BlockBlob`; S3 doesn't.
We set both headers when present β€” backends that signed the URL with
a specific Content-Type will reject mismatches, so the caller must
align their signing policy with the Content-Type we send here.
"""
with onnx_path.open("rb") as f:
data = f.read()
headers = {
"Content-Type": "application/octet-stream",
"x-ms-blob-type": "BlockBlob",
}
with httpx.Client(timeout=timeout_s) as client:
response = client.put(put_url, content=data, headers=headers)
response.raise_for_status()
@spaces.GPU(duration=GPU_DURATION_S)
def _export_to_onnx(
weights_path: str,
*,
img_size: int,
half: bool,
simplify: bool,
dynamic: bool,
opset: int,
int8: bool,
) -> dict[str, Any]:
"""Load the YOLO model and export to ONNX on the GPU worker."""
import torch # deferred so ZeroGPU handles CUDA import order
from ultralytics import YOLO
local_path = Path(weights_path)
try:
model = YOLO(str(local_path))
export_kwargs: dict[str, Any] = {
"format": "onnx",
"imgsz": img_size,
"simplify": simplify,
"opset": opset,
"dynamic": dynamic,
}
if int8:
export_kwargs["int8"] = True
if half and not int8:
# Ultralytics only accepts one of half / int8 at a time.
export_kwargs["half"] = True
onnx_path_str = model.export(**export_kwargs)
onnx_path = Path(onnx_path_str)
if not onnx_path.exists():
return {
"ok": False,
"error": f"Ultralytics reported export success but {onnx_path} is missing",
}
return {
"ok": True,
"onnx_path": str(onnx_path),
"size_mb": round(onnx_path.stat().st_size / (1024 * 1024), 2),
"task": getattr(model, "task", "unknown"),
"num_classes": len(getattr(model, "names", {}) or {}),
}
finally:
gc.collect()
try:
torch.cuda.empty_cache()
except Exception:
pass
def api_export_weights(
weights_url: str | None = None,
upload_put_url: str | None = None,
callback_url: str | None = None,
callback_token: str | None = None,
model_id: str | None = None,
img_size: int = DEFAULT_IMG_SIZE,
half: bool = False,
simplify: bool = True,
dynamic: bool = False,
opset: int = 17,
int8: bool = False,
) -> dict[str, Any]:
"""API endpoint invoked by the backend.
Args:
weights_url: signed GET URL for the .pt file
upload_put_url: signed PUT URL where we should upload the resulting .onnx
callback_url: backend endpoint to POST completion metadata to
callback_token: bearer token the backend trusts (passed through in Auth header)
model_id: opaque id the backend uses to correlate the callback
img_size: export image size (YOLO26 default 640)
half / simplify / dynamic / opset / int8: Ultralytics export flags
"""
if not weights_url:
return {"ok": False, "error": "weights_url is required"}
if not upload_put_url:
return {"ok": False, "error": "upload_put_url is required"}
started = time.monotonic()
pt_path: Path | None = None
onnx_path: Path | None = None
try:
pt_path = _download_weights(weights_url)
size_bytes = pt_path.stat().st_size
if size_bytes == 0:
return {"ok": False, "error": "Downloaded .pt file is empty"}
if size_bytes > MAX_MODEL_BYTES:
return {
"ok": False,
"error": f".pt file exceeds {MAX_MODEL_BYTES // (1024 * 1024)}MB limit",
}
export_result = _export_to_onnx(
str(pt_path),
img_size=img_size,
half=half,
simplify=simplify,
dynamic=dynamic,
opset=opset,
int8=int8,
)
if not export_result.get("ok"):
return {
"ok": False,
"error": export_result.get("error", "ONNX export failed"),
"traceback": export_result.get("traceback"),
}
onnx_path = Path(export_result["onnx_path"])
onnx_bytes = onnx_path.read_bytes()
checksum = hashlib.sha256(onnx_bytes).hexdigest()
_upload_onnx(upload_put_url, onnx_path)
result: dict[str, Any] = {
"ok": True,
"checksum": checksum,
"size_mb": export_result["size_mb"],
"task": export_result["task"],
"num_classes": export_result["num_classes"],
"elapsed_s": round(time.monotonic() - started, 3),
}
if callback_url:
_fire_callback(
callback_url,
token=callback_token,
payload={
"model_id": model_id,
"checksum": checksum,
"size_mb": export_result["size_mb"],
},
)
return result
except httpx.HTTPError as exc:
return {"ok": False, "error": f"HTTP failure: {exc}"}
except ValueError as exc:
return {"ok": False, "error": str(exc)}
except Exception as exc: # noqa: BLE001
return {
"ok": False,
"error": f"Export failed: {exc}",
"traceback": traceback.format_exc(),
}
finally:
for path in (pt_path, onnx_path):
if path is None:
continue
try:
os.unlink(path)
except FileNotFoundError:
pass
def _fire_callback(url: str, *, token: str | None, payload: dict[str, Any]) -> None:
headers = {"Content-Type": "application/json"}
if token:
headers["Authorization"] = f"Bearer {token}"
try:
with httpx.Client(timeout=30.0) as client:
client.post(url, json=payload, headers=headers)
except httpx.HTTPError as exc:
# The backend can still poll /models/{id}/onnx; callback failure is
# non-fatal.
print(f"[onnx-exporter] Callback failed: {exc}")
# ─── Gradio wiring ────────────────────────────────────────────────────────────
with gr.Blocks(title="Cadayn ONNX Exporter") as demo:
gr.Markdown(
"""
# Cadayn Custom Model ONNX Exporter
Converts an Ultralytics YOLO `.pt` to ONNX for browser-side WebGPU
inference. Invoked by the Cadayn backend when a custom model is
activated; uploads the resulting `.onnx` via a signed PUT URL and
optionally POSTs completion metadata back to a callback URL.
**API endpoint:** `/api/api_export_weights`
"""
)
with gr.Row():
api_weights_url = gr.Textbox(label="Signed .pt URL")
api_upload_put_url = gr.Textbox(label="Signed .onnx PUT URL")
with gr.Row():
api_callback_url = gr.Textbox(label="Callback URL (optional)")
api_callback_token = gr.Textbox(
label="Callback token (optional)", type="password"
)
api_model_id = gr.Textbox(label="Model id (echoed in callback)")
with gr.Row():
api_img_size = gr.Number(label="Image size", value=DEFAULT_IMG_SIZE)
api_half = gr.Checkbox(label="Half precision (fp16)", value=False)
api_simplify = gr.Checkbox(label="Simplify graph", value=True)
api_dynamic = gr.Checkbox(label="Dynamic axes", value=False)
api_opset = gr.Number(label="ONNX opset", value=17)
api_int8 = gr.Checkbox(label="INT8 quantisation", value=False)
api_output = gr.JSON(label="Export result")
api_weights_url.change(
fn=api_export_weights,
inputs=[
api_weights_url,
api_upload_put_url,
api_callback_url,
api_callback_token,
api_model_id,
api_img_size,
api_half,
api_simplify,
api_dynamic,
api_opset,
api_int8,
],
outputs=api_output,
api_name="api_export_weights",
)
if __name__ == "__main__":
demo.launch()