Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """Paleo‑Hebrew Epigraphy Pipeline (Gradio)""" | |
| import os | |
| import gc | |
| import io | |
| import json | |
| import time | |
| import base64 | |
| import tempfile | |
| import traceback | |
| import logging | |
| from pathlib import Path | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") | |
| try: | |
| from huggingface_hub.utils import disable_progress_bars | |
| disable_progress_bars() | |
| except Exception: | |
| pass | |
| import plotly.graph_objects as go | |
| from plotly.colors import qualitative | |
| # ----------------------------------------------------------------------------- | |
| # Logging | |
| # ----------------------------------------------------------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| log = logging.getLogger("paleo_demo") | |
| # ----------------------------------------------------------------------------- | |
| # Configuration via environment variables | |
| # ----------------------------------------------------------------------------- | |
| DET_REPO_ID = os.getenv("DET_REPO_ID", "mr3vial/paleo-hebrew-yolo") | |
| DET_FILENAME = os.getenv("DET_FILENAME", "best.onnx") | |
| DET_FALLBACK_PT = os.getenv("DET_FALLBACK_PT", "best.pt") | |
| CLS_REPO_ID = os.getenv("CLS_REPO_ID", "mr3vial/paleo-hebrew-convnext") | |
| CLS_WEIGHTS = os.getenv("CLS_WEIGHTS", "convnext_weights.pt") | |
| CLS_CLASSES = os.getenv("CLS_CLASSES", "classes.json") | |
| CLS_MODEL_NAME = os.getenv("CLS_MODEL_NAME", "convnext_large") | |
| MT5_BOX2HE_REPO = os.getenv( | |
| "MT5_BOX2HE_REPO", "mr3vial/paleo-hebrew-mt5-post-ocr-processing" | |
| ) | |
| MT5_BOX2EN_REPO = os.getenv("MT5_BOX2EN_REPO", "mr3vial/paleo-hebrew-mt5-translate") | |
| FEEDBACK_PATH = os.getenv("FEEDBACK_PATH", "/tmp/feedback.jsonl") | |
| EXAMPLES_DIR = Path(__file__).parent / "examples" | |
| IMG_HEIGHT = 420 | |
| # ----------------------------------------------------------------------------- | |
| # Translator choices (Hebrew -> English) for the he_then_en mode | |
| # ----------------------------------------------------------------------------- | |
| TRANSLATOR_CHOICES = [ | |
| ("None (use direct box→en)", "none"), | |
| ("OPUS-MT he→en (Helsinki-NLP/opus-mt-tc-big-he-en)", "opus"), | |
| ("NLLB-200 distilled 600M (he→en) [CC-BY-NC]", "nllb"), | |
| ("M2M100 418M (he→en)", "m2m100"), | |
| ] | |
| _TRANSLATOR_CACHE: Dict[str, Tuple[Any, Any, str]] = {} | |
| # ----------------------------------------------------------------------------- | |
| # Hebrew normalization (final forms -> base forms) | |
| # ----------------------------------------------------------------------------- | |
| FINAL_MAP: Dict[str, str] = { | |
| "ך": "כ", | |
| "ם": "מ", | |
| "ן": "נ", | |
| "ף": "פ", | |
| "ץ": "צ", | |
| } | |
| def normalize_hebrew_letter(ch: str) -> str: | |
| return FINAL_MAP.get(ch, ch) | |
| # ----------------------------------------------------------------------------- | |
| # Utility helpers | |
| # ----------------------------------------------------------------------------- | |
| def now_ts() -> str: | |
| return time.strftime("%Y-%m-%d %H:%M:%S") | |
| def normalize_spaces(s: str) -> str: | |
| return " ".join((s or "").strip().split()) | |
| def get_device() -> str: | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def safe_int(x: float, default: int = 0) -> int: | |
| try: | |
| return int(round(float(x))) | |
| except Exception: | |
| return default | |
| def sort_boxes_reading_order(boxes_xyxy: List[List[float]], rtl: bool = True) -> List[int]: | |
| """Heuristic reading order: group by y (lines), sort by x within line.""" | |
| if not boxes_xyxy: | |
| return [] | |
| centers: List[Tuple[int, float, float]] = [] | |
| heights: List[float] = [] | |
| for i, (x1, y1, x2, y2) in enumerate(boxes_xyxy): | |
| centers.append((i, 0.5 * (x1 + x2), 0.5 * (y1 + y2))) | |
| heights.append(max(1.0, y2 - y1)) | |
| line_tol = float(np.median(heights)) * 0.6 | |
| centers_sorted = sorted(centers, key=lambda t: t[2]) | |
| lines: List[List[Tuple[int, float, float]]] = [] | |
| for item in centers_sorted: | |
| if not lines: | |
| lines.append([item]) | |
| continue | |
| _, _, yc = item | |
| prev_line = lines[-1] | |
| prev_y = float(np.mean([p[2] for p in prev_line])) | |
| if abs(yc - prev_y) <= line_tol: | |
| prev_line.append(item) | |
| else: | |
| lines.append([item]) | |
| idxs: List[int] = [] | |
| for line in lines: | |
| line_sorted = sorted(line, key=lambda t: t[1], reverse=rtl) | |
| idxs.extend([i for i, _, _ in line_sorted]) | |
| return idxs | |
| # ----------------------------------------------------------------------------- | |
| # Model loading | |
| # ----------------------------------------------------------------------------- | |
| def load_detector() -> Tuple[Any, str]: | |
| from ultralytics import YOLO # lazy import | |
| try: | |
| det_path = hf_hub_download(repo_id=DET_REPO_ID, filename=DET_FILENAME) | |
| try: | |
| return YOLO(det_path, task="detect"), f"{DET_REPO_ID}/{DET_FILENAME}" | |
| except TypeError: | |
| return YOLO(det_path), f"{DET_REPO_ID}/{DET_FILENAME}" | |
| except Exception: | |
| pt_path = hf_hub_download(repo_id=DET_REPO_ID, filename=DET_FALLBACK_PT) | |
| try: | |
| return YOLO(pt_path, task="detect"), f"{DET_REPO_ID}/{DET_FALLBACK_PT} (PT fallback)" | |
| except TypeError: | |
| return YOLO(pt_path), f"{DET_REPO_ID}/{DET_FALLBACK_PT} (PT fallback)" | |
| def load_classifier() -> Tuple[Any, List[str], str]: | |
| import timm # lazy import | |
| device = get_device() | |
| classes_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_CLASSES) | |
| raw: List[Any] = [] | |
| try: | |
| with open(classes_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| if isinstance(data, dict) and "classes" in data: | |
| raw = list(data["classes"]) | |
| elif isinstance(data, dict): | |
| raw = list(data.values()) | |
| elif isinstance(data, list): | |
| raw = data | |
| except Exception: | |
| with open(classes_path, "r", encoding="utf-8") as f: | |
| raw = [ln.strip() for ln in f if ln.strip()] | |
| letters: List[str] = [] | |
| for ln in raw: | |
| val = str(ln).strip() | |
| val = val.split("_")[-1] if "_" in val else val | |
| val = val.strip('"\' ,') | |
| if val: | |
| letters.append(val) | |
| weights_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_WEIGHTS) | |
| ckpt = torch.load(weights_path, map_location="cpu") | |
| state = ckpt.get("model", ckpt.get("state_dict", ckpt)) | |
| if not isinstance(state, dict): | |
| raise RuntimeError(f"Bad checkpoint format: type(state)={type(state)}") | |
| new_state: Dict[str, torch.Tensor] = {} | |
| for k, v in state.items(): | |
| nk = k | |
| for pref in ("module.", "model."): | |
| if nk.startswith(pref): | |
| nk = nk[len(pref) :] | |
| new_state[nk] = v | |
| if "head.fc.weight" not in new_state: | |
| raise RuntimeError("Checkpoint is missing head.fc.weight; cannot infer number of classes.") | |
| num_classes_ckpt = int(new_state["head.fc.weight"].shape[0]) | |
| if len(letters) != num_classes_ckpt: | |
| log.warning( | |
| "classes.json has %d labels but checkpoint expects %d classes; truncating/padding.", | |
| len(letters), | |
| num_classes_ckpt, | |
| ) | |
| if len(letters) > num_classes_ckpt: | |
| letters = letters[:num_classes_ckpt] | |
| else: | |
| letters = letters + [f"cls_{i:02d}" for i in range(len(letters), num_classes_ckpt)] | |
| model = timm.create_model(CLS_MODEL_NAME, pretrained=False, num_classes=num_classes_ckpt) | |
| model.load_state_dict(new_state, strict=False) | |
| model.eval().to(device) | |
| return model, letters, f"{CLS_REPO_ID}/{CLS_WEIGHTS} ({CLS_MODEL_NAME}, C={num_classes_ckpt})" | |
| def load_mt5(repo_id: str) -> Tuple[Any, Any]: | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| device = get_device() | |
| if device == "cuda": | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| else: | |
| dtype = torch.float32 | |
| tok = AutoTokenizer.from_pretrained(repo_id, use_fast=False) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(repo_id, torch_dtype=dtype).to(device).eval() | |
| return tok, model | |
| def mt5_generate(tok, model, text: str, max_new_tokens: int = 128) -> str: | |
| device = get_device() | |
| inp = tok([text], return_tensors="pt", truncation=True, max_length=2048).to(device) | |
| out = model.generate(**inp, max_new_tokens=max_new_tokens, num_beams=4, do_sample=False) | |
| return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0]) | |
| # ----------------------------------------------------------------------------- | |
| # Hebrew -> English translators (optional, only used in he_then_en mode) | |
| # ----------------------------------------------------------------------------- | |
| def get_he2en_translator(kind: str) -> Tuple[Any, Any, str]: | |
| global _TRANSLATOR_CACHE | |
| if kind in _TRANSLATOR_CACHE: | |
| return _TRANSLATOR_CACHE[kind] | |
| _TRANSLATOR_CACHE.clear() | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| device = get_device() | |
| if kind == "opus": | |
| from transformers import MarianMTModel, MarianTokenizer | |
| repo = "Helsinki-NLP/opus-mt-tc-big-he-en" | |
| tok = MarianTokenizer.from_pretrained(repo) | |
| model = MarianMTModel.from_pretrained(repo).to(device).eval() | |
| elif kind == "nllb": | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| repo = "facebook/nllb-200-distilled-600M" | |
| tok = AutoTokenizer.from_pretrained(repo, src_lang="heb_Hebr") | |
| model = AutoModelForSeq2SeqLM.from_pretrained(repo).to(device).eval() | |
| elif kind == "m2m100": | |
| from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | |
| repo = "facebook/m2m100_418M" | |
| tok = M2M100Tokenizer.from_pretrained(repo) | |
| tok.src_lang = "he" | |
| model = M2M100ForConditionalGeneration.from_pretrained(repo).to(device).eval() | |
| else: | |
| _TRANSLATOR_CACHE[kind] = (None, None, "none") | |
| return _TRANSLATOR_CACHE[kind] | |
| _TRANSLATOR_CACHE[kind] = (tok, model, kind) | |
| return _TRANSLATOR_CACHE[kind] | |
| def translate_he_to_en(text_he: str, kind: str) -> str: | |
| if kind == "none": | |
| return "" | |
| tok, model, kind = get_he2en_translator(kind) | |
| device = get_device() | |
| if kind == "opus": | |
| batch = tok([text_he], return_tensors="pt", padding=True, truncation=True).to(device) | |
| out = model.generate(**batch, max_new_tokens=128, num_beams=4) | |
| return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0]) | |
| if kind == "nllb": | |
| batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device) | |
| out = model.generate( | |
| **batch, | |
| forced_bos_token_id=tok.convert_tokens_to_ids("eng_Latn"), | |
| max_new_tokens=128, | |
| num_beams=4, | |
| ) | |
| return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0]) | |
| if kind == "m2m100": | |
| batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device) | |
| out = model.generate( | |
| **batch, | |
| forced_bos_token_id=tok.get_lang_id("en"), | |
| max_new_tokens=128, | |
| num_beams=4, | |
| ) | |
| return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0]) | |
| return "" | |
| # ----------------------------------------------------------------------------- | |
| # Core pipeline | |
| # ----------------------------------------------------------------------------- | |
| class Loaded: | |
| det: Any | |
| det_name: str | |
| cls: Any | |
| cls_letters: List[str] | |
| cls_name: str | |
| mt5_he_tok: Any | |
| mt5_he: Any | |
| mt5_en_tok: Any | |
| mt5_en: Any | |
| LOADED: Optional[Loaded] = None | |
| def ensure_loaded() -> Loaded: | |
| global LOADED | |
| if LOADED is not None: | |
| return LOADED | |
| det, det_name = load_detector() | |
| cls, letters, cls_name = load_classifier() | |
| he_tok, he_model = load_mt5(MT5_BOX2HE_REPO) | |
| en_tok, en_model = load_mt5(MT5_BOX2EN_REPO) | |
| LOADED = Loaded( | |
| det=det, | |
| det_name=det_name, | |
| cls=cls, | |
| cls_letters=letters, | |
| cls_name=cls_name, | |
| mt5_he_tok=he_tok, | |
| mt5_he=he_model, | |
| mt5_en_tok=en_tok, | |
| mt5_en=en_model, | |
| ) | |
| log.info( | |
| "Loaded models: detector=%s classifier=%s box2he=%s box2en=%s", | |
| det_name, | |
| cls_name, | |
| MT5_BOX2HE_REPO, | |
| MT5_BOX2EN_REPO, | |
| ) | |
| return LOADED | |
| def yolo_predict(det_model, pil: Image.Image, conf: float, iou: float, max_det: int) -> Tuple[List[List[float]], List[float]]: | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as f: | |
| pil.save(f.name) | |
| res = det_model.predict( | |
| source=f.name, | |
| conf=float(conf), | |
| iou=float(iou), | |
| max_det=int(max_det), | |
| verbose=False, | |
| ) | |
| r0 = res[0] | |
| if r0.boxes is None or len(r0.boxes) == 0: | |
| return [], [] | |
| xyxy = r0.boxes.xyxy.detach().float().cpu().numpy().tolist() | |
| scores = r0.boxes.conf.detach().float().cpu().numpy().tolist() | |
| return xyxy, scores | |
| def classify_crops( | |
| cls_model: Any, | |
| letters: List[str], | |
| pil: Image.Image, | |
| boxes_xyxy: List[List[float]], | |
| pad: float, | |
| topk: int, | |
| ) -> Tuple[List[List[Tuple[str, float]]], List[str], List[Tuple[Image.Image, str]]]: | |
| device = get_device() | |
| W, H = pil.size | |
| crops: List[Image.Image] = [] | |
| for (x1, y1, x2, y2) in boxes_xyxy: | |
| w = max(1.0, x2 - x1) | |
| h = max(1.0, y2 - y1) | |
| p = int(round(max(w, h) * float(pad))) | |
| a = max(0, safe_int(x1) - p) | |
| b = max(0, safe_int(y1) - p) | |
| c = min(W, safe_int(x2) + p) | |
| d = min(H, safe_int(y2) + p) | |
| crop = pil.crop((a, b, c, d)).resize((96, 96), Image.BICUBIC) | |
| crops.append(crop) | |
| if not crops: | |
| return [], [], [] | |
| def preprocess(img: Image.Image) -> torch.Tensor: | |
| arr = np.asarray(img.convert("RGB"), dtype=np.float32) / 255.0 | |
| t = torch.from_numpy(arr).permute(2, 0, 1).contiguous() | |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) | |
| std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) | |
| return (t - mean) / std | |
| batch = torch.stack([preprocess(c) for c in crops], dim=0).to(device) | |
| logits = cls_model(batch) | |
| probs = torch.softmax(logits, dim=-1).detach().float().cpu().numpy() | |
| topk_list: List[List[Tuple[str, float]]] = [] | |
| top1_letters: List[str] = [] | |
| for i in range(len(crops)): | |
| p = probs[i] | |
| idx = np.argsort(-p)[: max(1, int(topk))] | |
| row: List[Tuple[str, float]] = [] | |
| for j in idx: | |
| lab = normalize_hebrew_letter(letters[j]) | |
| row.append((lab, float(p[j]))) | |
| topk_list.append(row) | |
| top1_letters.append(row[0][0] if row else "?") | |
| crop_gallery = [(crops[i], f"{i+1:02d}: {top1_letters[i]}") for i in range(len(crops))] | |
| return topk_list, top1_letters, crop_gallery | |
| # ----------------------------------------------------------------------------- | |
| # mT5 source serialization | |
| # ----------------------------------------------------------------------------- | |
| def _compute_box_normalizer_from_boxes(boxes: List[List[float]]) -> Tuple[float, float, float, float]: | |
| xs = [b[0] for b in boxes] + [b[2] for b in boxes] | |
| ys = [b[1] for b in boxes] + [b[3] for b in boxes] | |
| minx = float(min(xs)) if xs else 0.0 | |
| maxx = float(max(xs)) if xs else 1.0 | |
| miny = float(min(ys)) if ys else 0.0 | |
| maxy = float(max(ys)) if ys else 1.0 | |
| scalex = max(1e-6, maxx - minx) | |
| scaley = max(1e-6, maxy - miny) | |
| return minx, miny, scalex, scaley | |
| def build_source_text_mt5( | |
| *, | |
| pil: Image.Image, | |
| boxes_xyxy: List[List[float]], | |
| det_scores: List[float], | |
| topk_list: List[List[Tuple[str, float]]], | |
| coord_norm: str, | |
| order_mode: str, | |
| rtl: bool, | |
| ) -> str: | |
| """Serialize boxes+cls into the format mT5 was trained on. | |
| Header: | |
| [BOXES_CLS_V3] | |
| n=.. coord_norm=boxes|det|none | |
| Body: | |
| 01 x1=... y1=... x2=... y2=... w=... h=... xc=... yc=... score=... | cls א:0.123 ... | |
| """ | |
| coord_norm = (coord_norm or "boxes").strip().lower() | |
| if coord_norm not in ("boxes", "det", "none"): | |
| coord_norm = "boxes" | |
| order_mode = (order_mode or "reading").strip().lower() | |
| if order_mode not in ("reading", "detector"): | |
| order_mode = "reading" | |
| if not boxes_xyxy: | |
| return "[BOXES_CLS_V3]\n" + "n=0 coord_norm=none" | |
| if order_mode == "reading": | |
| idxs = sort_boxes_reading_order(boxes_xyxy, rtl=rtl) | |
| else: | |
| idxs = list(range(len(boxes_xyxy))) | |
| W, H = pil.size | |
| if coord_norm == "boxes": | |
| minx, miny, scalex, scaley = _compute_box_normalizer_from_boxes(boxes_xyxy) | |
| elif coord_norm == "det": | |
| minx, miny, scalex, scaley = 0.0, 0.0, float(W), float(H) | |
| else: | |
| minx, miny, scalex, scaley = 0.0, 0.0, 1.0, 1.0 | |
| lines = ["[BOXES_CLS_V3]", f"n={len(idxs)} coord_norm={coord_norm}"] | |
| for j, i in enumerate(idxs, start=1): | |
| x1, y1, x2, y2 = boxes_xyxy[i] | |
| sc = det_scores[i] if i < len(det_scores) else 1.0 | |
| if coord_norm != "none": | |
| x1 = (x1 - minx) / scalex | |
| x2 = (x2 - minx) / scalex | |
| y1 = (y1 - miny) / scaley | |
| y2 = (y2 - miny) / scaley | |
| xa, xb = (x1, x2) if x1 <= x2 else (x2, x1) | |
| ya, yb = (y1, y2) if y1 <= y2 else (y2, y1) | |
| w = max(1e-6, xb - xa) | |
| h = max(1e-6, yb - ya) | |
| xc = 0.5 * (xa + xb) | |
| yc = 0.5 * (ya + yb) | |
| cands = topk_list[i] if i < len(topk_list) else [] | |
| cls_str = " ".join([f"{ch}:{p:.3f}" for ch, p in cands]) if cands else "?" | |
| lines.append( | |
| f"{j:02d} x1={xa:.4f} y1={ya:.4f} x2={xb:.4f} y2={yb:.4f} " | |
| f"w={w:.4f} h={h:.4f} xc={xc:.4f} yc={yc:.4f} score={sc:.3f} | cls {cls_str}" | |
| ) | |
| return "\n".join(lines) | |
| # ----------------------------------------------------------------------------- | |
| # Plotly overlay (hover tooltips on bboxes) | |
| # ----------------------------------------------------------------------------- | |
| def _pil_to_data_uri(pil: Image.Image, fmt: str = "PNG") -> str: | |
| buf = io.BytesIO() | |
| pil.save(buf, format=fmt) | |
| b64 = base64.b64encode(buf.getvalue()).decode("ascii") | |
| return f"data:image/{fmt.lower()};base64,{b64}" | |
| def _hex_to_rgb(hex_color: str) -> Tuple[int, int, int]: | |
| s = hex_color.lstrip("#") | |
| return int(s[0:2], 16), int(s[2:4], 16), int(s[4:6], 16) | |
| def make_bbox_figure( | |
| pil: Image.Image, | |
| boxes_xyxy: List[List[float]], | |
| labels: Optional[List[str]] = None, | |
| det_scores: Optional[List[float]] = None, | |
| topk_list: Optional[List[List[Tuple[str, float]]]] = None, | |
| height: int = IMG_HEIGHT, | |
| ) -> go.Figure: | |
| W, H = pil.size | |
| fig = go.Figure() | |
| fig.add_layout_image( | |
| dict( | |
| source=_pil_to_data_uri(pil), | |
| xref="x", | |
| yref="y", | |
| x=0, | |
| y=0, | |
| sizex=W, | |
| sizey=H, | |
| sizing="stretch", | |
| layer="below", | |
| ) | |
| ) | |
| palette = qualitative.Plotly | |
| for i, (x1, y1, x2, y2) in enumerate(boxes_xyxy): | |
| color_hex = palette[i % len(palette)] | |
| r, g, b = _hex_to_rgb(color_hex) | |
| lab = labels[i] if labels and i < len(labels) else f"{i+1:02d}" | |
| sc = det_scores[i] if det_scores and i < len(det_scores) else None | |
| topk_str = "" | |
| if topk_list and i < len(topk_list) and topk_list[i]: | |
| topk_str = ", ".join([f"{c}:{p:.3f}" for c, p in topk_list[i]]) | |
| hover_lines = [f"#{i+1:02d}", f"top1: {lab}"] | |
| if sc is not None: | |
| hover_lines.append(f"det_conf: {sc:.3f}") | |
| if topk_str: | |
| hover_lines.append(f"topk: {topk_str}") | |
| xs = [x1, x2, x2, x1, x1] | |
| ys = [y1, y1, y2, y2, y1] | |
| fig.add_trace( | |
| go.Scatter( | |
| x=xs, | |
| y=ys, | |
| mode="lines", | |
| fill="toself", | |
| line=dict(width=3, color=f"rgb({r},{g},{b})"), | |
| fillcolor=f"rgba({r},{g},{b},0.25)", | |
| hoverinfo="text", | |
| text="<br>".join(hover_lines), | |
| showlegend=False, | |
| ) | |
| ) | |
| fig.update_xaxes(visible=False, range=[0, W], constrain="domain") | |
| fig.update_yaxes(visible=False, range=[H, 0], scaleanchor="x") | |
| fig.update_layout(height=height, margin=dict(l=0, r=0, t=30, b=0), hovermode="closest") | |
| return fig | |
| # ----------------------------------------------------------------------------- | |
| # UI callbacks | |
| # ----------------------------------------------------------------------------- | |
| def run_pipeline_tab( | |
| pil: Optional[Image.Image], | |
| conf: float, | |
| iou: float, | |
| max_det: int, | |
| crop_pad: float, | |
| topk_k: int, | |
| rtl: bool, | |
| output_mode: str, | |
| he2en_kind: str, | |
| coord_norm: str, | |
| order_mode: str, | |
| ) -> Tuple[ | |
| Optional[Any], | |
| List[Tuple[Image.Image, str]], | |
| str, | |
| str, | |
| str, | |
| str, | |
| List[List[Tuple[str, float]]], | |
| ]: | |
| try: | |
| if pil is None: | |
| return None, [], "", "", json.dumps({"error": "No input image"}, ensure_ascii=False, indent=2), "", [] | |
| pil = pil.convert("RGB") | |
| L = ensure_loaded() | |
| boxes, det_scores = yolo_predict(L.det, pil, conf, iou, max_det) | |
| if not boxes: | |
| fig = make_bbox_figure(pil, [], height=IMG_HEIGHT) | |
| dbg = {"ts": now_ts(), "detector": L.det_name, "n_boxes": 0} | |
| return fig, [], "No boxes detected.", "No boxes detected.", json.dumps(dbg, ensure_ascii=False, indent=2), "", [] | |
| topk_list, top1_letters, crop_gallery = classify_crops( | |
| L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k | |
| ) | |
| src = build_source_text_mt5( | |
| pil=pil, | |
| boxes_xyxy=boxes, | |
| det_scores=det_scores, | |
| topk_list=topk_list, | |
| coord_norm=coord_norm, | |
| order_mode=order_mode, | |
| rtl=rtl, | |
| ) | |
| heb = "" | |
| eng = "" | |
| if output_mode == "he": | |
| heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src) | |
| elif output_mode == "en_direct": | |
| eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src) | |
| else: | |
| heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src) | |
| if he2en_kind != "none": | |
| eng = translate_he_to_en(heb, he2en_kind) | |
| else: | |
| eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src) | |
| fig = make_bbox_figure( | |
| pil, | |
| boxes, | |
| labels=top1_letters, | |
| det_scores=det_scores, | |
| topk_list=topk_list, | |
| height=IMG_HEIGHT, | |
| ) | |
| dbg = { | |
| "ts": now_ts(), | |
| "device": get_device(), | |
| "detector": L.det_name, | |
| "classifier": L.cls_name, | |
| "n_boxes": len(boxes), | |
| "coord_norm": coord_norm, | |
| "order_mode": order_mode, | |
| "rtl": bool(rtl), | |
| "output_mode": output_mode, | |
| "he2en_kind": he2en_kind, | |
| } | |
| return fig, crop_gallery, heb, eng, json.dumps(dbg, ensure_ascii=False, indent=2), src, topk_list | |
| except Exception: | |
| log.exception("run_pipeline_tab failed") | |
| return None, [], "ERROR", "ERROR", traceback.format_exc(), "", [] | |
| def run_detector_tab( | |
| pil: Optional[Image.Image], conf: float, iou: float, max_det: int | |
| ) -> Tuple[Optional[Any], str]: | |
| try: | |
| if pil is None: | |
| return None, json.dumps({"error": "No input image"}, ensure_ascii=False, indent=2) | |
| pil = pil.convert("RGB") | |
| L = ensure_loaded() | |
| boxes, scores = yolo_predict(L.det, pil, conf, iou, max_det) | |
| fig = make_bbox_figure(pil, boxes, det_scores=scores, height=IMG_HEIGHT) | |
| dbg = { | |
| "ts": now_ts(), | |
| "detector": L.det_name, | |
| "n_boxes": len(boxes), | |
| "boxes_xyxy": boxes, | |
| "scores": scores, | |
| } | |
| return fig, json.dumps(dbg, ensure_ascii=False, indent=2) | |
| except Exception: | |
| log.exception("run_detector_tab failed") | |
| return None, traceback.format_exc() | |
| def run_classifier_tab( | |
| pil: Optional[Image.Image], | |
| conf: float, | |
| iou: float, | |
| max_det: int, | |
| crop_pad: float, | |
| topk_k: int, | |
| ) -> Tuple[List[Tuple[Image.Image, str]], str]: | |
| try: | |
| if pil is None: | |
| return [], json.dumps({"error": "No input image"}, ensure_ascii=False, indent=2) | |
| pil = pil.convert("RGB") | |
| L = ensure_loaded() | |
| boxes, det_scores = yolo_predict(L.det, pil, conf, iou, max_det) | |
| if not boxes: | |
| return [], json.dumps({"ts": now_ts(), "n_boxes": 0, "msg": "No boxes detected"}, ensure_ascii=False, indent=2) | |
| topk_list, top1_letters, crop_gallery = classify_crops( | |
| L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k | |
| ) | |
| details: Dict[str, Any] = {} | |
| for i, (row, box, sc) in enumerate(zip(topk_list, boxes, det_scores)): | |
| details[f"Box_{i+1:02d}"] = {"xyxy": box, "det_score": float(sc), "topk": row} | |
| dbg = {"ts": now_ts(), "classifier": L.cls_name, "n_boxes": len(boxes), "details": details} | |
| return crop_gallery, json.dumps(dbg, ensure_ascii=False, indent=2) | |
| except Exception: | |
| log.exception("run_classifier_tab failed") | |
| return [], traceback.format_exc() | |
| def show_topk_popup(topk_list: List[List[Tuple[str, float]]], evt: gr.SelectData) -> str: | |
| idx = getattr(evt, "index", None) | |
| if idx is None or not topk_list or idx < 0 or idx >= len(topk_list): | |
| return "" | |
| rows = topk_list[idx] | |
| lines = "\n".join([f"{lab:>2} {p:.4f}" for lab, p in rows]) | |
| return f""" | |
| <div class="modal-backdrop" onclick="this.parentElement.innerHTML = ''"> | |
| <div class="modal-card"> | |
| <div class="modal-title">Top‑K for crop #{idx + 1}</div> | |
| <pre class="modal-pre">{lines}</pre> | |
| <div class="modal-hint">Click anywhere outside to close</div> | |
| </div> | |
| </div> | |
| """ | |
| def save_feedback( | |
| heb_pred: str, | |
| eng_pred: str, | |
| heb_corr: str, | |
| eng_corr: str, | |
| notes: str, | |
| ) -> str: | |
| try: | |
| rec = { | |
| "ts": now_ts(), | |
| "heb_pred": heb_pred, | |
| "eng_pred": eng_pred, | |
| "heb_corr": heb_corr, | |
| "eng_corr": eng_corr, | |
| "notes": notes, | |
| } | |
| os.makedirs(os.path.dirname(FEEDBACK_PATH), exist_ok=True) | |
| with open(FEEDBACK_PATH, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(rec, ensure_ascii=False) + "\n") | |
| return f"Saved to {FEEDBACK_PATH}" | |
| except Exception: | |
| log.exception("save_feedback failed") | |
| return traceback.format_exc() | |
| # ----------------------------------------------------------------------------- | |
| # Examples helper | |
| # ----------------------------------------------------------------------------- | |
| if EXAMPLES_DIR.exists(): | |
| gr.set_static_paths([str(EXAMPLES_DIR)]) | |
| def list_example_images(max_n: int = 24) -> List[List[str]]: | |
| if not EXAMPLES_DIR.exists(): | |
| return [] | |
| paths = [p for p in sorted(EXAMPLES_DIR.iterdir()) if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}] | |
| return [[str(p)] for p in paths[:max_n]] | |
| # ----------------------------------------------------------------------------- | |
| # UI | |
| # ----------------------------------------------------------------------------- | |
| CSS = """ | |
| #topk_modal .modal-backdrop { | |
| position: fixed; | |
| inset: 0; | |
| background: rgba(0, 0, 0, 0.35); | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| z-index: 9999; | |
| } | |
| #topk_modal .modal-card { | |
| background: white; | |
| border-radius: 14px; | |
| padding: 18px 22px; | |
| min-width: 280px; | |
| max-width: 520px; | |
| box-shadow: 0 10px 40px rgba(0, 0, 0, 0.25); | |
| } | |
| #topk_modal .modal-title { | |
| font-size: 18px; | |
| font-weight: 700; | |
| margin-bottom: 10px; | |
| } | |
| #topk_modal .modal-pre { | |
| font-size: 16px; | |
| line-height: 1.25; | |
| margin: 0; | |
| white-space: pre; | |
| } | |
| #topk_modal .modal-hint { | |
| opacity: 0.6; | |
| margin-top: 10px; | |
| } | |
| """ | |
| with gr.Blocks(title="Paleo‑Hebrew Tablet Reader", css=CSS) as demo: | |
| gr.Markdown("# Paleo‑Hebrew Epigraphy Pipeline") | |
| topk_state = gr.State([]) | |
| modal = gr.HTML(value="", elem_id="topk_modal") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| inp = gr.Image(type="pil", label="Input Image", height=IMG_HEIGHT) | |
| with gr.Accordion("Vision Settings", open=True): | |
| conf = gr.Slider(0.05, 0.95, value=0.25, step=0.01, label="Box Confidence") | |
| iou = gr.Slider(0.10, 0.90, value=0.45, step=0.01, label="IoU") | |
| max_det = gr.Slider(1, 200, value=80, step=1, label="Max Detections") | |
| crop_pad = gr.Slider(0.0, 0.8, value=0.20, step=0.01, label="Crop Pad") | |
| topk_k = gr.Slider(1, 10, value=5, step=1, label="Classifier Top‑K") | |
| with gr.Accordion("mT5 + Translation Settings", open=True): | |
| rtl = gr.Checkbox(value=True, label="RTL order (right to left) [for reading-order mode]") | |
| order_mode = gr.Radio(["reading", "detector"], value="reading", label="Box order") | |
| coord_norm = gr.Radio(["boxes", "det", "none"], value="boxes", label="coord_norm (like training)") | |
| output_mode = gr.Radio(["he", "en_direct", "he_then_en"], value="he_then_en", label="Output mode") | |
| he2en_kind = gr.Dropdown(choices=TRANSLATOR_CHOICES, value="opus", label="Hebrew to English translator") | |
| with gr.Accordion("How to run a VLM yourself (optional)", open=False): | |
| gr.Markdown( | |
| """ | |
| If you want a VLM post‑OCR step (outside this Space), you can run it locally / on a GPU box. | |
| High-level steps: | |
| 1) Build a `tools_text` block from YOLO+classifier: `[DETECTOR ...]` + `[CLASSIFIER topk]`. | |
| 2) Feed image + prompt + tools_text into your Qwen3‑VL LoRA model. | |
| Skeleton (pseudo-code): | |
| ```python | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| proc = AutoProcessor.from_pretrained("mr3vial/paleo-hebrew-qwen3-vl-lora-post-ocr-processing") | |
| model = AutoModelForImageTextToText.from_pretrained("mr3vial/paleo-hebrew-qwen3-vl-lora-post-ocr-processing") | |
| prompt = "Transcribe the text on the tablet.\n\n" + tools_text | |
| inputs = proc(text=prompt, images=[pil_image], return_tensors="pt") | |
| out = model.generate(**inputs, max_new_tokens=64) | |
| pred = proc.batch_decode(out, skip_special_tokens=True)[0] | |
| ``` | |
| """ | |
| ) | |
| ex = list_example_images() | |
| if ex: | |
| gr.Markdown("### Examples") | |
| gr.Examples(examples=ex, inputs=[inp], cache_examples=False) | |
| with gr.Column(scale=7): | |
| with gr.Tabs(): | |
| with gr.Tab("End-to-End Pipeline"): | |
| run_pipe_btn = gr.Button("Run Full Pipeline", variant="primary") | |
| out_fig_pipe = gr.Plot(label="Detections (hover tooltip)") | |
| with gr.Row(): | |
| out_he_pipe = gr.Textbox(label="Hebrew", lines=2, interactive=False) | |
| out_en_pipe = gr.Textbox(label="English", lines=2, interactive=False) | |
| out_crops_pipe = gr.Gallery(columns=8, height=360, label="Letter crops") | |
| with gr.Accordion("Debug", open=False): | |
| out_dbg_pipe = gr.Code(label="Debug JSON", language="json") | |
| out_mt5_src = gr.Textbox(label="mT5 source text (fed into mT5)", lines=10, interactive=False) | |
| with gr.Tab("Detector"): | |
| run_det_btn = gr.Button("Run Detector", variant="primary") | |
| out_fig_det = gr.Plot(label="Detected Bounding Boxes") | |
| out_json_det = gr.Code(label="Raw Detection Output", language="json") | |
| with gr.Tab("Classifier"): | |
| run_cls_btn = gr.Button("Run Classifier", variant="primary") | |
| out_crops_cls = gr.Gallery(columns=8, height=360, label="Isolated Letter Crops") | |
| out_json_cls = gr.Code(label="Top‑K per Box", language="json") | |
| with gr.Tab("Feedback"): | |
| gr.Markdown("Submit corrections to improve the dataset.") | |
| heb_corr = gr.Textbox(label="Correct Hebrew", lines=2) | |
| eng_corr = gr.Textbox(label="Correct English", lines=2) | |
| notes = gr.Textbox(label="Notes", lines=2) | |
| save_btn = gr.Button("Submit Feedback") | |
| save_status = gr.Textbox(label="Status", interactive=False) | |
| # Wiring | |
| run_pipe_btn.click( | |
| fn=run_pipeline_tab, | |
| inputs=[inp, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind, coord_norm, order_mode], | |
| outputs=[out_fig_pipe, out_crops_pipe, out_he_pipe, out_en_pipe, out_dbg_pipe, out_mt5_src, topk_state], | |
| api_name=False, | |
| ) | |
| out_crops_pipe.select(fn=show_topk_popup, inputs=[topk_state], outputs=[modal]) | |
| run_det_btn.click( | |
| fn=run_detector_tab, | |
| inputs=[inp, conf, iou, max_det], | |
| outputs=[out_fig_det, out_json_det], | |
| api_name=False, | |
| ) | |
| run_cls_btn.click( | |
| fn=run_classifier_tab, | |
| inputs=[inp, conf, iou, max_det, crop_pad, topk_k], | |
| outputs=[out_crops_cls, out_json_cls], | |
| api_name=False, | |
| ) | |
| save_btn.click( | |
| fn=save_feedback, | |
| inputs=[out_he_pipe, out_en_pipe, heb_corr, eng_corr, notes], | |
| outputs=[save_status], | |
| api_name=False, | |
| ) | |
| demo.queue(max_size=32).launch(show_error=True) | |