mr3vial commited on
Commit
b4148b8
·
1 Parent(s): bb38d46
Files changed (1) hide show
  1. app.py +94 -255
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import gc
3
  import json
4
  import time
 
5
  from pathlib import Path
6
  from dataclasses import dataclass
7
  from typing import Any, Dict, List, Optional, Tuple
@@ -23,8 +24,8 @@ DET_FALLBACK_PT = os.getenv("DET_FALLBACK_PT", "best.pt")
23
 
24
  CLS_REPO_ID = os.getenv("CLS_REPO_ID", "mr3vial/paleo-hebrew-convnext")
25
  CLS_WEIGHTS = os.getenv("CLS_WEIGHTS", "convnext_weights.pt")
26
- CLS_CLASSES = os.getenv("CLS_CLASSES", "classes.json")
27
- CLS_MODEL_NAME = os.getenv("CLS_MODEL_NAME", "convnext_large")
28
 
29
  MT5_BOX2HE_REPO = os.getenv("MT5_BOX2HE_REPO", "mr3vial/paleo-hebrew-mt5-post-ocr-processing")
30
  MT5_BOX2EN_REPO = os.getenv("MT5_BOX2EN_REPO", "mr3vial/paleo-hebrew-mt5-translate")
@@ -51,10 +52,6 @@ def safe_int(x: float, default: int = 0) -> int:
51
  return default
52
 
53
  def sort_boxes_reading_order(boxes_xyxy: List[List[float]], rtl: bool = True) -> List[int]:
54
- """
55
- Rough reading order:
56
- - sort by y-center, then within a line by x-center (rtl: desc, ltr: asc).
57
- """
58
  if not boxes_xyxy:
59
  return []
60
  centers = []
@@ -102,12 +99,22 @@ def load_detector():
102
  @torch.no_grad()
103
  def load_classifier():
104
  import timm
105
- from safetensors.torch import load_file as safetensors_load
106
  device = get_device()
107
 
108
  classes_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_CLASSES)
109
- with open(classes_path, "r", encoding="utf-8") as f:
110
- raw = [ln.strip() for ln in f if ln.strip()]
 
 
 
 
 
 
 
 
 
 
111
 
112
  letters: List[str] = []
113
  for ln in raw:
@@ -118,7 +125,10 @@ def load_classifier():
118
 
119
  weights_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_WEIGHTS)
120
  model = timm.create_model(CLS_MODEL_NAME, pretrained=False, num_classes=len(letters))
121
- state = safetensors_load(weights_path)
 
 
 
122
 
123
  new_state = {}
124
  for k, v in state.items():
@@ -141,7 +151,6 @@ def preprocess_crop_imagenet(pil: Image.Image) -> torch.Tensor:
141
  @torch.no_grad()
142
  def load_mt5(repo_id: str):
143
  device = get_device()
144
- # Memory optimization: Load in bfloat16 to fit multiple models in Space RAM
145
  dtype = torch.bfloat16 if (device == "cpu" or torch.cuda.is_bf16_supported()) else torch.float16
146
  tok = AutoTokenizer.from_pretrained(repo_id, use_fast=False)
147
  model = AutoModelForSeq2SeqLM.from_pretrained(
@@ -157,251 +166,89 @@ def mt5_generate(tok, model, text: str, max_new_tokens: int = 128) -> str:
157
  out = model.generate(**inp, max_new_tokens=max_new_tokens, num_beams=4, do_sample=False)
158
  return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
159
 
160
- # =============================================================================
161
- # Optional Hebrew->English translators (user-selectable)
162
- # =============================================================================
163
- TRANSLATOR_CHOICES = [
164
- ("None (use direct box→en)", "none"),
165
- ("OPUS-MT he→en (Helsinki-NLP/opus-mt-tc-big-he-en)", "opus"),
166
- ("NLLB-200 distilled 600M (he→en) [CC-BY-NC]", "nllb"),
167
- ("M2M100 418M (he→en)", "m2m100"),
168
- ]
169
-
170
- _TRANSLATOR_CACHE: Dict[str, Tuple[Any, Any, str]] = {} # kind -> (tok, model, kind)
171
-
172
- @torch.no_grad()
173
- def get_he2en_translator(kind: str):
174
- global _TRANSLATOR_CACHE
175
- if kind in _TRANSLATOR_CACHE:
176
- return _TRANSLATOR_CACHE[kind]
177
-
178
- # Memory Safety: Clear old translators from memory before loading a new one
179
- _TRANSLATOR_CACHE.clear()
180
- gc.collect()
181
- if torch.cuda.is_available():
182
- torch.cuda.empty_cache()
183
-
184
- device = get_device()
185
-
186
- if kind == "opus":
187
- from transformers import MarianMTModel, MarianTokenizer
188
- repo = "Helsinki-NLP/opus-mt-tc-big-he-en"
189
- tok = MarianTokenizer.from_pretrained(repo)
190
- model = MarianMTModel.from_pretrained(repo).to(device).eval()
191
-
192
- elif kind == "nllb":
193
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
194
- repo = "facebook/nllb-200-distilled-600M"
195
- tok = AutoTokenizer.from_pretrained(repo, src_lang="heb_Hebr")
196
- model = AutoModelForSeq2SeqLM.from_pretrained(repo).to(device).eval()
197
-
198
- elif kind == "m2m100":
199
- from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
200
- repo = "facebook/m2m100_418M"
201
- tok = M2M100Tokenizer.from_pretrained(repo)
202
- tok.src_lang = "he"
203
- model = M2M100ForConditionalGeneration.from_pretrained(repo).to(device).eval()
204
-
205
- else:
206
- _TRANSLATOR_CACHE[kind] = (None, None, "none")
207
- return _TRANSLATOR_CACHE[kind]
208
-
209
- _TRANSLATOR_CACHE[kind] = (tok, model, kind)
210
- return _TRANSLATOR_CACHE[kind]
211
-
212
- @torch.no_grad()
213
- def translate_he_to_en(text_he: str, kind: str) -> str:
214
- if kind == "none":
215
- return ""
216
- tok, model, kind = get_he2en_translator(kind)
217
- device = get_device()
218
-
219
- if kind == "opus":
220
- batch = tok([text_he], return_tensors="pt", padding=True, truncation=True).to(device)
221
- out = model.generate(**batch, max_new_tokens=128, num_beams=4)
222
- return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
223
-
224
- if kind == "nllb":
225
- batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device)
226
- out = model.generate(**batch, forced_bos_token_id=tok.lang_code_to_id["eng_Latn"], max_new_tokens=128, num_beams=4)
227
- return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
228
-
229
- if kind == "m2m100":
230
- batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device)
231
- out = model.generate(**batch, forced_bos_token_id=tok.get_lang_id("en"), max_new_tokens=128, num_beams=4)
232
- return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
233
-
234
- return ""
235
-
236
- # =============================================================================
237
- # Pipeline Components
238
- # =============================================================================
239
- @dataclass
240
- class Loaded:
241
- det: Any
242
- det_name: str
243
- cls: Any
244
- cls_letters: List[str]
245
- cls_name: str
246
- mt5_he_tok: Any
247
- mt5_he: Any
248
- mt5_en_tok: Any
249
- mt5_en: Any
250
-
251
- LOADED: Optional[Loaded] = None
252
-
253
- def ensure_loaded() -> Loaded:
254
- global LOADED
255
- if LOADED is not None:
256
- return LOADED
257
-
258
- det, det_name = load_detector()
259
- cls, letters, cls_name = load_classifier()
260
- he_tok, he_model = load_mt5(MT5_BOX2HE_REPO)
261
- en_tok, en_model = load_mt5(MT5_BOX2EN_REPO)
262
-
263
- LOADED = Loaded(
264
- det=det, det_name=det_name,
265
- cls=cls, cls_letters=letters, cls_name=cls_name,
266
- mt5_he_tok=he_tok, mt5_he=he_model,
267
- mt5_en_tok=en_tok, mt5_en=en_model,
268
- )
269
- return LOADED
270
-
271
- def yolo_predict(det_model, pil: Image.Image, conf: float, iou: float, max_det: int):
272
- # Direct inference, no tempfile needed
273
- res = det_model.predict(
274
- source=pil,
275
- conf=float(conf),
276
- iou=float(iou),
277
- max_det=int(max_det),
278
- verbose=False,
279
- )
280
- r0 = res[0]
281
- if r0.boxes is None or len(r0.boxes) == 0:
282
- return [], []
283
- xyxy = r0.boxes.xyxy.detach().float().cpu().numpy().tolist()
284
- scores = r0.boxes.conf.detach().float().cpu().numpy().tolist()
285
- return xyxy, scores
286
-
287
- @torch.no_grad()
288
- def classify_crops(cls_model, letters: List[str], pil: Image.Image, boxes_xyxy, pad: float, topk: int):
289
- device = get_device()
290
- W, H = pil.size
291
- crops = []
292
- for (x1, y1, x2, y2) in boxes_xyxy:
293
- w = max(1.0, x2 - x1)
294
- h = max(1.0, y2 - y1)
295
- p = int(round(max(w, h) * float(pad)))
296
- a, b = max(0, safe_int(x1) - p), max(0, safe_int(y1) - p)
297
- c, d = min(W, safe_int(x2) + p), min(H, safe_int(y2) + p)
298
- crops.append(pil.crop((a, b, c, d)).resize((224, 224), Image.BICUBIC))
299
-
300
- if not crops:
301
- return [], [], []
302
-
303
- batch = torch.stack([preprocess_crop_imagenet(c) for c in crops], dim=0).to(device)
304
- logits = cls_model(batch)
305
- probs = torch.softmax(logits, dim=-1).detach().float().cpu().numpy()
306
-
307
- topk_list, top1_letters = [], []
308
- for i in range(len(crops)):
309
- p = probs[i]
310
- idx = np.argsort(-p)[: max(1, int(topk))]
311
- row = [(letters[j], float(p[j])) for j in idx]
312
- topk_list.append(row)
313
- top1_letters.append(row[0][0] if row else "?")
314
-
315
- crop_gallery = [(crops[i], f"{i+1:02d}: {top1_letters[i]}") for i in range(len(crops))]
316
- return topk_list, top1_letters, crop_gallery
317
-
318
- def build_source_text(boxes_xyxy: List[List[float]], scores: List[float], topk: List[List[Tuple[str, float]]], rtl: bool) -> str:
319
- idxs = sort_boxes_reading_order(boxes_xyxy, rtl=rtl)
320
- lines = ["[BOXES_CLS_V3]", f"n={len(idxs)} coord_norm=none"]
321
- for j, i in enumerate(idxs, start=1):
322
- x1, y1, x2, y2 = boxes_xyxy[i]
323
- sc = scores[i] if i < len(scores) else 1.0
324
- cands = topk[i] if i < len(topk) else []
325
- cls_str = " ".join([f"{ch}:{p:.3f}" for ch, p in cands]) if cands else "?"
326
- lines.append(f"{j:02d} x1={x1:.1f} y1={y1:.1f} x2={x2:.1f} y2={y2:.1f} score={sc:.3f} | cls {cls_str}")
327
- return "\n".join(lines)
328
-
329
- def make_annotated(pil: Image.Image, boxes_xyxy, top1_letters, scores):
330
- img = pil.copy()
331
- draw = ImageDraw.Draw(img)
332
- for i, bb in enumerate(boxes_xyxy):
333
- x1, y1, x2, y2 = [safe_int(v) for v in bb]
334
- lab = top1_letters[i] if i < len(top1_letters) else "?"
335
- sc = scores[i] if i < len(scores) else 0.0
336
- draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
337
- text_label = f"{i+1:02d}" + (f" {lab}" if lab != "?" else "")
338
- draw.text((x1, y1), text_label, fill="red")
339
- return img
340
-
341
  # =============================================================================
342
  # Tab Action Handlers
343
  # =============================================================================
344
  def run_pipeline_tab(pil, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind):
345
- if pil is None: return None, [], "", ""
346
- L = ensure_loaded()
347
-
348
- boxes, det_scores = yolo_predict(L.det, pil, conf, iou, max_det)
349
- if not boxes: return pil, [], "No boxes detected.", "No boxes detected."
350
-
351
- topk_list, top1_letters, crop_gallery = classify_crops(L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k)
352
- src = build_source_text(boxes, det_scores, topk_list, rtl=rtl)
353
-
354
- heb = ""
355
- eng = ""
356
- if output_mode == "he":
357
- heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src)
358
- elif output_mode == "en_direct":
359
- eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src)
360
- else: # he_then_en
361
- heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src)
362
- if he2en_kind != "none":
363
- eng = translate_he_to_en(heb, he2en_kind)
364
- else:
365
  eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src)
366
-
367
- annotated = make_annotated(pil, boxes, top1_letters, det_scores)
368
- return annotated, crop_gallery, heb, eng
 
 
 
 
 
 
 
 
 
369
 
370
  def run_detector_tab(pil, conf, iou, max_det):
371
- if pil is None: return None, "{}"
372
- L = ensure_loaded()
373
- boxes, scores = yolo_predict(L.det, pil, conf, iou, max_det)
374
- annotated = make_annotated(pil, boxes, [], scores)
375
-
376
- debug = {"num_boxes": len(boxes), "boxes_xyxy": boxes, "confidences": scores}
377
- return annotated, json.dumps(debug, indent=2)
 
 
 
 
 
378
 
379
  def run_classifier_tab(pil, conf, iou, max_det, crop_pad, topk_k):
380
- if pil is None: return [], "{}"
381
- L = ensure_loaded()
382
- boxes, _ = yolo_predict(L.det, pil, conf, iou, max_det)
383
- if not boxes: return [], "No boxes found to classify."
384
-
385
- topk_list, top1_letters, crop_gallery = classify_crops(L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k)
386
-
387
- details = {}
388
- for i, (row, box) in enumerate(zip(topk_list, boxes)):
389
- details[f"Box_{i+1:02d}"] = {
390
- "top_predictions": dict(row),
391
- "coordinates": box
392
- }
393
- return crop_gallery, json.dumps(details, ensure_ascii=False, indent=2)
 
 
 
 
 
394
 
395
  def save_feedback(heb_pred, eng_pred, heb_corr, eng_corr, notes):
396
- rec = {
397
- "ts": now_ts(),
398
- "heb_pred": heb_pred, "eng_pred": eng_pred,
399
- "heb_corr": heb_corr, "eng_corr": eng_corr,
400
- "notes": notes,
401
- }
402
- with open(FEEDBACK_PATH, "a", encoding="utf-8") as f:
403
- f.write(json.dumps(rec, ensure_ascii=False) + "\n")
404
- return f"Saved to {FEEDBACK_PATH}"
 
 
 
405
 
406
  # =============================================================================
407
  # UI
@@ -418,7 +265,6 @@ with gr.Blocks(title="Paleo-Hebrew Tablet Reader") as demo:
418
  gr.Markdown("# Paleo-Hebrew Epigraphy Pipeline")
419
 
420
  with gr.Row():
421
- # LEFT COLUMN: Global Inputs & Settings
422
  with gr.Column(scale=3):
423
  inp = gr.Image(type="pil", label="Input Image")
424
 
@@ -438,11 +284,8 @@ with gr.Blocks(title="Paleo-Hebrew Tablet Reader") as demo:
438
  if ex:
439
  gr.Examples(examples=ex, inputs=[inp], cache_examples=False)
440
 
441
- # RIGHT COLUMN: Tabs for different pipeline stages
442
  with gr.Column(scale=7):
443
  with gr.Tabs():
444
-
445
- # TAB 1: Main E2E Pipeline
446
  with gr.Tab("End-to-End Pipeline"):
447
  run_pipe_btn = gr.Button("Run Full Pipeline", variant="primary")
448
  out_annot_pipe = gr.Image(type="pil", label="Detections (bbox + top1)")
@@ -451,19 +294,16 @@ with gr.Blocks(title="Paleo-Hebrew Tablet Reader") as demo:
451
  out_en_pipe = gr.Textbox(label="English Output", lines=3, interactive=False)
452
  out_crops_pipe = gr.Gallery(label="Letter crops").style(grid=6, height="auto") if hasattr(gr.Gallery, 'style') else gr.Gallery(label="Letter crops")
453
 
454
- # TAB 2: Detector Only
455
  with gr.Tab("Detector (YOLO)"):
456
  run_det_btn = gr.Button("Run Detector", variant="primary")
457
  out_annot_det = gr.Image(type="pil", label="Detected Bounding Boxes")
458
- out_json_det = gr.Code(label="Raw Detection Output", language="json")
459
 
460
- # TAB 3: Classifier Only
461
  with gr.Tab("Classifier (ConvNeXt)"):
462
  run_cls_btn = gr.Button("Run Classifier", variant="primary")
463
  out_crops_cls = gr.Gallery(label="Isolated Letter Crops").style(grid=6, height="auto") if hasattr(gr.Gallery, 'style') else gr.Gallery(label="Isolated Letter Crops")
464
- out_json_cls = gr.Code(label="Top-K Probabilities per Box", language="json")
465
 
466
- # TAB 4: Feedback
467
  with gr.Tab("Feedback"):
468
  gr.Markdown("Submit corrections to improve the dataset.")
469
  heb_corr = gr.Textbox(label="Correct Hebrew", lines=2)
@@ -472,7 +312,6 @@ with gr.Blocks(title="Paleo-Hebrew Tablet Reader") as demo:
472
  save_btn = gr.Button("Submit Feedback")
473
  save_status = gr.Textbox(label="Status", interactive=False)
474
 
475
- # --- Button Clicks ---
476
  run_pipe_btn.click(
477
  fn=run_pipeline_tab,
478
  inputs=[inp, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind],
 
2
  import gc
3
  import json
4
  import time
5
+ import traceback
6
  from pathlib import Path
7
  from dataclasses import dataclass
8
  from typing import Any, Dict, List, Optional, Tuple
 
24
 
25
  CLS_REPO_ID = os.getenv("CLS_REPO_ID", "mr3vial/paleo-hebrew-convnext")
26
  CLS_WEIGHTS = os.getenv("CLS_WEIGHTS", "convnext_weights.pt")
27
+ CLS_CLASSES = os.getenv("CLS_CLASSES", "classes.txt")
28
+ CLS_MODEL_NAME = os.getenv("CLS_MODEL_NAME", "convnext_base")
29
 
30
  MT5_BOX2HE_REPO = os.getenv("MT5_BOX2HE_REPO", "mr3vial/paleo-hebrew-mt5-post-ocr-processing")
31
  MT5_BOX2EN_REPO = os.getenv("MT5_BOX2EN_REPO", "mr3vial/paleo-hebrew-mt5-translate")
 
52
  return default
53
 
54
  def sort_boxes_reading_order(boxes_xyxy: List[List[float]], rtl: bool = True) -> List[int]:
 
 
 
 
55
  if not boxes_xyxy:
56
  return []
57
  centers = []
 
99
  @torch.no_grad()
100
  def load_classifier():
101
  import timm
102
+
103
  device = get_device()
104
 
105
  classes_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_CLASSES)
106
+
107
+ # Безопасное чтение классов (поддержка и .json, и .txt)
108
+ if classes_path.endswith(".json"):
109
+ with open(classes_path, "r", encoding="utf-8") as f:
110
+ data = json.load(f)
111
+ if isinstance(data, dict):
112
+ raw = [str(v) for v in data.values()]
113
+ else:
114
+ raw = [str(x) for x in data]
115
+ else:
116
+ with open(classes_path, "r", encoding="utf-8") as f:
117
+ raw = [ln.strip() for ln in f if ln.strip()]
118
 
119
  letters: List[str] = []
120
  for ln in raw:
 
125
 
126
  weights_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_WEIGHTS)
127
  model = timm.create_model(CLS_MODEL_NAME, pretrained=False, num_classes=len(letters))
128
+
129
+ # Загрузка классических .pt весов вместо safetensors
130
+ ckpt = torch.load(weights_path, map_location="cpu", weights_only=False)
131
+ state = ckpt.get("model", ckpt.get("state_dict", ckpt))
132
 
133
  new_state = {}
134
  for k, v in state.items():
 
151
  @torch.no_grad()
152
  def load_mt5(repo_id: str):
153
  device = get_device()
 
154
  dtype = torch.bfloat16 if (device == "cpu" or torch.cuda.is_bf16_supported()) else torch.float16
155
  tok = AutoTokenizer.from_pretrained(repo_id, use_fast=False)
156
  model = AutoModelForSeq2SeqLM.from_pretrained(
 
166
  out = model.generate(**inp, max_new_tokens=max_new_tokens, num_beams=4, do_sample=False)
167
  return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  # =============================================================================
170
  # Tab Action Handlers
171
  # =============================================================================
172
  def run_pipeline_tab(pil, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind):
173
+ try:
174
+ if pil is None: return None, [], "", ""
175
+ pil = pil.convert("RGB") # Защита от RGBA картинок
176
+
177
+ L = ensure_loaded()
178
+
179
+ boxes, det_scores = yolo_predict(L.det, pil, conf, iou, max_det)
180
+ if not boxes: return pil, [], "No boxes detected.", "No boxes detected."
181
+
182
+ topk_list, top1_letters, crop_gallery = classify_crops(L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k)
183
+ src = build_source_text(boxes, det_scores, topk_list, rtl=rtl)
184
+
185
+ heb = ""
186
+ eng = ""
187
+ if output_mode == "he":
188
+ heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src)
189
+ elif output_mode == "en_direct":
 
 
 
190
  eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src)
191
+ else: # he_then_en
192
+ heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src)
193
+ if he2en_kind != "none":
194
+ eng = translate_he_to_en(heb, he2en_kind)
195
+ else:
196
+ eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src)
197
+
198
+ annotated = make_annotated(pil, boxes, top1_letters, det_scores)
199
+ return annotated, crop_gallery, heb, eng
200
+ except Exception as e:
201
+ err_msg = f"ERROR:\n{traceback.format_exc()}"
202
+ return None, [], err_msg, err_msg
203
 
204
  def run_detector_tab(pil, conf, iou, max_det):
205
+ try:
206
+ if pil is None: return None, "{}"
207
+ pil = pil.convert("RGB")
208
+
209
+ L = ensure_loaded()
210
+ boxes, scores = yolo_predict(L.det, pil, conf, iou, max_det)
211
+ annotated = make_annotated(pil, boxes, [], scores)
212
+
213
+ debug = {"num_boxes": len(boxes), "boxes_xyxy": boxes, "confidences": scores}
214
+ return annotated, json.dumps(debug, indent=2)
215
+ except Exception as e:
216
+ return None, json.dumps({"error": traceback.format_exc()}, indent=2)
217
 
218
  def run_classifier_tab(pil, conf, iou, max_det, crop_pad, topk_k):
219
+ try:
220
+ if pil is None: return [], "{}"
221
+ pil = pil.convert("RGB")
222
+
223
+ L = ensure_loaded()
224
+ boxes, _ = yolo_predict(L.det, pil, conf, iou, max_det)
225
+ if not boxes: return [], "No boxes found to classify."
226
+
227
+ topk_list, top1_letters, crop_gallery = classify_crops(L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k)
228
+
229
+ details = {}
230
+ for i, (row, box) in enumerate(zip(topk_list, boxes)):
231
+ details[f"Box_{i+1:02d}"] = {
232
+ "top_predictions": dict(row),
233
+ "coordinates": box
234
+ }
235
+ return crop_gallery, json.dumps(details, ensure_ascii=False, indent=2)
236
+ except Exception as e:
237
+ return [], json.dumps({"error": traceback.format_exc()}, indent=2)
238
 
239
  def save_feedback(heb_pred, eng_pred, heb_corr, eng_corr, notes):
240
+ try:
241
+ rec = {
242
+ "ts": now_ts(),
243
+ "heb_pred": heb_pred, "eng_pred": eng_pred,
244
+ "heb_corr": heb_corr, "eng_corr": eng_corr,
245
+ "notes": notes,
246
+ }
247
+ with open(FEEDBACK_PATH, "a", encoding="utf-8") as f:
248
+ f.write(json.dumps(rec, ensure_ascii=False) + "\n")
249
+ return f"Saved to {FEEDBACK_PATH}"
250
+ except Exception as e:
251
+ return f"Error saving: {str(e)}"
252
 
253
  # =============================================================================
254
  # UI
 
265
  gr.Markdown("# Paleo-Hebrew Epigraphy Pipeline")
266
 
267
  with gr.Row():
 
268
  with gr.Column(scale=3):
269
  inp = gr.Image(type="pil", label="Input Image")
270
 
 
284
  if ex:
285
  gr.Examples(examples=ex, inputs=[inp], cache_examples=False)
286
 
 
287
  with gr.Column(scale=7):
288
  with gr.Tabs():
 
 
289
  with gr.Tab("End-to-End Pipeline"):
290
  run_pipe_btn = gr.Button("Run Full Pipeline", variant="primary")
291
  out_annot_pipe = gr.Image(type="pil", label="Detections (bbox + top1)")
 
294
  out_en_pipe = gr.Textbox(label="English Output", lines=3, interactive=False)
295
  out_crops_pipe = gr.Gallery(label="Letter crops").style(grid=6, height="auto") if hasattr(gr.Gallery, 'style') else gr.Gallery(label="Letter crops")
296
 
 
297
  with gr.Tab("Detector (YOLO)"):
298
  run_det_btn = gr.Button("Run Detector", variant="primary")
299
  out_annot_det = gr.Image(type="pil", label="Detected Bounding Boxes")
300
+ out_json_det = gr.Code(label="Raw Detection Output / Error Logs", language="json")
301
 
 
302
  with gr.Tab("Classifier (ConvNeXt)"):
303
  run_cls_btn = gr.Button("Run Classifier", variant="primary")
304
  out_crops_cls = gr.Gallery(label="Isolated Letter Crops").style(grid=6, height="auto") if hasattr(gr.Gallery, 'style') else gr.Gallery(label="Isolated Letter Crops")
305
+ out_json_cls = gr.Code(label="Top-K Probabilities per Box / Error Logs", language="json")
306
 
 
307
  with gr.Tab("Feedback"):
308
  gr.Markdown("Submit corrections to improve the dataset.")
309
  heb_corr = gr.Textbox(label="Correct Hebrew", lines=2)
 
312
  save_btn = gr.Button("Submit Feedback")
313
  save_status = gr.Textbox(label="Status", interactive=False)
314
 
 
315
  run_pipe_btn.click(
316
  fn=run_pipeline_tab,
317
  inputs=[inp, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind],