Niiaz Bureev commited on
Commit
8816141
·
1 Parent(s): 9afba6d

added draft of demo

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +504 -0
  3. requirements.txt +13 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .ipynb_checkpoints/
app.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import json
4
+ import time
5
+ from pathlib import Path
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import numpy as np
10
+ from PIL import Image, ImageDraw
11
+
12
+ import gradio as gr
13
+ import torch
14
+ from huggingface_hub import hf_hub_download
15
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
16
+
17
+ # =============================================================================
18
+ # Configuration via env vars (edit in Space Settings -> Variables if needed)
19
+ # =============================================================================
20
+ DET_REPO_ID = os.getenv("DET_REPO_ID", "mr3vial/paleo-hebrew-yolo")
21
+ DET_FILENAME = os.getenv("DET_FILENAME", "best.onnx") # recommended
22
+ DET_FALLBACK_PT = os.getenv("DET_FALLBACK_PT", "best.pt")
23
+
24
+ CLS_REPO_ID = os.getenv("CLS_REPO_ID", "mr3vial/paleo-hebrew-convnext")
25
+ CLS_WEIGHTS = os.getenv("CLS_WEIGHTS", "model.safetensors")
26
+ CLS_CLASSES = os.getenv("CLS_CLASSES", "classes.txt")
27
+ CLS_MODEL_NAME = os.getenv("CLS_MODEL_NAME", "convnext_large")
28
+
29
+ MT5_BOX2HE_REPO = os.getenv("MT5_BOX2HE_REPO", "mr3vial/paleo-hebrew-mt5-post-ocr-processing")
30
+ MT5_BOX2EN_REPO = os.getenv("MT5_BOX2EN_REPO", "mr3vial/paleo-hebrew-mt5-translate")
31
+
32
+ FEEDBACK_PATH = os.getenv("FEEDBACK_PATH", "/tmp/feedback.jsonl")
33
+ EXAMPLES_DIR = Path(__file__).parent / "examples"
34
+
35
+ # =============================================================================
36
+ # Helpers
37
+ # =============================================================================
38
+ def now_ts() -> str:
39
+ return time.strftime("%Y-%m-%d %H:%M:%S")
40
+
41
+ def normalize_spaces(s: str) -> str:
42
+ return " ".join((s or "").strip().split())
43
+
44
+ def get_device() -> str:
45
+ return "cuda" if torch.cuda.is_available() else "cpu"
46
+
47
+ def safe_int(x: float, default: int = 0) -> int:
48
+ try:
49
+ return int(round(float(x)))
50
+ except Exception:
51
+ return default
52
+
53
+ def sort_boxes_reading_order(boxes_xyxy: List[List[float]], rtl: bool = True) -> List[int]:
54
+ """
55
+ Rough reading order:
56
+ - sort by y-center, then within a line by x-center (rtl: desc, ltr: asc).
57
+ """
58
+ if not boxes_xyxy:
59
+ return []
60
+ centers = []
61
+ heights = []
62
+ for i, (x1, y1, x2, y2) in enumerate(boxes_xyxy):
63
+ xc = 0.5 * (x1 + x2)
64
+ yc = 0.5 * (y1 + y2)
65
+ centers.append((i, xc, yc))
66
+ heights.append(max(1.0, y2 - y1))
67
+ line_tol = float(np.median(heights)) * 0.6
68
+
69
+ centers_sorted = sorted(centers, key=lambda t: t[2])
70
+ lines: List[List[Tuple[int, float, float]]] = []
71
+ for item in centers_sorted:
72
+ if not lines:
73
+ lines.append([item])
74
+ continue
75
+ _, _, yc = item
76
+ prev_line = lines[-1]
77
+ prev_y = float(np.mean([p[2] for p in prev_line]))
78
+ if abs(yc - prev_y) <= line_tol:
79
+ prev_line.append(item)
80
+ else:
81
+ lines.append([item])
82
+
83
+ idxs: List[int] = []
84
+ for line in lines:
85
+ line_sorted = sorted(line, key=lambda t: t[1], reverse=rtl)
86
+ idxs.extend([i for i, _, _ in line_sorted])
87
+ return idxs
88
+
89
+ # =============================================================================
90
+ # Load models (Detector, Classifier, mT5)
91
+ # =============================================================================
92
+ @torch.no_grad()
93
+ def load_detector():
94
+ from ultralytics import YOLO
95
+ try:
96
+ det_path = hf_hub_download(repo_id=DET_REPO_ID, filename=DET_FILENAME)
97
+ return YOLO(det_path), f"{DET_REPO_ID}/{DET_FILENAME}"
98
+ except Exception:
99
+ pt_path = hf_hub_download(repo_id=DET_REPO_ID, filename=DET_FALLBACK_PT)
100
+ return YOLO(pt_path), f"{DET_REPO_ID}/{DET_FALLBACK_PT} (PT fallback)"
101
+
102
+ @torch.no_grad()
103
+ def load_classifier():
104
+ import timm
105
+ from safetensors.torch import load_file as safetensors_load
106
+ device = get_device()
107
+
108
+ classes_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_CLASSES)
109
+ with open(classes_path, "r", encoding="utf-8") as f:
110
+ raw = [ln.strip() for ln in f if ln.strip()]
111
+
112
+ letters: List[str] = []
113
+ for ln in raw:
114
+ if "_" in ln:
115
+ letters.append(ln.split("_")[-1])
116
+ else:
117
+ letters.append(ln)
118
+
119
+ weights_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_WEIGHTS)
120
+ model = timm.create_model(CLS_MODEL_NAME, pretrained=False, num_classes=len(letters))
121
+ state = safetensors_load(weights_path)
122
+
123
+ new_state = {}
124
+ for k, v in state.items():
125
+ nk = k
126
+ for pref in ("module.", "model."):
127
+ if nk.startswith(pref):
128
+ nk = nk[len(pref):]
129
+ new_state[nk] = v
130
+ model.load_state_dict(new_state, strict=False)
131
+ model.eval().to(device)
132
+ return model, letters, f"{CLS_REPO_ID}/{CLS_WEIGHTS}"
133
+
134
+ def preprocess_crop_imagenet(pil: Image.Image) -> torch.Tensor:
135
+ arr = np.asarray(pil.convert("RGB"), dtype=np.float32) / 255.0
136
+ t = torch.from_numpy(arr).permute(2, 0, 1).contiguous()
137
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
138
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
139
+ return (t - mean) / std
140
+
141
+ @torch.no_grad()
142
+ def load_mt5(repo_id: str):
143
+ device = get_device()
144
+ # Memory optimization: Load in bfloat16 to fit multiple models in Space RAM
145
+ dtype = torch.bfloat16 if (device == "cpu" or torch.cuda.is_bf16_supported()) else torch.float16
146
+ tok = AutoTokenizer.from_pretrained(repo_id, use_fast=False)
147
+ model = AutoModelForSeq2SeqLM.from_pretrained(
148
+ repo_id,
149
+ torch_dtype=dtype,
150
+ ).to(device).eval()
151
+ return tok, model
152
+
153
+ @torch.no_grad()
154
+ def mt5_generate(tok, model, text: str, max_new_tokens: int = 128) -> str:
155
+ device = get_device()
156
+ inp = tok([text], return_tensors="pt", truncation=True, max_length=2048).to(device)
157
+ out = model.generate(**inp, max_new_tokens=max_new_tokens, num_beams=4, do_sample=False)
158
+ return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
159
+
160
+ # =============================================================================
161
+ # Optional Hebrew->English translators (user-selectable)
162
+ # =============================================================================
163
+ TRANSLATOR_CHOICES = [
164
+ ("None (use direct box→en)", "none"),
165
+ ("OPUS-MT he→en (Helsinki-NLP/opus-mt-tc-big-he-en)", "opus"),
166
+ ("NLLB-200 distilled 600M (he→en) [CC-BY-NC]", "nllb"),
167
+ ("M2M100 418M (he→en)", "m2m100"),
168
+ ]
169
+
170
+ _TRANSLATOR_CACHE: Dict[str, Tuple[Any, Any, str]] = {} # kind -> (tok, model, kind)
171
+
172
+ @torch.no_grad()
173
+ def get_he2en_translator(kind: str):
174
+ global _TRANSLATOR_CACHE
175
+ if kind in _TRANSLATOR_CACHE:
176
+ return _TRANSLATOR_CACHE[kind]
177
+
178
+ # Memory Safety: Clear old translators from memory before loading a new one
179
+ _TRANSLATOR_CACHE.clear()
180
+ gc.collect()
181
+ if torch.cuda.is_available():
182
+ torch.cuda.empty_cache()
183
+
184
+ device = get_device()
185
+
186
+ if kind == "opus":
187
+ from transformers import MarianMTModel, MarianTokenizer
188
+ repo = "Helsinki-NLP/opus-mt-tc-big-he-en"
189
+ tok = MarianTokenizer.from_pretrained(repo)
190
+ model = MarianMTModel.from_pretrained(repo).to(device).eval()
191
+
192
+ elif kind == "nllb":
193
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
194
+ repo = "facebook/nllb-200-distilled-600M"
195
+ tok = AutoTokenizer.from_pretrained(repo, src_lang="heb_Hebr")
196
+ model = AutoModelForSeq2SeqLM.from_pretrained(repo).to(device).eval()
197
+
198
+ elif kind == "m2m100":
199
+ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
200
+ repo = "facebook/m2m100_418M"
201
+ tok = M2M100Tokenizer.from_pretrained(repo)
202
+ tok.src_lang = "he"
203
+ model = M2M100ForConditionalGeneration.from_pretrained(repo).to(device).eval()
204
+
205
+ else:
206
+ _TRANSLATOR_CACHE[kind] = (None, None, "none")
207
+ return _TRANSLATOR_CACHE[kind]
208
+
209
+ _TRANSLATOR_CACHE[kind] = (tok, model, kind)
210
+ return _TRANSLATOR_CACHE[kind]
211
+
212
+ @torch.no_grad()
213
+ def translate_he_to_en(text_he: str, kind: str) -> str:
214
+ if kind == "none":
215
+ return ""
216
+ tok, model, kind = get_he2en_translator(kind)
217
+ device = get_device()
218
+
219
+ if kind == "opus":
220
+ batch = tok([text_he], return_tensors="pt", padding=True, truncation=True).to(device)
221
+ out = model.generate(**batch, max_new_tokens=128, num_beams=4)
222
+ return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
223
+
224
+ if kind == "nllb":
225
+ batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device)
226
+ out = model.generate(**batch, forced_bos_token_id=tok.lang_code_to_id["eng_Latn"], max_new_tokens=128, num_beams=4)
227
+ return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
228
+
229
+ if kind == "m2m100":
230
+ batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device)
231
+ out = model.generate(**batch, forced_bos_token_id=tok.get_lang_id("en"), max_new_tokens=128, num_beams=4)
232
+ return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
233
+
234
+ return ""
235
+
236
+ # =============================================================================
237
+ # Pipeline Components
238
+ # =============================================================================
239
+ @dataclass
240
+ class Loaded:
241
+ det: Any
242
+ det_name: str
243
+ cls: Any
244
+ cls_letters: List[str]
245
+ cls_name: str
246
+ mt5_he_tok: Any
247
+ mt5_he: Any
248
+ mt5_en_tok: Any
249
+ mt5_en: Any
250
+
251
+ LOADED: Optional[Loaded] = None
252
+
253
+ def ensure_loaded() -> Loaded:
254
+ global LOADED
255
+ if LOADED is not None:
256
+ return LOADED
257
+
258
+ det, det_name = load_detector()
259
+ cls, letters, cls_name = load_classifier()
260
+ he_tok, he_model = load_mt5(MT5_BOX2HE_REPO)
261
+ en_tok, en_model = load_mt5(MT5_BOX2EN_REPO)
262
+
263
+ LOADED = Loaded(
264
+ det=det, det_name=det_name,
265
+ cls=cls, cls_letters=letters, cls_name=cls_name,
266
+ mt5_he_tok=he_tok, mt5_he=he_model,
267
+ mt5_en_tok=en_tok, mt5_en=en_model,
268
+ )
269
+ return LOADED
270
+
271
+ def yolo_predict(det_model, pil: Image.Image, conf: float, iou: float, max_det: int):
272
+ # Direct inference, no tempfile needed
273
+ res = det_model.predict(
274
+ source=pil,
275
+ conf=float(conf),
276
+ iou=float(iou),
277
+ max_det=int(max_det),
278
+ verbose=False,
279
+ )
280
+ r0 = res[0]
281
+ if r0.boxes is None or len(r0.boxes) == 0:
282
+ return [], []
283
+ xyxy = r0.boxes.xyxy.detach().float().cpu().numpy().tolist()
284
+ scores = r0.boxes.conf.detach().float().cpu().numpy().tolist()
285
+ return xyxy, scores
286
+
287
+ @torch.no_grad()
288
+ def classify_crops(cls_model, letters: List[str], pil: Image.Image, boxes_xyxy, pad: float, topk: int):
289
+ device = get_device()
290
+ W, H = pil.size
291
+ crops = []
292
+ for (x1, y1, x2, y2) in boxes_xyxy:
293
+ w = max(1.0, x2 - x1)
294
+ h = max(1.0, y2 - y1)
295
+ p = int(round(max(w, h) * float(pad)))
296
+ a, b = max(0, safe_int(x1) - p), max(0, safe_int(y1) - p)
297
+ c, d = min(W, safe_int(x2) + p), min(H, safe_int(y2) + p)
298
+ crops.append(pil.crop((a, b, c, d)).resize((224, 224), Image.BICUBIC))
299
+
300
+ if not crops:
301
+ return [], [], []
302
+
303
+ batch = torch.stack([preprocess_crop_imagenet(c) for c in crops], dim=0).to(device)
304
+ logits = cls_model(batch)
305
+ probs = torch.softmax(logits, dim=-1).detach().float().cpu().numpy()
306
+
307
+ topk_list, top1_letters = [], []
308
+ for i in range(len(crops)):
309
+ p = probs[i]
310
+ idx = np.argsort(-p)[: max(1, int(topk))]
311
+ row = [(letters[j], float(p[j])) for j in idx]
312
+ topk_list.append(row)
313
+ top1_letters.append(row[0][0] if row else "?")
314
+
315
+ crop_gallery = [(crops[i], f"{i+1:02d}: {top1_letters[i]}") for i in range(len(crops))]
316
+ return topk_list, top1_letters, crop_gallery
317
+
318
+ def build_source_text(boxes_xyxy: List[List[float]], scores: List[float], topk: List[List[Tuple[str, float]]], rtl: bool) -> str:
319
+ idxs = sort_boxes_reading_order(boxes_xyxy, rtl=rtl)
320
+ lines = ["[BOXES_CLS_V3]", f"n={len(idxs)} coord_norm=none"]
321
+ for j, i in enumerate(idxs, start=1):
322
+ x1, y1, x2, y2 = boxes_xyxy[i]
323
+ sc = scores[i] if i < len(scores) else 1.0
324
+ cands = topk[i] if i < len(topk) else []
325
+ cls_str = " ".join([f"{ch}:{p:.3f}" for ch, p in cands]) if cands else "?"
326
+ lines.append(f"{j:02d} x1={x1:.1f} y1={y1:.1f} x2={x2:.1f} y2={y2:.1f} score={sc:.3f} | cls {cls_str}")
327
+ return "\n".join(lines)
328
+
329
+ def make_annotated(pil: Image.Image, boxes_xyxy, top1_letters, scores):
330
+ img = pil.copy()
331
+ draw = ImageDraw.Draw(img)
332
+ for i, bb in enumerate(boxes_xyxy):
333
+ x1, y1, x2, y2 = [safe_int(v) for v in bb]
334
+ lab = top1_letters[i] if i < len(top1_letters) else "?"
335
+ sc = scores[i] if i < len(scores) else 0.0
336
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
337
+ text_label = f"{i+1:02d}" + (f" {lab}" if lab != "?" else "")
338
+ draw.text((x1, y1), text_label, fill="red")
339
+ return img
340
+
341
+ # =============================================================================
342
+ # Tab Action Handlers
343
+ # =============================================================================
344
+ def run_pipeline_tab(pil, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind):
345
+ if pil is None: return None, [], "", ""
346
+ L = ensure_loaded()
347
+
348
+ boxes, det_scores = yolo_predict(L.det, pil, conf, iou, max_det)
349
+ if not boxes: return pil, [], "No boxes detected.", "No boxes detected."
350
+
351
+ topk_list, top1_letters, crop_gallery = classify_crops(L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k)
352
+ src = build_source_text(boxes, det_scores, topk_list, rtl=rtl)
353
+
354
+ heb = ""
355
+ eng = ""
356
+ if output_mode == "he":
357
+ heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src)
358
+ elif output_mode == "en_direct":
359
+ eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src)
360
+ else: # he_then_en
361
+ heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src)
362
+ if he2en_kind != "none":
363
+ eng = translate_he_to_en(heb, he2en_kind)
364
+ else:
365
+ eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src)
366
+
367
+ annotated = make_annotated(pil, boxes, top1_letters, det_scores)
368
+ return annotated, crop_gallery, heb, eng
369
+
370
+ def run_detector_tab(pil, conf, iou, max_det):
371
+ if pil is None: return None, "{}"
372
+ L = ensure_loaded()
373
+ boxes, scores = yolo_predict(L.det, pil, conf, iou, max_det)
374
+ annotated = make_annotated(pil, boxes, [], scores)
375
+
376
+ debug = {"num_boxes": len(boxes), "boxes_xyxy": boxes, "confidences": scores}
377
+ return annotated, json.dumps(debug, indent=2)
378
+
379
+ def run_classifier_tab(pil, conf, iou, max_det, crop_pad, topk_k):
380
+ if pil is None: return [], "{}"
381
+ L = ensure_loaded()
382
+ boxes, _ = yolo_predict(L.det, pil, conf, iou, max_det)
383
+ if not boxes: return [], "No boxes found to classify."
384
+
385
+ topk_list, top1_letters, crop_gallery = classify_crops(L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k)
386
+
387
+ details = {}
388
+ for i, (row, box) in enumerate(zip(topk_list, boxes)):
389
+ details[f"Box_{i+1:02d}"] = {
390
+ "top_predictions": dict(row),
391
+ "coordinates": box
392
+ }
393
+ return crop_gallery, json.dumps(details, ensure_ascii=False, indent=2)
394
+
395
+ def save_feedback(heb_pred, eng_pred, heb_corr, eng_corr, notes):
396
+ rec = {
397
+ "ts": now_ts(),
398
+ "heb_pred": heb_pred, "eng_pred": eng_pred,
399
+ "heb_corr": heb_corr, "eng_corr": eng_corr,
400
+ "notes": notes,
401
+ }
402
+ with open(FEEDBACK_PATH, "a", encoding="utf-8") as f:
403
+ f.write(json.dumps(rec, ensure_ascii=False) + "\n")
404
+ return f"Saved to {FEEDBACK_PATH}"
405
+
406
+ # =============================================================================
407
+ # UI
408
+ # =============================================================================
409
+ if EXAMPLES_DIR.exists():
410
+ gr.set_static_paths([str(EXAMPLES_DIR)])
411
+
412
+ def list_example_images(max_n: int = 24):
413
+ if not EXAMPLES_DIR.exists(): return []
414
+ paths = [p for p in sorted(EXAMPLES_DIR.iterdir()) if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}]
415
+ return [[str(p)] for p in paths[:max_n]]
416
+
417
+ with gr.Blocks(title="Paleo-Hebrew Tablet Reader") as demo:
418
+ gr.Markdown("# Paleo-Hebrew Epigraphy Pipeline")
419
+
420
+ with gr.Row():
421
+ # LEFT COLUMN: Global Inputs & Settings
422
+ with gr.Column(scale=3):
423
+ inp = gr.Image(type="pil", label="Input Image")
424
+
425
+ with gr.Accordion("Vision Settings", open=True):
426
+ conf = gr.Slider(0.05, 0.95, value=0.25, step=0.01, label="YOLO conf")
427
+ iou = gr.Slider(0.10, 0.90, value=0.45, step=0.01, label="YOLO IoU")
428
+ max_det = gr.Slider(1, 200, value=80, step=1, label="Max Detections")
429
+ crop_pad = gr.Slider(0.0, 0.8, value=0.20, step=0.01, label="Crop Pad")
430
+ topk_k = gr.Slider(1, 10, value=5, step=1, label="Classifier Top-K")
431
+
432
+ with gr.Accordion("Translation Settings", open=True):
433
+ rtl = gr.Checkbox(value=True, label="RTL order (right→left)")
434
+ output_mode = gr.Radio(["he", "en_direct", "he_then_en"], value="he_then_en", label="Output mode")
435
+ he2en_kind = gr.Dropdown(choices=TRANSLATOR_CHOICES, value="opus", label="Hebrew → English translator")
436
+
437
+ ex = list_example_images()
438
+ if ex:
439
+ gr.Examples(examples=ex, inputs=[inp], cache_examples=False)
440
+
441
+ # RIGHT COLUMN: Tabs for different pipeline stages
442
+ with gr.Column(scale=7):
443
+ with gr.Tabs():
444
+
445
+ # TAB 1: Main E2E Pipeline
446
+ with gr.Tab("End-to-End Pipeline"):
447
+ run_pipe_btn = gr.Button("Run Full Pipeline", variant="primary")
448
+ out_annot_pipe = gr.Image(type="pil", label="Detections (bbox + top1)")
449
+ with gr.Row():
450
+ out_he_pipe = gr.Textbox(label="Hebrew Output", lines=3, interactive=False)
451
+ out_en_pipe = gr.Textbox(label="English Output", lines=3, interactive=False)
452
+ out_crops_pipe = gr.Gallery(label="Letter crops").style(grid=6, height="auto") if hasattr(gr.Gallery, 'style') else gr.Gallery(label="Letter crops")
453
+
454
+ # TAB 2: Detector Only
455
+ with gr.Tab("Detector (YOLO)"):
456
+ run_det_btn = gr.Button("Run Detector", variant="primary")
457
+ out_annot_det = gr.Image(type="pil", label="Detected Bounding Boxes")
458
+ out_json_det = gr.Code(label="Raw Detection Output", language="json")
459
+
460
+ # TAB 3: Classifier Only
461
+ with gr.Tab("Classifier (ConvNeXt)"):
462
+ run_cls_btn = gr.Button("Run Classifier", variant="primary")
463
+ out_crops_cls = gr.Gallery(label="Isolated Letter Crops").style(grid=6, height="auto") if hasattr(gr.Gallery, 'style') else gr.Gallery(label="Isolated Letter Crops")
464
+ out_json_cls = gr.Code(label="Top-K Probabilities per Box", language="json")
465
+
466
+ # TAB 4: Feedback
467
+ with gr.Tab("Feedback"):
468
+ gr.Markdown("Submit corrections to improve the dataset.")
469
+ heb_corr = gr.Textbox(label="Correct Hebrew", lines=2)
470
+ eng_corr = gr.Textbox(label="Correct English", lines=2)
471
+ notes = gr.Textbox(label="Notes (e.g., missed boxes, ambiguous glyphs)", lines=2)
472
+ save_btn = gr.Button("Submit Feedback")
473
+ save_status = gr.Textbox(label="Status", interactive=False)
474
+
475
+ # --- Button Clicks ---
476
+ run_pipe_btn.click(
477
+ fn=run_pipeline_tab,
478
+ inputs=[inp, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind],
479
+ outputs=[out_annot_pipe, out_crops_pipe, out_he_pipe, out_en_pipe],
480
+ api_name=False
481
+ )
482
+
483
+ run_det_btn.click(
484
+ fn=run_detector_tab,
485
+ inputs=[inp, conf, iou, max_det],
486
+ outputs=[out_annot_det, out_json_det],
487
+ api_name=False
488
+ )
489
+
490
+ run_cls_btn.click(
491
+ fn=run_classifier_tab,
492
+ inputs=[inp, conf, iou, max_det, crop_pad, topk_k],
493
+ outputs=[out_crops_cls, out_json_cls],
494
+ api_name=False
495
+ )
496
+
497
+ save_btn.click(
498
+ fn=save_feedback,
499
+ inputs=[out_he_pipe, out_en_pipe, heb_corr, eng_corr, notes],
500
+ outputs=[save_status],
501
+ api_name=False
502
+ )
503
+
504
+ demo.queue(max_size=32).launch()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ huggingface_hub>=0.20.0
3
+ transformers>=4.40.0
4
+ sentencepiece
5
+ torch
6
+ numpy
7
+ Pillow
8
+
9
+ ultralytics>=8.0.0
10
+ onnxruntime
11
+
12
+ timm
13
+ safetensors