Spaces:
Sleeping
Sleeping
| import os | |
| import gc | |
| import json | |
| import time | |
| from pathlib import Path | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # ============================================================================= | |
| # Configuration via env vars (edit in Space Settings -> Variables if needed) | |
| # ============================================================================= | |
| DET_REPO_ID = os.getenv("DET_REPO_ID", "mr3vial/paleo-hebrew-yolo") | |
| DET_FILENAME = os.getenv("DET_FILENAME", "best.onnx") # recommended | |
| DET_FALLBACK_PT = os.getenv("DET_FALLBACK_PT", "best.pt") | |
| CLS_REPO_ID = os.getenv("CLS_REPO_ID", "mr3vial/paleo-hebrew-convnext") | |
| CLS_WEIGHTS = os.getenv("CLS_WEIGHTS", "model.safetensors") | |
| CLS_CLASSES = os.getenv("CLS_CLASSES", "classes.txt") | |
| CLS_MODEL_NAME = os.getenv("CLS_MODEL_NAME", "convnext_large") | |
| MT5_BOX2HE_REPO = os.getenv("MT5_BOX2HE_REPO", "mr3vial/paleo-hebrew-mt5-post-ocr-processing") | |
| MT5_BOX2EN_REPO = os.getenv("MT5_BOX2EN_REPO", "mr3vial/paleo-hebrew-mt5-translate") | |
| FEEDBACK_PATH = os.getenv("FEEDBACK_PATH", "/tmp/feedback.jsonl") | |
| EXAMPLES_DIR = Path(__file__).parent / "examples" | |
| # ============================================================================= | |
| # Helpers | |
| # ============================================================================= | |
| def now_ts() -> str: | |
| return time.strftime("%Y-%m-%d %H:%M:%S") | |
| def normalize_spaces(s: str) -> str: | |
| return " ".join((s or "").strip().split()) | |
| def get_device() -> str: | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def safe_int(x: float, default: int = 0) -> int: | |
| try: | |
| return int(round(float(x))) | |
| except Exception: | |
| return default | |
| def sort_boxes_reading_order(boxes_xyxy: List[List[float]], rtl: bool = True) -> List[int]: | |
| """ | |
| Rough reading order: | |
| - sort by y-center, then within a line by x-center (rtl: desc, ltr: asc). | |
| """ | |
| if not boxes_xyxy: | |
| return [] | |
| centers = [] | |
| heights = [] | |
| for i, (x1, y1, x2, y2) in enumerate(boxes_xyxy): | |
| xc = 0.5 * (x1 + x2) | |
| yc = 0.5 * (y1 + y2) | |
| centers.append((i, xc, yc)) | |
| heights.append(max(1.0, y2 - y1)) | |
| line_tol = float(np.median(heights)) * 0.6 | |
| centers_sorted = sorted(centers, key=lambda t: t[2]) | |
| lines: List[List[Tuple[int, float, float]]] = [] | |
| for item in centers_sorted: | |
| if not lines: | |
| lines.append([item]) | |
| continue | |
| _, _, yc = item | |
| prev_line = lines[-1] | |
| prev_y = float(np.mean([p[2] for p in prev_line])) | |
| if abs(yc - prev_y) <= line_tol: | |
| prev_line.append(item) | |
| else: | |
| lines.append([item]) | |
| idxs: List[int] = [] | |
| for line in lines: | |
| line_sorted = sorted(line, key=lambda t: t[1], reverse=rtl) | |
| idxs.extend([i for i, _, _ in line_sorted]) | |
| return idxs | |
| # ============================================================================= | |
| # Load models (Detector, Classifier, mT5) | |
| # ============================================================================= | |
| def load_detector(): | |
| from ultralytics import YOLO | |
| try: | |
| det_path = hf_hub_download(repo_id=DET_REPO_ID, filename=DET_FILENAME) | |
| return YOLO(det_path), f"{DET_REPO_ID}/{DET_FILENAME}" | |
| except Exception: | |
| pt_path = hf_hub_download(repo_id=DET_REPO_ID, filename=DET_FALLBACK_PT) | |
| return YOLO(pt_path), f"{DET_REPO_ID}/{DET_FALLBACK_PT} (PT fallback)" | |
| def load_classifier(): | |
| import timm | |
| from safetensors.torch import load_file as safetensors_load | |
| device = get_device() | |
| classes_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_CLASSES) | |
| with open(classes_path, "r", encoding="utf-8") as f: | |
| raw = [ln.strip() for ln in f if ln.strip()] | |
| letters: List[str] = [] | |
| for ln in raw: | |
| if "_" in ln: | |
| letters.append(ln.split("_")[-1]) | |
| else: | |
| letters.append(ln) | |
| weights_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_WEIGHTS) | |
| model = timm.create_model(CLS_MODEL_NAME, pretrained=False, num_classes=len(letters)) | |
| state = safetensors_load(weights_path) | |
| new_state = {} | |
| for k, v in state.items(): | |
| nk = k | |
| for pref in ("module.", "model."): | |
| if nk.startswith(pref): | |
| nk = nk[len(pref):] | |
| new_state[nk] = v | |
| model.load_state_dict(new_state, strict=False) | |
| model.eval().to(device) | |
| return model, letters, f"{CLS_REPO_ID}/{CLS_WEIGHTS}" | |
| def preprocess_crop_imagenet(pil: Image.Image) -> torch.Tensor: | |
| arr = np.asarray(pil.convert("RGB"), dtype=np.float32) / 255.0 | |
| t = torch.from_numpy(arr).permute(2, 0, 1).contiguous() | |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) | |
| std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) | |
| return (t - mean) / std | |
| def load_mt5(repo_id: str): | |
| device = get_device() | |
| # Memory optimization: Load in bfloat16 to fit multiple models in Space RAM | |
| dtype = torch.bfloat16 if (device == "cpu" or torch.cuda.is_bf16_supported()) else torch.float16 | |
| tok = AutoTokenizer.from_pretrained(repo_id, use_fast=False) | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| repo_id, | |
| torch_dtype=dtype, | |
| ).to(device).eval() | |
| return tok, model | |
| def mt5_generate(tok, model, text: str, max_new_tokens: int = 128) -> str: | |
| device = get_device() | |
| inp = tok([text], return_tensors="pt", truncation=True, max_length=2048).to(device) | |
| out = model.generate(**inp, max_new_tokens=max_new_tokens, num_beams=4, do_sample=False) | |
| return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0]) | |
| # ============================================================================= | |
| # Optional Hebrew->English translators (user-selectable) | |
| # ============================================================================= | |
| TRANSLATOR_CHOICES = [ | |
| ("None (use direct box→en)", "none"), | |
| ("OPUS-MT he→en (Helsinki-NLP/opus-mt-tc-big-he-en)", "opus"), | |
| ("NLLB-200 distilled 600M (he→en) [CC-BY-NC]", "nllb"), | |
| ("M2M100 418M (he→en)", "m2m100"), | |
| ] | |
| _TRANSLATOR_CACHE: Dict[str, Tuple[Any, Any, str]] = {} # kind -> (tok, model, kind) | |
| def get_he2en_translator(kind: str): | |
| global _TRANSLATOR_CACHE | |
| if kind in _TRANSLATOR_CACHE: | |
| return _TRANSLATOR_CACHE[kind] | |
| # Memory Safety: Clear old translators from memory before loading a new one | |
| _TRANSLATOR_CACHE.clear() | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| device = get_device() | |
| if kind == "opus": | |
| from transformers import MarianMTModel, MarianTokenizer | |
| repo = "Helsinki-NLP/opus-mt-tc-big-he-en" | |
| tok = MarianTokenizer.from_pretrained(repo) | |
| model = MarianMTModel.from_pretrained(repo).to(device).eval() | |
| elif kind == "nllb": | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| repo = "facebook/nllb-200-distilled-600M" | |
| tok = AutoTokenizer.from_pretrained(repo, src_lang="heb_Hebr") | |
| model = AutoModelForSeq2SeqLM.from_pretrained(repo).to(device).eval() | |
| elif kind == "m2m100": | |
| from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | |
| repo = "facebook/m2m100_418M" | |
| tok = M2M100Tokenizer.from_pretrained(repo) | |
| tok.src_lang = "he" | |
| model = M2M100ForConditionalGeneration.from_pretrained(repo).to(device).eval() | |
| else: | |
| _TRANSLATOR_CACHE[kind] = (None, None, "none") | |
| return _TRANSLATOR_CACHE[kind] | |
| _TRANSLATOR_CACHE[kind] = (tok, model, kind) | |
| return _TRANSLATOR_CACHE[kind] | |
| def translate_he_to_en(text_he: str, kind: str) -> str: | |
| if kind == "none": | |
| return "" | |
| tok, model, kind = get_he2en_translator(kind) | |
| device = get_device() | |
| if kind == "opus": | |
| batch = tok([text_he], return_tensors="pt", padding=True, truncation=True).to(device) | |
| out = model.generate(**batch, max_new_tokens=128, num_beams=4) | |
| return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0]) | |
| if kind == "nllb": | |
| batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device) | |
| out = model.generate(**batch, forced_bos_token_id=tok.lang_code_to_id["eng_Latn"], max_new_tokens=128, num_beams=4) | |
| return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0]) | |
| if kind == "m2m100": | |
| batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device) | |
| out = model.generate(**batch, forced_bos_token_id=tok.get_lang_id("en"), max_new_tokens=128, num_beams=4) | |
| return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0]) | |
| return "" | |
| # ============================================================================= | |
| # Pipeline Components | |
| # ============================================================================= | |
| class Loaded: | |
| det: Any | |
| det_name: str | |
| cls: Any | |
| cls_letters: List[str] | |
| cls_name: str | |
| mt5_he_tok: Any | |
| mt5_he: Any | |
| mt5_en_tok: Any | |
| mt5_en: Any | |
| LOADED: Optional[Loaded] = None | |
| def ensure_loaded() -> Loaded: | |
| global LOADED | |
| if LOADED is not None: | |
| return LOADED | |
| det, det_name = load_detector() | |
| cls, letters, cls_name = load_classifier() | |
| he_tok, he_model = load_mt5(MT5_BOX2HE_REPO) | |
| en_tok, en_model = load_mt5(MT5_BOX2EN_REPO) | |
| LOADED = Loaded( | |
| det=det, det_name=det_name, | |
| cls=cls, cls_letters=letters, cls_name=cls_name, | |
| mt5_he_tok=he_tok, mt5_he=he_model, | |
| mt5_en_tok=en_tok, mt5_en=en_model, | |
| ) | |
| return LOADED | |
| def yolo_predict(det_model, pil: Image.Image, conf: float, iou: float, max_det: int): | |
| # Direct inference, no tempfile needed | |
| res = det_model.predict( | |
| source=pil, | |
| conf=float(conf), | |
| iou=float(iou), | |
| max_det=int(max_det), | |
| verbose=False, | |
| ) | |
| r0 = res[0] | |
| if r0.boxes is None or len(r0.boxes) == 0: | |
| return [], [] | |
| xyxy = r0.boxes.xyxy.detach().float().cpu().numpy().tolist() | |
| scores = r0.boxes.conf.detach().float().cpu().numpy().tolist() | |
| return xyxy, scores | |
| def classify_crops(cls_model, letters: List[str], pil: Image.Image, boxes_xyxy, pad: float, topk: int): | |
| device = get_device() | |
| W, H = pil.size | |
| crops = [] | |
| for (x1, y1, x2, y2) in boxes_xyxy: | |
| w = max(1.0, x2 - x1) | |
| h = max(1.0, y2 - y1) | |
| p = int(round(max(w, h) * float(pad))) | |
| a, b = max(0, safe_int(x1) - p), max(0, safe_int(y1) - p) | |
| c, d = min(W, safe_int(x2) + p), min(H, safe_int(y2) + p) | |
| crops.append(pil.crop((a, b, c, d)).resize((224, 224), Image.BICUBIC)) | |
| if not crops: | |
| return [], [], [] | |
| batch = torch.stack([preprocess_crop_imagenet(c) for c in crops], dim=0).to(device) | |
| logits = cls_model(batch) | |
| probs = torch.softmax(logits, dim=-1).detach().float().cpu().numpy() | |
| topk_list, top1_letters = [], [] | |
| for i in range(len(crops)): | |
| p = probs[i] | |
| idx = np.argsort(-p)[: max(1, int(topk))] | |
| row = [(letters[j], float(p[j])) for j in idx] | |
| topk_list.append(row) | |
| top1_letters.append(row[0][0] if row else "?") | |
| crop_gallery = [(crops[i], f"{i+1:02d}: {top1_letters[i]}") for i in range(len(crops))] | |
| return topk_list, top1_letters, crop_gallery | |
| def build_source_text(boxes_xyxy: List[List[float]], scores: List[float], topk: List[List[Tuple[str, float]]], rtl: bool) -> str: | |
| idxs = sort_boxes_reading_order(boxes_xyxy, rtl=rtl) | |
| lines = ["[BOXES_CLS_V3]", f"n={len(idxs)} coord_norm=none"] | |
| for j, i in enumerate(idxs, start=1): | |
| x1, y1, x2, y2 = boxes_xyxy[i] | |
| sc = scores[i] if i < len(scores) else 1.0 | |
| cands = topk[i] if i < len(topk) else [] | |
| cls_str = " ".join([f"{ch}:{p:.3f}" for ch, p in cands]) if cands else "?" | |
| lines.append(f"{j:02d} x1={x1:.1f} y1={y1:.1f} x2={x2:.1f} y2={y2:.1f} score={sc:.3f} | cls {cls_str}") | |
| return "\n".join(lines) | |
| def make_annotated(pil: Image.Image, boxes_xyxy, top1_letters, scores): | |
| img = pil.copy() | |
| draw = ImageDraw.Draw(img) | |
| for i, bb in enumerate(boxes_xyxy): | |
| x1, y1, x2, y2 = [safe_int(v) for v in bb] | |
| lab = top1_letters[i] if i < len(top1_letters) else "?" | |
| sc = scores[i] if i < len(scores) else 0.0 | |
| draw.rectangle([x1, y1, x2, y2], outline="red", width=2) | |
| text_label = f"{i+1:02d}" + (f" {lab}" if lab != "?" else "") | |
| draw.text((x1, y1), text_label, fill="red") | |
| return img | |
| # ============================================================================= | |
| # Tab Action Handlers | |
| # ============================================================================= | |
| def run_pipeline_tab(pil, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind): | |
| if pil is None: return None, [], "", "" | |
| L = ensure_loaded() | |
| boxes, det_scores = yolo_predict(L.det, pil, conf, iou, max_det) | |
| if not boxes: return pil, [], "No boxes detected.", "No boxes detected." | |
| topk_list, top1_letters, crop_gallery = classify_crops(L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k) | |
| src = build_source_text(boxes, det_scores, topk_list, rtl=rtl) | |
| heb = "" | |
| eng = "" | |
| if output_mode == "he": | |
| heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src) | |
| elif output_mode == "en_direct": | |
| eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src) | |
| else: # he_then_en | |
| heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src) | |
| if he2en_kind != "none": | |
| eng = translate_he_to_en(heb, he2en_kind) | |
| else: | |
| eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src) | |
| annotated = make_annotated(pil, boxes, top1_letters, det_scores) | |
| return annotated, crop_gallery, heb, eng | |
| def run_detector_tab(pil, conf, iou, max_det): | |
| if pil is None: return None, "{}" | |
| L = ensure_loaded() | |
| boxes, scores = yolo_predict(L.det, pil, conf, iou, max_det) | |
| annotated = make_annotated(pil, boxes, [], scores) | |
| debug = {"num_boxes": len(boxes), "boxes_xyxy": boxes, "confidences": scores} | |
| return annotated, json.dumps(debug, indent=2) | |
| def run_classifier_tab(pil, conf, iou, max_det, crop_pad, topk_k): | |
| if pil is None: return [], "{}" | |
| L = ensure_loaded() | |
| boxes, _ = yolo_predict(L.det, pil, conf, iou, max_det) | |
| if not boxes: return [], "No boxes found to classify." | |
| topk_list, top1_letters, crop_gallery = classify_crops(L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k) | |
| details = {} | |
| for i, (row, box) in enumerate(zip(topk_list, boxes)): | |
| details[f"Box_{i+1:02d}"] = { | |
| "top_predictions": dict(row), | |
| "coordinates": box | |
| } | |
| return crop_gallery, json.dumps(details, ensure_ascii=False, indent=2) | |
| def save_feedback(heb_pred, eng_pred, heb_corr, eng_corr, notes): | |
| rec = { | |
| "ts": now_ts(), | |
| "heb_pred": heb_pred, "eng_pred": eng_pred, | |
| "heb_corr": heb_corr, "eng_corr": eng_corr, | |
| "notes": notes, | |
| } | |
| with open(FEEDBACK_PATH, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(rec, ensure_ascii=False) + "\n") | |
| return f"Saved to {FEEDBACK_PATH}" | |
| # ============================================================================= | |
| # UI | |
| # ============================================================================= | |
| if EXAMPLES_DIR.exists(): | |
| gr.set_static_paths([str(EXAMPLES_DIR)]) | |
| def list_example_images(max_n: int = 24): | |
| if not EXAMPLES_DIR.exists(): return [] | |
| paths = [p for p in sorted(EXAMPLES_DIR.iterdir()) if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}] | |
| return [[str(p)] for p in paths[:max_n]] | |
| with gr.Blocks(title="Paleo-Hebrew Tablet Reader") as demo: | |
| gr.Markdown("# Paleo-Hebrew Epigraphy Pipeline") | |
| with gr.Row(): | |
| # LEFT COLUMN: Global Inputs & Settings | |
| with gr.Column(scale=3): | |
| inp = gr.Image(type="pil", label="Input Image") | |
| with gr.Accordion("Vision Settings", open=True): | |
| conf = gr.Slider(0.05, 0.95, value=0.25, step=0.01, label="YOLO conf") | |
| iou = gr.Slider(0.10, 0.90, value=0.45, step=0.01, label="YOLO IoU") | |
| max_det = gr.Slider(1, 200, value=80, step=1, label="Max Detections") | |
| crop_pad = gr.Slider(0.0, 0.8, value=0.20, step=0.01, label="Crop Pad") | |
| topk_k = gr.Slider(1, 10, value=5, step=1, label="Classifier Top-K") | |
| with gr.Accordion("Translation Settings", open=True): | |
| rtl = gr.Checkbox(value=True, label="RTL order (right→left)") | |
| output_mode = gr.Radio(["he", "en_direct", "he_then_en"], value="he_then_en", label="Output mode") | |
| he2en_kind = gr.Dropdown(choices=TRANSLATOR_CHOICES, value="opus", label="Hebrew → English translator") | |
| ex = list_example_images() | |
| if ex: | |
| gr.Examples(examples=ex, inputs=[inp], cache_examples=False) | |
| # RIGHT COLUMN: Tabs for different pipeline stages | |
| with gr.Column(scale=7): | |
| with gr.Tabs(): | |
| # TAB 1: Main E2E Pipeline | |
| with gr.Tab("End-to-End Pipeline"): | |
| run_pipe_btn = gr.Button("Run Full Pipeline", variant="primary") | |
| out_annot_pipe = gr.Image(type="pil", label="Detections (bbox + top1)") | |
| with gr.Row(): | |
| out_he_pipe = gr.Textbox(label="Hebrew Output", lines=3, interactive=False) | |
| out_en_pipe = gr.Textbox(label="English Output", lines=3, interactive=False) | |
| out_crops_pipe = gr.Gallery(label="Letter crops").style(grid=6, height="auto") if hasattr(gr.Gallery, 'style') else gr.Gallery(label="Letter crops") | |
| # TAB 2: Detector Only | |
| with gr.Tab("Detector (YOLO)"): | |
| run_det_btn = gr.Button("Run Detector", variant="primary") | |
| out_annot_det = gr.Image(type="pil", label="Detected Bounding Boxes") | |
| out_json_det = gr.Code(label="Raw Detection Output", language="json") | |
| # TAB 3: Classifier Only | |
| with gr.Tab("Classifier (ConvNeXt)"): | |
| run_cls_btn = gr.Button("Run Classifier", variant="primary") | |
| out_crops_cls = gr.Gallery(label="Isolated Letter Crops").style(grid=6, height="auto") if hasattr(gr.Gallery, 'style') else gr.Gallery(label="Isolated Letter Crops") | |
| out_json_cls = gr.Code(label="Top-K Probabilities per Box", language="json") | |
| # TAB 4: Feedback | |
| with gr.Tab("Feedback"): | |
| gr.Markdown("Submit corrections to improve the dataset.") | |
| heb_corr = gr.Textbox(label="Correct Hebrew", lines=2) | |
| eng_corr = gr.Textbox(label="Correct English", lines=2) | |
| notes = gr.Textbox(label="Notes (e.g., missed boxes, ambiguous glyphs)", lines=2) | |
| save_btn = gr.Button("Submit Feedback") | |
| save_status = gr.Textbox(label="Status", interactive=False) | |
| # --- Button Clicks --- | |
| run_pipe_btn.click( | |
| fn=run_pipeline_tab, | |
| inputs=[inp, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind], | |
| outputs=[out_annot_pipe, out_crops_pipe, out_he_pipe, out_en_pipe], | |
| api_name=False | |
| ) | |
| run_det_btn.click( | |
| fn=run_detector_tab, | |
| inputs=[inp, conf, iou, max_det], | |
| outputs=[out_annot_det, out_json_det], | |
| api_name=False | |
| ) | |
| run_cls_btn.click( | |
| fn=run_classifier_tab, | |
| inputs=[inp, conf, iou, max_det, crop_pad, topk_k], | |
| outputs=[out_crops_cls, out_json_cls], | |
| api_name=False | |
| ) | |
| save_btn.click( | |
| fn=save_feedback, | |
| inputs=[out_he_pipe, out_en_pipe, heb_corr, eng_corr, notes], | |
| outputs=[save_status], | |
| api_name=False | |
| ) | |
| demo.queue(max_size=32).launch() |