feat(identify): open-set image retrieval subpackage
Browse files- CLAUDE.md +24 -0
- README.md +35 -0
- data_label_factory/identify/README.md +288 -0
- data_label_factory/identify/__init__.py +47 -0
- data_label_factory/identify/__main__.py +7 -0
- data_label_factory/identify/build_index.py +133 -0
- data_label_factory/identify/cli.py +62 -0
- data_label_factory/identify/serve.py +309 -0
- data_label_factory/identify/train.py +206 -0
- data_label_factory/identify/verify_index.py +102 -0
- data_label_factory/runpod/pod_falcon_server.py +411 -0
- pyproject.toml +18 -1
- web/app/api/falcon-frame/route.ts +19 -3
- web/app/canvas/live/page.tsx +58 -14
- web/lib/iou-tracker.ts +2 -0
CLAUDE.md
CHANGED
|
@@ -190,6 +190,30 @@ QWEN_URL=http://192.168.1.244:8291 data_label_factory status
|
|
| 190 |
|
| 191 |
---
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
## Optional GPU path
|
| 194 |
|
| 195 |
If a user has more than ~10k images and wants the run to finish in minutes
|
|
|
|
| 190 |
|
| 191 |
---
|
| 192 |
|
| 193 |
+
## Optional: open-set identification (`data_label_factory.identify`)
|
| 194 |
+
|
| 195 |
+
If a user wants to **identify** which one of N known things they're holding
|
| 196 |
+
up to a webcam (rather than detect arbitrary objects), point them at the
|
| 197 |
+
identify subpackage. It's a CLIP retrieval index — needs only 1 image per
|
| 198 |
+
class, no training required.
|
| 199 |
+
|
| 200 |
+
```bash
|
| 201 |
+
pip install -e ".[identify]"
|
| 202 |
+
python3 -m data_label_factory.identify index --refs ~/my-things/ --out my.npz
|
| 203 |
+
python3 -m data_label_factory.identify verify --index my.npz
|
| 204 |
+
# (optional) python3 -m data_label_factory.identify train --refs ~/my-things/ --out my-proj.pt
|
| 205 |
+
python3 -m data_label_factory.identify serve --index my.npz --refs ~/my-things/
|
| 206 |
+
# → web/canvas/live talks to it via FALCON_URL=http://localhost:8500/api/falcon
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
The full blueprint for any image set is at
|
| 210 |
+
`data_label_factory/identify/README.md`. **This is the right tool for
|
| 211 |
+
"trading cards / products / album covers / parts catalog identification"
|
| 212 |
+
use cases. The base data_label_factory pipeline is for closed-set bbox
|
| 213 |
+
detection.**
|
| 214 |
+
|
| 215 |
+
---
|
| 216 |
+
|
| 217 |
## Optional GPU path
|
| 218 |
|
| 219 |
If a user has more than ~10k images and wants the run to finish in minutes
|
README.md
CHANGED
|
@@ -272,6 +272,41 @@ runpod is just an option.
|
|
| 272 |
|
| 273 |
---
|
| 274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
## Configuration reference
|
| 276 |
|
| 277 |
### Environment variables
|
|
|
|
| 272 |
|
| 273 |
---
|
| 274 |
|
| 275 |
+
## Optional: open-set image identification
|
| 276 |
+
|
| 277 |
+
The base pipeline produces COCO labels for training a closed-set **detector**.
|
| 278 |
+
The opt-in `data_label_factory.identify` subpackage produces a CLIP retrieval
|
| 279 |
+
**index** for open-set identification — given a known set of N reference images,
|
| 280 |
+
identify which one a webcam frame is showing. **Use it when you have 1 image
|
| 281 |
+
per class and want zero training time.**
|
| 282 |
+
|
| 283 |
+
```bash
|
| 284 |
+
pip install -e ".[identify]"
|
| 285 |
+
|
| 286 |
+
# Build an index from a folder of references
|
| 287 |
+
python3 -m data_label_factory.identify index --refs ~/my-cards/ --out my.npz
|
| 288 |
+
|
| 289 |
+
# Optional: contrastive fine-tune for fine-grained accuracy (~5 min on M4 MPS)
|
| 290 |
+
python3 -m data_label_factory.identify train --refs ~/my-cards/ --out my-proj.pt
|
| 291 |
+
python3 -m data_label_factory.identify index --refs ~/my-cards/ --out my.npz --projection my-proj.pt
|
| 292 |
+
|
| 293 |
+
# Self-test the index
|
| 294 |
+
python3 -m data_label_factory.identify verify --index my.npz
|
| 295 |
+
|
| 296 |
+
# Serve as a mac_tensor-shaped /api/falcon endpoint
|
| 297 |
+
python3 -m data_label_factory.identify serve --index my.npz --refs ~/my-cards/
|
| 298 |
+
# → web/canvas/live can hit it with FALCON_URL=http://localhost:8500/api/falcon
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
Built-in **rarity / variant detection** for free — if your filenames encode a
|
| 302 |
+
suffix like `_pscr`, `_scr`, `_ur`, the matched filename's suffix becomes a
|
| 303 |
+
separate `rarity` field on the response. See
|
| 304 |
+
[`data_label_factory/identify/README.md`](data_label_factory/identify/README.md)
|
| 305 |
+
for the full blueprint and concrete examples (trading cards, album covers,
|
| 306 |
+
industrial parts, plant species, …).
|
| 307 |
+
|
| 308 |
+
---
|
| 309 |
+
|
| 310 |
## Configuration reference
|
| 311 |
|
| 312 |
### Environment variables
|
data_label_factory/identify/README.md
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# `data_label_factory.identify` — open-set image retrieval
|
| 2 |
+
|
| 3 |
+
The companion to the main labeling pipeline. Where the base
|
| 4 |
+
`data_label_factory` produces COCO labels for training a closed-set
|
| 5 |
+
**detector**, this subpackage produces a CLIP-based **retrieval index** for
|
| 6 |
+
open-set **identification** — given a known set of N reference images,
|
| 7 |
+
identify which one a webcam frame is showing.
|
| 8 |
+
|
| 9 |
+
**Use this when:**
|
| 10 |
+
|
| 11 |
+
- You have **1 image per class** (a product catalog, a card collection, an
|
| 12 |
+
art portfolio, a parts diagram, …) and want a "what is this thing I'm
|
| 13 |
+
holding up?" tool.
|
| 14 |
+
- You want **zero training time** by default and the option to fine-tune for
|
| 15 |
+
more accuracy.
|
| 16 |
+
- You want to **add new items in seconds** by dropping a JPG in a folder
|
| 17 |
+
and re-indexing.
|
| 18 |
+
- You want **rarity / variant detection** for free — different prints of
|
| 19 |
+
the same item indexed under filenames that encode the variant.
|
| 20 |
+
|
| 21 |
+
**Use the base pipeline instead when:**
|
| 22 |
+
|
| 23 |
+
- You need to detect multiple object instances per image with bounding boxes
|
| 24 |
+
- Your objects appear in cluttered scenes and need a real detector
|
| 25 |
+
- You have many images per class and want a closed-set classifier
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## The 4-step blueprint (works for ANY image set)
|
| 30 |
+
|
| 31 |
+
This is the entire workflow. Replace `~/my-collection/` with your reference
|
| 32 |
+
folder and you're done.
|
| 33 |
+
|
| 34 |
+
### Step 0 — install (one-time, ~1 min)
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
pip install -e ".[identify]"
|
| 38 |
+
# This pulls torch, pillow, clip, fastapi, ultralytics, and uvicorn
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Step 1 — gather references (5–30 min depending on source)
|
| 42 |
+
|
| 43 |
+
You need **one image per class**. The filename becomes the label, so be
|
| 44 |
+
deliberate:
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
~/my-collection/
|
| 48 |
+
├── blue_eyes_white_dragon.jpg
|
| 49 |
+
├── dark_magician.jpg
|
| 50 |
+
├── exodia_the_forbidden_one.jpg
|
| 51 |
+
└── ...
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
**Naming rules:**
|
| 55 |
+
|
| 56 |
+
- The filename stem (minus extension) becomes the displayed label.
|
| 57 |
+
- Optional set-code prefixes are auto-stripped: `LOCH-JP001_dark_magician.jpg`
|
| 58 |
+
→ `Dark Magician`.
|
| 59 |
+
- Optional rarity suffixes are extracted as a separate field if they match
|
| 60 |
+
one of: `pscr`, `scr`, `ur`, `sr`, `op`, `utr`, `cr`, `ea`, `gmr`. Example:
|
| 61 |
+
`dark_magician_pscr.jpg` → name=`Dark Magician`, rarity=`PScR`.
|
| 62 |
+
- Underscores become spaces, then title-cased.
|
| 63 |
+
|
| 64 |
+
**Where to get reference images:**
|
| 65 |
+
|
| 66 |
+
| Domain | Source |
|
| 67 |
+
|---|---|
|
| 68 |
+
| Trading cards | ygoprodeck (Yu-Gi-Oh!), Pokémon TCG API, Scryfall (MTG), yugipedia |
|
| 69 |
+
| Products | Amazon listing main image, manufacturer site |
|
| 70 |
+
| Art / paintings | Wikimedia Commons, museum APIs |
|
| 71 |
+
| Industrial parts | Manufacturer catalog scrapes |
|
| 72 |
+
| Faces | Selfies (with permission!) |
|
| 73 |
+
| Album covers | MusicBrainz cover art archive |
|
| 74 |
+
| Movie posters | TMDB API |
|
| 75 |
+
|
| 76 |
+
**You can mix sources** — e.g. include both English and Japanese versions of
|
| 77 |
+
the same card under different filenames. The retrieval system treats them as
|
| 78 |
+
separate references but the cosine match will pick whichever is closer to
|
| 79 |
+
your live input.
|
| 80 |
+
|
| 81 |
+
### Step 2 — build the index (10 sec)
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
python3 -m data_label_factory.identify index \
|
| 85 |
+
--refs ~/my-collection/ \
|
| 86 |
+
--out my-index.npz
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
This CLIP-encodes every image and saves the embeddings to a single `.npz`
|
| 90 |
+
file (~300 KB for 150 references). On Apple Silicon MPS this is ~50 ms per
|
| 91 |
+
image — 150 images takes about 8 seconds.
|
| 92 |
+
|
| 93 |
+
**Output**: `my-index.npz` containing `embeddings`, `names`, `filenames`.
|
| 94 |
+
|
| 95 |
+
### Step 3 — verify the index (5 sec)
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
python3 -m data_label_factory.identify verify --index my-index.npz
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
Self-tests every reference: each one should match itself as the top-1
|
| 102 |
+
result. Reports:
|
| 103 |
+
|
| 104 |
+
- **Top-1 self-identification rate** (should be 100%)
|
| 105 |
+
- **Most-confusable pairs** — references with high mutual similarity
|
| 106 |
+
(visually similar items the model might confuse at runtime)
|
| 107 |
+
- **Margin analysis** — the gap between "correct match" and "best wrong
|
| 108 |
+
match" cosine scores. **This is the strongest predictor of live accuracy.**
|
| 109 |
+
|
| 110 |
+
**Margin guidelines:**
|
| 111 |
+
|
| 112 |
+
| Median margin | What it means | Action |
|
| 113 |
+
|---|---|---|
|
| 114 |
+
| **> 0.3** | Strong separation, live accuracy will be excellent | Ship it |
|
| 115 |
+
| **0.1 – 0.3** | Medium separation, expect some confusion on visually similar items | Consider Step 4 |
|
| 116 |
+
| **< 0.1** | References look too similar to off-the-shelf CLIP | **Run Step 4** (fine-tune) |
|
| 117 |
+
|
| 118 |
+
### Step 4 (OPTIONAL) — fine-tune the retrieval head (5–15 min)
|
| 119 |
+
|
| 120 |
+
If the verify output shows margin < 0.1, your domain (yugioh cards, MTG
|
| 121 |
+
cards, similar-looking product variants, …) confuses generic CLIP. Fix it
|
| 122 |
+
with a contrastive fine-tune:
|
| 123 |
+
|
| 124 |
+
```bash
|
| 125 |
+
python3 -m data_label_factory.identify train \
|
| 126 |
+
--refs ~/my-collection/ \
|
| 127 |
+
--out my-projection.pt \
|
| 128 |
+
--epochs 12
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
**What this does:**
|
| 132 |
+
|
| 133 |
+
- Loads frozen CLIP ViT-B/32
|
| 134 |
+
- Trains a small **projection head** (~400k params) on top of CLIP features
|
| 135 |
+
- Uses **K-cards-per-batch sampling** (16 distinct classes × 4 augmentations
|
| 136 |
+
= 64-image batches)
|
| 137 |
+
- Loss: **SupCon** (Khosla et al. 2020) — pulls augmentations of the same
|
| 138 |
+
class together, pushes different classes apart
|
| 139 |
+
- Augmentations: random crop, rotation ±20°, color jitter, perspective warp,
|
| 140 |
+
Gaussian blur, occasional grayscale
|
| 141 |
+
- Output: a **1.5 MB `.pt` file** containing the projection head weights
|
| 142 |
+
|
| 143 |
+
**Reference run** (150-class set, M4 Mac mini, MPS): 12 epochs in ~6 min.
|
| 144 |
+
Margin improvement: 0.07 → 0.36 (5× wider).
|
| 145 |
+
|
| 146 |
+
Then re-build the index with the projection head:
|
| 147 |
+
|
| 148 |
+
```bash
|
| 149 |
+
python3 -m data_label_factory.identify index \
|
| 150 |
+
--refs ~/my-collection/ \
|
| 151 |
+
--out my-index.npz \
|
| 152 |
+
--projection my-projection.pt
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
And re-verify to confirm the margin actually widened:
|
| 156 |
+
|
| 157 |
+
```bash
|
| 158 |
+
python3 -m data_label_factory.identify verify --index my-index.npz
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
### Step 5 — serve it as an HTTP endpoint (instant)
|
| 162 |
+
|
| 163 |
+
```bash
|
| 164 |
+
python3 -m data_label_factory.identify serve \
|
| 165 |
+
--index my-index.npz \
|
| 166 |
+
--refs ~/my-collection/ \
|
| 167 |
+
--projection my-projection.pt \
|
| 168 |
+
--port 8500
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
This starts a FastAPI server with:
|
| 172 |
+
|
| 173 |
+
- `POST /api/falcon` — multipart `image` + `query` → JSON response in the
|
| 174 |
+
same shape as `mac_tensor`'s `/api/falcon` endpoint, so it's a drop-in
|
| 175 |
+
replacement for any client that talks to mac_tensor (including the
|
| 176 |
+
data-label-factory `web/canvas/live` UI).
|
| 177 |
+
- `GET /refs/<filename>` — serves your reference images as a static mount
|
| 178 |
+
so a browser UI can display "this is what the model thinks you're showing".
|
| 179 |
+
- `GET /health` — JSON status with index size, projection state, request
|
| 180 |
+
counter, etc.
|
| 181 |
+
|
| 182 |
+
**Point the live tracker UI at it:**
|
| 183 |
+
|
| 184 |
+
```bash
|
| 185 |
+
# In web/.env.local
|
| 186 |
+
FALCON_URL=http://localhost:8500/api/falcon
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
Then open `http://localhost:3030/canvas/live` and click **Use Webcam**.
|
| 190 |
+
|
| 191 |
+
---
|
| 192 |
+
|
| 193 |
+
## Concrete examples
|
| 194 |
+
|
| 195 |
+
### Trading cards (the original use case)
|
| 196 |
+
|
| 197 |
+
```bash
|
| 198 |
+
# Step 1: download reference images via the gather command
|
| 199 |
+
data_label_factory gather --project projects/yugioh.yaml --max-per-query 1
|
| 200 |
+
# → produces ~/data-label-factory/yugioh/positive/cards/*.jpg
|
| 201 |
+
|
| 202 |
+
# Step 2-5: build, verify, train, serve
|
| 203 |
+
python3 -m data_label_factory.identify index --refs ~/data-label-factory/yugioh/positive/cards/ --out yugioh.npz
|
| 204 |
+
python3 -m data_label_factory.identify verify --index yugioh.npz
|
| 205 |
+
python3 -m data_label_factory.identify train --refs ~/data-label-factory/yugioh/positive/cards/ --out yugioh_proj.pt
|
| 206 |
+
python3 -m data_label_factory.identify index --refs ~/data-label-factory/yugioh/positive/cards/ --out yugioh.npz --projection yugioh_proj.pt
|
| 207 |
+
python3 -m data_label_factory.identify serve --index yugioh.npz --refs ~/data-label-factory/yugioh/positive/cards/ --projection yugioh_proj.pt
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
### Album covers ("Shazam for vinyl")
|
| 211 |
+
|
| 212 |
+
```bash
|
| 213 |
+
# Get reference images from MusicBrainz cover art archive (one per album)
|
| 214 |
+
mkdir ~/my-vinyl
|
| 215 |
+
# ... drop in jpgs named after the album ...
|
| 216 |
+
python3 -m data_label_factory.identify index --refs ~/my-vinyl --out vinyl.npz
|
| 217 |
+
python3 -m data_label_factory.identify serve --index vinyl.npz --refs ~/my-vinyl
|
| 218 |
+
# Hold up a record sleeve to your webcam → get the album back
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
### Industrial parts catalog ("which screw is this?")
|
| 222 |
+
|
| 223 |
+
```bash
|
| 224 |
+
mkdir ~/parts
|
| 225 |
+
# Drop in one studio shot per part: m3_bolt_10mm.jpg, hex_nut_5mm.jpg, ...
|
| 226 |
+
python3 -m data_label_factory.identify index --refs ~/parts --out parts.npz
|
| 227 |
+
python3 -m data_label_factory.identify train --refs ~/parts --out parts_proj.pt --epochs 20
|
| 228 |
+
python3 -m data_label_factory.identify index --refs ~/parts --out parts.npz --projection parts_proj.pt
|
| 229 |
+
python3 -m data_label_factory.identify serve --index parts.npz --refs ~/parts --projection parts_proj.pt
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
### Plant species ID
|
| 233 |
+
|
| 234 |
+
Same loop with reference images keyed by species name. You don't need PlantNet's
|
| 235 |
+
scale to be useful for **your** garden.
|
| 236 |
+
|
| 237 |
+
---
|
| 238 |
+
|
| 239 |
+
## The data-label-factory loop, applied to retrieval
|
| 240 |
+
|
| 241 |
+
```
|
| 242 |
+
gather (web search / API / phone photos)
|
| 243 |
+
↓
|
| 244 |
+
label (the filename IS the label — naming convention does the work)
|
| 245 |
+
↓
|
| 246 |
+
verify (data_label_factory.identify verify — self-test)
|
| 247 |
+
↓
|
| 248 |
+
train (optional) (data_label_factory.identify train — fine-tune projection head)
|
| 249 |
+
↓
|
| 250 |
+
deploy (data_label_factory.identify serve — HTTP endpoint)
|
| 251 |
+
↓
|
| 252 |
+
review (data-label-factory web/canvas/live — sees this server as a falcon backend)
|
| 253 |
+
```
|
| 254 |
+
|
| 255 |
+
Same loop, same conventions, just **retrieval instead of detection**.
|
| 256 |
+
|
| 257 |
+
---
|
| 258 |
+
|
| 259 |
+
## Files in this folder
|
| 260 |
+
|
| 261 |
+
```
|
| 262 |
+
identify/
|
| 263 |
+
├── __init__.py package marker + lazy import
|
| 264 |
+
├── __main__.py enables `python3 -m data_label_factory.identify <cmd>`
|
| 265 |
+
├── cli.py argparse dispatcher for the four commands
|
| 266 |
+
├── train.py Step 4: contrastive fine-tune
|
| 267 |
+
├── build_index.py Step 2: CLIP encode + save index
|
| 268 |
+
├── verify_index.py Step 3: self-test + margin analysis
|
| 269 |
+
├── serve.py Step 5: FastAPI HTTP endpoint
|
| 270 |
+
└── README.md you are here
|
| 271 |
+
```
|
| 272 |
+
|
| 273 |
+
---
|
| 274 |
+
|
| 275 |
+
## Why this is **lazy-loaded** (not always-on)
|
| 276 |
+
|
| 277 |
+
The base `data_label_factory` package only depends on `pyyaml`, `pillow`, and
|
| 278 |
+
`requests` — kept lightweight so users running the labeling pipeline don't
|
| 279 |
+
pay any ML import cost. The `identify` subpackage adds heavy deps (torch,
|
| 280 |
+
clip, ultralytics, fastapi) and is only loaded when explicitly invoked via
|
| 281 |
+
`python3 -m data_label_factory.identify <command>`. Same opt-in pattern as
|
| 282 |
+
the `runpod` subpackage.
|
| 283 |
+
|
| 284 |
+
Install the heavy deps with the optional extra:
|
| 285 |
+
|
| 286 |
+
```bash
|
| 287 |
+
pip install -e ".[identify]"
|
| 288 |
+
```
|
data_label_factory/identify/__init__.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""data_label_factory.identify — open-set retrieval / card identification.
|
| 2 |
+
|
| 3 |
+
The companion to the bbox-grounding pipeline. Where the main `data_label_factory`
|
| 4 |
+
CLI produces COCO labels for training a closed-set detector, this subpackage
|
| 5 |
+
produces a CLIP-based retrieval index for open-set identification.
|
| 6 |
+
|
| 7 |
+
Use it when you have a known set of N reference images (cards, products, parts,
|
| 8 |
+
artworks, etc) and want to identify which one a webcam frame is showing — with
|
| 9 |
+
a single reference image per class and zero training time.
|
| 10 |
+
|
| 11 |
+
Pipeline stages
|
| 12 |
+
---------------
|
| 13 |
+
|
| 14 |
+
references/ ← user provides 1 image per class
|
| 15 |
+
↓
|
| 16 |
+
train_identifier ← optional: contrastive fine-tune of a small
|
| 17 |
+
↓ projection head on top of frozen CLIP
|
| 18 |
+
clip_proj.pt
|
| 19 |
+
↓
|
| 20 |
+
build_index ← CLIP-encode each reference + apply projection
|
| 21 |
+
↓ head, save embeddings to .npz
|
| 22 |
+
card_index.npz
|
| 23 |
+
↓
|
| 24 |
+
verify_index ← self-test: each reference should match itself
|
| 25 |
+
↓ as top-1 with high cosine similarity
|
| 26 |
+
serve_identifier ← HTTP server (mac_tensor /api/falcon-shaped)
|
| 27 |
+
↓ that the live tracker UI talks to
|
| 28 |
+
/api/falcon
|
| 29 |
+
|
| 30 |
+
This is the data-label-factory loop applied to retrieval instead of detection.
|
| 31 |
+
|
| 32 |
+
CLI
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
python3 -m data_label_factory.identify train --refs limit-over-pack/ --out clip_proj.pt
|
| 36 |
+
python3 -m data_label_factory.identify index --refs limit-over-pack/ --proj clip_proj.pt --out card_index.npz
|
| 37 |
+
python3 -m data_label_factory.identify verify --index card_index.npz --refs limit-over-pack/
|
| 38 |
+
python3 -m data_label_factory.identify serve --index card_index.npz --refs limit-over-pack/ --port 8500
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
__all__ = ["main"]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
"""Lazy entry point — only imports the heavy ML deps if user invokes the CLI."""
|
| 46 |
+
from .cli import main as _main
|
| 47 |
+
return _main()
|
data_label_factory/identify/__main__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Enables `python3 -m data_label_factory.identify <command>`."""
|
| 2 |
+
|
| 3 |
+
from .cli import main
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
sys.exit(main())
|
data_label_factory/identify/build_index.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Build a CLIP retrieval index from a folder of reference images.
|
| 2 |
+
|
| 3 |
+
Each image's filename becomes its display label (with set-code prefixes
|
| 4 |
+
stripped and rarity suffixes preserved). Optionally applies a fine-tuned
|
| 5 |
+
projection head produced by `data_label_factory.identify train`.
|
| 6 |
+
|
| 7 |
+
The output `.npz` contains three arrays:
|
| 8 |
+
embeddings (N, D) L2-normalized
|
| 9 |
+
names (N,) cleaned display names
|
| 10 |
+
filenames (N,) original filenames (so the server can serve refs)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
import sys
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main(argv: list[str] | None = None) -> int:
|
| 23 |
+
parser = argparse.ArgumentParser(
|
| 24 |
+
prog="data_label_factory.identify index",
|
| 25 |
+
description=(
|
| 26 |
+
"Encode every image in a reference folder with CLIP (optionally "
|
| 27 |
+
"passed through a fine-tuned projection head) and save the embeddings "
|
| 28 |
+
"as a searchable .npz index."
|
| 29 |
+
),
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument("--refs", required=True, help="Folder of reference images")
|
| 32 |
+
parser.add_argument("--out", default="card_index.npz", help="Output .npz path")
|
| 33 |
+
parser.add_argument("--projection", default=None,
|
| 34 |
+
help="Optional fine-tuned projection head .pt (from `train`)")
|
| 35 |
+
parser.add_argument("--clip-model", default="ViT-B/32")
|
| 36 |
+
args = parser.parse_args(argv)
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
import numpy as np
|
| 40 |
+
import torch
|
| 41 |
+
import torch.nn as nn
|
| 42 |
+
import torch.nn.functional as F
|
| 43 |
+
from PIL import Image
|
| 44 |
+
import clip
|
| 45 |
+
except ImportError as e:
|
| 46 |
+
raise SystemExit(
|
| 47 |
+
f"missing dependency: {e}\n"
|
| 48 |
+
"install with:\n"
|
| 49 |
+
" pip install torch pillow git+https://github.com/openai/CLIP.git"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
DEVICE = ("mps" if torch.backends.mps.is_available()
|
| 53 |
+
else "cuda" if torch.cuda.is_available() else "cpu")
|
| 54 |
+
print(f"[index] device={DEVICE}", flush=True)
|
| 55 |
+
|
| 56 |
+
refs = Path(args.refs)
|
| 57 |
+
if not refs.is_dir():
|
| 58 |
+
raise SystemExit(f"refs folder not found: {refs}")
|
| 59 |
+
|
| 60 |
+
print(f"[index] loading CLIP {args.clip_model} …", flush=True)
|
| 61 |
+
model, preprocess = clip.load(args.clip_model, device=DEVICE)
|
| 62 |
+
model.eval()
|
| 63 |
+
|
| 64 |
+
head = None
|
| 65 |
+
if args.projection and os.path.exists(args.projection):
|
| 66 |
+
print(f"[index] loading projection head from {args.projection}", flush=True)
|
| 67 |
+
|
| 68 |
+
class ProjectionHead(nn.Module):
|
| 69 |
+
def __init__(self, in_dim=512, hidden=512, out_dim=256):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.net = nn.Sequential(
|
| 72 |
+
nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden, out_dim))
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
return F.normalize(self.net(x), dim=-1)
|
| 76 |
+
|
| 77 |
+
ckpt = torch.load(args.projection, map_location=DEVICE)
|
| 78 |
+
sd = ckpt.get("state_dict", ckpt)
|
| 79 |
+
head = ProjectionHead(
|
| 80 |
+
in_dim=ckpt.get("in_dim", 512),
|
| 81 |
+
hidden=ckpt.get("hidden", 512),
|
| 82 |
+
out_dim=ckpt.get("out_dim", 256),
|
| 83 |
+
).to(DEVICE)
|
| 84 |
+
head.load_state_dict(sd)
|
| 85 |
+
head.eval()
|
| 86 |
+
print(f"[index] out_dim={ckpt.get('out_dim', 256)}", flush=True)
|
| 87 |
+
|
| 88 |
+
files = sorted(f for f in os.listdir(refs)
|
| 89 |
+
if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp")))
|
| 90 |
+
if not files:
|
| 91 |
+
raise SystemExit(f"no images in {refs}")
|
| 92 |
+
print(f"[index] {len(files)} reference images", flush=True)
|
| 93 |
+
|
| 94 |
+
embeddings, names, filenames = [], [], []
|
| 95 |
+
for i, fname in enumerate(files, 1):
|
| 96 |
+
path = refs / fname
|
| 97 |
+
# Strip set-code prefix (e.g. "LOCH-JP001_") and clean up underscores
|
| 98 |
+
stem = os.path.splitext(fname)[0]
|
| 99 |
+
stem = re.sub(r"^[A-Z]+-[A-Z]+\d+_", "", stem)
|
| 100 |
+
name = stem.replace("_", " ").title()
|
| 101 |
+
# "Pharaoh S Servant" → "Pharaoh's Servant"
|
| 102 |
+
name = re.sub(r"\b(\w+) S\b", r"\1's", name)
|
| 103 |
+
try:
|
| 104 |
+
img = Image.open(path).convert("RGB")
|
| 105 |
+
except Exception as e:
|
| 106 |
+
print(f"[index] skip {fname}: {e}", flush=True)
|
| 107 |
+
continue
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
tensor = preprocess(img).unsqueeze(0).to(DEVICE)
|
| 110 |
+
feat = model.encode_image(tensor).float()
|
| 111 |
+
feat = feat / feat.norm(dim=-1, keepdim=True)
|
| 112 |
+
if head is not None:
|
| 113 |
+
feat = head(feat)
|
| 114 |
+
embeddings.append(feat.cpu().numpy()[0].astype(np.float32))
|
| 115 |
+
names.append(name)
|
| 116 |
+
filenames.append(fname)
|
| 117 |
+
if i % 25 == 0 or i == len(files):
|
| 118 |
+
print(f"[index] [{i:3d}/{len(files)}] {name[:50]}", flush=True)
|
| 119 |
+
|
| 120 |
+
emb = np.stack(embeddings, axis=0)
|
| 121 |
+
out = Path(args.out)
|
| 122 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 123 |
+
np.savez(out,
|
| 124 |
+
embeddings=emb,
|
| 125 |
+
names=np.array(names, dtype=object),
|
| 126 |
+
filenames=np.array(filenames, dtype=object))
|
| 127 |
+
print(f"\n[index] ✓ wrote {out} ({emb.shape[0]} refs × {emb.shape[1]} dims, "
|
| 128 |
+
f"{out.stat().st_size / 1024:.1f} KB)", flush=True)
|
| 129 |
+
return 0
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
sys.exit(main())
|
data_label_factory/identify/cli.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI dispatcher for `python3 -m data_label_factory.identify <command>`.
|
| 2 |
+
|
| 3 |
+
Subcommands:
|
| 4 |
+
index → build_index.main
|
| 5 |
+
verify → verify_index.main
|
| 6 |
+
train → train.main
|
| 7 |
+
serve → serve.main
|
| 8 |
+
|
| 9 |
+
Each is lazy-loaded so users only pay the import cost for the command they
|
| 10 |
+
actually invoke.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
HELP = """\
|
| 19 |
+
data_label_factory.identify — open-set image retrieval
|
| 20 |
+
|
| 21 |
+
usage: python3 -m data_label_factory.identify <command> [options]
|
| 22 |
+
|
| 23 |
+
commands:
|
| 24 |
+
index Build a CLIP retrieval index from a folder of reference images
|
| 25 |
+
verify Self-test an index and report margin / confusable pairs
|
| 26 |
+
train Contrastive fine-tune a projection head (improves accuracy)
|
| 27 |
+
serve Run an HTTP server that exposes the index as /api/falcon
|
| 28 |
+
|
| 29 |
+
run any command with --help for its options. The full blueprint is in
|
| 30 |
+
data_label_factory/identify/README.md.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def main(argv: list[str] | None = None) -> int:
|
| 35 |
+
args = list(argv) if argv is not None else sys.argv[1:]
|
| 36 |
+
if not args or args[0] in ("-h", "--help", "help"):
|
| 37 |
+
print(HELP)
|
| 38 |
+
return 0
|
| 39 |
+
|
| 40 |
+
cmd = args[0]
|
| 41 |
+
rest = args[1:]
|
| 42 |
+
|
| 43 |
+
if cmd == "index":
|
| 44 |
+
from .build_index import main as _main
|
| 45 |
+
return _main(rest)
|
| 46 |
+
if cmd == "verify":
|
| 47 |
+
from .verify_index import main as _main
|
| 48 |
+
return _main(rest)
|
| 49 |
+
if cmd == "train":
|
| 50 |
+
from .train import main as _main
|
| 51 |
+
return _main(rest)
|
| 52 |
+
if cmd == "serve":
|
| 53 |
+
from .serve import main as _main
|
| 54 |
+
return _main(rest)
|
| 55 |
+
|
| 56 |
+
print(f"unknown command: {cmd}\n", file=sys.stderr)
|
| 57 |
+
print(HELP, file=sys.stderr)
|
| 58 |
+
return 1
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
sys.exit(main())
|
data_label_factory/identify/serve.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HTTP server that serves a CLIP retrieval index over a mac_tensor-shaped
|
| 2 |
+
/api/falcon endpoint. Compatible with the existing data-label-factory web UI
|
| 3 |
+
(`web/canvas/live`) without any client changes.
|
| 4 |
+
|
| 5 |
+
Architecture per request:
|
| 6 |
+
1. YOLOv8-World detects "card-shaped" regions (open-vocab "card" class)
|
| 7 |
+
2. Each region is cropped, CLIP-encoded, optionally projection-headed
|
| 8 |
+
3. Cosine-matched against the loaded index → top match per region
|
| 9 |
+
4. If YOLO finds nothing, falls back to classifying the center crop
|
| 10 |
+
5. Returns mac_tensor /api/falcon-shaped JSON so the existing proxy works
|
| 11 |
+
|
| 12 |
+
Also serves the reference images at /refs/<filename> so the live tracker UI
|
| 13 |
+
can show "this is what the model thinks you're holding" alongside the webcam.
|
| 14 |
+
|
| 15 |
+
Configurable via env vars:
|
| 16 |
+
CARD_INDEX path to .npz from `index` (default: card_index.npz)
|
| 17 |
+
CLIP_PROJ optional path to projection head .pt (default: clip_proj.pt)
|
| 18 |
+
REFS_DIR folder of reference images served at /refs/ (default: limit-over-pack)
|
| 19 |
+
YOLO_CONF YOLO confidence threshold (default: 0.05)
|
| 20 |
+
CLIP_SIM_THRESHOLD minimum cosine to accept a match (default: 0.70)
|
| 21 |
+
CLIP_MARGIN_THRESHOLD minimum top1−top2 cosine gap to be 'confident' (default: 0.04)
|
| 22 |
+
PORT HTTP port (default: 8500)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import io
|
| 29 |
+
import os
|
| 30 |
+
import sys
|
| 31 |
+
import threading
|
| 32 |
+
import time
|
| 33 |
+
import traceback
|
| 34 |
+
from typing import Any
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main(argv: list[str] | None = None) -> int:
|
| 38 |
+
parser = argparse.ArgumentParser(
|
| 39 |
+
prog="data_label_factory.identify serve",
|
| 40 |
+
description=(
|
| 41 |
+
"Run a mac_tensor-shaped /api/falcon HTTP server that serves a CLIP "
|
| 42 |
+
"retrieval index. Compatible with the existing data-label-factory "
|
| 43 |
+
"web/canvas/live UI without client changes."
|
| 44 |
+
),
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument("--index", default=os.environ.get("CARD_INDEX", "card_index.npz"),
|
| 47 |
+
help="Path to the .npz index built by `index`")
|
| 48 |
+
parser.add_argument("--projection", default=os.environ.get("CLIP_PROJ", "clip_proj.pt"),
|
| 49 |
+
help="Path to the .pt projection head from `train` (optional)")
|
| 50 |
+
parser.add_argument("--refs", default=os.environ.get("REFS_DIR", "limit-over-pack"),
|
| 51 |
+
help="Folder of reference images, served at /refs/")
|
| 52 |
+
parser.add_argument("--port", type=int, default=int(os.environ.get("PORT", "8500")))
|
| 53 |
+
parser.add_argument("--host", default="0.0.0.0")
|
| 54 |
+
parser.add_argument("--sim-threshold", type=float,
|
| 55 |
+
default=float(os.environ.get("CLIP_SIM_THRESHOLD", "0.70")))
|
| 56 |
+
parser.add_argument("--margin-threshold", type=float,
|
| 57 |
+
default=float(os.environ.get("CLIP_MARGIN_THRESHOLD", "0.04")))
|
| 58 |
+
parser.add_argument("--yolo-conf", type=float,
|
| 59 |
+
default=float(os.environ.get("YOLO_CONF", "0.05")))
|
| 60 |
+
parser.add_argument("--no-yolo", action="store_true",
|
| 61 |
+
help="Skip YOLO detection entirely; always classify the center crop")
|
| 62 |
+
args = parser.parse_args(argv)
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
import numpy as np
|
| 66 |
+
import torch
|
| 67 |
+
import torch.nn as nn
|
| 68 |
+
import torch.nn.functional as F
|
| 69 |
+
from PIL import Image
|
| 70 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 71 |
+
from fastapi.responses import JSONResponse, PlainTextResponse
|
| 72 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 73 |
+
from fastapi.staticfiles import StaticFiles
|
| 74 |
+
import uvicorn
|
| 75 |
+
import clip
|
| 76 |
+
except ImportError as e:
|
| 77 |
+
raise SystemExit(
|
| 78 |
+
f"missing dependency: {e}\n"
|
| 79 |
+
"install with:\n"
|
| 80 |
+
" pip install fastapi 'uvicorn[standard]' python-multipart pillow torch "
|
| 81 |
+
"git+https://github.com/openai/CLIP.git\n"
|
| 82 |
+
" (and `pip install ultralytics` if you want YOLO detection)"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
DEVICE = ("mps" if torch.backends.mps.is_available()
|
| 86 |
+
else "cuda" if torch.cuda.is_available() else "cpu")
|
| 87 |
+
print(f"[serve] device={DEVICE}", flush=True)
|
| 88 |
+
|
| 89 |
+
# ---------- load CLIP + projection head ----------
|
| 90 |
+
print(f"[serve] loading CLIP ViT-B/32 …", flush=True)
|
| 91 |
+
clip_model, clip_preprocess = clip.load("ViT-B/32", device=DEVICE)
|
| 92 |
+
clip_model.eval()
|
| 93 |
+
|
| 94 |
+
proj_head = None
|
| 95 |
+
if args.projection and os.path.exists(args.projection):
|
| 96 |
+
class ProjectionHead(nn.Module):
|
| 97 |
+
def __init__(self, in_dim=512, hidden=512, out_dim=256):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.net = nn.Sequential(
|
| 100 |
+
nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden, out_dim))
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
return F.normalize(self.net(x), dim=-1)
|
| 104 |
+
|
| 105 |
+
ckpt = torch.load(args.projection, map_location=DEVICE)
|
| 106 |
+
sd = ckpt.get("state_dict", ckpt)
|
| 107 |
+
proj_head = ProjectionHead(
|
| 108 |
+
in_dim=ckpt.get("in_dim", 512),
|
| 109 |
+
hidden=ckpt.get("hidden", 512),
|
| 110 |
+
out_dim=ckpt.get("out_dim", 256),
|
| 111 |
+
).to(DEVICE)
|
| 112 |
+
proj_head.load_state_dict(sd)
|
| 113 |
+
proj_head.eval()
|
| 114 |
+
print(f"[serve] loaded fine-tuned projection head from {args.projection}", flush=True)
|
| 115 |
+
else:
|
| 116 |
+
print(f"[serve] no projection head — using raw CLIP features", flush=True)
|
| 117 |
+
|
| 118 |
+
# ---------- load index ----------
|
| 119 |
+
if not os.path.exists(args.index):
|
| 120 |
+
raise SystemExit(f"index not found: {args.index}\n"
|
| 121 |
+
f"build one with: data_label_factory.identify index --refs <folder>")
|
| 122 |
+
npz = np.load(args.index, allow_pickle=True)
|
| 123 |
+
CARD_EMB = npz["embeddings"]
|
| 124 |
+
CARD_NAMES = list(npz["names"])
|
| 125 |
+
CARD_FILES = list(npz["filenames"]) if "filenames" in npz.files else ["" for _ in CARD_NAMES]
|
| 126 |
+
print(f"[serve] loaded {len(CARD_NAMES)} refs from {args.index}", flush=True)
|
| 127 |
+
|
| 128 |
+
# ---------- optional YOLO for multi-card detection ----------
|
| 129 |
+
yolo = None
|
| 130 |
+
if not args.no_yolo:
|
| 131 |
+
try:
|
| 132 |
+
from ultralytics import YOLO
|
| 133 |
+
print(f"[serve] loading YOLOv8s-world for card detection …", flush=True)
|
| 134 |
+
yolo = YOLO("yolov8s-world.pt")
|
| 135 |
+
yolo.set_classes(["card", "trading card", "playing card"])
|
| 136 |
+
print(f"[serve] yolo ready (device={yolo.device})", flush=True)
|
| 137 |
+
except Exception as e:
|
| 138 |
+
print(f"[serve] YOLO unavailable ({e}); using whole-frame mode only", flush=True)
|
| 139 |
+
|
| 140 |
+
# ---------- helpers ----------
|
| 141 |
+
RARITY_SUFFIXES = {
|
| 142 |
+
"pscr": "PScR", "scr": "ScR", "ur": "UR", "sr": "SR",
|
| 143 |
+
"op": "OP", "utr": "UtR", "cr": "CR", "ea": "EA", "gmr": "GMR",
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
def _split_name_and_rarity(full: str) -> tuple[str, str]:
|
| 147 |
+
parts = full.split()
|
| 148 |
+
if parts and parts[-1].lower() in RARITY_SUFFIXES:
|
| 149 |
+
return " ".join(parts[:-1]), RARITY_SUFFIXES[parts[-1].lower()]
|
| 150 |
+
return full, ""
|
| 151 |
+
|
| 152 |
+
def _embed_pil(pil) -> "np.ndarray":
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
t = clip_preprocess(pil).unsqueeze(0).to(DEVICE)
|
| 155 |
+
f = clip_model.encode_image(t).float()
|
| 156 |
+
f = f / f.norm(dim=-1, keepdim=True)
|
| 157 |
+
if proj_head is not None:
|
| 158 |
+
f = proj_head(f)
|
| 159 |
+
return f.cpu().numpy()[0].astype(np.float32)
|
| 160 |
+
|
| 161 |
+
def _identify_crop(crop, top_k: int = 3) -> dict:
|
| 162 |
+
q = _embed_pil(crop)
|
| 163 |
+
sims = CARD_EMB @ q
|
| 164 |
+
order = np.argsort(-sims)[:top_k]
|
| 165 |
+
top = [{
|
| 166 |
+
"name": CARD_NAMES[i],
|
| 167 |
+
"filename": CARD_FILES[i] if i < len(CARD_FILES) else "",
|
| 168 |
+
"score": float(sims[i]),
|
| 169 |
+
} for i in order]
|
| 170 |
+
margin = top[0]["score"] - top[1]["score"] if len(top) > 1 else top[0]["score"]
|
| 171 |
+
return {
|
| 172 |
+
"top": top,
|
| 173 |
+
"best_name": top[0]["name"],
|
| 174 |
+
"best_filename": top[0]["filename"],
|
| 175 |
+
"best_score": top[0]["score"],
|
| 176 |
+
"margin": float(margin),
|
| 177 |
+
"confident": float(margin) >= args.margin_threshold,
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
# ---------- FastAPI app ----------
|
| 181 |
+
app = FastAPI(title="data-label-factory identify worker")
|
| 182 |
+
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
| 183 |
+
|
| 184 |
+
if os.path.isdir(args.refs):
|
| 185 |
+
app.mount("/refs", StaticFiles(directory=args.refs), name="refs")
|
| 186 |
+
print(f"[serve] mounted /refs/ from {args.refs}", flush=True)
|
| 187 |
+
|
| 188 |
+
_state = {"requests": 0, "last_query": ""}
|
| 189 |
+
_lock = threading.Lock()
|
| 190 |
+
|
| 191 |
+
@app.get("/")
|
| 192 |
+
def root() -> PlainTextResponse:
|
| 193 |
+
return PlainTextResponse(
|
| 194 |
+
f"data-label-factory identify · index={len(CARD_NAMES)} refs · "
|
| 195 |
+
f"requests={_state['requests']} · last_query={_state['last_query']!r}\n"
|
| 196 |
+
f"POST /api/falcon (multipart: image, query) — mac_tensor-shaped\n"
|
| 197 |
+
f"GET /refs/<filename> — reference images\n"
|
| 198 |
+
f"GET /health — JSON status\n"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
@app.get("/health")
|
| 202 |
+
def health() -> dict:
|
| 203 |
+
return {
|
| 204 |
+
"phase": "ready",
|
| 205 |
+
"model_loaded": True,
|
| 206 |
+
"device": DEVICE,
|
| 207 |
+
"index_size": len(CARD_NAMES),
|
| 208 |
+
"has_projection": proj_head is not None,
|
| 209 |
+
"has_yolo": yolo is not None,
|
| 210 |
+
"sim_threshold": args.sim_threshold,
|
| 211 |
+
"margin_threshold": args.margin_threshold,
|
| 212 |
+
"requests_served": _state["requests"],
|
| 213 |
+
"last_query": _state["last_query"],
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
@app.post("/api/falcon")
|
| 217 |
+
async def falcon(image: UploadFile = File(...), query: str = Form(...)) -> JSONResponse:
|
| 218 |
+
t0 = time.time()
|
| 219 |
+
try:
|
| 220 |
+
pil = Image.open(io.BytesIO(await image.read())).convert("RGB")
|
| 221 |
+
except Exception as e:
|
| 222 |
+
raise HTTPException(400, f"bad image: {e}")
|
| 223 |
+
W, H = pil.size
|
| 224 |
+
|
| 225 |
+
with _lock:
|
| 226 |
+
_state["last_query"] = query
|
| 227 |
+
|
| 228 |
+
masks: list[dict] = []
|
| 229 |
+
|
| 230 |
+
# 1. YOLO multi-card detection (if available)
|
| 231 |
+
if yolo is not None:
|
| 232 |
+
try:
|
| 233 |
+
results = yolo.predict(pil, conf=args.yolo_conf, iou=0.5, verbose=False)
|
| 234 |
+
if results:
|
| 235 |
+
boxes = getattr(results[0], "boxes", None)
|
| 236 |
+
if boxes is not None and boxes.xyxy is not None:
|
| 237 |
+
for x1, y1, x2, y2 in boxes.xyxy.cpu().numpy().tolist():
|
| 238 |
+
bx1, by1 = max(0, int(x1)), max(0, int(y1))
|
| 239 |
+
bx2, by2 = min(W, int(x2)), min(H, int(y2))
|
| 240 |
+
if bx2 - bx1 < 20 or by2 - by1 < 20:
|
| 241 |
+
continue
|
| 242 |
+
crop = pil.crop((bx1, by1, bx2, by2))
|
| 243 |
+
info = _identify_crop(crop)
|
| 244 |
+
if info["best_score"] < args.sim_threshold:
|
| 245 |
+
continue
|
| 246 |
+
name, rarity = _split_name_and_rarity(info["best_name"])
|
| 247 |
+
display = f"{name} ({rarity})" if rarity else name
|
| 248 |
+
if not info["confident"]:
|
| 249 |
+
display = f"{display}?"
|
| 250 |
+
masks.append({
|
| 251 |
+
"bbox_norm": {
|
| 252 |
+
"x1": float(x1) / W, "y1": float(y1) / H,
|
| 253 |
+
"x2": float(x2) / W, "y2": float(y2) / H,
|
| 254 |
+
},
|
| 255 |
+
"area_fraction": float((x2 - x1) * (y2 - y1)) / max(W * H, 1),
|
| 256 |
+
"label": display,
|
| 257 |
+
"name": name,
|
| 258 |
+
"rarity": rarity,
|
| 259 |
+
"score": info["best_score"],
|
| 260 |
+
"top_k": info["top"],
|
| 261 |
+
"margin": info["margin"],
|
| 262 |
+
"confident": info["confident"],
|
| 263 |
+
"ref_filename": info["best_filename"],
|
| 264 |
+
})
|
| 265 |
+
except Exception as e:
|
| 266 |
+
print(f"[serve] yolo error: {e}", flush=True)
|
| 267 |
+
|
| 268 |
+
# 2. Whole-frame fallback (single-card workflow)
|
| 269 |
+
if not masks:
|
| 270 |
+
cx1, cy1 = int(W * 0.10), int(H * 0.05)
|
| 271 |
+
cx2, cy2 = int(W * 0.90), int(H * 0.95)
|
| 272 |
+
center = pil.crop((cx1, cy1, cx2, cy2))
|
| 273 |
+
info = _identify_crop(center)
|
| 274 |
+
if info["best_score"] >= args.sim_threshold and info["confident"]:
|
| 275 |
+
name, rarity = _split_name_and_rarity(info["best_name"])
|
| 276 |
+
display = f"{name} ({rarity})" if rarity else name
|
| 277 |
+
masks.append({
|
| 278 |
+
"bbox_norm": {
|
| 279 |
+
"x1": cx1 / W, "y1": cy1 / H, "x2": cx2 / W, "y2": cy2 / H,
|
| 280 |
+
},
|
| 281 |
+
"area_fraction": (cx2 - cx1) * (cy2 - cy1) / max(W * H, 1),
|
| 282 |
+
"label": display,
|
| 283 |
+
"name": name,
|
| 284 |
+
"rarity": rarity,
|
| 285 |
+
"score": info["best_score"],
|
| 286 |
+
"top_k": info["top"],
|
| 287 |
+
"margin": info["margin"],
|
| 288 |
+
"confident": True,
|
| 289 |
+
"ref_filename": info["best_filename"],
|
| 290 |
+
})
|
| 291 |
+
|
| 292 |
+
with _lock:
|
| 293 |
+
_state["requests"] += 1
|
| 294 |
+
|
| 295 |
+
return JSONResponse(content={
|
| 296 |
+
"image_size": [W, H],
|
| 297 |
+
"count": len(masks),
|
| 298 |
+
"masks": masks,
|
| 299 |
+
"query": query,
|
| 300 |
+
"elapsed_seconds": round(time.time() - t0, 3),
|
| 301 |
+
})
|
| 302 |
+
|
| 303 |
+
print(f"\n[serve] listening on http://{args.host}:{args.port}", flush=True)
|
| 304 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="warning")
|
| 305 |
+
return 0
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
sys.exit(main())
|
data_label_factory/identify/train.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contrastive fine-tune of a small projection head on top of frozen CLIP.
|
| 2 |
+
|
| 3 |
+
Wraps the proven training loop that took the 150-card index from cosine
|
| 4 |
+
margin 0.074 → 0.36 (5x improvement). The CLIP backbone stays frozen, only
|
| 5 |
+
a tiny ~400k-param projection MLP is trained, so this runs on Apple Silicon
|
| 6 |
+
MPS in ~5 minutes for a 150-class set.
|
| 7 |
+
|
| 8 |
+
Data generation: K cards × M augmentations per batch (default 16 × 4 = 64).
|
| 9 |
+
Loss: SupCon (Khosla et al. 2020).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
import sys
|
| 18 |
+
import time
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
# Lazy heavy imports — only triggered when this module is actually invoked.
|
| 22 |
+
|
| 23 |
+
DEFAULT_PALETTE_HINT = "ViT-B/32 + 512→512→256 projection"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main(argv: list[str] | None = None) -> int:
|
| 27 |
+
parser = argparse.ArgumentParser(
|
| 28 |
+
prog="data_label_factory.identify train",
|
| 29 |
+
description=(
|
| 30 |
+
"Contrastive fine-tune a small projection head on top of frozen CLIP. "
|
| 31 |
+
"Use this when off-the-shelf CLIP retrieval is too noisy for your "
|
| 32 |
+
f"reference set. Architecture: {DEFAULT_PALETTE_HINT}."
|
| 33 |
+
),
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument("--refs", required=True,
|
| 36 |
+
help="Folder of reference images (1 per class). Filenames become labels.")
|
| 37 |
+
parser.add_argument("--out", default="clip_proj.pt",
|
| 38 |
+
help="Output path for the trained projection head .pt")
|
| 39 |
+
parser.add_argument("--epochs", type=int, default=12)
|
| 40 |
+
parser.add_argument("--k-cards", type=int, default=16,
|
| 41 |
+
help="Distinct classes per training batch.")
|
| 42 |
+
parser.add_argument("--m-augs", type=int, default=4,
|
| 43 |
+
help="Augmentations per class per batch.")
|
| 44 |
+
parser.add_argument("--steps-per-epoch", type=int, default=80)
|
| 45 |
+
parser.add_argument("--lr", type=float, default=5e-4)
|
| 46 |
+
parser.add_argument("--temperature", type=float, default=0.1)
|
| 47 |
+
parser.add_argument("--clip-model", default="ViT-B/32")
|
| 48 |
+
args = parser.parse_args(argv)
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
import numpy as np
|
| 52 |
+
import torch
|
| 53 |
+
import torch.nn as nn
|
| 54 |
+
import torch.nn.functional as F
|
| 55 |
+
from torch.utils.data import Dataset, DataLoader, Sampler
|
| 56 |
+
from torchvision import transforms
|
| 57 |
+
from PIL import Image
|
| 58 |
+
import clip
|
| 59 |
+
except ImportError as e:
|
| 60 |
+
raise SystemExit(
|
| 61 |
+
f"missing dependency: {e}\n"
|
| 62 |
+
"install with:\n"
|
| 63 |
+
" pip install torch torchvision pillow git+https://github.com/openai/CLIP.git"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
DEVICE = ("mps" if torch.backends.mps.is_available()
|
| 67 |
+
else "cuda" if torch.cuda.is_available() else "cpu")
|
| 68 |
+
print(f"[train] device={DEVICE}", flush=True)
|
| 69 |
+
|
| 70 |
+
refs = Path(args.refs)
|
| 71 |
+
if not refs.is_dir():
|
| 72 |
+
raise SystemExit(f"refs folder not found: {refs}")
|
| 73 |
+
|
| 74 |
+
print(f"[train] loading CLIP {args.clip_model} …", flush=True)
|
| 75 |
+
clip_model, clip_preprocess = clip.load(args.clip_model, device=DEVICE)
|
| 76 |
+
clip_model.eval()
|
| 77 |
+
for p in clip_model.parameters():
|
| 78 |
+
p.requires_grad = False
|
| 79 |
+
|
| 80 |
+
class CardDataset(Dataset):
|
| 81 |
+
def __init__(self, folder: Path, augs_per_card: int):
|
| 82 |
+
files = sorted(f for f in os.listdir(folder)
|
| 83 |
+
if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp")))
|
| 84 |
+
if not files:
|
| 85 |
+
raise SystemExit(f"no images in {folder}")
|
| 86 |
+
self.images = []
|
| 87 |
+
for f in files:
|
| 88 |
+
self.images.append(Image.open(folder / f).convert("RGB"))
|
| 89 |
+
self.aug_per_card = augs_per_card
|
| 90 |
+
self.aug = transforms.Compose([
|
| 91 |
+
transforms.RandomResizedCrop(256, scale=(0.6, 1.0), ratio=(0.7, 1.4)),
|
| 92 |
+
transforms.RandomRotation(20),
|
| 93 |
+
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05),
|
| 94 |
+
transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
|
| 95 |
+
transforms.RandomApply([transforms.GaussianBlur(5, sigma=(0.1, 2.0))], p=0.3),
|
| 96 |
+
transforms.RandomGrayscale(p=0.05),
|
| 97 |
+
])
|
| 98 |
+
|
| 99 |
+
def __len__(self):
|
| 100 |
+
return len(self.images) * self.aug_per_card
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, idx):
|
| 103 |
+
card_idx = idx % len(self.images)
|
| 104 |
+
return clip_preprocess(self.aug(self.images[card_idx])), card_idx
|
| 105 |
+
|
| 106 |
+
class KCardsSampler(Sampler):
|
| 107 |
+
def __init__(self, dataset, k_cards: int, m_augs: int, steps: int):
|
| 108 |
+
self.n = len(dataset.images)
|
| 109 |
+
self.k = k_cards
|
| 110 |
+
self.m = m_augs
|
| 111 |
+
self.steps = steps
|
| 112 |
+
|
| 113 |
+
def __iter__(self):
|
| 114 |
+
for _ in range(self.steps):
|
| 115 |
+
cards = random.sample(range(self.n), self.k)
|
| 116 |
+
batch = []
|
| 117 |
+
for c in cards:
|
| 118 |
+
for _ in range(self.m):
|
| 119 |
+
batch.append(c)
|
| 120 |
+
random.shuffle(batch)
|
| 121 |
+
yield from batch
|
| 122 |
+
|
| 123 |
+
def __len__(self):
|
| 124 |
+
return self.steps * self.k * self.m
|
| 125 |
+
|
| 126 |
+
class ProjectionHead(nn.Module):
|
| 127 |
+
def __init__(self, in_dim=512, hidden=512, out_dim=256):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.net = nn.Sequential(
|
| 130 |
+
nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden, out_dim))
|
| 131 |
+
|
| 132 |
+
def forward(self, x):
|
| 133 |
+
return F.normalize(self.net(x), dim=-1)
|
| 134 |
+
|
| 135 |
+
def supcon_loss(features: "torch.Tensor", labels: "torch.Tensor", temperature: float) -> "torch.Tensor":
|
| 136 |
+
device = features.device
|
| 137 |
+
bsz = features.size(0)
|
| 138 |
+
labels = labels.contiguous().view(-1, 1)
|
| 139 |
+
mask = torch.eq(labels, labels.T).float().to(device)
|
| 140 |
+
sim = torch.matmul(features, features.T) / temperature
|
| 141 |
+
sim_max, _ = torch.max(sim, dim=1, keepdim=True)
|
| 142 |
+
logits = sim - sim_max.detach()
|
| 143 |
+
self_mask = torch.scatter(
|
| 144 |
+
torch.ones_like(mask), 1,
|
| 145 |
+
torch.arange(bsz, device=device).view(-1, 1), 0)
|
| 146 |
+
pos_mask = mask * self_mask
|
| 147 |
+
exp_logits = torch.exp(logits) * self_mask
|
| 148 |
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
|
| 149 |
+
pos_count = pos_mask.sum(1)
|
| 150 |
+
pos_count = torch.where(pos_count == 0, torch.ones_like(pos_count), pos_count)
|
| 151 |
+
return -((pos_mask * log_prob).sum(1) / pos_count).mean()
|
| 152 |
+
|
| 153 |
+
print(f"[train] dataset from {refs}", flush=True)
|
| 154 |
+
ds = CardDataset(refs, augs_per_card=args.m_augs)
|
| 155 |
+
print(f"[train] {len(ds.images)} reference images", flush=True)
|
| 156 |
+
|
| 157 |
+
sampler = KCardsSampler(ds, k_cards=args.k_cards, m_augs=args.m_augs,
|
| 158 |
+
steps=args.steps_per_epoch)
|
| 159 |
+
loader = DataLoader(ds, batch_size=args.k_cards * args.m_augs,
|
| 160 |
+
sampler=sampler, num_workers=0, drop_last=True)
|
| 161 |
+
|
| 162 |
+
head = ProjectionHead(in_dim=512, hidden=512, out_dim=256).to(DEVICE)
|
| 163 |
+
print(f"[train] projection head: {sum(p.numel() for p in head.parameters()):,} params", flush=True)
|
| 164 |
+
|
| 165 |
+
optimizer = torch.optim.AdamW(head.parameters(), lr=args.lr, weight_decay=1e-4)
|
| 166 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 167 |
+
optimizer, T_max=args.epochs * args.steps_per_epoch)
|
| 168 |
+
|
| 169 |
+
print(f"\n[train] {args.epochs} epochs · {args.steps_per_epoch} steps · "
|
| 170 |
+
f"batch={args.k_cards * args.m_augs} (K={args.k_cards}×M={args.m_augs})\n", flush=True)
|
| 171 |
+
t0 = time.time()
|
| 172 |
+
for epoch in range(args.epochs):
|
| 173 |
+
head.train()
|
| 174 |
+
epoch_loss, n_batches = 0.0, 0
|
| 175 |
+
for imgs, labels in loader:
|
| 176 |
+
imgs = imgs.to(DEVICE)
|
| 177 |
+
labels = labels.to(DEVICE)
|
| 178 |
+
with torch.no_grad():
|
| 179 |
+
feats = clip_model.encode_image(imgs).float()
|
| 180 |
+
feats = feats / feats.norm(dim=-1, keepdim=True)
|
| 181 |
+
proj = head(feats)
|
| 182 |
+
loss = supcon_loss(proj, labels, temperature=args.temperature)
|
| 183 |
+
optimizer.zero_grad()
|
| 184 |
+
loss.backward()
|
| 185 |
+
optimizer.step()
|
| 186 |
+
scheduler.step()
|
| 187 |
+
epoch_loss += loss.item()
|
| 188 |
+
n_batches += 1
|
| 189 |
+
print(f"[train] epoch {epoch + 1:2d}/{args.epochs} loss={epoch_loss / max(n_batches, 1):.4f} "
|
| 190 |
+
f"({time.time() - t0:.0f}s)", flush=True)
|
| 191 |
+
|
| 192 |
+
out_path = Path(args.out)
|
| 193 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 194 |
+
torch.save({
|
| 195 |
+
"state_dict": head.state_dict(),
|
| 196 |
+
"in_dim": 512, "hidden": 512, "out_dim": 256,
|
| 197 |
+
"model": args.clip_model,
|
| 198 |
+
"epochs": args.epochs, "k_cards": args.k_cards, "m_augs": args.m_augs,
|
| 199 |
+
"ref_count": len(ds.images),
|
| 200 |
+
}, out_path)
|
| 201 |
+
print(f"\n[train] ✓ saved {out_path} ({out_path.stat().st_size / 1024:.0f} KB)", flush=True)
|
| 202 |
+
return 0
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
sys.exit(main())
|
data_label_factory/identify/verify_index.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Self-test the index for top-1 accuracy + report confusable pairs.
|
| 2 |
+
|
| 3 |
+
For each reference image, embed it and verify that its top-1 match in the
|
| 4 |
+
index is itself. Reports the cosine margin between correct and best-wrong
|
| 5 |
+
matches — the most useful number for predicting live accuracy.
|
| 6 |
+
|
| 7 |
+
Run this immediately after building an index to catch bad data BEFORE
|
| 8 |
+
deploying it to a live tracker.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main(argv: list[str] | None = None) -> int:
|
| 19 |
+
parser = argparse.ArgumentParser(
|
| 20 |
+
prog="data_label_factory.identify verify",
|
| 21 |
+
description=(
|
| 22 |
+
"Self-test a built index. Each reference image should match itself "
|
| 23 |
+
"as top-1; a wide cosine margin between correct and best-wrong matches "
|
| 24 |
+
"is the strongest predictor of live accuracy."
|
| 25 |
+
),
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument("--index", default="card_index.npz", help="Path to .npz from `index`")
|
| 28 |
+
parser.add_argument("--top-confusables", type=int, default=5,
|
| 29 |
+
help="How many of the most-confusable pairs to print")
|
| 30 |
+
args = parser.parse_args(argv)
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
import numpy as np
|
| 34 |
+
except ImportError:
|
| 35 |
+
raise SystemExit("numpy required: pip install numpy")
|
| 36 |
+
|
| 37 |
+
npz = np.load(args.index, allow_pickle=True)
|
| 38 |
+
EMB = npz["embeddings"]
|
| 39 |
+
NAMES = list(npz["names"])
|
| 40 |
+
print(f"[verify] index: {len(NAMES)} refs × {EMB.shape[1]} dims")
|
| 41 |
+
|
| 42 |
+
# Pairwise similarity matrix (small N, fits in memory)
|
| 43 |
+
sims = EMB @ EMB.T
|
| 44 |
+
np.fill_diagonal(sims, -1.0)
|
| 45 |
+
|
| 46 |
+
# Top confusable pairs
|
| 47 |
+
print(f"\nMost-confusable pairs (highest cosine sim between DIFFERENT refs):")
|
| 48 |
+
flat_idx = np.argpartition(sims.flatten(), -args.top_confusables * 2)[-args.top_confusables * 2:]
|
| 49 |
+
seen = set()
|
| 50 |
+
shown = 0
|
| 51 |
+
for fi in flat_idx[np.argsort(sims.flatten()[flat_idx])[::-1]]:
|
| 52 |
+
i, j = divmod(int(fi), len(NAMES))
|
| 53 |
+
if (j, i) in seen:
|
| 54 |
+
continue
|
| 55 |
+
seen.add((i, j))
|
| 56 |
+
print(f" {sims[i, j]:.3f} {NAMES[i][:42]} ↔ {NAMES[j][:42]}")
|
| 57 |
+
shown += 1
|
| 58 |
+
if shown >= args.top_confusables:
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
# Restore diagonal for self-test
|
| 62 |
+
np.fill_diagonal(sims, 1.0)
|
| 63 |
+
|
| 64 |
+
# Self-identity test: each ref's top-1 in EMB @ EMB[i] should be i
|
| 65 |
+
correct = 0
|
| 66 |
+
failures = []
|
| 67 |
+
for i in range(len(NAMES)):
|
| 68 |
+
row = EMB @ EMB[i]
|
| 69 |
+
top = int(np.argmax(row))
|
| 70 |
+
if top == i:
|
| 71 |
+
correct += 1
|
| 72 |
+
else:
|
| 73 |
+
failures.append((NAMES[i], NAMES[top], float(row[top]), float(row[i])))
|
| 74 |
+
|
| 75 |
+
pct = correct / len(NAMES) * 100
|
| 76 |
+
print(f"\nself-identity test: {correct}/{len(NAMES)} = {pct:.1f}% top-1 self-id")
|
| 77 |
+
for name, mismatch, score_wrong, score_right in failures[:10]:
|
| 78 |
+
print(f" ✗ {name[:42]} → matched {mismatch[:42]} "
|
| 79 |
+
f"(top={score_wrong:.3f} vs self={score_right:.3f})")
|
| 80 |
+
|
| 81 |
+
# Margin analysis: gap between "I matched myself" and "best wrong match"
|
| 82 |
+
correct_scores, best_wrong_scores = [], []
|
| 83 |
+
for i in range(len(NAMES)):
|
| 84 |
+
row = EMB @ EMB[i]
|
| 85 |
+
correct_scores.append(row[i])
|
| 86 |
+
row[i] = -1
|
| 87 |
+
best_wrong_scores.append(row.max())
|
| 88 |
+
|
| 89 |
+
median_correct = float(np.median(correct_scores))
|
| 90 |
+
median_wrong = float(np.median(best_wrong_scores))
|
| 91 |
+
margin = median_correct - median_wrong
|
| 92 |
+
print(f"\nthreshold analysis:")
|
| 93 |
+
print(f" median correct match score: {median_correct:.3f}")
|
| 94 |
+
print(f" median best-wrong-match score: {median_wrong:.3f}")
|
| 95 |
+
print(f" gap (margin): {margin:.3f}")
|
| 96 |
+
suggested = max(0.5, median_wrong + 0.05)
|
| 97 |
+
print(f" → recommended SIM_THRESHOLD = {suggested:.2f}")
|
| 98 |
+
return 0 if pct >= 99 else 1
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
sys.exit(main())
|
data_label_factory/runpod/pod_falcon_server.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
pod_falcon_server.py — single-file Falcon Perception HTTP server for a RunPod pod.
|
| 4 |
+
|
| 5 |
+
Designed to be curl-installed via the pod's dockerStartCmd. Two phases:
|
| 6 |
+
|
| 7 |
+
1. **Boot phase (instant):** start a FastAPI server on 0.0.0.0:8000 with two
|
| 8 |
+
endpoints: `/health` (always responds) and `/api/falcon` (returns 503 until
|
| 9 |
+
the model is loaded). A background thread starts heavy installation.
|
| 10 |
+
|
| 11 |
+
2. **Install phase (~5-10 min):** install pip deps, install falcon-perception
|
| 12 |
+
with --no-deps, download the Falcon model from Hugging Face, instantiate
|
| 13 |
+
the inference engine. As soon as it's ready, `/api/falcon` flips to live.
|
| 14 |
+
|
| 15 |
+
The endpoint shape MATCHES mac_tensor's /api/falcon so the existing
|
| 16 |
+
`web/app/api/falcon-frame/route.ts` proxy works against it without changes:
|
| 17 |
+
|
| 18 |
+
Request: multipart/form-data with `image` (file) + `query` (string)
|
| 19 |
+
Response: {
|
| 20 |
+
"image_size": [w, h],
|
| 21 |
+
"count": int,
|
| 22 |
+
"masks": [{"bbox_norm": {x1, y1, x2, y2}, "area_fraction": float}, ...],
|
| 23 |
+
"elapsed_seconds": float,
|
| 24 |
+
"cold_start": bool
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
You can poll progress via:
|
| 28 |
+
curl https://<pod-id>-8000.proxy.runpod.net/health
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import io
|
| 32 |
+
import os
|
| 33 |
+
import subprocess
|
| 34 |
+
import sys
|
| 35 |
+
import threading
|
| 36 |
+
import time
|
| 37 |
+
import traceback
|
| 38 |
+
from typing import Any
|
| 39 |
+
|
| 40 |
+
# ============================================================
|
| 41 |
+
# Boot phase — keep imports minimal so the server starts FAST
|
| 42 |
+
# ============================================================
|
| 43 |
+
|
| 44 |
+
print("[server] starting boot phase…", flush=True)
|
| 45 |
+
BOOT_T0 = time.time()
|
| 46 |
+
|
| 47 |
+
# Install fastapi + uvicorn synchronously since we need them for the boot server.
|
| 48 |
+
# These are tiny (~30 MB) so this takes ~10 seconds.
|
| 49 |
+
def _pip(args, retries=3):
|
| 50 |
+
for attempt in range(retries):
|
| 51 |
+
r = subprocess.run(
|
| 52 |
+
[sys.executable, "-m", "pip", "install", "--quiet", "--no-cache-dir"] + args,
|
| 53 |
+
capture_output=True, text=True,
|
| 54 |
+
)
|
| 55 |
+
if r.returncode == 0:
|
| 56 |
+
return True
|
| 57 |
+
print(f"[pip] attempt {attempt+1} failed: {r.stderr[:300]}", flush=True)
|
| 58 |
+
time.sleep(3)
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
print("[server] installing fastapi + uvicorn + multipart…", flush=True)
|
| 62 |
+
if not _pip(["fastapi==0.115.6", "uvicorn[standard]==0.32.1", "python-multipart==0.0.20", "pillow"]):
|
| 63 |
+
print("[server] CRITICAL: failed to install fastapi", flush=True)
|
| 64 |
+
sys.exit(1)
|
| 65 |
+
|
| 66 |
+
# Now we can import fastapi
|
| 67 |
+
from fastapi import FastAPI, Request, UploadFile, File, Form, HTTPException
|
| 68 |
+
from fastapi.responses import JSONResponse, PlainTextResponse
|
| 69 |
+
from PIL import Image
|
| 70 |
+
|
| 71 |
+
app = FastAPI(title="data-label-factory falcon worker")
|
| 72 |
+
|
| 73 |
+
STATE: dict[str, Any] = {
|
| 74 |
+
"phase": "boot",
|
| 75 |
+
"boot_started": BOOT_T0,
|
| 76 |
+
"model_loaded": False,
|
| 77 |
+
"install_log": [],
|
| 78 |
+
"error": None,
|
| 79 |
+
"model_id": None,
|
| 80 |
+
"device": None,
|
| 81 |
+
"load_seconds": None,
|
| 82 |
+
"cold_start_used": False,
|
| 83 |
+
"requests_served": 0,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
def _log(msg: str) -> None:
|
| 87 |
+
line = f"[{time.time() - BOOT_T0:6.1f}s] {msg}"
|
| 88 |
+
print(line, flush=True)
|
| 89 |
+
STATE["install_log"].append(line)
|
| 90 |
+
# cap log to last 200 lines
|
| 91 |
+
if len(STATE["install_log"]) > 200:
|
| 92 |
+
STATE["install_log"] = STATE["install_log"][-200:]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ============================================================
|
| 96 |
+
# Endpoints
|
| 97 |
+
# ============================================================
|
| 98 |
+
|
| 99 |
+
@app.get("/")
|
| 100 |
+
def root() -> PlainTextResponse:
|
| 101 |
+
return PlainTextResponse(
|
| 102 |
+
f"data-label-factory falcon worker · phase={STATE['phase']} "
|
| 103 |
+
f"loaded={STATE['model_loaded']} requests={STATE['requests_served']}\n"
|
| 104 |
+
f"see /health for full status, POST /api/falcon for inference\n"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@app.get("/health")
|
| 109 |
+
def health() -> dict:
|
| 110 |
+
return {
|
| 111 |
+
"phase": STATE["phase"],
|
| 112 |
+
"model_loaded": STATE["model_loaded"],
|
| 113 |
+
"model_id": STATE.get("model_id"),
|
| 114 |
+
"device": STATE.get("device"),
|
| 115 |
+
"load_seconds": STATE.get("load_seconds"),
|
| 116 |
+
"uptime_seconds": round(time.time() - BOOT_T0, 1),
|
| 117 |
+
"requests_served": STATE["requests_served"],
|
| 118 |
+
"error": STATE["error"],
|
| 119 |
+
"recent_log": STATE["install_log"][-30:],
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@app.post("/api/falcon")
|
| 124 |
+
async def falcon(image: UploadFile = File(...), query: str = Form(...)) -> JSONResponse:
|
| 125 |
+
if not STATE["model_loaded"]:
|
| 126 |
+
return JSONResponse(
|
| 127 |
+
status_code=503,
|
| 128 |
+
content={
|
| 129 |
+
"error": "model not loaded yet",
|
| 130 |
+
"phase": STATE["phase"],
|
| 131 |
+
"loaded": False,
|
| 132 |
+
"uptime": round(time.time() - BOOT_T0, 1),
|
| 133 |
+
"recent": STATE["install_log"][-5:],
|
| 134 |
+
},
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
t0 = time.time()
|
| 138 |
+
img_bytes = await image.read()
|
| 139 |
+
try:
|
| 140 |
+
pil = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
| 141 |
+
except Exception as e:
|
| 142 |
+
return JSONResponse(status_code=400, content={"error": f"bad image: {e}"})
|
| 143 |
+
|
| 144 |
+
cold = not STATE["cold_start_used"]
|
| 145 |
+
STATE["cold_start_used"] = True
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
result = _run_inference(pil, query)
|
| 149 |
+
except Exception as e:
|
| 150 |
+
return JSONResponse(
|
| 151 |
+
status_code=500,
|
| 152 |
+
content={"error": str(e), "trace": traceback.format_exc().splitlines()[-6:]},
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
STATE["requests_served"] += 1
|
| 156 |
+
return JSONResponse(content={
|
| 157 |
+
"image_size": [pil.width, pil.height],
|
| 158 |
+
"count": result["count"],
|
| 159 |
+
"masks": result["masks"],
|
| 160 |
+
"query": query,
|
| 161 |
+
"elapsed_seconds": round(time.time() - t0, 3),
|
| 162 |
+
"cold_start": cold,
|
| 163 |
+
})
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ============================================================
|
| 167 |
+
# Heavy install + inference (loaded in background thread)
|
| 168 |
+
# ============================================================
|
| 169 |
+
|
| 170 |
+
_engine = None
|
| 171 |
+
_tokenizer = None
|
| 172 |
+
_image_processor = None
|
| 173 |
+
_model = None
|
| 174 |
+
_model_args = None
|
| 175 |
+
_sampling_params = None
|
| 176 |
+
_torch = None # cached torch module reference
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _run_inference(pil_img: "Image.Image", query: str) -> dict:
|
| 180 |
+
"""Single-image Falcon Perception forward pass.
|
| 181 |
+
|
| 182 |
+
Uses task='segmentation' per the prior session learning ('detection mode
|
| 183 |
+
returns empty bboxes'). Extracts bboxes from each segmentation mask via
|
| 184 |
+
pycocotools mask decoding.
|
| 185 |
+
"""
|
| 186 |
+
if _engine is None:
|
| 187 |
+
raise RuntimeError("model not loaded")
|
| 188 |
+
|
| 189 |
+
from falcon_perception import build_prompt_for_task # type: ignore
|
| 190 |
+
from falcon_perception.paged_inference import Sequence # type: ignore
|
| 191 |
+
|
| 192 |
+
W, H = pil_img.size
|
| 193 |
+
task = "segmentation" if getattr(_model_args, "do_segmentation", False) else "detection"
|
| 194 |
+
prompt = build_prompt_for_task(query, task)
|
| 195 |
+
|
| 196 |
+
sequences = [Sequence(
|
| 197 |
+
text=prompt,
|
| 198 |
+
image=pil_img,
|
| 199 |
+
min_image_size=256,
|
| 200 |
+
max_image_size=1024,
|
| 201 |
+
task=task,
|
| 202 |
+
)]
|
| 203 |
+
with _torch.inference_mode():
|
| 204 |
+
_engine.generate(
|
| 205 |
+
sequences,
|
| 206 |
+
sampling_params=_sampling_params,
|
| 207 |
+
use_tqdm=False,
|
| 208 |
+
print_stats=False,
|
| 209 |
+
)
|
| 210 |
+
seq = sequences[0]
|
| 211 |
+
aux = seq.output_aux
|
| 212 |
+
|
| 213 |
+
masks_out: list[dict] = []
|
| 214 |
+
|
| 215 |
+
# Path A: detection mode (bboxes_raw is populated)
|
| 216 |
+
bboxes_raw = getattr(aux, "bboxes_raw", None)
|
| 217 |
+
if bboxes_raw:
|
| 218 |
+
try:
|
| 219 |
+
from falcon_perception.visualization_utils import pair_bbox_entries # type: ignore
|
| 220 |
+
pairs = pair_bbox_entries(bboxes_raw)
|
| 221 |
+
for entry in pairs:
|
| 222 |
+
if hasattr(entry, "_asdict"):
|
| 223 |
+
d = entry._asdict()
|
| 224 |
+
elif isinstance(entry, dict):
|
| 225 |
+
d = entry
|
| 226 |
+
else:
|
| 227 |
+
vals = list(entry)
|
| 228 |
+
if len(vals) < 5:
|
| 229 |
+
continue
|
| 230 |
+
d = {"x1": vals[1], "y1": vals[2], "x2": vals[3], "y2": vals[4]}
|
| 231 |
+
x1 = float(d.get("x1", 0)); y1 = float(d.get("y1", 0))
|
| 232 |
+
x2 = float(d.get("x2", 0)); y2 = float(d.get("y2", 0))
|
| 233 |
+
masks_out.append({
|
| 234 |
+
"bbox_norm": {
|
| 235 |
+
"x1": x1 / W if x1 > 1.5 else x1,
|
| 236 |
+
"y1": y1 / H if y1 > 1.5 else y1,
|
| 237 |
+
"x2": x2 / W if x2 > 1.5 else x2,
|
| 238 |
+
"y2": y2 / H if y2 > 1.5 else y2,
|
| 239 |
+
},
|
| 240 |
+
"area_fraction": ((x2 - x1) * (y2 - y1)) / (W * H) if W and H else 0.0,
|
| 241 |
+
})
|
| 242 |
+
except Exception as e:
|
| 243 |
+
_log(f"pair_bbox_entries failed: {e}")
|
| 244 |
+
|
| 245 |
+
# Path B: segmentation mode (masks_rle is populated)
|
| 246 |
+
if not masks_out:
|
| 247 |
+
masks_rle = getattr(aux, "masks_rle", None) or []
|
| 248 |
+
for m in masks_rle:
|
| 249 |
+
try:
|
| 250 |
+
# Try to extract a bbox from the mask. Multiple possible shapes.
|
| 251 |
+
if isinstance(m, dict) and "bbox" in m:
|
| 252 |
+
bb = m["bbox"] # could be [x,y,w,h] or [x1,y1,x2,y2]
|
| 253 |
+
if len(bb) == 4:
|
| 254 |
+
x1, y1 = float(bb[0]), float(bb[1])
|
| 255 |
+
# Heuristic: if last two are smaller than first two, treat as w/h
|
| 256 |
+
if bb[2] < bb[0] or bb[3] < bb[1]:
|
| 257 |
+
x2, y2 = x1 + float(bb[2]), y1 + float(bb[3])
|
| 258 |
+
else:
|
| 259 |
+
x2, y2 = float(bb[2]), float(bb[3])
|
| 260 |
+
masks_out.append({
|
| 261 |
+
"bbox_norm": {
|
| 262 |
+
"x1": x1 / W if x1 > 1.5 else x1,
|
| 263 |
+
"y1": y1 / H if y1 > 1.5 else y1,
|
| 264 |
+
"x2": x2 / W if x2 > 1.5 else x2,
|
| 265 |
+
"y2": y2 / H if y2 > 1.5 else y2,
|
| 266 |
+
},
|
| 267 |
+
"area_fraction": float(m.get("area", (x2 - x1) * (y2 - y1))) / max(W * H, 1),
|
| 268 |
+
})
|
| 269 |
+
continue
|
| 270 |
+
# Fall back to decoding the RLE mask via pycocotools
|
| 271 |
+
from pycocotools import mask as maskUtils # type: ignore
|
| 272 |
+
import numpy as np # type: ignore
|
| 273 |
+
rle = m if isinstance(m, dict) else {"counts": m, "size": [H, W]}
|
| 274 |
+
if "size" not in rle:
|
| 275 |
+
rle["size"] = [H, W]
|
| 276 |
+
if isinstance(rle.get("counts"), str):
|
| 277 |
+
rle["counts"] = rle["counts"].encode()
|
| 278 |
+
decoded = maskUtils.decode(rle)
|
| 279 |
+
if decoded is None or decoded.size == 0:
|
| 280 |
+
continue
|
| 281 |
+
ys, xs = np.where(decoded > 0)
|
| 282 |
+
if xs.size == 0:
|
| 283 |
+
continue
|
| 284 |
+
x1, y1 = int(xs.min()), int(ys.min())
|
| 285 |
+
x2, y2 = int(xs.max()), int(ys.max())
|
| 286 |
+
masks_out.append({
|
| 287 |
+
"bbox_norm": {"x1": x1 / W, "y1": y1 / H, "x2": x2 / W, "y2": y2 / H},
|
| 288 |
+
"area_fraction": float(decoded.sum()) / max(W * H, 1),
|
| 289 |
+
})
|
| 290 |
+
except Exception as e:
|
| 291 |
+
_log(f"mask parse failed: {e}")
|
| 292 |
+
|
| 293 |
+
return {"count": len(masks_out), "masks": masks_out}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _heavy_install_and_load() -> None:
|
| 297 |
+
"""Background thread: install heavy deps, download model, load inference engine."""
|
| 298 |
+
global _engine, _tokenizer, _image_processor, _model, _model_args, _sampling_params, _torch
|
| 299 |
+
try:
|
| 300 |
+
STATE["phase"] = "installing pip"
|
| 301 |
+
_log("installing transformers + qwen-vl-utils + accelerate + safetensors …")
|
| 302 |
+
if not _pip([
|
| 303 |
+
"transformers>=4.49.0,<5",
|
| 304 |
+
"qwen-vl-utils>=0.0.10",
|
| 305 |
+
"accelerate>=0.34",
|
| 306 |
+
"safetensors>=0.4",
|
| 307 |
+
"einops>=0.8.0",
|
| 308 |
+
"opencv-python>=4.10.0",
|
| 309 |
+
"scipy>=1.13.0",
|
| 310 |
+
"pycocotools>=2.0.7",
|
| 311 |
+
"tyro>=0.8.0",
|
| 312 |
+
"huggingface_hub>=0.26",
|
| 313 |
+
"numpy<2", # falcon-perception is happier with numpy 1.x
|
| 314 |
+
]):
|
| 315 |
+
raise RuntimeError("pip install of heavy deps failed")
|
| 316 |
+
|
| 317 |
+
STATE["phase"] = "installing falcon-perception"
|
| 318 |
+
_log("installing falcon-perception (--no-deps to preserve base torch)…")
|
| 319 |
+
if not _pip(["--no-deps", "falcon-perception"]):
|
| 320 |
+
raise RuntimeError("pip install of falcon-perception failed")
|
| 321 |
+
|
| 322 |
+
STATE["phase"] = "loading model"
|
| 323 |
+
_log("importing torch + falcon_perception …")
|
| 324 |
+
import torch as _t # type: ignore
|
| 325 |
+
_torch = _t
|
| 326 |
+
|
| 327 |
+
from falcon_perception import ( # type: ignore
|
| 328 |
+
PERCEPTION_MODEL_ID,
|
| 329 |
+
build_prompt_for_task,
|
| 330 |
+
load_and_prepare_model,
|
| 331 |
+
setup_torch_config,
|
| 332 |
+
)
|
| 333 |
+
from falcon_perception.data import ImageProcessor # type: ignore
|
| 334 |
+
from falcon_perception.paged_inference import ( # type: ignore
|
| 335 |
+
PagedInferenceEngine,
|
| 336 |
+
SamplingParams,
|
| 337 |
+
Sequence,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
STATE["model_id"] = PERCEPTION_MODEL_ID
|
| 341 |
+
_log(f"model id: {PERCEPTION_MODEL_ID}")
|
| 342 |
+
_log("setting up torch …")
|
| 343 |
+
setup_torch_config()
|
| 344 |
+
|
| 345 |
+
_log("loading model + processor (downloads ~600 MB on first run, may take 2-5 min)…")
|
| 346 |
+
load_t0 = time.time()
|
| 347 |
+
_model, _tokenizer, _model_args = load_and_prepare_model(
|
| 348 |
+
hf_model_id=PERCEPTION_MODEL_ID,
|
| 349 |
+
hf_revision="main",
|
| 350 |
+
hf_local_dir=None,
|
| 351 |
+
device=None, # let model pick CUDA
|
| 352 |
+
dtype="bfloat16",
|
| 353 |
+
compile=False, # skip torch.compile to keep load fast (~30s vs 60s+)
|
| 354 |
+
)
|
| 355 |
+
_log("instantiating ImageProcessor + PagedInferenceEngine…")
|
| 356 |
+
_image_processor = ImageProcessor(patch_size=16, merge_size=1)
|
| 357 |
+
_engine = PagedInferenceEngine(
|
| 358 |
+
_model, _tokenizer, _image_processor,
|
| 359 |
+
max_batch_size=1,
|
| 360 |
+
max_seq_length=8192,
|
| 361 |
+
n_pages=128,
|
| 362 |
+
page_size=128,
|
| 363 |
+
prefill_length_limit=8192,
|
| 364 |
+
enable_hr_cache=False,
|
| 365 |
+
capture_cudagraph=False,
|
| 366 |
+
)
|
| 367 |
+
_sampling_params = SamplingParams(
|
| 368 |
+
stop_token_ids=[_tokenizer.eos_token_id, _tokenizer.end_of_query_token_id],
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
STATE["load_seconds"] = round(time.time() - load_t0, 1)
|
| 372 |
+
STATE["device"] = "cuda" if _torch.cuda.is_available() else "cpu"
|
| 373 |
+
|
| 374 |
+
# Quick warmup so the first real request isn't 30s slower than steady state
|
| 375 |
+
_log("warmup pass on a dummy image…")
|
| 376 |
+
warmup_img = Image.new("RGB", (256, 256), color=(128, 128, 128))
|
| 377 |
+
warmup_seqs = [Sequence(
|
| 378 |
+
text=build_prompt_for_task("anything", "detection"),
|
| 379 |
+
image=warmup_img,
|
| 380 |
+
min_image_size=256,
|
| 381 |
+
max_image_size=512,
|
| 382 |
+
task="detection",
|
| 383 |
+
)]
|
| 384 |
+
with _torch.inference_mode():
|
| 385 |
+
_engine.generate(warmup_seqs, sampling_params=_sampling_params,
|
| 386 |
+
use_tqdm=False, print_stats=False)
|
| 387 |
+
|
| 388 |
+
STATE["phase"] = "ready"
|
| 389 |
+
STATE["model_loaded"] = True
|
| 390 |
+
_log(f"✓ READY in {time.time() - BOOT_T0:.1f}s total")
|
| 391 |
+
|
| 392 |
+
except Exception as e:
|
| 393 |
+
STATE["phase"] = "FAILED"
|
| 394 |
+
STATE["error"] = str(e)
|
| 395 |
+
_log(f"FATAL: {e}")
|
| 396 |
+
_log(traceback.format_exc())
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
# Kick off the install thread now (server hasn't started yet but the import is done)
|
| 400 |
+
threading.Thread(target=_heavy_install_and_load, daemon=True).start()
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# ============================================================
|
| 404 |
+
# Run the server
|
| 405 |
+
# ============================================================
|
| 406 |
+
|
| 407 |
+
if __name__ == "__main__":
|
| 408 |
+
import uvicorn
|
| 409 |
+
port = int(os.environ.get("PORT", "8000"))
|
| 410 |
+
_log(f"booting uvicorn on 0.0.0.0:{port}")
|
| 411 |
+
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
|
pyproject.toml
CHANGED
|
@@ -57,6 +57,18 @@ runpod = [
|
|
| 57 |
"datasets>=3.0",
|
| 58 |
"pyarrow>=17.0",
|
| 59 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
dev = [
|
| 61 |
"pytest>=7.0",
|
| 62 |
"ruff>=0.5.0",
|
|
@@ -73,8 +85,13 @@ data_label_factory = "data_label_factory.cli:main"
|
|
| 73 |
data-label-factory = "data_label_factory.cli:main"
|
| 74 |
|
| 75 |
[tool.setuptools]
|
| 76 |
-
packages = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
[tool.setuptools.package-data]
|
| 79 |
data_label_factory = ["*.py"]
|
| 80 |
"data_label_factory.runpod" = ["*.py", "*.md", "Dockerfile", "requirements-pod.txt"]
|
|
|
|
|
|
| 57 |
"datasets>=3.0",
|
| 58 |
"pyarrow>=17.0",
|
| 59 |
]
|
| 60 |
+
identify = [
|
| 61 |
+
# Open-set CLIP retrieval: build/verify/train/serve a card-style index
|
| 62 |
+
"torch>=2.1",
|
| 63 |
+
"torchvision>=0.16",
|
| 64 |
+
"numpy>=1.24,<2",
|
| 65 |
+
"fastapi>=0.115",
|
| 66 |
+
"uvicorn[standard]>=0.32",
|
| 67 |
+
"python-multipart>=0.0.20",
|
| 68 |
+
"pillow>=10.0",
|
| 69 |
+
"ultralytics>=8.3",
|
| 70 |
+
"clip @ git+https://github.com/openai/CLIP.git",
|
| 71 |
+
]
|
| 72 |
dev = [
|
| 73 |
"pytest>=7.0",
|
| 74 |
"ruff>=0.5.0",
|
|
|
|
| 85 |
data-label-factory = "data_label_factory.cli:main"
|
| 86 |
|
| 87 |
[tool.setuptools]
|
| 88 |
+
packages = [
|
| 89 |
+
"data_label_factory",
|
| 90 |
+
"data_label_factory.runpod",
|
| 91 |
+
"data_label_factory.identify",
|
| 92 |
+
]
|
| 93 |
|
| 94 |
[tool.setuptools.package-data]
|
| 95 |
data_label_factory = ["*.py"]
|
| 96 |
"data_label_factory.runpod" = ["*.py", "*.md", "Dockerfile", "requirements-pod.txt"]
|
| 97 |
+
"data_label_factory.identify" = ["*.py", "*.md"]
|
web/app/api/falcon-frame/route.ts
CHANGED
|
@@ -24,6 +24,9 @@ type Bbox = {
|
|
| 24 |
y2: number;
|
| 25 |
score: number;
|
| 26 |
label: string;
|
|
|
|
|
|
|
|
|
|
| 27 |
};
|
| 28 |
|
| 29 |
const FALCON_URL = process.env.FALCON_URL ?? "http://localhost:8500/api/falcon";
|
|
@@ -96,17 +99,30 @@ export async function POST(req: NextRequest) {
|
|
| 96 |
upstreamCount = data.count ?? 0;
|
| 97 |
imgW = data.image_size?.[0] ?? data.width ?? 0;
|
| 98 |
imgH = data.image_size?.[1] ?? data.height ?? 0;
|
| 99 |
-
// mac_tensor returns masks: [{bbox_norm:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
for (const m of data.masks ?? []) {
|
| 101 |
const bn = m.bbox_norm ?? {};
|
| 102 |
if (bn.x1 == null) continue;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
bboxes.push({
|
| 104 |
x1: bn.x1,
|
| 105 |
y1: bn.y1,
|
| 106 |
x2: bn.x2,
|
| 107 |
y2: bn.y2,
|
| 108 |
-
score: m.area_fraction ?? 1,
|
| 109 |
-
label: query,
|
|
|
|
|
|
|
|
|
|
| 110 |
});
|
| 111 |
}
|
| 112 |
}
|
|
|
|
| 24 |
y2: number;
|
| 25 |
score: number;
|
| 26 |
label: string;
|
| 27 |
+
ref_url?: string; // URL to a reference image (for the live tracker sidebar)
|
| 28 |
+
margin?: number;
|
| 29 |
+
confident?: boolean;
|
| 30 |
};
|
| 31 |
|
| 32 |
const FALCON_URL = process.env.FALCON_URL ?? "http://localhost:8500/api/falcon";
|
|
|
|
| 99 |
upstreamCount = data.count ?? 0;
|
| 100 |
imgW = data.image_size?.[0] ?? data.width ?? 0;
|
| 101 |
imgH = data.image_size?.[1] ?? data.height ?? 0;
|
| 102 |
+
// mac_tensor returns masks: [{bbox_norm:{x1,y1,x2,y2}, area_fraction, label?, score?, ref_filename?}]
|
| 103 |
+
// The label/score/ref_filename are present in identify-mode (CLIP retrieval).
|
| 104 |
+
// Construct an absolute ref_url from the upstream base + filename so the
|
| 105 |
+
// browser can render the reference card image directly without an extra
|
| 106 |
+
// proxy hop.
|
| 107 |
+
const upstreamBase = new URL(FALCON_URL);
|
| 108 |
+
upstreamBase.pathname = "/refs/";
|
| 109 |
for (const m of data.masks ?? []) {
|
| 110 |
const bn = m.bbox_norm ?? {};
|
| 111 |
if (bn.x1 == null) continue;
|
| 112 |
+
let ref_url: string | undefined = undefined;
|
| 113 |
+
if (typeof m.ref_filename === "string" && m.ref_filename) {
|
| 114 |
+
ref_url = upstreamBase.toString() + m.ref_filename;
|
| 115 |
+
}
|
| 116 |
bboxes.push({
|
| 117 |
x1: bn.x1,
|
| 118 |
y1: bn.y1,
|
| 119 |
x2: bn.x2,
|
| 120 |
y2: bn.y2,
|
| 121 |
+
score: typeof m.score === "number" ? m.score : (m.area_fraction ?? 1),
|
| 122 |
+
label: typeof m.label === "string" && m.label ? m.label : query,
|
| 123 |
+
ref_url,
|
| 124 |
+
margin: typeof m.margin === "number" ? m.margin : undefined,
|
| 125 |
+
confident: typeof m.confident === "boolean" ? m.confident : undefined,
|
| 126 |
});
|
| 127 |
}
|
| 128 |
}
|
web/app/canvas/live/page.tsx
CHANGED
|
@@ -25,7 +25,14 @@ type SourceMode = "idle" | "file" | "webcam";
|
|
| 25 |
type FalconResponse = {
|
| 26 |
ok: boolean;
|
| 27 |
count?: number;
|
| 28 |
-
bboxes?: Array<{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
image_size?: { w: number; h: number };
|
| 30 |
elapsed_ms?: number;
|
| 31 |
upstream?: string;
|
|
@@ -43,6 +50,11 @@ export default function LiveTrackerPage() {
|
|
| 43 |
const streamRef = useRef<MediaStream | null>(null);
|
| 44 |
const objectUrlRef = useRef<string | null>(null);
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
const [mode, setMode] = useState<SourceMode>("idle");
|
| 47 |
const [query, setQuery] = useState<string>("fiber optic drone");
|
| 48 |
const [activeTracks, setActiveTracks] = useState<Track[]>([]);
|
|
@@ -98,7 +110,7 @@ export default function LiveTrackerPage() {
|
|
| 98 |
|
| 99 |
const form = new FormData();
|
| 100 |
form.set("image", blob, "frame.jpg");
|
| 101 |
-
form.set("query",
|
| 102 |
|
| 103 |
let resp: FalconResponse;
|
| 104 |
try {
|
|
@@ -133,6 +145,7 @@ export default function LiveTrackerPage() {
|
|
| 133 |
y2: isNormalized ? b.y2 * H : b.y2,
|
| 134 |
score: b.score,
|
| 135 |
label: b.label,
|
|
|
|
| 136 |
};
|
| 137 |
});
|
| 138 |
|
|
@@ -347,7 +360,10 @@ export default function LiveTrackerPage() {
|
|
| 347 |
<input
|
| 348 |
type="text"
|
| 349 |
value={query}
|
| 350 |
-
onChange={(e) =>
|
|
|
|
|
|
|
|
|
|
| 351 |
className="px-3 py-1.5 rounded-md bg-zinc-800 border border-zinc-700 text-zinc-100 text-sm w-64 focus:outline-none focus:ring-2 focus:ring-cyan-500"
|
| 352 |
placeholder="e.g. fiber optic drone"
|
| 353 |
/>
|
|
@@ -394,19 +410,47 @@ export default function LiveTrackerPage() {
|
|
| 394 |
{activeTracks.length === 0 ? (
|
| 395 |
<div className="text-sm text-zinc-500">none yet</div>
|
| 396 |
) : (
|
| 397 |
-
<div className="space-y-
|
| 398 |
{activeTracks.map((t) => (
|
| 399 |
-
<div
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
/
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
</div>
|
| 407 |
-
<span className="text-zinc-400 text-xs whitespace-nowrap ml-2">
|
| 408 |
-
{t.hits}/{t.age}f
|
| 409 |
-
</span>
|
| 410 |
</div>
|
| 411 |
))}
|
| 412 |
</div>
|
|
|
|
| 25 |
type FalconResponse = {
|
| 26 |
ok: boolean;
|
| 27 |
count?: number;
|
| 28 |
+
bboxes?: Array<{
|
| 29 |
+
x1: number; y1: number; x2: number; y2: number;
|
| 30 |
+
score: number;
|
| 31 |
+
label: string;
|
| 32 |
+
ref_url?: string;
|
| 33 |
+
margin?: number;
|
| 34 |
+
confident?: boolean;
|
| 35 |
+
}>;
|
| 36 |
image_size?: { w: number; h: number };
|
| 37 |
elapsed_ms?: number;
|
| 38 |
upstream?: string;
|
|
|
|
| 50 |
const streamRef = useRef<MediaStream | null>(null);
|
| 51 |
const objectUrlRef = useRef<string | null>(null);
|
| 52 |
|
| 53 |
+
// Live query ref — read from inside the sendNextFrame loop instead of
|
| 54 |
+
// closure-captured `query` to avoid stale-closure bugs when the user
|
| 55 |
+
// types a new query mid-stream.
|
| 56 |
+
const queryRef = useRef<string>("fiber optic drone");
|
| 57 |
+
|
| 58 |
const [mode, setMode] = useState<SourceMode>("idle");
|
| 59 |
const [query, setQuery] = useState<string>("fiber optic drone");
|
| 60 |
const [activeTracks, setActiveTracks] = useState<Track[]>([]);
|
|
|
|
| 110 |
|
| 111 |
const form = new FormData();
|
| 112 |
form.set("image", blob, "frame.jpg");
|
| 113 |
+
form.set("query", queryRef.current);
|
| 114 |
|
| 115 |
let resp: FalconResponse;
|
| 116 |
try {
|
|
|
|
| 145 |
y2: isNormalized ? b.y2 * H : b.y2,
|
| 146 |
score: b.score,
|
| 147 |
label: b.label,
|
| 148 |
+
ref_url: b.ref_url,
|
| 149 |
};
|
| 150 |
});
|
| 151 |
|
|
|
|
| 360 |
<input
|
| 361 |
type="text"
|
| 362 |
value={query}
|
| 363 |
+
onChange={(e) => {
|
| 364 |
+
setQuery(e.target.value);
|
| 365 |
+
queryRef.current = e.target.value;
|
| 366 |
+
}}
|
| 367 |
className="px-3 py-1.5 rounded-md bg-zinc-800 border border-zinc-700 text-zinc-100 text-sm w-64 focus:outline-none focus:ring-2 focus:ring-cyan-500"
|
| 368 |
placeholder="e.g. fiber optic drone"
|
| 369 |
/>
|
|
|
|
| 410 |
{activeTracks.length === 0 ? (
|
| 411 |
<div className="text-sm text-zinc-500">none yet</div>
|
| 412 |
) : (
|
| 413 |
+
<div className="space-y-3">
|
| 414 |
{activeTracks.map((t) => (
|
| 415 |
+
<div
|
| 416 |
+
key={t.id}
|
| 417 |
+
className="rounded-md border border-zinc-800 bg-zinc-950 p-2"
|
| 418 |
+
>
|
| 419 |
+
<div className="flex items-start gap-3">
|
| 420 |
+
{/* Reference card image (if backend provided one) */}
|
| 421 |
+
{t.ref_url ? (
|
| 422 |
+
/* eslint-disable-next-line @next/next/no-img-element */
|
| 423 |
+
<img
|
| 424 |
+
src={t.ref_url}
|
| 425 |
+
alt={t.label}
|
| 426 |
+
className="w-16 h-auto rounded border-2 flex-shrink-0"
|
| 427 |
+
style={{ borderColor: t.color }}
|
| 428 |
+
/>
|
| 429 |
+
) : (
|
| 430 |
+
<div
|
| 431 |
+
className="w-16 h-22 rounded border-2 flex-shrink-0 flex items-center justify-center text-xs text-zinc-600"
|
| 432 |
+
style={{ borderColor: t.color }}
|
| 433 |
+
>
|
| 434 |
+
no ref
|
| 435 |
+
</div>
|
| 436 |
+
)}
|
| 437 |
+
<div className="flex-1 min-w-0">
|
| 438 |
+
<div className="flex items-center gap-1.5">
|
| 439 |
+
<span
|
| 440 |
+
className="h-2 w-2 rounded-sm flex-shrink-0"
|
| 441 |
+
style={{ backgroundColor: t.color }}
|
| 442 |
+
/>
|
| 443 |
+
<span className="text-xs text-zinc-500 font-mono">#{t.id}</span>
|
| 444 |
+
</div>
|
| 445 |
+
<div className="text-sm text-zinc-100 leading-tight mt-1 break-words">
|
| 446 |
+
{t.label}
|
| 447 |
+
</div>
|
| 448 |
+
<div className="text-xs text-zinc-500 mt-1.5 font-mono">
|
| 449 |
+
{typeof t.score === "number" ? `score ${t.score.toFixed(2)} · ` : ""}
|
| 450 |
+
seen {t.hits}/{t.age}f
|
| 451 |
+
</div>
|
| 452 |
+
</div>
|
| 453 |
</div>
|
|
|
|
|
|
|
|
|
|
| 454 |
</div>
|
| 455 |
))}
|
| 456 |
</div>
|
web/lib/iou-tracker.ts
CHANGED
|
@@ -12,6 +12,7 @@ export type Detection = {
|
|
| 12 |
y2: number;
|
| 13 |
score?: number;
|
| 14 |
label?: string;
|
|
|
|
| 15 |
};
|
| 16 |
|
| 17 |
export type Track = Detection & {
|
|
@@ -103,6 +104,7 @@ export class IoUTracker {
|
|
| 103 |
t.x1 = d.x1; t.y1 = d.y1; t.x2 = d.x2; t.y2 = d.y2;
|
| 104 |
t.score = d.score ?? t.score;
|
| 105 |
t.label = d.label ?? t.label;
|
|
|
|
| 106 |
t.age += 1;
|
| 107 |
t.hits += 1;
|
| 108 |
t.framesSinceSeen = 0;
|
|
|
|
| 12 |
y2: number;
|
| 13 |
score?: number;
|
| 14 |
label?: string;
|
| 15 |
+
ref_url?: string;
|
| 16 |
};
|
| 17 |
|
| 18 |
export type Track = Detection & {
|
|
|
|
| 104 |
t.x1 = d.x1; t.y1 = d.y1; t.x2 = d.x2; t.y2 = d.y2;
|
| 105 |
t.score = d.score ?? t.score;
|
| 106 |
t.label = d.label ?? t.label;
|
| 107 |
+
t.ref_url = d.ref_url ?? t.ref_url;
|
| 108 |
t.age += 1;
|
| 109 |
t.hits += 1;
|
| 110 |
t.framesSinceSeen = 0;
|