"""ONNX export ZeroGPU Space for custom YOLO models. Paired with `model-validator-zerogpu`. When a user activates a custom model the backend asks this Space to convert the stored `.pt` weights to a browser-friendly `.onnx` file that the frontend WebGPU detector can cache in IndexedDB. The backend handles: 1. Generating a short-lived signed GET URL for the `.pt` (Azure Blob) 2. Generating a short-lived signed PUT URL for the `.onnx` destination 3. Calling this Space with both URLs + the backend callback URL This Space: 1. Downloads the .pt 2. Loads it with Ultralytics' YOLO and exports to ONNX (simplified, optionally quantised to int8 for mobile / bandwidth-sensitive clients) 3. Uploads the resulting .onnx via the signed PUT URL 4. POSTs `{ok, checksum, bytes}` back to the backend callback so the `custom_models.onnx_key` + `onnx_checksum` columns are updated Mirrors the sam3 / parakeet / model-validator pattern — Gradio app with an `api_*` function exposed as the Space's remote API, run under `@spaces.GPU` so each export gets a fresh GPU slice. """ from __future__ import annotations try: import spaces ZEROGPU_AVAILABLE = True except ImportError: ZEROGPU_AVAILABLE = False class _spaces_stub: @staticmethod def GPU(duration=60): def decorator(func): return func return decorator spaces = _spaces_stub() import gc import hashlib import os import tempfile import time import traceback from pathlib import Path from typing import Any import gradio as gr import httpx MAX_MODEL_BYTES = 200 * 1024 * 1024 # 200 MB cap GPU_DURATION_S = 180 # export may take a while on cold start + large models DEFAULT_IMG_SIZE = 640 def _download_weights(source_url: str, timeout_s: float = 120.0) -> Path: tmp = tempfile.NamedTemporaryFile(suffix=".pt", delete=False) tmp_path = Path(tmp.name) tmp.close() with httpx.stream( "GET", source_url, timeout=timeout_s, follow_redirects=True ) as response: response.raise_for_status() with tmp_path.open("wb") as f: for chunk in response.iter_bytes(chunk_size=1024 * 1024): f.write(chunk) if tmp_path.stat().st_size > MAX_MODEL_BYTES: raise ValueError( f"Weights exceed {MAX_MODEL_BYTES // (1024 * 1024)}MB limit" ) return tmp_path def _upload_onnx(put_url: str, onnx_path: Path, timeout_s: float = 120.0) -> None: """PUT the ONNX file to the backend-provided signed URL. Azure Blob's PUT API expects `x-ms-blob-type: BlockBlob`; S3 doesn't. We set both headers when present — backends that signed the URL with a specific Content-Type will reject mismatches, so the caller must align their signing policy with the Content-Type we send here. """ with onnx_path.open("rb") as f: data = f.read() headers = { "Content-Type": "application/octet-stream", "x-ms-blob-type": "BlockBlob", } with httpx.Client(timeout=timeout_s) as client: response = client.put(put_url, content=data, headers=headers) response.raise_for_status() @spaces.GPU(duration=GPU_DURATION_S) def _export_to_onnx( weights_path: str, *, img_size: int, half: bool, simplify: bool, dynamic: bool, opset: int, int8: bool, ) -> dict[str, Any]: """Load the YOLO model and export to ONNX on the GPU worker.""" import torch # deferred so ZeroGPU handles CUDA import order from ultralytics import YOLO local_path = Path(weights_path) try: model = YOLO(str(local_path)) export_kwargs: dict[str, Any] = { "format": "onnx", "imgsz": img_size, "simplify": simplify, "opset": opset, "dynamic": dynamic, } if int8: export_kwargs["int8"] = True if half and not int8: # Ultralytics only accepts one of half / int8 at a time. export_kwargs["half"] = True onnx_path_str = model.export(**export_kwargs) onnx_path = Path(onnx_path_str) if not onnx_path.exists(): return { "ok": False, "error": f"Ultralytics reported export success but {onnx_path} is missing", } return { "ok": True, "onnx_path": str(onnx_path), "size_mb": round(onnx_path.stat().st_size / (1024 * 1024), 2), "task": getattr(model, "task", "unknown"), "num_classes": len(getattr(model, "names", {}) or {}), } finally: gc.collect() try: torch.cuda.empty_cache() except Exception: pass def api_export_weights( weights_url: str | None = None, upload_put_url: str | None = None, callback_url: str | None = None, callback_token: str | None = None, model_id: str | None = None, img_size: int = DEFAULT_IMG_SIZE, half: bool = False, simplify: bool = True, dynamic: bool = False, opset: int = 17, int8: bool = False, ) -> dict[str, Any]: """API endpoint invoked by the backend. Args: weights_url: signed GET URL for the .pt file upload_put_url: signed PUT URL where we should upload the resulting .onnx callback_url: backend endpoint to POST completion metadata to callback_token: bearer token the backend trusts (passed through in Auth header) model_id: opaque id the backend uses to correlate the callback img_size: export image size (YOLO26 default 640) half / simplify / dynamic / opset / int8: Ultralytics export flags """ if not weights_url: return {"ok": False, "error": "weights_url is required"} if not upload_put_url: return {"ok": False, "error": "upload_put_url is required"} started = time.monotonic() pt_path: Path | None = None onnx_path: Path | None = None try: pt_path = _download_weights(weights_url) size_bytes = pt_path.stat().st_size if size_bytes == 0: return {"ok": False, "error": "Downloaded .pt file is empty"} if size_bytes > MAX_MODEL_BYTES: return { "ok": False, "error": f".pt file exceeds {MAX_MODEL_BYTES // (1024 * 1024)}MB limit", } export_result = _export_to_onnx( str(pt_path), img_size=img_size, half=half, simplify=simplify, dynamic=dynamic, opset=opset, int8=int8, ) if not export_result.get("ok"): return { "ok": False, "error": export_result.get("error", "ONNX export failed"), "traceback": export_result.get("traceback"), } onnx_path = Path(export_result["onnx_path"]) onnx_bytes = onnx_path.read_bytes() checksum = hashlib.sha256(onnx_bytes).hexdigest() _upload_onnx(upload_put_url, onnx_path) result: dict[str, Any] = { "ok": True, "checksum": checksum, "size_mb": export_result["size_mb"], "task": export_result["task"], "num_classes": export_result["num_classes"], "elapsed_s": round(time.monotonic() - started, 3), } if callback_url: _fire_callback( callback_url, token=callback_token, payload={ "model_id": model_id, "checksum": checksum, "size_mb": export_result["size_mb"], }, ) return result except httpx.HTTPError as exc: return {"ok": False, "error": f"HTTP failure: {exc}"} except ValueError as exc: return {"ok": False, "error": str(exc)} except Exception as exc: # noqa: BLE001 return { "ok": False, "error": f"Export failed: {exc}", "traceback": traceback.format_exc(), } finally: for path in (pt_path, onnx_path): if path is None: continue try: os.unlink(path) except FileNotFoundError: pass def _fire_callback(url: str, *, token: str | None, payload: dict[str, Any]) -> None: headers = {"Content-Type": "application/json"} if token: headers["Authorization"] = f"Bearer {token}" try: with httpx.Client(timeout=30.0) as client: client.post(url, json=payload, headers=headers) except httpx.HTTPError as exc: # The backend can still poll /models/{id}/onnx; callback failure is # non-fatal. print(f"[onnx-exporter] Callback failed: {exc}") # ─── Gradio wiring ──────────────────────────────────────────────────────────── with gr.Blocks(title="Cadayn ONNX Exporter") as demo: gr.Markdown( """ # Cadayn Custom Model ONNX Exporter Converts an Ultralytics YOLO `.pt` to ONNX for browser-side WebGPU inference. Invoked by the Cadayn backend when a custom model is activated; uploads the resulting `.onnx` via a signed PUT URL and optionally POSTs completion metadata back to a callback URL. **API endpoint:** `/api/api_export_weights` """ ) with gr.Row(): api_weights_url = gr.Textbox(label="Signed .pt URL") api_upload_put_url = gr.Textbox(label="Signed .onnx PUT URL") with gr.Row(): api_callback_url = gr.Textbox(label="Callback URL (optional)") api_callback_token = gr.Textbox( label="Callback token (optional)", type="password" ) api_model_id = gr.Textbox(label="Model id (echoed in callback)") with gr.Row(): api_img_size = gr.Number(label="Image size", value=DEFAULT_IMG_SIZE) api_half = gr.Checkbox(label="Half precision (fp16)", value=False) api_simplify = gr.Checkbox(label="Simplify graph", value=True) api_dynamic = gr.Checkbox(label="Dynamic axes", value=False) api_opset = gr.Number(label="ONNX opset", value=17) api_int8 = gr.Checkbox(label="INT8 quantisation", value=False) api_output = gr.JSON(label="Export result") api_weights_url.change( fn=api_export_weights, inputs=[ api_weights_url, api_upload_put_url, api_callback_url, api_callback_token, api_model_id, api_img_size, api_half, api_simplify, api_dynamic, api_opset, api_int8, ], outputs=api_output, api_name="api_export_weights", ) if __name__ == "__main__": demo.launch()