File size: 10,995 Bytes
54190c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
"""ONNX export ZeroGPU Space for custom YOLO models.

Paired with `model-validator-zerogpu`. When a user activates a custom model
the backend asks this Space to convert the stored `.pt` weights to a
browser-friendly `.onnx` file that the frontend WebGPU detector can cache
in IndexedDB. The backend handles:

  1. Generating a short-lived signed GET URL for the `.pt` (Azure Blob)
  2. Generating a short-lived signed PUT URL for the `.onnx` destination
  3. Calling this Space with both URLs + the backend callback URL

This Space:

  1. Downloads the .pt
  2. Loads it with Ultralytics' YOLO and exports to ONNX (simplified,
     optionally quantised to int8 for mobile / bandwidth-sensitive clients)
  3. Uploads the resulting .onnx via the signed PUT URL
  4. POSTs `{ok, checksum, bytes}` back to the backend callback so the
     `custom_models.onnx_key` + `onnx_checksum` columns are updated

Mirrors the sam3 / parakeet / model-validator pattern β€” Gradio app with an
`api_*` function exposed as the Space's remote API, run under
`@spaces.GPU` so each export gets a fresh GPU slice.
"""

from __future__ import annotations

try:
    import spaces

    ZEROGPU_AVAILABLE = True
except ImportError:
    ZEROGPU_AVAILABLE = False

    class _spaces_stub:
        @staticmethod
        def GPU(duration=60):
            def decorator(func):
                return func

            return decorator

    spaces = _spaces_stub()

import gc
import hashlib
import os
import tempfile
import time
import traceback
from pathlib import Path
from typing import Any

import gradio as gr
import httpx

MAX_MODEL_BYTES = 200 * 1024 * 1024  # 200 MB cap
GPU_DURATION_S = 180  # export may take a while on cold start + large models
DEFAULT_IMG_SIZE = 640


def _download_weights(source_url: str, timeout_s: float = 120.0) -> Path:
    tmp = tempfile.NamedTemporaryFile(suffix=".pt", delete=False)
    tmp_path = Path(tmp.name)
    tmp.close()

    with httpx.stream(
        "GET", source_url, timeout=timeout_s, follow_redirects=True
    ) as response:
        response.raise_for_status()
        with tmp_path.open("wb") as f:
            for chunk in response.iter_bytes(chunk_size=1024 * 1024):
                f.write(chunk)
                if tmp_path.stat().st_size > MAX_MODEL_BYTES:
                    raise ValueError(
                        f"Weights exceed {MAX_MODEL_BYTES // (1024 * 1024)}MB limit"
                    )
    return tmp_path


def _upload_onnx(put_url: str, onnx_path: Path, timeout_s: float = 120.0) -> None:
    """PUT the ONNX file to the backend-provided signed URL.

    Azure Blob's PUT API expects `x-ms-blob-type: BlockBlob`; S3 doesn't.
    We set both headers when present β€” backends that signed the URL with
    a specific Content-Type will reject mismatches, so the caller must
    align their signing policy with the Content-Type we send here.
    """
    with onnx_path.open("rb") as f:
        data = f.read()
    headers = {
        "Content-Type": "application/octet-stream",
        "x-ms-blob-type": "BlockBlob",
    }
    with httpx.Client(timeout=timeout_s) as client:
        response = client.put(put_url, content=data, headers=headers)
        response.raise_for_status()


@spaces.GPU(duration=GPU_DURATION_S)
def _export_to_onnx(
    weights_path: str,
    *,
    img_size: int,
    half: bool,
    simplify: bool,
    dynamic: bool,
    opset: int,
    int8: bool,
) -> dict[str, Any]:
    """Load the YOLO model and export to ONNX on the GPU worker."""
    import torch  # deferred so ZeroGPU handles CUDA import order
    from ultralytics import YOLO

    local_path = Path(weights_path)
    try:
        model = YOLO(str(local_path))

        export_kwargs: dict[str, Any] = {
            "format": "onnx",
            "imgsz": img_size,
            "simplify": simplify,
            "opset": opset,
            "dynamic": dynamic,
        }
        if int8:
            export_kwargs["int8"] = True
        if half and not int8:
            # Ultralytics only accepts one of half / int8 at a time.
            export_kwargs["half"] = True

        onnx_path_str = model.export(**export_kwargs)
        onnx_path = Path(onnx_path_str)
        if not onnx_path.exists():
            return {
                "ok": False,
                "error": f"Ultralytics reported export success but {onnx_path} is missing",
            }

        return {
            "ok": True,
            "onnx_path": str(onnx_path),
            "size_mb": round(onnx_path.stat().st_size / (1024 * 1024), 2),
            "task": getattr(model, "task", "unknown"),
            "num_classes": len(getattr(model, "names", {}) or {}),
        }
    finally:
        gc.collect()
        try:
            torch.cuda.empty_cache()
        except Exception:
            pass


def api_export_weights(
    weights_url: str | None = None,
    upload_put_url: str | None = None,
    callback_url: str | None = None,
    callback_token: str | None = None,
    model_id: str | None = None,
    img_size: int = DEFAULT_IMG_SIZE,
    half: bool = False,
    simplify: bool = True,
    dynamic: bool = False,
    opset: int = 17,
    int8: bool = False,
) -> dict[str, Any]:
    """API endpoint invoked by the backend.

    Args:
        weights_url: signed GET URL for the .pt file
        upload_put_url: signed PUT URL where we should upload the resulting .onnx
        callback_url: backend endpoint to POST completion metadata to
        callback_token: bearer token the backend trusts (passed through in Auth header)
        model_id: opaque id the backend uses to correlate the callback
        img_size: export image size (YOLO26 default 640)
        half / simplify / dynamic / opset / int8: Ultralytics export flags
    """
    if not weights_url:
        return {"ok": False, "error": "weights_url is required"}
    if not upload_put_url:
        return {"ok": False, "error": "upload_put_url is required"}

    started = time.monotonic()
    pt_path: Path | None = None
    onnx_path: Path | None = None

    try:
        pt_path = _download_weights(weights_url)
        size_bytes = pt_path.stat().st_size
        if size_bytes == 0:
            return {"ok": False, "error": "Downloaded .pt file is empty"}
        if size_bytes > MAX_MODEL_BYTES:
            return {
                "ok": False,
                "error": f".pt file exceeds {MAX_MODEL_BYTES // (1024 * 1024)}MB limit",
            }

        export_result = _export_to_onnx(
            str(pt_path),
            img_size=img_size,
            half=half,
            simplify=simplify,
            dynamic=dynamic,
            opset=opset,
            int8=int8,
        )
        if not export_result.get("ok"):
            return {
                "ok": False,
                "error": export_result.get("error", "ONNX export failed"),
                "traceback": export_result.get("traceback"),
            }

        onnx_path = Path(export_result["onnx_path"])
        onnx_bytes = onnx_path.read_bytes()
        checksum = hashlib.sha256(onnx_bytes).hexdigest()

        _upload_onnx(upload_put_url, onnx_path)

        result: dict[str, Any] = {
            "ok": True,
            "checksum": checksum,
            "size_mb": export_result["size_mb"],
            "task": export_result["task"],
            "num_classes": export_result["num_classes"],
            "elapsed_s": round(time.monotonic() - started, 3),
        }

        if callback_url:
            _fire_callback(
                callback_url,
                token=callback_token,
                payload={
                    "model_id": model_id,
                    "checksum": checksum,
                    "size_mb": export_result["size_mb"],
                },
            )

        return result
    except httpx.HTTPError as exc:
        return {"ok": False, "error": f"HTTP failure: {exc}"}
    except ValueError as exc:
        return {"ok": False, "error": str(exc)}
    except Exception as exc:  # noqa: BLE001
        return {
            "ok": False,
            "error": f"Export failed: {exc}",
            "traceback": traceback.format_exc(),
        }
    finally:
        for path in (pt_path, onnx_path):
            if path is None:
                continue
            try:
                os.unlink(path)
            except FileNotFoundError:
                pass


def _fire_callback(url: str, *, token: str | None, payload: dict[str, Any]) -> None:
    headers = {"Content-Type": "application/json"}
    if token:
        headers["Authorization"] = f"Bearer {token}"
    try:
        with httpx.Client(timeout=30.0) as client:
            client.post(url, json=payload, headers=headers)
    except httpx.HTTPError as exc:
        # The backend can still poll /models/{id}/onnx; callback failure is
        # non-fatal.
        print(f"[onnx-exporter] Callback failed: {exc}")


# ─── Gradio wiring ────────────────────────────────────────────────────────────

with gr.Blocks(title="Cadayn ONNX Exporter") as demo:
    gr.Markdown(
        """
        # Cadayn Custom Model ONNX Exporter

        Converts an Ultralytics YOLO `.pt` to ONNX for browser-side WebGPU
        inference. Invoked by the Cadayn backend when a custom model is
        activated; uploads the resulting `.onnx` via a signed PUT URL and
        optionally POSTs completion metadata back to a callback URL.

        **API endpoint:** `/api/api_export_weights`
        """
    )

    with gr.Row():
        api_weights_url = gr.Textbox(label="Signed .pt URL")
        api_upload_put_url = gr.Textbox(label="Signed .onnx PUT URL")

    with gr.Row():
        api_callback_url = gr.Textbox(label="Callback URL (optional)")
        api_callback_token = gr.Textbox(
            label="Callback token (optional)", type="password"
        )
        api_model_id = gr.Textbox(label="Model id (echoed in callback)")

    with gr.Row():
        api_img_size = gr.Number(label="Image size", value=DEFAULT_IMG_SIZE)
        api_half = gr.Checkbox(label="Half precision (fp16)", value=False)
        api_simplify = gr.Checkbox(label="Simplify graph", value=True)
        api_dynamic = gr.Checkbox(label="Dynamic axes", value=False)
        api_opset = gr.Number(label="ONNX opset", value=17)
        api_int8 = gr.Checkbox(label="INT8 quantisation", value=False)

    api_output = gr.JSON(label="Export result")

    api_weights_url.change(
        fn=api_export_weights,
        inputs=[
            api_weights_url,
            api_upload_put_url,
            api_callback_url,
            api_callback_token,
            api_model_id,
            api_img_size,
            api_half,
            api_simplify,
            api_dynamic,
            api_opset,
            api_int8,
        ],
        outputs=api_output,
        api_name="api_export_weights",
    )


if __name__ == "__main__":
    demo.launch()