waltgrace commited on
Commit
37ddbc1
·
verified ·
1 Parent(s): 6060fb2

feat(identify): open-set image retrieval subpackage

Browse files
CLAUDE.md CHANGED
@@ -190,6 +190,30 @@ QWEN_URL=http://192.168.1.244:8291 data_label_factory status
190
 
191
  ---
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  ## Optional GPU path
194
 
195
  If a user has more than ~10k images and wants the run to finish in minutes
 
190
 
191
  ---
192
 
193
+ ## Optional: open-set identification (`data_label_factory.identify`)
194
+
195
+ If a user wants to **identify** which one of N known things they're holding
196
+ up to a webcam (rather than detect arbitrary objects), point them at the
197
+ identify subpackage. It's a CLIP retrieval index — needs only 1 image per
198
+ class, no training required.
199
+
200
+ ```bash
201
+ pip install -e ".[identify]"
202
+ python3 -m data_label_factory.identify index --refs ~/my-things/ --out my.npz
203
+ python3 -m data_label_factory.identify verify --index my.npz
204
+ # (optional) python3 -m data_label_factory.identify train --refs ~/my-things/ --out my-proj.pt
205
+ python3 -m data_label_factory.identify serve --index my.npz --refs ~/my-things/
206
+ # → web/canvas/live talks to it via FALCON_URL=http://localhost:8500/api/falcon
207
+ ```
208
+
209
+ The full blueprint for any image set is at
210
+ `data_label_factory/identify/README.md`. **This is the right tool for
211
+ "trading cards / products / album covers / parts catalog identification"
212
+ use cases. The base data_label_factory pipeline is for closed-set bbox
213
+ detection.**
214
+
215
+ ---
216
+
217
  ## Optional GPU path
218
 
219
  If a user has more than ~10k images and wants the run to finish in minutes
README.md CHANGED
@@ -272,6 +272,41 @@ runpod is just an option.
272
 
273
  ---
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  ## Configuration reference
276
 
277
  ### Environment variables
 
272
 
273
  ---
274
 
275
+ ## Optional: open-set image identification
276
+
277
+ The base pipeline produces COCO labels for training a closed-set **detector**.
278
+ The opt-in `data_label_factory.identify` subpackage produces a CLIP retrieval
279
+ **index** for open-set identification — given a known set of N reference images,
280
+ identify which one a webcam frame is showing. **Use it when you have 1 image
281
+ per class and want zero training time.**
282
+
283
+ ```bash
284
+ pip install -e ".[identify]"
285
+
286
+ # Build an index from a folder of references
287
+ python3 -m data_label_factory.identify index --refs ~/my-cards/ --out my.npz
288
+
289
+ # Optional: contrastive fine-tune for fine-grained accuracy (~5 min on M4 MPS)
290
+ python3 -m data_label_factory.identify train --refs ~/my-cards/ --out my-proj.pt
291
+ python3 -m data_label_factory.identify index --refs ~/my-cards/ --out my.npz --projection my-proj.pt
292
+
293
+ # Self-test the index
294
+ python3 -m data_label_factory.identify verify --index my.npz
295
+
296
+ # Serve as a mac_tensor-shaped /api/falcon endpoint
297
+ python3 -m data_label_factory.identify serve --index my.npz --refs ~/my-cards/
298
+ # → web/canvas/live can hit it with FALCON_URL=http://localhost:8500/api/falcon
299
+ ```
300
+
301
+ Built-in **rarity / variant detection** for free — if your filenames encode a
302
+ suffix like `_pscr`, `_scr`, `_ur`, the matched filename's suffix becomes a
303
+ separate `rarity` field on the response. See
304
+ [`data_label_factory/identify/README.md`](data_label_factory/identify/README.md)
305
+ for the full blueprint and concrete examples (trading cards, album covers,
306
+ industrial parts, plant species, …).
307
+
308
+ ---
309
+
310
  ## Configuration reference
311
 
312
  ### Environment variables
data_label_factory/identify/README.md ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # `data_label_factory.identify` — open-set image retrieval
2
+
3
+ The companion to the main labeling pipeline. Where the base
4
+ `data_label_factory` produces COCO labels for training a closed-set
5
+ **detector**, this subpackage produces a CLIP-based **retrieval index** for
6
+ open-set **identification** — given a known set of N reference images,
7
+ identify which one a webcam frame is showing.
8
+
9
+ **Use this when:**
10
+
11
+ - You have **1 image per class** (a product catalog, a card collection, an
12
+ art portfolio, a parts diagram, …) and want a "what is this thing I'm
13
+ holding up?" tool.
14
+ - You want **zero training time** by default and the option to fine-tune for
15
+ more accuracy.
16
+ - You want to **add new items in seconds** by dropping a JPG in a folder
17
+ and re-indexing.
18
+ - You want **rarity / variant detection** for free — different prints of
19
+ the same item indexed under filenames that encode the variant.
20
+
21
+ **Use the base pipeline instead when:**
22
+
23
+ - You need to detect multiple object instances per image with bounding boxes
24
+ - Your objects appear in cluttered scenes and need a real detector
25
+ - You have many images per class and want a closed-set classifier
26
+
27
+ ---
28
+
29
+ ## The 4-step blueprint (works for ANY image set)
30
+
31
+ This is the entire workflow. Replace `~/my-collection/` with your reference
32
+ folder and you're done.
33
+
34
+ ### Step 0 — install (one-time, ~1 min)
35
+
36
+ ```bash
37
+ pip install -e ".[identify]"
38
+ # This pulls torch, pillow, clip, fastapi, ultralytics, and uvicorn
39
+ ```
40
+
41
+ ### Step 1 — gather references (5–30 min depending on source)
42
+
43
+ You need **one image per class**. The filename becomes the label, so be
44
+ deliberate:
45
+
46
+ ```
47
+ ~/my-collection/
48
+ ├── blue_eyes_white_dragon.jpg
49
+ ├── dark_magician.jpg
50
+ ├── exodia_the_forbidden_one.jpg
51
+ └── ...
52
+ ```
53
+
54
+ **Naming rules:**
55
+
56
+ - The filename stem (minus extension) becomes the displayed label.
57
+ - Optional set-code prefixes are auto-stripped: `LOCH-JP001_dark_magician.jpg`
58
+ → `Dark Magician`.
59
+ - Optional rarity suffixes are extracted as a separate field if they match
60
+ one of: `pscr`, `scr`, `ur`, `sr`, `op`, `utr`, `cr`, `ea`, `gmr`. Example:
61
+ `dark_magician_pscr.jpg` → name=`Dark Magician`, rarity=`PScR`.
62
+ - Underscores become spaces, then title-cased.
63
+
64
+ **Where to get reference images:**
65
+
66
+ | Domain | Source |
67
+ |---|---|
68
+ | Trading cards | ygoprodeck (Yu-Gi-Oh!), Pokémon TCG API, Scryfall (MTG), yugipedia |
69
+ | Products | Amazon listing main image, manufacturer site |
70
+ | Art / paintings | Wikimedia Commons, museum APIs |
71
+ | Industrial parts | Manufacturer catalog scrapes |
72
+ | Faces | Selfies (with permission!) |
73
+ | Album covers | MusicBrainz cover art archive |
74
+ | Movie posters | TMDB API |
75
+
76
+ **You can mix sources** — e.g. include both English and Japanese versions of
77
+ the same card under different filenames. The retrieval system treats them as
78
+ separate references but the cosine match will pick whichever is closer to
79
+ your live input.
80
+
81
+ ### Step 2 — build the index (10 sec)
82
+
83
+ ```bash
84
+ python3 -m data_label_factory.identify index \
85
+ --refs ~/my-collection/ \
86
+ --out my-index.npz
87
+ ```
88
+
89
+ This CLIP-encodes every image and saves the embeddings to a single `.npz`
90
+ file (~300 KB for 150 references). On Apple Silicon MPS this is ~50 ms per
91
+ image — 150 images takes about 8 seconds.
92
+
93
+ **Output**: `my-index.npz` containing `embeddings`, `names`, `filenames`.
94
+
95
+ ### Step 3 — verify the index (5 sec)
96
+
97
+ ```bash
98
+ python3 -m data_label_factory.identify verify --index my-index.npz
99
+ ```
100
+
101
+ Self-tests every reference: each one should match itself as the top-1
102
+ result. Reports:
103
+
104
+ - **Top-1 self-identification rate** (should be 100%)
105
+ - **Most-confusable pairs** — references with high mutual similarity
106
+ (visually similar items the model might confuse at runtime)
107
+ - **Margin analysis** — the gap between "correct match" and "best wrong
108
+ match" cosine scores. **This is the strongest predictor of live accuracy.**
109
+
110
+ **Margin guidelines:**
111
+
112
+ | Median margin | What it means | Action |
113
+ |---|---|---|
114
+ | **> 0.3** | Strong separation, live accuracy will be excellent | Ship it |
115
+ | **0.1 – 0.3** | Medium separation, expect some confusion on visually similar items | Consider Step 4 |
116
+ | **< 0.1** | References look too similar to off-the-shelf CLIP | **Run Step 4** (fine-tune) |
117
+
118
+ ### Step 4 (OPTIONAL) — fine-tune the retrieval head (5–15 min)
119
+
120
+ If the verify output shows margin < 0.1, your domain (yugioh cards, MTG
121
+ cards, similar-looking product variants, …) confuses generic CLIP. Fix it
122
+ with a contrastive fine-tune:
123
+
124
+ ```bash
125
+ python3 -m data_label_factory.identify train \
126
+ --refs ~/my-collection/ \
127
+ --out my-projection.pt \
128
+ --epochs 12
129
+ ```
130
+
131
+ **What this does:**
132
+
133
+ - Loads frozen CLIP ViT-B/32
134
+ - Trains a small **projection head** (~400k params) on top of CLIP features
135
+ - Uses **K-cards-per-batch sampling** (16 distinct classes × 4 augmentations
136
+ = 64-image batches)
137
+ - Loss: **SupCon** (Khosla et al. 2020) — pulls augmentations of the same
138
+ class together, pushes different classes apart
139
+ - Augmentations: random crop, rotation ±20°, color jitter, perspective warp,
140
+ Gaussian blur, occasional grayscale
141
+ - Output: a **1.5 MB `.pt` file** containing the projection head weights
142
+
143
+ **Reference run** (150-class set, M4 Mac mini, MPS): 12 epochs in ~6 min.
144
+ Margin improvement: 0.07 → 0.36 (5× wider).
145
+
146
+ Then re-build the index with the projection head:
147
+
148
+ ```bash
149
+ python3 -m data_label_factory.identify index \
150
+ --refs ~/my-collection/ \
151
+ --out my-index.npz \
152
+ --projection my-projection.pt
153
+ ```
154
+
155
+ And re-verify to confirm the margin actually widened:
156
+
157
+ ```bash
158
+ python3 -m data_label_factory.identify verify --index my-index.npz
159
+ ```
160
+
161
+ ### Step 5 — serve it as an HTTP endpoint (instant)
162
+
163
+ ```bash
164
+ python3 -m data_label_factory.identify serve \
165
+ --index my-index.npz \
166
+ --refs ~/my-collection/ \
167
+ --projection my-projection.pt \
168
+ --port 8500
169
+ ```
170
+
171
+ This starts a FastAPI server with:
172
+
173
+ - `POST /api/falcon` — multipart `image` + `query` → JSON response in the
174
+ same shape as `mac_tensor`'s `/api/falcon` endpoint, so it's a drop-in
175
+ replacement for any client that talks to mac_tensor (including the
176
+ data-label-factory `web/canvas/live` UI).
177
+ - `GET /refs/<filename>` — serves your reference images as a static mount
178
+ so a browser UI can display "this is what the model thinks you're showing".
179
+ - `GET /health` — JSON status with index size, projection state, request
180
+ counter, etc.
181
+
182
+ **Point the live tracker UI at it:**
183
+
184
+ ```bash
185
+ # In web/.env.local
186
+ FALCON_URL=http://localhost:8500/api/falcon
187
+ ```
188
+
189
+ Then open `http://localhost:3030/canvas/live` and click **Use Webcam**.
190
+
191
+ ---
192
+
193
+ ## Concrete examples
194
+
195
+ ### Trading cards (the original use case)
196
+
197
+ ```bash
198
+ # Step 1: download reference images via the gather command
199
+ data_label_factory gather --project projects/yugioh.yaml --max-per-query 1
200
+ # → produces ~/data-label-factory/yugioh/positive/cards/*.jpg
201
+
202
+ # Step 2-5: build, verify, train, serve
203
+ python3 -m data_label_factory.identify index --refs ~/data-label-factory/yugioh/positive/cards/ --out yugioh.npz
204
+ python3 -m data_label_factory.identify verify --index yugioh.npz
205
+ python3 -m data_label_factory.identify train --refs ~/data-label-factory/yugioh/positive/cards/ --out yugioh_proj.pt
206
+ python3 -m data_label_factory.identify index --refs ~/data-label-factory/yugioh/positive/cards/ --out yugioh.npz --projection yugioh_proj.pt
207
+ python3 -m data_label_factory.identify serve --index yugioh.npz --refs ~/data-label-factory/yugioh/positive/cards/ --projection yugioh_proj.pt
208
+ ```
209
+
210
+ ### Album covers ("Shazam for vinyl")
211
+
212
+ ```bash
213
+ # Get reference images from MusicBrainz cover art archive (one per album)
214
+ mkdir ~/my-vinyl
215
+ # ... drop in jpgs named after the album ...
216
+ python3 -m data_label_factory.identify index --refs ~/my-vinyl --out vinyl.npz
217
+ python3 -m data_label_factory.identify serve --index vinyl.npz --refs ~/my-vinyl
218
+ # Hold up a record sleeve to your webcam → get the album back
219
+ ```
220
+
221
+ ### Industrial parts catalog ("which screw is this?")
222
+
223
+ ```bash
224
+ mkdir ~/parts
225
+ # Drop in one studio shot per part: m3_bolt_10mm.jpg, hex_nut_5mm.jpg, ...
226
+ python3 -m data_label_factory.identify index --refs ~/parts --out parts.npz
227
+ python3 -m data_label_factory.identify train --refs ~/parts --out parts_proj.pt --epochs 20
228
+ python3 -m data_label_factory.identify index --refs ~/parts --out parts.npz --projection parts_proj.pt
229
+ python3 -m data_label_factory.identify serve --index parts.npz --refs ~/parts --projection parts_proj.pt
230
+ ```
231
+
232
+ ### Plant species ID
233
+
234
+ Same loop with reference images keyed by species name. You don't need PlantNet's
235
+ scale to be useful for **your** garden.
236
+
237
+ ---
238
+
239
+ ## The data-label-factory loop, applied to retrieval
240
+
241
+ ```
242
+ gather (web search / API / phone photos)
243
+
244
+ label (the filename IS the label — naming convention does the work)
245
+
246
+ verify (data_label_factory.identify verify — self-test)
247
+
248
+ train (optional) (data_label_factory.identify train — fine-tune projection head)
249
+
250
+ deploy (data_label_factory.identify serve — HTTP endpoint)
251
+
252
+ review (data-label-factory web/canvas/live — sees this server as a falcon backend)
253
+ ```
254
+
255
+ Same loop, same conventions, just **retrieval instead of detection**.
256
+
257
+ ---
258
+
259
+ ## Files in this folder
260
+
261
+ ```
262
+ identify/
263
+ ├── __init__.py package marker + lazy import
264
+ ├── __main__.py enables `python3 -m data_label_factory.identify <cmd>`
265
+ ├── cli.py argparse dispatcher for the four commands
266
+ ├── train.py Step 4: contrastive fine-tune
267
+ ├── build_index.py Step 2: CLIP encode + save index
268
+ ├── verify_index.py Step 3: self-test + margin analysis
269
+ ├── serve.py Step 5: FastAPI HTTP endpoint
270
+ └── README.md you are here
271
+ ```
272
+
273
+ ---
274
+
275
+ ## Why this is **lazy-loaded** (not always-on)
276
+
277
+ The base `data_label_factory` package only depends on `pyyaml`, `pillow`, and
278
+ `requests` — kept lightweight so users running the labeling pipeline don't
279
+ pay any ML import cost. The `identify` subpackage adds heavy deps (torch,
280
+ clip, ultralytics, fastapi) and is only loaded when explicitly invoked via
281
+ `python3 -m data_label_factory.identify <command>`. Same opt-in pattern as
282
+ the `runpod` subpackage.
283
+
284
+ Install the heavy deps with the optional extra:
285
+
286
+ ```bash
287
+ pip install -e ".[identify]"
288
+ ```
data_label_factory/identify/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """data_label_factory.identify — open-set retrieval / card identification.
2
+
3
+ The companion to the bbox-grounding pipeline. Where the main `data_label_factory`
4
+ CLI produces COCO labels for training a closed-set detector, this subpackage
5
+ produces a CLIP-based retrieval index for open-set identification.
6
+
7
+ Use it when you have a known set of N reference images (cards, products, parts,
8
+ artworks, etc) and want to identify which one a webcam frame is showing — with
9
+ a single reference image per class and zero training time.
10
+
11
+ Pipeline stages
12
+ ---------------
13
+
14
+ references/ ← user provides 1 image per class
15
+
16
+ train_identifier ← optional: contrastive fine-tune of a small
17
+ ↓ projection head on top of frozen CLIP
18
+ clip_proj.pt
19
+
20
+ build_index ← CLIP-encode each reference + apply projection
21
+ ↓ head, save embeddings to .npz
22
+ card_index.npz
23
+
24
+ verify_index ← self-test: each reference should match itself
25
+ ↓ as top-1 with high cosine similarity
26
+ serve_identifier ← HTTP server (mac_tensor /api/falcon-shaped)
27
+ ↓ that the live tracker UI talks to
28
+ /api/falcon
29
+
30
+ This is the data-label-factory loop applied to retrieval instead of detection.
31
+
32
+ CLI
33
+ ---
34
+
35
+ python3 -m data_label_factory.identify train --refs limit-over-pack/ --out clip_proj.pt
36
+ python3 -m data_label_factory.identify index --refs limit-over-pack/ --proj clip_proj.pt --out card_index.npz
37
+ python3 -m data_label_factory.identify verify --index card_index.npz --refs limit-over-pack/
38
+ python3 -m data_label_factory.identify serve --index card_index.npz --refs limit-over-pack/ --port 8500
39
+ """
40
+
41
+ __all__ = ["main"]
42
+
43
+
44
+ def main():
45
+ """Lazy entry point — only imports the heavy ML deps if user invokes the CLI."""
46
+ from .cli import main as _main
47
+ return _main()
data_label_factory/identify/__main__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Enables `python3 -m data_label_factory.identify <command>`."""
2
+
3
+ from .cli import main
4
+ import sys
5
+
6
+ if __name__ == "__main__":
7
+ sys.exit(main())
data_label_factory/identify/build_index.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build a CLIP retrieval index from a folder of reference images.
2
+
3
+ Each image's filename becomes its display label (with set-code prefixes
4
+ stripped and rarity suffixes preserved). Optionally applies a fine-tuned
5
+ projection head produced by `data_label_factory.identify train`.
6
+
7
+ The output `.npz` contains three arrays:
8
+ embeddings (N, D) L2-normalized
9
+ names (N,) cleaned display names
10
+ filenames (N,) original filenames (so the server can serve refs)
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import os
17
+ import re
18
+ import sys
19
+ from pathlib import Path
20
+
21
+
22
+ def main(argv: list[str] | None = None) -> int:
23
+ parser = argparse.ArgumentParser(
24
+ prog="data_label_factory.identify index",
25
+ description=(
26
+ "Encode every image in a reference folder with CLIP (optionally "
27
+ "passed through a fine-tuned projection head) and save the embeddings "
28
+ "as a searchable .npz index."
29
+ ),
30
+ )
31
+ parser.add_argument("--refs", required=True, help="Folder of reference images")
32
+ parser.add_argument("--out", default="card_index.npz", help="Output .npz path")
33
+ parser.add_argument("--projection", default=None,
34
+ help="Optional fine-tuned projection head .pt (from `train`)")
35
+ parser.add_argument("--clip-model", default="ViT-B/32")
36
+ args = parser.parse_args(argv)
37
+
38
+ try:
39
+ import numpy as np
40
+ import torch
41
+ import torch.nn as nn
42
+ import torch.nn.functional as F
43
+ from PIL import Image
44
+ import clip
45
+ except ImportError as e:
46
+ raise SystemExit(
47
+ f"missing dependency: {e}\n"
48
+ "install with:\n"
49
+ " pip install torch pillow git+https://github.com/openai/CLIP.git"
50
+ )
51
+
52
+ DEVICE = ("mps" if torch.backends.mps.is_available()
53
+ else "cuda" if torch.cuda.is_available() else "cpu")
54
+ print(f"[index] device={DEVICE}", flush=True)
55
+
56
+ refs = Path(args.refs)
57
+ if not refs.is_dir():
58
+ raise SystemExit(f"refs folder not found: {refs}")
59
+
60
+ print(f"[index] loading CLIP {args.clip_model} …", flush=True)
61
+ model, preprocess = clip.load(args.clip_model, device=DEVICE)
62
+ model.eval()
63
+
64
+ head = None
65
+ if args.projection and os.path.exists(args.projection):
66
+ print(f"[index] loading projection head from {args.projection}", flush=True)
67
+
68
+ class ProjectionHead(nn.Module):
69
+ def __init__(self, in_dim=512, hidden=512, out_dim=256):
70
+ super().__init__()
71
+ self.net = nn.Sequential(
72
+ nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden, out_dim))
73
+
74
+ def forward(self, x):
75
+ return F.normalize(self.net(x), dim=-1)
76
+
77
+ ckpt = torch.load(args.projection, map_location=DEVICE)
78
+ sd = ckpt.get("state_dict", ckpt)
79
+ head = ProjectionHead(
80
+ in_dim=ckpt.get("in_dim", 512),
81
+ hidden=ckpt.get("hidden", 512),
82
+ out_dim=ckpt.get("out_dim", 256),
83
+ ).to(DEVICE)
84
+ head.load_state_dict(sd)
85
+ head.eval()
86
+ print(f"[index] out_dim={ckpt.get('out_dim', 256)}", flush=True)
87
+
88
+ files = sorted(f for f in os.listdir(refs)
89
+ if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp")))
90
+ if not files:
91
+ raise SystemExit(f"no images in {refs}")
92
+ print(f"[index] {len(files)} reference images", flush=True)
93
+
94
+ embeddings, names, filenames = [], [], []
95
+ for i, fname in enumerate(files, 1):
96
+ path = refs / fname
97
+ # Strip set-code prefix (e.g. "LOCH-JP001_") and clean up underscores
98
+ stem = os.path.splitext(fname)[0]
99
+ stem = re.sub(r"^[A-Z]+-[A-Z]+\d+_", "", stem)
100
+ name = stem.replace("_", " ").title()
101
+ # "Pharaoh S Servant" → "Pharaoh's Servant"
102
+ name = re.sub(r"\b(\w+) S\b", r"\1's", name)
103
+ try:
104
+ img = Image.open(path).convert("RGB")
105
+ except Exception as e:
106
+ print(f"[index] skip {fname}: {e}", flush=True)
107
+ continue
108
+ with torch.no_grad():
109
+ tensor = preprocess(img).unsqueeze(0).to(DEVICE)
110
+ feat = model.encode_image(tensor).float()
111
+ feat = feat / feat.norm(dim=-1, keepdim=True)
112
+ if head is not None:
113
+ feat = head(feat)
114
+ embeddings.append(feat.cpu().numpy()[0].astype(np.float32))
115
+ names.append(name)
116
+ filenames.append(fname)
117
+ if i % 25 == 0 or i == len(files):
118
+ print(f"[index] [{i:3d}/{len(files)}] {name[:50]}", flush=True)
119
+
120
+ emb = np.stack(embeddings, axis=0)
121
+ out = Path(args.out)
122
+ out.parent.mkdir(parents=True, exist_ok=True)
123
+ np.savez(out,
124
+ embeddings=emb,
125
+ names=np.array(names, dtype=object),
126
+ filenames=np.array(filenames, dtype=object))
127
+ print(f"\n[index] ✓ wrote {out} ({emb.shape[0]} refs × {emb.shape[1]} dims, "
128
+ f"{out.stat().st_size / 1024:.1f} KB)", flush=True)
129
+ return 0
130
+
131
+
132
+ if __name__ == "__main__":
133
+ sys.exit(main())
data_label_factory/identify/cli.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI dispatcher for `python3 -m data_label_factory.identify <command>`.
2
+
3
+ Subcommands:
4
+ index → build_index.main
5
+ verify → verify_index.main
6
+ train → train.main
7
+ serve → serve.main
8
+
9
+ Each is lazy-loaded so users only pay the import cost for the command they
10
+ actually invoke.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import sys
16
+
17
+
18
+ HELP = """\
19
+ data_label_factory.identify — open-set image retrieval
20
+
21
+ usage: python3 -m data_label_factory.identify <command> [options]
22
+
23
+ commands:
24
+ index Build a CLIP retrieval index from a folder of reference images
25
+ verify Self-test an index and report margin / confusable pairs
26
+ train Contrastive fine-tune a projection head (improves accuracy)
27
+ serve Run an HTTP server that exposes the index as /api/falcon
28
+
29
+ run any command with --help for its options. The full blueprint is in
30
+ data_label_factory/identify/README.md.
31
+ """
32
+
33
+
34
+ def main(argv: list[str] | None = None) -> int:
35
+ args = list(argv) if argv is not None else sys.argv[1:]
36
+ if not args or args[0] in ("-h", "--help", "help"):
37
+ print(HELP)
38
+ return 0
39
+
40
+ cmd = args[0]
41
+ rest = args[1:]
42
+
43
+ if cmd == "index":
44
+ from .build_index import main as _main
45
+ return _main(rest)
46
+ if cmd == "verify":
47
+ from .verify_index import main as _main
48
+ return _main(rest)
49
+ if cmd == "train":
50
+ from .train import main as _main
51
+ return _main(rest)
52
+ if cmd == "serve":
53
+ from .serve import main as _main
54
+ return _main(rest)
55
+
56
+ print(f"unknown command: {cmd}\n", file=sys.stderr)
57
+ print(HELP, file=sys.stderr)
58
+ return 1
59
+
60
+
61
+ if __name__ == "__main__":
62
+ sys.exit(main())
data_label_factory/identify/serve.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HTTP server that serves a CLIP retrieval index over a mac_tensor-shaped
2
+ /api/falcon endpoint. Compatible with the existing data-label-factory web UI
3
+ (`web/canvas/live`) without any client changes.
4
+
5
+ Architecture per request:
6
+ 1. YOLOv8-World detects "card-shaped" regions (open-vocab "card" class)
7
+ 2. Each region is cropped, CLIP-encoded, optionally projection-headed
8
+ 3. Cosine-matched against the loaded index → top match per region
9
+ 4. If YOLO finds nothing, falls back to classifying the center crop
10
+ 5. Returns mac_tensor /api/falcon-shaped JSON so the existing proxy works
11
+
12
+ Also serves the reference images at /refs/<filename> so the live tracker UI
13
+ can show "this is what the model thinks you're holding" alongside the webcam.
14
+
15
+ Configurable via env vars:
16
+ CARD_INDEX path to .npz from `index` (default: card_index.npz)
17
+ CLIP_PROJ optional path to projection head .pt (default: clip_proj.pt)
18
+ REFS_DIR folder of reference images served at /refs/ (default: limit-over-pack)
19
+ YOLO_CONF YOLO confidence threshold (default: 0.05)
20
+ CLIP_SIM_THRESHOLD minimum cosine to accept a match (default: 0.70)
21
+ CLIP_MARGIN_THRESHOLD minimum top1−top2 cosine gap to be 'confident' (default: 0.04)
22
+ PORT HTTP port (default: 8500)
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import io
29
+ import os
30
+ import sys
31
+ import threading
32
+ import time
33
+ import traceback
34
+ from typing import Any
35
+
36
+
37
+ def main(argv: list[str] | None = None) -> int:
38
+ parser = argparse.ArgumentParser(
39
+ prog="data_label_factory.identify serve",
40
+ description=(
41
+ "Run a mac_tensor-shaped /api/falcon HTTP server that serves a CLIP "
42
+ "retrieval index. Compatible with the existing data-label-factory "
43
+ "web/canvas/live UI without client changes."
44
+ ),
45
+ )
46
+ parser.add_argument("--index", default=os.environ.get("CARD_INDEX", "card_index.npz"),
47
+ help="Path to the .npz index built by `index`")
48
+ parser.add_argument("--projection", default=os.environ.get("CLIP_PROJ", "clip_proj.pt"),
49
+ help="Path to the .pt projection head from `train` (optional)")
50
+ parser.add_argument("--refs", default=os.environ.get("REFS_DIR", "limit-over-pack"),
51
+ help="Folder of reference images, served at /refs/")
52
+ parser.add_argument("--port", type=int, default=int(os.environ.get("PORT", "8500")))
53
+ parser.add_argument("--host", default="0.0.0.0")
54
+ parser.add_argument("--sim-threshold", type=float,
55
+ default=float(os.environ.get("CLIP_SIM_THRESHOLD", "0.70")))
56
+ parser.add_argument("--margin-threshold", type=float,
57
+ default=float(os.environ.get("CLIP_MARGIN_THRESHOLD", "0.04")))
58
+ parser.add_argument("--yolo-conf", type=float,
59
+ default=float(os.environ.get("YOLO_CONF", "0.05")))
60
+ parser.add_argument("--no-yolo", action="store_true",
61
+ help="Skip YOLO detection entirely; always classify the center crop")
62
+ args = parser.parse_args(argv)
63
+
64
+ try:
65
+ import numpy as np
66
+ import torch
67
+ import torch.nn as nn
68
+ import torch.nn.functional as F
69
+ from PIL import Image
70
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
71
+ from fastapi.responses import JSONResponse, PlainTextResponse
72
+ from fastapi.middleware.cors import CORSMiddleware
73
+ from fastapi.staticfiles import StaticFiles
74
+ import uvicorn
75
+ import clip
76
+ except ImportError as e:
77
+ raise SystemExit(
78
+ f"missing dependency: {e}\n"
79
+ "install with:\n"
80
+ " pip install fastapi 'uvicorn[standard]' python-multipart pillow torch "
81
+ "git+https://github.com/openai/CLIP.git\n"
82
+ " (and `pip install ultralytics` if you want YOLO detection)"
83
+ )
84
+
85
+ DEVICE = ("mps" if torch.backends.mps.is_available()
86
+ else "cuda" if torch.cuda.is_available() else "cpu")
87
+ print(f"[serve] device={DEVICE}", flush=True)
88
+
89
+ # ---------- load CLIP + projection head ----------
90
+ print(f"[serve] loading CLIP ViT-B/32 …", flush=True)
91
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=DEVICE)
92
+ clip_model.eval()
93
+
94
+ proj_head = None
95
+ if args.projection and os.path.exists(args.projection):
96
+ class ProjectionHead(nn.Module):
97
+ def __init__(self, in_dim=512, hidden=512, out_dim=256):
98
+ super().__init__()
99
+ self.net = nn.Sequential(
100
+ nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden, out_dim))
101
+
102
+ def forward(self, x):
103
+ return F.normalize(self.net(x), dim=-1)
104
+
105
+ ckpt = torch.load(args.projection, map_location=DEVICE)
106
+ sd = ckpt.get("state_dict", ckpt)
107
+ proj_head = ProjectionHead(
108
+ in_dim=ckpt.get("in_dim", 512),
109
+ hidden=ckpt.get("hidden", 512),
110
+ out_dim=ckpt.get("out_dim", 256),
111
+ ).to(DEVICE)
112
+ proj_head.load_state_dict(sd)
113
+ proj_head.eval()
114
+ print(f"[serve] loaded fine-tuned projection head from {args.projection}", flush=True)
115
+ else:
116
+ print(f"[serve] no projection head — using raw CLIP features", flush=True)
117
+
118
+ # ---------- load index ----------
119
+ if not os.path.exists(args.index):
120
+ raise SystemExit(f"index not found: {args.index}\n"
121
+ f"build one with: data_label_factory.identify index --refs <folder>")
122
+ npz = np.load(args.index, allow_pickle=True)
123
+ CARD_EMB = npz["embeddings"]
124
+ CARD_NAMES = list(npz["names"])
125
+ CARD_FILES = list(npz["filenames"]) if "filenames" in npz.files else ["" for _ in CARD_NAMES]
126
+ print(f"[serve] loaded {len(CARD_NAMES)} refs from {args.index}", flush=True)
127
+
128
+ # ---------- optional YOLO for multi-card detection ----------
129
+ yolo = None
130
+ if not args.no_yolo:
131
+ try:
132
+ from ultralytics import YOLO
133
+ print(f"[serve] loading YOLOv8s-world for card detection …", flush=True)
134
+ yolo = YOLO("yolov8s-world.pt")
135
+ yolo.set_classes(["card", "trading card", "playing card"])
136
+ print(f"[serve] yolo ready (device={yolo.device})", flush=True)
137
+ except Exception as e:
138
+ print(f"[serve] YOLO unavailable ({e}); using whole-frame mode only", flush=True)
139
+
140
+ # ---------- helpers ----------
141
+ RARITY_SUFFIXES = {
142
+ "pscr": "PScR", "scr": "ScR", "ur": "UR", "sr": "SR",
143
+ "op": "OP", "utr": "UtR", "cr": "CR", "ea": "EA", "gmr": "GMR",
144
+ }
145
+
146
+ def _split_name_and_rarity(full: str) -> tuple[str, str]:
147
+ parts = full.split()
148
+ if parts and parts[-1].lower() in RARITY_SUFFIXES:
149
+ return " ".join(parts[:-1]), RARITY_SUFFIXES[parts[-1].lower()]
150
+ return full, ""
151
+
152
+ def _embed_pil(pil) -> "np.ndarray":
153
+ with torch.no_grad():
154
+ t = clip_preprocess(pil).unsqueeze(0).to(DEVICE)
155
+ f = clip_model.encode_image(t).float()
156
+ f = f / f.norm(dim=-1, keepdim=True)
157
+ if proj_head is not None:
158
+ f = proj_head(f)
159
+ return f.cpu().numpy()[0].astype(np.float32)
160
+
161
+ def _identify_crop(crop, top_k: int = 3) -> dict:
162
+ q = _embed_pil(crop)
163
+ sims = CARD_EMB @ q
164
+ order = np.argsort(-sims)[:top_k]
165
+ top = [{
166
+ "name": CARD_NAMES[i],
167
+ "filename": CARD_FILES[i] if i < len(CARD_FILES) else "",
168
+ "score": float(sims[i]),
169
+ } for i in order]
170
+ margin = top[0]["score"] - top[1]["score"] if len(top) > 1 else top[0]["score"]
171
+ return {
172
+ "top": top,
173
+ "best_name": top[0]["name"],
174
+ "best_filename": top[0]["filename"],
175
+ "best_score": top[0]["score"],
176
+ "margin": float(margin),
177
+ "confident": float(margin) >= args.margin_threshold,
178
+ }
179
+
180
+ # ---------- FastAPI app ----------
181
+ app = FastAPI(title="data-label-factory identify worker")
182
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
183
+
184
+ if os.path.isdir(args.refs):
185
+ app.mount("/refs", StaticFiles(directory=args.refs), name="refs")
186
+ print(f"[serve] mounted /refs/ from {args.refs}", flush=True)
187
+
188
+ _state = {"requests": 0, "last_query": ""}
189
+ _lock = threading.Lock()
190
+
191
+ @app.get("/")
192
+ def root() -> PlainTextResponse:
193
+ return PlainTextResponse(
194
+ f"data-label-factory identify · index={len(CARD_NAMES)} refs · "
195
+ f"requests={_state['requests']} · last_query={_state['last_query']!r}\n"
196
+ f"POST /api/falcon (multipart: image, query) — mac_tensor-shaped\n"
197
+ f"GET /refs/<filename> — reference images\n"
198
+ f"GET /health — JSON status\n"
199
+ )
200
+
201
+ @app.get("/health")
202
+ def health() -> dict:
203
+ return {
204
+ "phase": "ready",
205
+ "model_loaded": True,
206
+ "device": DEVICE,
207
+ "index_size": len(CARD_NAMES),
208
+ "has_projection": proj_head is not None,
209
+ "has_yolo": yolo is not None,
210
+ "sim_threshold": args.sim_threshold,
211
+ "margin_threshold": args.margin_threshold,
212
+ "requests_served": _state["requests"],
213
+ "last_query": _state["last_query"],
214
+ }
215
+
216
+ @app.post("/api/falcon")
217
+ async def falcon(image: UploadFile = File(...), query: str = Form(...)) -> JSONResponse:
218
+ t0 = time.time()
219
+ try:
220
+ pil = Image.open(io.BytesIO(await image.read())).convert("RGB")
221
+ except Exception as e:
222
+ raise HTTPException(400, f"bad image: {e}")
223
+ W, H = pil.size
224
+
225
+ with _lock:
226
+ _state["last_query"] = query
227
+
228
+ masks: list[dict] = []
229
+
230
+ # 1. YOLO multi-card detection (if available)
231
+ if yolo is not None:
232
+ try:
233
+ results = yolo.predict(pil, conf=args.yolo_conf, iou=0.5, verbose=False)
234
+ if results:
235
+ boxes = getattr(results[0], "boxes", None)
236
+ if boxes is not None and boxes.xyxy is not None:
237
+ for x1, y1, x2, y2 in boxes.xyxy.cpu().numpy().tolist():
238
+ bx1, by1 = max(0, int(x1)), max(0, int(y1))
239
+ bx2, by2 = min(W, int(x2)), min(H, int(y2))
240
+ if bx2 - bx1 < 20 or by2 - by1 < 20:
241
+ continue
242
+ crop = pil.crop((bx1, by1, bx2, by2))
243
+ info = _identify_crop(crop)
244
+ if info["best_score"] < args.sim_threshold:
245
+ continue
246
+ name, rarity = _split_name_and_rarity(info["best_name"])
247
+ display = f"{name} ({rarity})" if rarity else name
248
+ if not info["confident"]:
249
+ display = f"{display}?"
250
+ masks.append({
251
+ "bbox_norm": {
252
+ "x1": float(x1) / W, "y1": float(y1) / H,
253
+ "x2": float(x2) / W, "y2": float(y2) / H,
254
+ },
255
+ "area_fraction": float((x2 - x1) * (y2 - y1)) / max(W * H, 1),
256
+ "label": display,
257
+ "name": name,
258
+ "rarity": rarity,
259
+ "score": info["best_score"],
260
+ "top_k": info["top"],
261
+ "margin": info["margin"],
262
+ "confident": info["confident"],
263
+ "ref_filename": info["best_filename"],
264
+ })
265
+ except Exception as e:
266
+ print(f"[serve] yolo error: {e}", flush=True)
267
+
268
+ # 2. Whole-frame fallback (single-card workflow)
269
+ if not masks:
270
+ cx1, cy1 = int(W * 0.10), int(H * 0.05)
271
+ cx2, cy2 = int(W * 0.90), int(H * 0.95)
272
+ center = pil.crop((cx1, cy1, cx2, cy2))
273
+ info = _identify_crop(center)
274
+ if info["best_score"] >= args.sim_threshold and info["confident"]:
275
+ name, rarity = _split_name_and_rarity(info["best_name"])
276
+ display = f"{name} ({rarity})" if rarity else name
277
+ masks.append({
278
+ "bbox_norm": {
279
+ "x1": cx1 / W, "y1": cy1 / H, "x2": cx2 / W, "y2": cy2 / H,
280
+ },
281
+ "area_fraction": (cx2 - cx1) * (cy2 - cy1) / max(W * H, 1),
282
+ "label": display,
283
+ "name": name,
284
+ "rarity": rarity,
285
+ "score": info["best_score"],
286
+ "top_k": info["top"],
287
+ "margin": info["margin"],
288
+ "confident": True,
289
+ "ref_filename": info["best_filename"],
290
+ })
291
+
292
+ with _lock:
293
+ _state["requests"] += 1
294
+
295
+ return JSONResponse(content={
296
+ "image_size": [W, H],
297
+ "count": len(masks),
298
+ "masks": masks,
299
+ "query": query,
300
+ "elapsed_seconds": round(time.time() - t0, 3),
301
+ })
302
+
303
+ print(f"\n[serve] listening on http://{args.host}:{args.port}", flush=True)
304
+ uvicorn.run(app, host=args.host, port=args.port, log_level="warning")
305
+ return 0
306
+
307
+
308
+ if __name__ == "__main__":
309
+ sys.exit(main())
data_label_factory/identify/train.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contrastive fine-tune of a small projection head on top of frozen CLIP.
2
+
3
+ Wraps the proven training loop that took the 150-card index from cosine
4
+ margin 0.074 → 0.36 (5x improvement). The CLIP backbone stays frozen, only
5
+ a tiny ~400k-param projection MLP is trained, so this runs on Apple Silicon
6
+ MPS in ~5 minutes for a 150-class set.
7
+
8
+ Data generation: K cards × M augmentations per batch (default 16 × 4 = 64).
9
+ Loss: SupCon (Khosla et al. 2020).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import os
16
+ import random
17
+ import sys
18
+ import time
19
+ from pathlib import Path
20
+
21
+ # Lazy heavy imports — only triggered when this module is actually invoked.
22
+
23
+ DEFAULT_PALETTE_HINT = "ViT-B/32 + 512→512→256 projection"
24
+
25
+
26
+ def main(argv: list[str] | None = None) -> int:
27
+ parser = argparse.ArgumentParser(
28
+ prog="data_label_factory.identify train",
29
+ description=(
30
+ "Contrastive fine-tune a small projection head on top of frozen CLIP. "
31
+ "Use this when off-the-shelf CLIP retrieval is too noisy for your "
32
+ f"reference set. Architecture: {DEFAULT_PALETTE_HINT}."
33
+ ),
34
+ )
35
+ parser.add_argument("--refs", required=True,
36
+ help="Folder of reference images (1 per class). Filenames become labels.")
37
+ parser.add_argument("--out", default="clip_proj.pt",
38
+ help="Output path for the trained projection head .pt")
39
+ parser.add_argument("--epochs", type=int, default=12)
40
+ parser.add_argument("--k-cards", type=int, default=16,
41
+ help="Distinct classes per training batch.")
42
+ parser.add_argument("--m-augs", type=int, default=4,
43
+ help="Augmentations per class per batch.")
44
+ parser.add_argument("--steps-per-epoch", type=int, default=80)
45
+ parser.add_argument("--lr", type=float, default=5e-4)
46
+ parser.add_argument("--temperature", type=float, default=0.1)
47
+ parser.add_argument("--clip-model", default="ViT-B/32")
48
+ args = parser.parse_args(argv)
49
+
50
+ try:
51
+ import numpy as np
52
+ import torch
53
+ import torch.nn as nn
54
+ import torch.nn.functional as F
55
+ from torch.utils.data import Dataset, DataLoader, Sampler
56
+ from torchvision import transforms
57
+ from PIL import Image
58
+ import clip
59
+ except ImportError as e:
60
+ raise SystemExit(
61
+ f"missing dependency: {e}\n"
62
+ "install with:\n"
63
+ " pip install torch torchvision pillow git+https://github.com/openai/CLIP.git"
64
+ )
65
+
66
+ DEVICE = ("mps" if torch.backends.mps.is_available()
67
+ else "cuda" if torch.cuda.is_available() else "cpu")
68
+ print(f"[train] device={DEVICE}", flush=True)
69
+
70
+ refs = Path(args.refs)
71
+ if not refs.is_dir():
72
+ raise SystemExit(f"refs folder not found: {refs}")
73
+
74
+ print(f"[train] loading CLIP {args.clip_model} …", flush=True)
75
+ clip_model, clip_preprocess = clip.load(args.clip_model, device=DEVICE)
76
+ clip_model.eval()
77
+ for p in clip_model.parameters():
78
+ p.requires_grad = False
79
+
80
+ class CardDataset(Dataset):
81
+ def __init__(self, folder: Path, augs_per_card: int):
82
+ files = sorted(f for f in os.listdir(folder)
83
+ if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp")))
84
+ if not files:
85
+ raise SystemExit(f"no images in {folder}")
86
+ self.images = []
87
+ for f in files:
88
+ self.images.append(Image.open(folder / f).convert("RGB"))
89
+ self.aug_per_card = augs_per_card
90
+ self.aug = transforms.Compose([
91
+ transforms.RandomResizedCrop(256, scale=(0.6, 1.0), ratio=(0.7, 1.4)),
92
+ transforms.RandomRotation(20),
93
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05),
94
+ transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
95
+ transforms.RandomApply([transforms.GaussianBlur(5, sigma=(0.1, 2.0))], p=0.3),
96
+ transforms.RandomGrayscale(p=0.05),
97
+ ])
98
+
99
+ def __len__(self):
100
+ return len(self.images) * self.aug_per_card
101
+
102
+ def __getitem__(self, idx):
103
+ card_idx = idx % len(self.images)
104
+ return clip_preprocess(self.aug(self.images[card_idx])), card_idx
105
+
106
+ class KCardsSampler(Sampler):
107
+ def __init__(self, dataset, k_cards: int, m_augs: int, steps: int):
108
+ self.n = len(dataset.images)
109
+ self.k = k_cards
110
+ self.m = m_augs
111
+ self.steps = steps
112
+
113
+ def __iter__(self):
114
+ for _ in range(self.steps):
115
+ cards = random.sample(range(self.n), self.k)
116
+ batch = []
117
+ for c in cards:
118
+ for _ in range(self.m):
119
+ batch.append(c)
120
+ random.shuffle(batch)
121
+ yield from batch
122
+
123
+ def __len__(self):
124
+ return self.steps * self.k * self.m
125
+
126
+ class ProjectionHead(nn.Module):
127
+ def __init__(self, in_dim=512, hidden=512, out_dim=256):
128
+ super().__init__()
129
+ self.net = nn.Sequential(
130
+ nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden, out_dim))
131
+
132
+ def forward(self, x):
133
+ return F.normalize(self.net(x), dim=-1)
134
+
135
+ def supcon_loss(features: "torch.Tensor", labels: "torch.Tensor", temperature: float) -> "torch.Tensor":
136
+ device = features.device
137
+ bsz = features.size(0)
138
+ labels = labels.contiguous().view(-1, 1)
139
+ mask = torch.eq(labels, labels.T).float().to(device)
140
+ sim = torch.matmul(features, features.T) / temperature
141
+ sim_max, _ = torch.max(sim, dim=1, keepdim=True)
142
+ logits = sim - sim_max.detach()
143
+ self_mask = torch.scatter(
144
+ torch.ones_like(mask), 1,
145
+ torch.arange(bsz, device=device).view(-1, 1), 0)
146
+ pos_mask = mask * self_mask
147
+ exp_logits = torch.exp(logits) * self_mask
148
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
149
+ pos_count = pos_mask.sum(1)
150
+ pos_count = torch.where(pos_count == 0, torch.ones_like(pos_count), pos_count)
151
+ return -((pos_mask * log_prob).sum(1) / pos_count).mean()
152
+
153
+ print(f"[train] dataset from {refs}", flush=True)
154
+ ds = CardDataset(refs, augs_per_card=args.m_augs)
155
+ print(f"[train] {len(ds.images)} reference images", flush=True)
156
+
157
+ sampler = KCardsSampler(ds, k_cards=args.k_cards, m_augs=args.m_augs,
158
+ steps=args.steps_per_epoch)
159
+ loader = DataLoader(ds, batch_size=args.k_cards * args.m_augs,
160
+ sampler=sampler, num_workers=0, drop_last=True)
161
+
162
+ head = ProjectionHead(in_dim=512, hidden=512, out_dim=256).to(DEVICE)
163
+ print(f"[train] projection head: {sum(p.numel() for p in head.parameters()):,} params", flush=True)
164
+
165
+ optimizer = torch.optim.AdamW(head.parameters(), lr=args.lr, weight_decay=1e-4)
166
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
167
+ optimizer, T_max=args.epochs * args.steps_per_epoch)
168
+
169
+ print(f"\n[train] {args.epochs} epochs · {args.steps_per_epoch} steps · "
170
+ f"batch={args.k_cards * args.m_augs} (K={args.k_cards}×M={args.m_augs})\n", flush=True)
171
+ t0 = time.time()
172
+ for epoch in range(args.epochs):
173
+ head.train()
174
+ epoch_loss, n_batches = 0.0, 0
175
+ for imgs, labels in loader:
176
+ imgs = imgs.to(DEVICE)
177
+ labels = labels.to(DEVICE)
178
+ with torch.no_grad():
179
+ feats = clip_model.encode_image(imgs).float()
180
+ feats = feats / feats.norm(dim=-1, keepdim=True)
181
+ proj = head(feats)
182
+ loss = supcon_loss(proj, labels, temperature=args.temperature)
183
+ optimizer.zero_grad()
184
+ loss.backward()
185
+ optimizer.step()
186
+ scheduler.step()
187
+ epoch_loss += loss.item()
188
+ n_batches += 1
189
+ print(f"[train] epoch {epoch + 1:2d}/{args.epochs} loss={epoch_loss / max(n_batches, 1):.4f} "
190
+ f"({time.time() - t0:.0f}s)", flush=True)
191
+
192
+ out_path = Path(args.out)
193
+ out_path.parent.mkdir(parents=True, exist_ok=True)
194
+ torch.save({
195
+ "state_dict": head.state_dict(),
196
+ "in_dim": 512, "hidden": 512, "out_dim": 256,
197
+ "model": args.clip_model,
198
+ "epochs": args.epochs, "k_cards": args.k_cards, "m_augs": args.m_augs,
199
+ "ref_count": len(ds.images),
200
+ }, out_path)
201
+ print(f"\n[train] ✓ saved {out_path} ({out_path.stat().st_size / 1024:.0f} KB)", flush=True)
202
+ return 0
203
+
204
+
205
+ if __name__ == "__main__":
206
+ sys.exit(main())
data_label_factory/identify/verify_index.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Self-test the index for top-1 accuracy + report confusable pairs.
2
+
3
+ For each reference image, embed it and verify that its top-1 match in the
4
+ index is itself. Reports the cosine margin between correct and best-wrong
5
+ matches — the most useful number for predicting live accuracy.
6
+
7
+ Run this immediately after building an index to catch bad data BEFORE
8
+ deploying it to a live tracker.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import sys
15
+ from pathlib import Path
16
+
17
+
18
+ def main(argv: list[str] | None = None) -> int:
19
+ parser = argparse.ArgumentParser(
20
+ prog="data_label_factory.identify verify",
21
+ description=(
22
+ "Self-test a built index. Each reference image should match itself "
23
+ "as top-1; a wide cosine margin between correct and best-wrong matches "
24
+ "is the strongest predictor of live accuracy."
25
+ ),
26
+ )
27
+ parser.add_argument("--index", default="card_index.npz", help="Path to .npz from `index`")
28
+ parser.add_argument("--top-confusables", type=int, default=5,
29
+ help="How many of the most-confusable pairs to print")
30
+ args = parser.parse_args(argv)
31
+
32
+ try:
33
+ import numpy as np
34
+ except ImportError:
35
+ raise SystemExit("numpy required: pip install numpy")
36
+
37
+ npz = np.load(args.index, allow_pickle=True)
38
+ EMB = npz["embeddings"]
39
+ NAMES = list(npz["names"])
40
+ print(f"[verify] index: {len(NAMES)} refs × {EMB.shape[1]} dims")
41
+
42
+ # Pairwise similarity matrix (small N, fits in memory)
43
+ sims = EMB @ EMB.T
44
+ np.fill_diagonal(sims, -1.0)
45
+
46
+ # Top confusable pairs
47
+ print(f"\nMost-confusable pairs (highest cosine sim between DIFFERENT refs):")
48
+ flat_idx = np.argpartition(sims.flatten(), -args.top_confusables * 2)[-args.top_confusables * 2:]
49
+ seen = set()
50
+ shown = 0
51
+ for fi in flat_idx[np.argsort(sims.flatten()[flat_idx])[::-1]]:
52
+ i, j = divmod(int(fi), len(NAMES))
53
+ if (j, i) in seen:
54
+ continue
55
+ seen.add((i, j))
56
+ print(f" {sims[i, j]:.3f} {NAMES[i][:42]} ↔ {NAMES[j][:42]}")
57
+ shown += 1
58
+ if shown >= args.top_confusables:
59
+ break
60
+
61
+ # Restore diagonal for self-test
62
+ np.fill_diagonal(sims, 1.0)
63
+
64
+ # Self-identity test: each ref's top-1 in EMB @ EMB[i] should be i
65
+ correct = 0
66
+ failures = []
67
+ for i in range(len(NAMES)):
68
+ row = EMB @ EMB[i]
69
+ top = int(np.argmax(row))
70
+ if top == i:
71
+ correct += 1
72
+ else:
73
+ failures.append((NAMES[i], NAMES[top], float(row[top]), float(row[i])))
74
+
75
+ pct = correct / len(NAMES) * 100
76
+ print(f"\nself-identity test: {correct}/{len(NAMES)} = {pct:.1f}% top-1 self-id")
77
+ for name, mismatch, score_wrong, score_right in failures[:10]:
78
+ print(f" ✗ {name[:42]} → matched {mismatch[:42]} "
79
+ f"(top={score_wrong:.3f} vs self={score_right:.3f})")
80
+
81
+ # Margin analysis: gap between "I matched myself" and "best wrong match"
82
+ correct_scores, best_wrong_scores = [], []
83
+ for i in range(len(NAMES)):
84
+ row = EMB @ EMB[i]
85
+ correct_scores.append(row[i])
86
+ row[i] = -1
87
+ best_wrong_scores.append(row.max())
88
+
89
+ median_correct = float(np.median(correct_scores))
90
+ median_wrong = float(np.median(best_wrong_scores))
91
+ margin = median_correct - median_wrong
92
+ print(f"\nthreshold analysis:")
93
+ print(f" median correct match score: {median_correct:.3f}")
94
+ print(f" median best-wrong-match score: {median_wrong:.3f}")
95
+ print(f" gap (margin): {margin:.3f}")
96
+ suggested = max(0.5, median_wrong + 0.05)
97
+ print(f" → recommended SIM_THRESHOLD = {suggested:.2f}")
98
+ return 0 if pct >= 99 else 1
99
+
100
+
101
+ if __name__ == "__main__":
102
+ sys.exit(main())
data_label_factory/runpod/pod_falcon_server.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ pod_falcon_server.py — single-file Falcon Perception HTTP server for a RunPod pod.
4
+
5
+ Designed to be curl-installed via the pod's dockerStartCmd. Two phases:
6
+
7
+ 1. **Boot phase (instant):** start a FastAPI server on 0.0.0.0:8000 with two
8
+ endpoints: `/health` (always responds) and `/api/falcon` (returns 503 until
9
+ the model is loaded). A background thread starts heavy installation.
10
+
11
+ 2. **Install phase (~5-10 min):** install pip deps, install falcon-perception
12
+ with --no-deps, download the Falcon model from Hugging Face, instantiate
13
+ the inference engine. As soon as it's ready, `/api/falcon` flips to live.
14
+
15
+ The endpoint shape MATCHES mac_tensor's /api/falcon so the existing
16
+ `web/app/api/falcon-frame/route.ts` proxy works against it without changes:
17
+
18
+ Request: multipart/form-data with `image` (file) + `query` (string)
19
+ Response: {
20
+ "image_size": [w, h],
21
+ "count": int,
22
+ "masks": [{"bbox_norm": {x1, y1, x2, y2}, "area_fraction": float}, ...],
23
+ "elapsed_seconds": float,
24
+ "cold_start": bool
25
+ }
26
+
27
+ You can poll progress via:
28
+ curl https://<pod-id>-8000.proxy.runpod.net/health
29
+ """
30
+
31
+ import io
32
+ import os
33
+ import subprocess
34
+ import sys
35
+ import threading
36
+ import time
37
+ import traceback
38
+ from typing import Any
39
+
40
+ # ============================================================
41
+ # Boot phase — keep imports minimal so the server starts FAST
42
+ # ============================================================
43
+
44
+ print("[server] starting boot phase…", flush=True)
45
+ BOOT_T0 = time.time()
46
+
47
+ # Install fastapi + uvicorn synchronously since we need them for the boot server.
48
+ # These are tiny (~30 MB) so this takes ~10 seconds.
49
+ def _pip(args, retries=3):
50
+ for attempt in range(retries):
51
+ r = subprocess.run(
52
+ [sys.executable, "-m", "pip", "install", "--quiet", "--no-cache-dir"] + args,
53
+ capture_output=True, text=True,
54
+ )
55
+ if r.returncode == 0:
56
+ return True
57
+ print(f"[pip] attempt {attempt+1} failed: {r.stderr[:300]}", flush=True)
58
+ time.sleep(3)
59
+ return False
60
+
61
+ print("[server] installing fastapi + uvicorn + multipart…", flush=True)
62
+ if not _pip(["fastapi==0.115.6", "uvicorn[standard]==0.32.1", "python-multipart==0.0.20", "pillow"]):
63
+ print("[server] CRITICAL: failed to install fastapi", flush=True)
64
+ sys.exit(1)
65
+
66
+ # Now we can import fastapi
67
+ from fastapi import FastAPI, Request, UploadFile, File, Form, HTTPException
68
+ from fastapi.responses import JSONResponse, PlainTextResponse
69
+ from PIL import Image
70
+
71
+ app = FastAPI(title="data-label-factory falcon worker")
72
+
73
+ STATE: dict[str, Any] = {
74
+ "phase": "boot",
75
+ "boot_started": BOOT_T0,
76
+ "model_loaded": False,
77
+ "install_log": [],
78
+ "error": None,
79
+ "model_id": None,
80
+ "device": None,
81
+ "load_seconds": None,
82
+ "cold_start_used": False,
83
+ "requests_served": 0,
84
+ }
85
+
86
+ def _log(msg: str) -> None:
87
+ line = f"[{time.time() - BOOT_T0:6.1f}s] {msg}"
88
+ print(line, flush=True)
89
+ STATE["install_log"].append(line)
90
+ # cap log to last 200 lines
91
+ if len(STATE["install_log"]) > 200:
92
+ STATE["install_log"] = STATE["install_log"][-200:]
93
+
94
+
95
+ # ============================================================
96
+ # Endpoints
97
+ # ============================================================
98
+
99
+ @app.get("/")
100
+ def root() -> PlainTextResponse:
101
+ return PlainTextResponse(
102
+ f"data-label-factory falcon worker · phase={STATE['phase']} "
103
+ f"loaded={STATE['model_loaded']} requests={STATE['requests_served']}\n"
104
+ f"see /health for full status, POST /api/falcon for inference\n"
105
+ )
106
+
107
+
108
+ @app.get("/health")
109
+ def health() -> dict:
110
+ return {
111
+ "phase": STATE["phase"],
112
+ "model_loaded": STATE["model_loaded"],
113
+ "model_id": STATE.get("model_id"),
114
+ "device": STATE.get("device"),
115
+ "load_seconds": STATE.get("load_seconds"),
116
+ "uptime_seconds": round(time.time() - BOOT_T0, 1),
117
+ "requests_served": STATE["requests_served"],
118
+ "error": STATE["error"],
119
+ "recent_log": STATE["install_log"][-30:],
120
+ }
121
+
122
+
123
+ @app.post("/api/falcon")
124
+ async def falcon(image: UploadFile = File(...), query: str = Form(...)) -> JSONResponse:
125
+ if not STATE["model_loaded"]:
126
+ return JSONResponse(
127
+ status_code=503,
128
+ content={
129
+ "error": "model not loaded yet",
130
+ "phase": STATE["phase"],
131
+ "loaded": False,
132
+ "uptime": round(time.time() - BOOT_T0, 1),
133
+ "recent": STATE["install_log"][-5:],
134
+ },
135
+ )
136
+
137
+ t0 = time.time()
138
+ img_bytes = await image.read()
139
+ try:
140
+ pil = Image.open(io.BytesIO(img_bytes)).convert("RGB")
141
+ except Exception as e:
142
+ return JSONResponse(status_code=400, content={"error": f"bad image: {e}"})
143
+
144
+ cold = not STATE["cold_start_used"]
145
+ STATE["cold_start_used"] = True
146
+
147
+ try:
148
+ result = _run_inference(pil, query)
149
+ except Exception as e:
150
+ return JSONResponse(
151
+ status_code=500,
152
+ content={"error": str(e), "trace": traceback.format_exc().splitlines()[-6:]},
153
+ )
154
+
155
+ STATE["requests_served"] += 1
156
+ return JSONResponse(content={
157
+ "image_size": [pil.width, pil.height],
158
+ "count": result["count"],
159
+ "masks": result["masks"],
160
+ "query": query,
161
+ "elapsed_seconds": round(time.time() - t0, 3),
162
+ "cold_start": cold,
163
+ })
164
+
165
+
166
+ # ============================================================
167
+ # Heavy install + inference (loaded in background thread)
168
+ # ============================================================
169
+
170
+ _engine = None
171
+ _tokenizer = None
172
+ _image_processor = None
173
+ _model = None
174
+ _model_args = None
175
+ _sampling_params = None
176
+ _torch = None # cached torch module reference
177
+
178
+
179
+ def _run_inference(pil_img: "Image.Image", query: str) -> dict:
180
+ """Single-image Falcon Perception forward pass.
181
+
182
+ Uses task='segmentation' per the prior session learning ('detection mode
183
+ returns empty bboxes'). Extracts bboxes from each segmentation mask via
184
+ pycocotools mask decoding.
185
+ """
186
+ if _engine is None:
187
+ raise RuntimeError("model not loaded")
188
+
189
+ from falcon_perception import build_prompt_for_task # type: ignore
190
+ from falcon_perception.paged_inference import Sequence # type: ignore
191
+
192
+ W, H = pil_img.size
193
+ task = "segmentation" if getattr(_model_args, "do_segmentation", False) else "detection"
194
+ prompt = build_prompt_for_task(query, task)
195
+
196
+ sequences = [Sequence(
197
+ text=prompt,
198
+ image=pil_img,
199
+ min_image_size=256,
200
+ max_image_size=1024,
201
+ task=task,
202
+ )]
203
+ with _torch.inference_mode():
204
+ _engine.generate(
205
+ sequences,
206
+ sampling_params=_sampling_params,
207
+ use_tqdm=False,
208
+ print_stats=False,
209
+ )
210
+ seq = sequences[0]
211
+ aux = seq.output_aux
212
+
213
+ masks_out: list[dict] = []
214
+
215
+ # Path A: detection mode (bboxes_raw is populated)
216
+ bboxes_raw = getattr(aux, "bboxes_raw", None)
217
+ if bboxes_raw:
218
+ try:
219
+ from falcon_perception.visualization_utils import pair_bbox_entries # type: ignore
220
+ pairs = pair_bbox_entries(bboxes_raw)
221
+ for entry in pairs:
222
+ if hasattr(entry, "_asdict"):
223
+ d = entry._asdict()
224
+ elif isinstance(entry, dict):
225
+ d = entry
226
+ else:
227
+ vals = list(entry)
228
+ if len(vals) < 5:
229
+ continue
230
+ d = {"x1": vals[1], "y1": vals[2], "x2": vals[3], "y2": vals[4]}
231
+ x1 = float(d.get("x1", 0)); y1 = float(d.get("y1", 0))
232
+ x2 = float(d.get("x2", 0)); y2 = float(d.get("y2", 0))
233
+ masks_out.append({
234
+ "bbox_norm": {
235
+ "x1": x1 / W if x1 > 1.5 else x1,
236
+ "y1": y1 / H if y1 > 1.5 else y1,
237
+ "x2": x2 / W if x2 > 1.5 else x2,
238
+ "y2": y2 / H if y2 > 1.5 else y2,
239
+ },
240
+ "area_fraction": ((x2 - x1) * (y2 - y1)) / (W * H) if W and H else 0.0,
241
+ })
242
+ except Exception as e:
243
+ _log(f"pair_bbox_entries failed: {e}")
244
+
245
+ # Path B: segmentation mode (masks_rle is populated)
246
+ if not masks_out:
247
+ masks_rle = getattr(aux, "masks_rle", None) or []
248
+ for m in masks_rle:
249
+ try:
250
+ # Try to extract a bbox from the mask. Multiple possible shapes.
251
+ if isinstance(m, dict) and "bbox" in m:
252
+ bb = m["bbox"] # could be [x,y,w,h] or [x1,y1,x2,y2]
253
+ if len(bb) == 4:
254
+ x1, y1 = float(bb[0]), float(bb[1])
255
+ # Heuristic: if last two are smaller than first two, treat as w/h
256
+ if bb[2] < bb[0] or bb[3] < bb[1]:
257
+ x2, y2 = x1 + float(bb[2]), y1 + float(bb[3])
258
+ else:
259
+ x2, y2 = float(bb[2]), float(bb[3])
260
+ masks_out.append({
261
+ "bbox_norm": {
262
+ "x1": x1 / W if x1 > 1.5 else x1,
263
+ "y1": y1 / H if y1 > 1.5 else y1,
264
+ "x2": x2 / W if x2 > 1.5 else x2,
265
+ "y2": y2 / H if y2 > 1.5 else y2,
266
+ },
267
+ "area_fraction": float(m.get("area", (x2 - x1) * (y2 - y1))) / max(W * H, 1),
268
+ })
269
+ continue
270
+ # Fall back to decoding the RLE mask via pycocotools
271
+ from pycocotools import mask as maskUtils # type: ignore
272
+ import numpy as np # type: ignore
273
+ rle = m if isinstance(m, dict) else {"counts": m, "size": [H, W]}
274
+ if "size" not in rle:
275
+ rle["size"] = [H, W]
276
+ if isinstance(rle.get("counts"), str):
277
+ rle["counts"] = rle["counts"].encode()
278
+ decoded = maskUtils.decode(rle)
279
+ if decoded is None or decoded.size == 0:
280
+ continue
281
+ ys, xs = np.where(decoded > 0)
282
+ if xs.size == 0:
283
+ continue
284
+ x1, y1 = int(xs.min()), int(ys.min())
285
+ x2, y2 = int(xs.max()), int(ys.max())
286
+ masks_out.append({
287
+ "bbox_norm": {"x1": x1 / W, "y1": y1 / H, "x2": x2 / W, "y2": y2 / H},
288
+ "area_fraction": float(decoded.sum()) / max(W * H, 1),
289
+ })
290
+ except Exception as e:
291
+ _log(f"mask parse failed: {e}")
292
+
293
+ return {"count": len(masks_out), "masks": masks_out}
294
+
295
+
296
+ def _heavy_install_and_load() -> None:
297
+ """Background thread: install heavy deps, download model, load inference engine."""
298
+ global _engine, _tokenizer, _image_processor, _model, _model_args, _sampling_params, _torch
299
+ try:
300
+ STATE["phase"] = "installing pip"
301
+ _log("installing transformers + qwen-vl-utils + accelerate + safetensors …")
302
+ if not _pip([
303
+ "transformers>=4.49.0,<5",
304
+ "qwen-vl-utils>=0.0.10",
305
+ "accelerate>=0.34",
306
+ "safetensors>=0.4",
307
+ "einops>=0.8.0",
308
+ "opencv-python>=4.10.0",
309
+ "scipy>=1.13.0",
310
+ "pycocotools>=2.0.7",
311
+ "tyro>=0.8.0",
312
+ "huggingface_hub>=0.26",
313
+ "numpy<2", # falcon-perception is happier with numpy 1.x
314
+ ]):
315
+ raise RuntimeError("pip install of heavy deps failed")
316
+
317
+ STATE["phase"] = "installing falcon-perception"
318
+ _log("installing falcon-perception (--no-deps to preserve base torch)…")
319
+ if not _pip(["--no-deps", "falcon-perception"]):
320
+ raise RuntimeError("pip install of falcon-perception failed")
321
+
322
+ STATE["phase"] = "loading model"
323
+ _log("importing torch + falcon_perception …")
324
+ import torch as _t # type: ignore
325
+ _torch = _t
326
+
327
+ from falcon_perception import ( # type: ignore
328
+ PERCEPTION_MODEL_ID,
329
+ build_prompt_for_task,
330
+ load_and_prepare_model,
331
+ setup_torch_config,
332
+ )
333
+ from falcon_perception.data import ImageProcessor # type: ignore
334
+ from falcon_perception.paged_inference import ( # type: ignore
335
+ PagedInferenceEngine,
336
+ SamplingParams,
337
+ Sequence,
338
+ )
339
+
340
+ STATE["model_id"] = PERCEPTION_MODEL_ID
341
+ _log(f"model id: {PERCEPTION_MODEL_ID}")
342
+ _log("setting up torch …")
343
+ setup_torch_config()
344
+
345
+ _log("loading model + processor (downloads ~600 MB on first run, may take 2-5 min)…")
346
+ load_t0 = time.time()
347
+ _model, _tokenizer, _model_args = load_and_prepare_model(
348
+ hf_model_id=PERCEPTION_MODEL_ID,
349
+ hf_revision="main",
350
+ hf_local_dir=None,
351
+ device=None, # let model pick CUDA
352
+ dtype="bfloat16",
353
+ compile=False, # skip torch.compile to keep load fast (~30s vs 60s+)
354
+ )
355
+ _log("instantiating ImageProcessor + PagedInferenceEngine…")
356
+ _image_processor = ImageProcessor(patch_size=16, merge_size=1)
357
+ _engine = PagedInferenceEngine(
358
+ _model, _tokenizer, _image_processor,
359
+ max_batch_size=1,
360
+ max_seq_length=8192,
361
+ n_pages=128,
362
+ page_size=128,
363
+ prefill_length_limit=8192,
364
+ enable_hr_cache=False,
365
+ capture_cudagraph=False,
366
+ )
367
+ _sampling_params = SamplingParams(
368
+ stop_token_ids=[_tokenizer.eos_token_id, _tokenizer.end_of_query_token_id],
369
+ )
370
+
371
+ STATE["load_seconds"] = round(time.time() - load_t0, 1)
372
+ STATE["device"] = "cuda" if _torch.cuda.is_available() else "cpu"
373
+
374
+ # Quick warmup so the first real request isn't 30s slower than steady state
375
+ _log("warmup pass on a dummy image…")
376
+ warmup_img = Image.new("RGB", (256, 256), color=(128, 128, 128))
377
+ warmup_seqs = [Sequence(
378
+ text=build_prompt_for_task("anything", "detection"),
379
+ image=warmup_img,
380
+ min_image_size=256,
381
+ max_image_size=512,
382
+ task="detection",
383
+ )]
384
+ with _torch.inference_mode():
385
+ _engine.generate(warmup_seqs, sampling_params=_sampling_params,
386
+ use_tqdm=False, print_stats=False)
387
+
388
+ STATE["phase"] = "ready"
389
+ STATE["model_loaded"] = True
390
+ _log(f"✓ READY in {time.time() - BOOT_T0:.1f}s total")
391
+
392
+ except Exception as e:
393
+ STATE["phase"] = "FAILED"
394
+ STATE["error"] = str(e)
395
+ _log(f"FATAL: {e}")
396
+ _log(traceback.format_exc())
397
+
398
+
399
+ # Kick off the install thread now (server hasn't started yet but the import is done)
400
+ threading.Thread(target=_heavy_install_and_load, daemon=True).start()
401
+
402
+
403
+ # ============================================================
404
+ # Run the server
405
+ # ============================================================
406
+
407
+ if __name__ == "__main__":
408
+ import uvicorn
409
+ port = int(os.environ.get("PORT", "8000"))
410
+ _log(f"booting uvicorn on 0.0.0.0:{port}")
411
+ uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
pyproject.toml CHANGED
@@ -57,6 +57,18 @@ runpod = [
57
  "datasets>=3.0",
58
  "pyarrow>=17.0",
59
  ]
 
 
 
 
 
 
 
 
 
 
 
 
60
  dev = [
61
  "pytest>=7.0",
62
  "ruff>=0.5.0",
@@ -73,8 +85,13 @@ data_label_factory = "data_label_factory.cli:main"
73
  data-label-factory = "data_label_factory.cli:main"
74
 
75
  [tool.setuptools]
76
- packages = ["data_label_factory", "data_label_factory.runpod"]
 
 
 
 
77
 
78
  [tool.setuptools.package-data]
79
  data_label_factory = ["*.py"]
80
  "data_label_factory.runpod" = ["*.py", "*.md", "Dockerfile", "requirements-pod.txt"]
 
 
57
  "datasets>=3.0",
58
  "pyarrow>=17.0",
59
  ]
60
+ identify = [
61
+ # Open-set CLIP retrieval: build/verify/train/serve a card-style index
62
+ "torch>=2.1",
63
+ "torchvision>=0.16",
64
+ "numpy>=1.24,<2",
65
+ "fastapi>=0.115",
66
+ "uvicorn[standard]>=0.32",
67
+ "python-multipart>=0.0.20",
68
+ "pillow>=10.0",
69
+ "ultralytics>=8.3",
70
+ "clip @ git+https://github.com/openai/CLIP.git",
71
+ ]
72
  dev = [
73
  "pytest>=7.0",
74
  "ruff>=0.5.0",
 
85
  data-label-factory = "data_label_factory.cli:main"
86
 
87
  [tool.setuptools]
88
+ packages = [
89
+ "data_label_factory",
90
+ "data_label_factory.runpod",
91
+ "data_label_factory.identify",
92
+ ]
93
 
94
  [tool.setuptools.package-data]
95
  data_label_factory = ["*.py"]
96
  "data_label_factory.runpod" = ["*.py", "*.md", "Dockerfile", "requirements-pod.txt"]
97
+ "data_label_factory.identify" = ["*.py", "*.md"]
web/app/api/falcon-frame/route.ts CHANGED
@@ -24,6 +24,9 @@ type Bbox = {
24
  y2: number;
25
  score: number;
26
  label: string;
 
 
 
27
  };
28
 
29
  const FALCON_URL = process.env.FALCON_URL ?? "http://localhost:8500/api/falcon";
@@ -96,17 +99,30 @@ export async function POST(req: NextRequest) {
96
  upstreamCount = data.count ?? 0;
97
  imgW = data.image_size?.[0] ?? data.width ?? 0;
98
  imgH = data.image_size?.[1] ?? data.height ?? 0;
99
- // mac_tensor returns masks: [{bbox_norm: {x1,y1,x2,y2}, slot, area_fraction}]
 
 
 
 
 
 
100
  for (const m of data.masks ?? []) {
101
  const bn = m.bbox_norm ?? {};
102
  if (bn.x1 == null) continue;
 
 
 
 
103
  bboxes.push({
104
  x1: bn.x1,
105
  y1: bn.y1,
106
  x2: bn.x2,
107
  y2: bn.y2,
108
- score: m.area_fraction ?? 1,
109
- label: query,
 
 
 
110
  });
111
  }
112
  }
 
24
  y2: number;
25
  score: number;
26
  label: string;
27
+ ref_url?: string; // URL to a reference image (for the live tracker sidebar)
28
+ margin?: number;
29
+ confident?: boolean;
30
  };
31
 
32
  const FALCON_URL = process.env.FALCON_URL ?? "http://localhost:8500/api/falcon";
 
99
  upstreamCount = data.count ?? 0;
100
  imgW = data.image_size?.[0] ?? data.width ?? 0;
101
  imgH = data.image_size?.[1] ?? data.height ?? 0;
102
+ // mac_tensor returns masks: [{bbox_norm:{x1,y1,x2,y2}, area_fraction, label?, score?, ref_filename?}]
103
+ // The label/score/ref_filename are present in identify-mode (CLIP retrieval).
104
+ // Construct an absolute ref_url from the upstream base + filename so the
105
+ // browser can render the reference card image directly without an extra
106
+ // proxy hop.
107
+ const upstreamBase = new URL(FALCON_URL);
108
+ upstreamBase.pathname = "/refs/";
109
  for (const m of data.masks ?? []) {
110
  const bn = m.bbox_norm ?? {};
111
  if (bn.x1 == null) continue;
112
+ let ref_url: string | undefined = undefined;
113
+ if (typeof m.ref_filename === "string" && m.ref_filename) {
114
+ ref_url = upstreamBase.toString() + m.ref_filename;
115
+ }
116
  bboxes.push({
117
  x1: bn.x1,
118
  y1: bn.y1,
119
  x2: bn.x2,
120
  y2: bn.y2,
121
+ score: typeof m.score === "number" ? m.score : (m.area_fraction ?? 1),
122
+ label: typeof m.label === "string" && m.label ? m.label : query,
123
+ ref_url,
124
+ margin: typeof m.margin === "number" ? m.margin : undefined,
125
+ confident: typeof m.confident === "boolean" ? m.confident : undefined,
126
  });
127
  }
128
  }
web/app/canvas/live/page.tsx CHANGED
@@ -25,7 +25,14 @@ type SourceMode = "idle" | "file" | "webcam";
25
  type FalconResponse = {
26
  ok: boolean;
27
  count?: number;
28
- bboxes?: Array<{ x1: number; y1: number; x2: number; y2: number; score: number; label: string }>;
 
 
 
 
 
 
 
29
  image_size?: { w: number; h: number };
30
  elapsed_ms?: number;
31
  upstream?: string;
@@ -43,6 +50,11 @@ export default function LiveTrackerPage() {
43
  const streamRef = useRef<MediaStream | null>(null);
44
  const objectUrlRef = useRef<string | null>(null);
45
 
 
 
 
 
 
46
  const [mode, setMode] = useState<SourceMode>("idle");
47
  const [query, setQuery] = useState<string>("fiber optic drone");
48
  const [activeTracks, setActiveTracks] = useState<Track[]>([]);
@@ -98,7 +110,7 @@ export default function LiveTrackerPage() {
98
 
99
  const form = new FormData();
100
  form.set("image", blob, "frame.jpg");
101
- form.set("query", query);
102
 
103
  let resp: FalconResponse;
104
  try {
@@ -133,6 +145,7 @@ export default function LiveTrackerPage() {
133
  y2: isNormalized ? b.y2 * H : b.y2,
134
  score: b.score,
135
  label: b.label,
 
136
  };
137
  });
138
 
@@ -347,7 +360,10 @@ export default function LiveTrackerPage() {
347
  <input
348
  type="text"
349
  value={query}
350
- onChange={(e) => setQuery(e.target.value)}
 
 
 
351
  className="px-3 py-1.5 rounded-md bg-zinc-800 border border-zinc-700 text-zinc-100 text-sm w-64 focus:outline-none focus:ring-2 focus:ring-cyan-500"
352
  placeholder="e.g. fiber optic drone"
353
  />
@@ -394,19 +410,47 @@ export default function LiveTrackerPage() {
394
  {activeTracks.length === 0 ? (
395
  <div className="text-sm text-zinc-500">none yet</div>
396
  ) : (
397
- <div className="space-y-2">
398
  {activeTracks.map((t) => (
399
- <div key={t.id} className="flex items-center justify-between text-sm">
400
- <div className="flex items-center gap-2 min-w-0">
401
- <span
402
- className="h-3 w-3 rounded-sm border border-zinc-600 flex-shrink-0"
403
- style={{ backgroundColor: t.color }}
404
- />
405
- <span className="text-zinc-100 truncate">#{t.id} {t.label}</span>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  </div>
407
- <span className="text-zinc-400 text-xs whitespace-nowrap ml-2">
408
- {t.hits}/{t.age}f
409
- </span>
410
  </div>
411
  ))}
412
  </div>
 
25
  type FalconResponse = {
26
  ok: boolean;
27
  count?: number;
28
+ bboxes?: Array<{
29
+ x1: number; y1: number; x2: number; y2: number;
30
+ score: number;
31
+ label: string;
32
+ ref_url?: string;
33
+ margin?: number;
34
+ confident?: boolean;
35
+ }>;
36
  image_size?: { w: number; h: number };
37
  elapsed_ms?: number;
38
  upstream?: string;
 
50
  const streamRef = useRef<MediaStream | null>(null);
51
  const objectUrlRef = useRef<string | null>(null);
52
 
53
+ // Live query ref — read from inside the sendNextFrame loop instead of
54
+ // closure-captured `query` to avoid stale-closure bugs when the user
55
+ // types a new query mid-stream.
56
+ const queryRef = useRef<string>("fiber optic drone");
57
+
58
  const [mode, setMode] = useState<SourceMode>("idle");
59
  const [query, setQuery] = useState<string>("fiber optic drone");
60
  const [activeTracks, setActiveTracks] = useState<Track[]>([]);
 
110
 
111
  const form = new FormData();
112
  form.set("image", blob, "frame.jpg");
113
+ form.set("query", queryRef.current);
114
 
115
  let resp: FalconResponse;
116
  try {
 
145
  y2: isNormalized ? b.y2 * H : b.y2,
146
  score: b.score,
147
  label: b.label,
148
+ ref_url: b.ref_url,
149
  };
150
  });
151
 
 
360
  <input
361
  type="text"
362
  value={query}
363
+ onChange={(e) => {
364
+ setQuery(e.target.value);
365
+ queryRef.current = e.target.value;
366
+ }}
367
  className="px-3 py-1.5 rounded-md bg-zinc-800 border border-zinc-700 text-zinc-100 text-sm w-64 focus:outline-none focus:ring-2 focus:ring-cyan-500"
368
  placeholder="e.g. fiber optic drone"
369
  />
 
410
  {activeTracks.length === 0 ? (
411
  <div className="text-sm text-zinc-500">none yet</div>
412
  ) : (
413
+ <div className="space-y-3">
414
  {activeTracks.map((t) => (
415
+ <div
416
+ key={t.id}
417
+ className="rounded-md border border-zinc-800 bg-zinc-950 p-2"
418
+ >
419
+ <div className="flex items-start gap-3">
420
+ {/* Reference card image (if backend provided one) */}
421
+ {t.ref_url ? (
422
+ /* eslint-disable-next-line @next/next/no-img-element */
423
+ <img
424
+ src={t.ref_url}
425
+ alt={t.label}
426
+ className="w-16 h-auto rounded border-2 flex-shrink-0"
427
+ style={{ borderColor: t.color }}
428
+ />
429
+ ) : (
430
+ <div
431
+ className="w-16 h-22 rounded border-2 flex-shrink-0 flex items-center justify-center text-xs text-zinc-600"
432
+ style={{ borderColor: t.color }}
433
+ >
434
+ no ref
435
+ </div>
436
+ )}
437
+ <div className="flex-1 min-w-0">
438
+ <div className="flex items-center gap-1.5">
439
+ <span
440
+ className="h-2 w-2 rounded-sm flex-shrink-0"
441
+ style={{ backgroundColor: t.color }}
442
+ />
443
+ <span className="text-xs text-zinc-500 font-mono">#{t.id}</span>
444
+ </div>
445
+ <div className="text-sm text-zinc-100 leading-tight mt-1 break-words">
446
+ {t.label}
447
+ </div>
448
+ <div className="text-xs text-zinc-500 mt-1.5 font-mono">
449
+ {typeof t.score === "number" ? `score ${t.score.toFixed(2)} · ` : ""}
450
+ seen {t.hits}/{t.age}f
451
+ </div>
452
+ </div>
453
  </div>
 
 
 
454
  </div>
455
  ))}
456
  </div>
web/lib/iou-tracker.ts CHANGED
@@ -12,6 +12,7 @@ export type Detection = {
12
  y2: number;
13
  score?: number;
14
  label?: string;
 
15
  };
16
 
17
  export type Track = Detection & {
@@ -103,6 +104,7 @@ export class IoUTracker {
103
  t.x1 = d.x1; t.y1 = d.y1; t.x2 = d.x2; t.y2 = d.y2;
104
  t.score = d.score ?? t.score;
105
  t.label = d.label ?? t.label;
 
106
  t.age += 1;
107
  t.hits += 1;
108
  t.framesSinceSeen = 0;
 
12
  y2: number;
13
  score?: number;
14
  label?: string;
15
+ ref_url?: string;
16
  };
17
 
18
  export type Track = Detection & {
 
104
  t.x1 = d.x1; t.y1 = d.y1; t.x2 = d.x2; t.y2 = d.y2;
105
  t.score = d.score ?? t.score;
106
  t.label = d.label ?? t.label;
107
+ t.ref_url = d.ref_url ?? t.ref_url;
108
  t.age += 1;
109
  t.hits += 1;
110
  t.framesSinceSeen = 0;