import os import gc import json import time from pathlib import Path from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple import numpy as np from PIL import Image, ImageDraw import gradio as gr import torch from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # ============================================================================= # Configuration via env vars (edit in Space Settings -> Variables if needed) # ============================================================================= DET_REPO_ID = os.getenv("DET_REPO_ID", "mr3vial/paleo-hebrew-yolo") DET_FILENAME = os.getenv("DET_FILENAME", "best.onnx") # recommended DET_FALLBACK_PT = os.getenv("DET_FALLBACK_PT", "best.pt") CLS_REPO_ID = os.getenv("CLS_REPO_ID", "mr3vial/paleo-hebrew-convnext") CLS_WEIGHTS = os.getenv("CLS_WEIGHTS", "model.safetensors") CLS_CLASSES = os.getenv("CLS_CLASSES", "classes.txt") CLS_MODEL_NAME = os.getenv("CLS_MODEL_NAME", "convnext_large") MT5_BOX2HE_REPO = os.getenv("MT5_BOX2HE_REPO", "mr3vial/paleo-hebrew-mt5-post-ocr-processing") MT5_BOX2EN_REPO = os.getenv("MT5_BOX2EN_REPO", "mr3vial/paleo-hebrew-mt5-translate") FEEDBACK_PATH = os.getenv("FEEDBACK_PATH", "/tmp/feedback.jsonl") EXAMPLES_DIR = Path(__file__).parent / "examples" # ============================================================================= # Helpers # ============================================================================= def now_ts() -> str: return time.strftime("%Y-%m-%d %H:%M:%S") def normalize_spaces(s: str) -> str: return " ".join((s or "").strip().split()) def get_device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" def safe_int(x: float, default: int = 0) -> int: try: return int(round(float(x))) except Exception: return default def sort_boxes_reading_order(boxes_xyxy: List[List[float]], rtl: bool = True) -> List[int]: """ Rough reading order: - sort by y-center, then within a line by x-center (rtl: desc, ltr: asc). """ if not boxes_xyxy: return [] centers = [] heights = [] for i, (x1, y1, x2, y2) in enumerate(boxes_xyxy): xc = 0.5 * (x1 + x2) yc = 0.5 * (y1 + y2) centers.append((i, xc, yc)) heights.append(max(1.0, y2 - y1)) line_tol = float(np.median(heights)) * 0.6 centers_sorted = sorted(centers, key=lambda t: t[2]) lines: List[List[Tuple[int, float, float]]] = [] for item in centers_sorted: if not lines: lines.append([item]) continue _, _, yc = item prev_line = lines[-1] prev_y = float(np.mean([p[2] for p in prev_line])) if abs(yc - prev_y) <= line_tol: prev_line.append(item) else: lines.append([item]) idxs: List[int] = [] for line in lines: line_sorted = sorted(line, key=lambda t: t[1], reverse=rtl) idxs.extend([i for i, _, _ in line_sorted]) return idxs # ============================================================================= # Load models (Detector, Classifier, mT5) # ============================================================================= @torch.no_grad() def load_detector(): from ultralytics import YOLO try: det_path = hf_hub_download(repo_id=DET_REPO_ID, filename=DET_FILENAME) return YOLO(det_path), f"{DET_REPO_ID}/{DET_FILENAME}" except Exception: pt_path = hf_hub_download(repo_id=DET_REPO_ID, filename=DET_FALLBACK_PT) return YOLO(pt_path), f"{DET_REPO_ID}/{DET_FALLBACK_PT} (PT fallback)" @torch.no_grad() def load_classifier(): import timm from safetensors.torch import load_file as safetensors_load device = get_device() classes_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_CLASSES) with open(classes_path, "r", encoding="utf-8") as f: raw = [ln.strip() for ln in f if ln.strip()] letters: List[str] = [] for ln in raw: if "_" in ln: letters.append(ln.split("_")[-1]) else: letters.append(ln) weights_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_WEIGHTS) model = timm.create_model(CLS_MODEL_NAME, pretrained=False, num_classes=len(letters)) state = safetensors_load(weights_path) new_state = {} for k, v in state.items(): nk = k for pref in ("module.", "model."): if nk.startswith(pref): nk = nk[len(pref):] new_state[nk] = v model.load_state_dict(new_state, strict=False) model.eval().to(device) return model, letters, f"{CLS_REPO_ID}/{CLS_WEIGHTS}" def preprocess_crop_imagenet(pil: Image.Image) -> torch.Tensor: arr = np.asarray(pil.convert("RGB"), dtype=np.float32) / 255.0 t = torch.from_numpy(arr).permute(2, 0, 1).contiguous() mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) return (t - mean) / std @torch.no_grad() def load_mt5(repo_id: str): device = get_device() # Memory optimization: Load in bfloat16 to fit multiple models in Space RAM dtype = torch.bfloat16 if (device == "cpu" or torch.cuda.is_bf16_supported()) else torch.float16 tok = AutoTokenizer.from_pretrained(repo_id, use_fast=False) model = AutoModelForSeq2SeqLM.from_pretrained( repo_id, torch_dtype=dtype, ).to(device).eval() return tok, model @torch.no_grad() def mt5_generate(tok, model, text: str, max_new_tokens: int = 128) -> str: device = get_device() inp = tok([text], return_tensors="pt", truncation=True, max_length=2048).to(device) out = model.generate(**inp, max_new_tokens=max_new_tokens, num_beams=4, do_sample=False) return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0]) # ============================================================================= # Optional Hebrew->English translators (user-selectable) # ============================================================================= TRANSLATOR_CHOICES = [ ("None (use direct box→en)", "none"), ("OPUS-MT he→en (Helsinki-NLP/opus-mt-tc-big-he-en)", "opus"), ("NLLB-200 distilled 600M (he→en) [CC-BY-NC]", "nllb"), ("M2M100 418M (he→en)", "m2m100"), ] _TRANSLATOR_CACHE: Dict[str, Tuple[Any, Any, str]] = {} # kind -> (tok, model, kind) @torch.no_grad() def get_he2en_translator(kind: str): global _TRANSLATOR_CACHE if kind in _TRANSLATOR_CACHE: return _TRANSLATOR_CACHE[kind] # Memory Safety: Clear old translators from memory before loading a new one _TRANSLATOR_CACHE.clear() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() device = get_device() if kind == "opus": from transformers import MarianMTModel, MarianTokenizer repo = "Helsinki-NLP/opus-mt-tc-big-he-en" tok = MarianTokenizer.from_pretrained(repo) model = MarianMTModel.from_pretrained(repo).to(device).eval() elif kind == "nllb": from transformers import AutoModelForSeq2SeqLM, AutoTokenizer repo = "facebook/nllb-200-distilled-600M" tok = AutoTokenizer.from_pretrained(repo, src_lang="heb_Hebr") model = AutoModelForSeq2SeqLM.from_pretrained(repo).to(device).eval() elif kind == "m2m100": from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer repo = "facebook/m2m100_418M" tok = M2M100Tokenizer.from_pretrained(repo) tok.src_lang = "he" model = M2M100ForConditionalGeneration.from_pretrained(repo).to(device).eval() else: _TRANSLATOR_CACHE[kind] = (None, None, "none") return _TRANSLATOR_CACHE[kind] _TRANSLATOR_CACHE[kind] = (tok, model, kind) return _TRANSLATOR_CACHE[kind] @torch.no_grad() def translate_he_to_en(text_he: str, kind: str) -> str: if kind == "none": return "" tok, model, kind = get_he2en_translator(kind) device = get_device() if kind == "opus": batch = tok([text_he], return_tensors="pt", padding=True, truncation=True).to(device) out = model.generate(**batch, max_new_tokens=128, num_beams=4) return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0]) if kind == "nllb": batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device) out = model.generate(**batch, forced_bos_token_id=tok.lang_code_to_id["eng_Latn"], max_new_tokens=128, num_beams=4) return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0]) if kind == "m2m100": batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device) out = model.generate(**batch, forced_bos_token_id=tok.get_lang_id("en"), max_new_tokens=128, num_beams=4) return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0]) return "" # ============================================================================= # Pipeline Components # ============================================================================= @dataclass class Loaded: det: Any det_name: str cls: Any cls_letters: List[str] cls_name: str mt5_he_tok: Any mt5_he: Any mt5_en_tok: Any mt5_en: Any LOADED: Optional[Loaded] = None def ensure_loaded() -> Loaded: global LOADED if LOADED is not None: return LOADED det, det_name = load_detector() cls, letters, cls_name = load_classifier() he_tok, he_model = load_mt5(MT5_BOX2HE_REPO) en_tok, en_model = load_mt5(MT5_BOX2EN_REPO) LOADED = Loaded( det=det, det_name=det_name, cls=cls, cls_letters=letters, cls_name=cls_name, mt5_he_tok=he_tok, mt5_he=he_model, mt5_en_tok=en_tok, mt5_en=en_model, ) return LOADED def yolo_predict(det_model, pil: Image.Image, conf: float, iou: float, max_det: int): # Direct inference, no tempfile needed res = det_model.predict( source=pil, conf=float(conf), iou=float(iou), max_det=int(max_det), verbose=False, ) r0 = res[0] if r0.boxes is None or len(r0.boxes) == 0: return [], [] xyxy = r0.boxes.xyxy.detach().float().cpu().numpy().tolist() scores = r0.boxes.conf.detach().float().cpu().numpy().tolist() return xyxy, scores @torch.no_grad() def classify_crops(cls_model, letters: List[str], pil: Image.Image, boxes_xyxy, pad: float, topk: int): device = get_device() W, H = pil.size crops = [] for (x1, y1, x2, y2) in boxes_xyxy: w = max(1.0, x2 - x1) h = max(1.0, y2 - y1) p = int(round(max(w, h) * float(pad))) a, b = max(0, safe_int(x1) - p), max(0, safe_int(y1) - p) c, d = min(W, safe_int(x2) + p), min(H, safe_int(y2) + p) crops.append(pil.crop((a, b, c, d)).resize((224, 224), Image.BICUBIC)) if not crops: return [], [], [] batch = torch.stack([preprocess_crop_imagenet(c) for c in crops], dim=0).to(device) logits = cls_model(batch) probs = torch.softmax(logits, dim=-1).detach().float().cpu().numpy() topk_list, top1_letters = [], [] for i in range(len(crops)): p = probs[i] idx = np.argsort(-p)[: max(1, int(topk))] row = [(letters[j], float(p[j])) for j in idx] topk_list.append(row) top1_letters.append(row[0][0] if row else "?") crop_gallery = [(crops[i], f"{i+1:02d}: {top1_letters[i]}") for i in range(len(crops))] return topk_list, top1_letters, crop_gallery def build_source_text(boxes_xyxy: List[List[float]], scores: List[float], topk: List[List[Tuple[str, float]]], rtl: bool) -> str: idxs = sort_boxes_reading_order(boxes_xyxy, rtl=rtl) lines = ["[BOXES_CLS_V3]", f"n={len(idxs)} coord_norm=none"] for j, i in enumerate(idxs, start=1): x1, y1, x2, y2 = boxes_xyxy[i] sc = scores[i] if i < len(scores) else 1.0 cands = topk[i] if i < len(topk) else [] cls_str = " ".join([f"{ch}:{p:.3f}" for ch, p in cands]) if cands else "?" lines.append(f"{j:02d} x1={x1:.1f} y1={y1:.1f} x2={x2:.1f} y2={y2:.1f} score={sc:.3f} | cls {cls_str}") return "\n".join(lines) def make_annotated(pil: Image.Image, boxes_xyxy, top1_letters, scores): img = pil.copy() draw = ImageDraw.Draw(img) for i, bb in enumerate(boxes_xyxy): x1, y1, x2, y2 = [safe_int(v) for v in bb] lab = top1_letters[i] if i < len(top1_letters) else "?" sc = scores[i] if i < len(scores) else 0.0 draw.rectangle([x1, y1, x2, y2], outline="red", width=2) text_label = f"{i+1:02d}" + (f" {lab}" if lab != "?" else "") draw.text((x1, y1), text_label, fill="red") return img # ============================================================================= # Tab Action Handlers # ============================================================================= def run_pipeline_tab(pil, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind): if pil is None: return None, [], "", "" L = ensure_loaded() boxes, det_scores = yolo_predict(L.det, pil, conf, iou, max_det) if not boxes: return pil, [], "No boxes detected.", "No boxes detected." topk_list, top1_letters, crop_gallery = classify_crops(L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k) src = build_source_text(boxes, det_scores, topk_list, rtl=rtl) heb = "" eng = "" if output_mode == "he": heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src) elif output_mode == "en_direct": eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src) else: # he_then_en heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src) if he2en_kind != "none": eng = translate_he_to_en(heb, he2en_kind) else: eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src) annotated = make_annotated(pil, boxes, top1_letters, det_scores) return annotated, crop_gallery, heb, eng def run_detector_tab(pil, conf, iou, max_det): if pil is None: return None, "{}" L = ensure_loaded() boxes, scores = yolo_predict(L.det, pil, conf, iou, max_det) annotated = make_annotated(pil, boxes, [], scores) debug = {"num_boxes": len(boxes), "boxes_xyxy": boxes, "confidences": scores} return annotated, json.dumps(debug, indent=2) def run_classifier_tab(pil, conf, iou, max_det, crop_pad, topk_k): if pil is None: return [], "{}" L = ensure_loaded() boxes, _ = yolo_predict(L.det, pil, conf, iou, max_det) if not boxes: return [], "No boxes found to classify." topk_list, top1_letters, crop_gallery = classify_crops(L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k) details = {} for i, (row, box) in enumerate(zip(topk_list, boxes)): details[f"Box_{i+1:02d}"] = { "top_predictions": dict(row), "coordinates": box } return crop_gallery, json.dumps(details, ensure_ascii=False, indent=2) def save_feedback(heb_pred, eng_pred, heb_corr, eng_corr, notes): rec = { "ts": now_ts(), "heb_pred": heb_pred, "eng_pred": eng_pred, "heb_corr": heb_corr, "eng_corr": eng_corr, "notes": notes, } with open(FEEDBACK_PATH, "a", encoding="utf-8") as f: f.write(json.dumps(rec, ensure_ascii=False) + "\n") return f"Saved to {FEEDBACK_PATH}" # ============================================================================= # UI # ============================================================================= if EXAMPLES_DIR.exists(): gr.set_static_paths([str(EXAMPLES_DIR)]) def list_example_images(max_n: int = 24): if not EXAMPLES_DIR.exists(): return [] paths = [p for p in sorted(EXAMPLES_DIR.iterdir()) if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}] return [[str(p)] for p in paths[:max_n]] with gr.Blocks(title="Paleo-Hebrew Tablet Reader") as demo: gr.Markdown("# Paleo-Hebrew Epigraphy Pipeline") with gr.Row(): # LEFT COLUMN: Global Inputs & Settings with gr.Column(scale=3): inp = gr.Image(type="pil", label="Input Image") with gr.Accordion("Vision Settings", open=True): conf = gr.Slider(0.05, 0.95, value=0.25, step=0.01, label="YOLO conf") iou = gr.Slider(0.10, 0.90, value=0.45, step=0.01, label="YOLO IoU") max_det = gr.Slider(1, 200, value=80, step=1, label="Max Detections") crop_pad = gr.Slider(0.0, 0.8, value=0.20, step=0.01, label="Crop Pad") topk_k = gr.Slider(1, 10, value=5, step=1, label="Classifier Top-K") with gr.Accordion("Translation Settings", open=True): rtl = gr.Checkbox(value=True, label="RTL order (right→left)") output_mode = gr.Radio(["he", "en_direct", "he_then_en"], value="he_then_en", label="Output mode") he2en_kind = gr.Dropdown(choices=TRANSLATOR_CHOICES, value="opus", label="Hebrew → English translator") ex = list_example_images() if ex: gr.Examples(examples=ex, inputs=[inp], cache_examples=False) # RIGHT COLUMN: Tabs for different pipeline stages with gr.Column(scale=7): with gr.Tabs(): # TAB 1: Main E2E Pipeline with gr.Tab("End-to-End Pipeline"): run_pipe_btn = gr.Button("Run Full Pipeline", variant="primary") out_annot_pipe = gr.Image(type="pil", label="Detections (bbox + top1)") with gr.Row(): out_he_pipe = gr.Textbox(label="Hebrew Output", lines=3, interactive=False) out_en_pipe = gr.Textbox(label="English Output", lines=3, interactive=False) out_crops_pipe = gr.Gallery(label="Letter crops").style(grid=6, height="auto") if hasattr(gr.Gallery, 'style') else gr.Gallery(label="Letter crops") # TAB 2: Detector Only with gr.Tab("Detector (YOLO)"): run_det_btn = gr.Button("Run Detector", variant="primary") out_annot_det = gr.Image(type="pil", label="Detected Bounding Boxes") out_json_det = gr.Code(label="Raw Detection Output", language="json") # TAB 3: Classifier Only with gr.Tab("Classifier (ConvNeXt)"): run_cls_btn = gr.Button("Run Classifier", variant="primary") out_crops_cls = gr.Gallery(label="Isolated Letter Crops").style(grid=6, height="auto") if hasattr(gr.Gallery, 'style') else gr.Gallery(label="Isolated Letter Crops") out_json_cls = gr.Code(label="Top-K Probabilities per Box", language="json") # TAB 4: Feedback with gr.Tab("Feedback"): gr.Markdown("Submit corrections to improve the dataset.") heb_corr = gr.Textbox(label="Correct Hebrew", lines=2) eng_corr = gr.Textbox(label="Correct English", lines=2) notes = gr.Textbox(label="Notes (e.g., missed boxes, ambiguous glyphs)", lines=2) save_btn = gr.Button("Submit Feedback") save_status = gr.Textbox(label="Status", interactive=False) # --- Button Clicks --- run_pipe_btn.click( fn=run_pipeline_tab, inputs=[inp, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind], outputs=[out_annot_pipe, out_crops_pipe, out_he_pipe, out_en_pipe], api_name=False ) run_det_btn.click( fn=run_detector_tab, inputs=[inp, conf, iou, max_det], outputs=[out_annot_det, out_json_det], api_name=False ) run_cls_btn.click( fn=run_classifier_tab, inputs=[inp, conf, iou, max_det, crop_pad, topk_k], outputs=[out_crops_cls, out_json_cls], api_name=False ) save_btn.click( fn=save_feedback, inputs=[out_he_pipe, out_en_pipe, heb_corr, eng_corr, notes], outputs=[save_status], api_name=False ) demo.queue(max_size=32).launch()