mr3vial's picture
Update app.py
e96a9b4 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Paleo‑Hebrew Epigraphy Pipeline (Gradio)"""
import os
import gc
import io
import json
import time
import base64
import tempfile
import traceback
import logging
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from PIL import Image
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
try:
from huggingface_hub.utils import disable_progress_bars
disable_progress_bars()
except Exception:
pass
import plotly.graph_objects as go
from plotly.colors import qualitative
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
log = logging.getLogger("paleo_demo")
# -----------------------------------------------------------------------------
# Configuration via environment variables
# -----------------------------------------------------------------------------
DET_REPO_ID = os.getenv("DET_REPO_ID", "mr3vial/paleo-hebrew-yolo")
DET_FILENAME = os.getenv("DET_FILENAME", "best.onnx")
DET_FALLBACK_PT = os.getenv("DET_FALLBACK_PT", "best.pt")
CLS_REPO_ID = os.getenv("CLS_REPO_ID", "mr3vial/paleo-hebrew-convnext")
CLS_WEIGHTS = os.getenv("CLS_WEIGHTS", "convnext_weights.pt")
CLS_CLASSES = os.getenv("CLS_CLASSES", "classes.json")
CLS_MODEL_NAME = os.getenv("CLS_MODEL_NAME", "convnext_large")
MT5_BOX2HE_REPO = os.getenv(
"MT5_BOX2HE_REPO", "mr3vial/paleo-hebrew-mt5-post-ocr-processing"
)
MT5_BOX2EN_REPO = os.getenv("MT5_BOX2EN_REPO", "mr3vial/paleo-hebrew-mt5-translate")
FEEDBACK_PATH = os.getenv("FEEDBACK_PATH", "/tmp/feedback.jsonl")
EXAMPLES_DIR = Path(__file__).parent / "examples"
IMG_HEIGHT = 420
# -----------------------------------------------------------------------------
# Translator choices (Hebrew -> English) for the he_then_en mode
# -----------------------------------------------------------------------------
TRANSLATOR_CHOICES = [
("None (use direct box→en)", "none"),
("OPUS-MT he→en (Helsinki-NLP/opus-mt-tc-big-he-en)", "opus"),
("NLLB-200 distilled 600M (he→en) [CC-BY-NC]", "nllb"),
("M2M100 418M (he→en)", "m2m100"),
]
_TRANSLATOR_CACHE: Dict[str, Tuple[Any, Any, str]] = {}
# -----------------------------------------------------------------------------
# Hebrew normalization (final forms -> base forms)
# -----------------------------------------------------------------------------
FINAL_MAP: Dict[str, str] = {
"ך": "כ",
"ם": "מ",
"ן": "נ",
"ף": "פ",
"ץ": "צ",
}
def normalize_hebrew_letter(ch: str) -> str:
return FINAL_MAP.get(ch, ch)
# -----------------------------------------------------------------------------
# Utility helpers
# -----------------------------------------------------------------------------
def now_ts() -> str:
return time.strftime("%Y-%m-%d %H:%M:%S")
def normalize_spaces(s: str) -> str:
return " ".join((s or "").strip().split())
def get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
def safe_int(x: float, default: int = 0) -> int:
try:
return int(round(float(x)))
except Exception:
return default
def sort_boxes_reading_order(boxes_xyxy: List[List[float]], rtl: bool = True) -> List[int]:
"""Heuristic reading order: group by y (lines), sort by x within line."""
if not boxes_xyxy:
return []
centers: List[Tuple[int, float, float]] = []
heights: List[float] = []
for i, (x1, y1, x2, y2) in enumerate(boxes_xyxy):
centers.append((i, 0.5 * (x1 + x2), 0.5 * (y1 + y2)))
heights.append(max(1.0, y2 - y1))
line_tol = float(np.median(heights)) * 0.6
centers_sorted = sorted(centers, key=lambda t: t[2])
lines: List[List[Tuple[int, float, float]]] = []
for item in centers_sorted:
if not lines:
lines.append([item])
continue
_, _, yc = item
prev_line = lines[-1]
prev_y = float(np.mean([p[2] for p in prev_line]))
if abs(yc - prev_y) <= line_tol:
prev_line.append(item)
else:
lines.append([item])
idxs: List[int] = []
for line in lines:
line_sorted = sorted(line, key=lambda t: t[1], reverse=rtl)
idxs.extend([i for i, _, _ in line_sorted])
return idxs
# -----------------------------------------------------------------------------
# Model loading
# -----------------------------------------------------------------------------
@torch.inference_mode()
def load_detector() -> Tuple[Any, str]:
from ultralytics import YOLO # lazy import
try:
det_path = hf_hub_download(repo_id=DET_REPO_ID, filename=DET_FILENAME)
try:
return YOLO(det_path, task="detect"), f"{DET_REPO_ID}/{DET_FILENAME}"
except TypeError:
return YOLO(det_path), f"{DET_REPO_ID}/{DET_FILENAME}"
except Exception:
pt_path = hf_hub_download(repo_id=DET_REPO_ID, filename=DET_FALLBACK_PT)
try:
return YOLO(pt_path, task="detect"), f"{DET_REPO_ID}/{DET_FALLBACK_PT} (PT fallback)"
except TypeError:
return YOLO(pt_path), f"{DET_REPO_ID}/{DET_FALLBACK_PT} (PT fallback)"
@torch.inference_mode()
def load_classifier() -> Tuple[Any, List[str], str]:
import timm # lazy import
device = get_device()
classes_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_CLASSES)
raw: List[Any] = []
try:
with open(classes_path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and "classes" in data:
raw = list(data["classes"])
elif isinstance(data, dict):
raw = list(data.values())
elif isinstance(data, list):
raw = data
except Exception:
with open(classes_path, "r", encoding="utf-8") as f:
raw = [ln.strip() for ln in f if ln.strip()]
letters: List[str] = []
for ln in raw:
val = str(ln).strip()
val = val.split("_")[-1] if "_" in val else val
val = val.strip('"\' ,')
if val:
letters.append(val)
weights_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_WEIGHTS)
ckpt = torch.load(weights_path, map_location="cpu")
state = ckpt.get("model", ckpt.get("state_dict", ckpt))
if not isinstance(state, dict):
raise RuntimeError(f"Bad checkpoint format: type(state)={type(state)}")
new_state: Dict[str, torch.Tensor] = {}
for k, v in state.items():
nk = k
for pref in ("module.", "model."):
if nk.startswith(pref):
nk = nk[len(pref) :]
new_state[nk] = v
if "head.fc.weight" not in new_state:
raise RuntimeError("Checkpoint is missing head.fc.weight; cannot infer number of classes.")
num_classes_ckpt = int(new_state["head.fc.weight"].shape[0])
if len(letters) != num_classes_ckpt:
log.warning(
"classes.json has %d labels but checkpoint expects %d classes; truncating/padding.",
len(letters),
num_classes_ckpt,
)
if len(letters) > num_classes_ckpt:
letters = letters[:num_classes_ckpt]
else:
letters = letters + [f"cls_{i:02d}" for i in range(len(letters), num_classes_ckpt)]
model = timm.create_model(CLS_MODEL_NAME, pretrained=False, num_classes=num_classes_ckpt)
model.load_state_dict(new_state, strict=False)
model.eval().to(device)
return model, letters, f"{CLS_REPO_ID}/{CLS_WEIGHTS} ({CLS_MODEL_NAME}, C={num_classes_ckpt})"
@torch.inference_mode()
def load_mt5(repo_id: str) -> Tuple[Any, Any]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
device = get_device()
if device == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
dtype = torch.float32
tok = AutoTokenizer.from_pretrained(repo_id, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(repo_id, torch_dtype=dtype).to(device).eval()
return tok, model
@torch.inference_mode()
def mt5_generate(tok, model, text: str, max_new_tokens: int = 128) -> str:
device = get_device()
inp = tok([text], return_tensors="pt", truncation=True, max_length=2048).to(device)
out = model.generate(**inp, max_new_tokens=max_new_tokens, num_beams=4, do_sample=False)
return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
# -----------------------------------------------------------------------------
# Hebrew -> English translators (optional, only used in he_then_en mode)
# -----------------------------------------------------------------------------
@torch.inference_mode()
def get_he2en_translator(kind: str) -> Tuple[Any, Any, str]:
global _TRANSLATOR_CACHE
if kind in _TRANSLATOR_CACHE:
return _TRANSLATOR_CACHE[kind]
_TRANSLATOR_CACHE.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
device = get_device()
if kind == "opus":
from transformers import MarianMTModel, MarianTokenizer
repo = "Helsinki-NLP/opus-mt-tc-big-he-en"
tok = MarianTokenizer.from_pretrained(repo)
model = MarianMTModel.from_pretrained(repo).to(device).eval()
elif kind == "nllb":
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
repo = "facebook/nllb-200-distilled-600M"
tok = AutoTokenizer.from_pretrained(repo, src_lang="heb_Hebr")
model = AutoModelForSeq2SeqLM.from_pretrained(repo).to(device).eval()
elif kind == "m2m100":
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
repo = "facebook/m2m100_418M"
tok = M2M100Tokenizer.from_pretrained(repo)
tok.src_lang = "he"
model = M2M100ForConditionalGeneration.from_pretrained(repo).to(device).eval()
else:
_TRANSLATOR_CACHE[kind] = (None, None, "none")
return _TRANSLATOR_CACHE[kind]
_TRANSLATOR_CACHE[kind] = (tok, model, kind)
return _TRANSLATOR_CACHE[kind]
@torch.inference_mode()
def translate_he_to_en(text_he: str, kind: str) -> str:
if kind == "none":
return ""
tok, model, kind = get_he2en_translator(kind)
device = get_device()
if kind == "opus":
batch = tok([text_he], return_tensors="pt", padding=True, truncation=True).to(device)
out = model.generate(**batch, max_new_tokens=128, num_beams=4)
return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
if kind == "nllb":
batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device)
out = model.generate(
**batch,
forced_bos_token_id=tok.convert_tokens_to_ids("eng_Latn"),
max_new_tokens=128,
num_beams=4,
)
return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
if kind == "m2m100":
batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device)
out = model.generate(
**batch,
forced_bos_token_id=tok.get_lang_id("en"),
max_new_tokens=128,
num_beams=4,
)
return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
return ""
# -----------------------------------------------------------------------------
# Core pipeline
# -----------------------------------------------------------------------------
@dataclass
class Loaded:
det: Any
det_name: str
cls: Any
cls_letters: List[str]
cls_name: str
mt5_he_tok: Any
mt5_he: Any
mt5_en_tok: Any
mt5_en: Any
LOADED: Optional[Loaded] = None
def ensure_loaded() -> Loaded:
global LOADED
if LOADED is not None:
return LOADED
det, det_name = load_detector()
cls, letters, cls_name = load_classifier()
he_tok, he_model = load_mt5(MT5_BOX2HE_REPO)
en_tok, en_model = load_mt5(MT5_BOX2EN_REPO)
LOADED = Loaded(
det=det,
det_name=det_name,
cls=cls,
cls_letters=letters,
cls_name=cls_name,
mt5_he_tok=he_tok,
mt5_he=he_model,
mt5_en_tok=en_tok,
mt5_en=en_model,
)
log.info(
"Loaded models: detector=%s classifier=%s box2he=%s box2en=%s",
det_name,
cls_name,
MT5_BOX2HE_REPO,
MT5_BOX2EN_REPO,
)
return LOADED
@torch.inference_mode()
def yolo_predict(det_model, pil: Image.Image, conf: float, iou: float, max_det: int) -> Tuple[List[List[float]], List[float]]:
with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as f:
pil.save(f.name)
res = det_model.predict(
source=f.name,
conf=float(conf),
iou=float(iou),
max_det=int(max_det),
verbose=False,
)
r0 = res[0]
if r0.boxes is None or len(r0.boxes) == 0:
return [], []
xyxy = r0.boxes.xyxy.detach().float().cpu().numpy().tolist()
scores = r0.boxes.conf.detach().float().cpu().numpy().tolist()
return xyxy, scores
@torch.inference_mode()
def classify_crops(
cls_model: Any,
letters: List[str],
pil: Image.Image,
boxes_xyxy: List[List[float]],
pad: float,
topk: int,
) -> Tuple[List[List[Tuple[str, float]]], List[str], List[Tuple[Image.Image, str]]]:
device = get_device()
W, H = pil.size
crops: List[Image.Image] = []
for (x1, y1, x2, y2) in boxes_xyxy:
w = max(1.0, x2 - x1)
h = max(1.0, y2 - y1)
p = int(round(max(w, h) * float(pad)))
a = max(0, safe_int(x1) - p)
b = max(0, safe_int(y1) - p)
c = min(W, safe_int(x2) + p)
d = min(H, safe_int(y2) + p)
crop = pil.crop((a, b, c, d)).resize((96, 96), Image.BICUBIC)
crops.append(crop)
if not crops:
return [], [], []
def preprocess(img: Image.Image) -> torch.Tensor:
arr = np.asarray(img.convert("RGB"), dtype=np.float32) / 255.0
t = torch.from_numpy(arr).permute(2, 0, 1).contiguous()
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
return (t - mean) / std
batch = torch.stack([preprocess(c) for c in crops], dim=0).to(device)
logits = cls_model(batch)
probs = torch.softmax(logits, dim=-1).detach().float().cpu().numpy()
topk_list: List[List[Tuple[str, float]]] = []
top1_letters: List[str] = []
for i in range(len(crops)):
p = probs[i]
idx = np.argsort(-p)[: max(1, int(topk))]
row: List[Tuple[str, float]] = []
for j in idx:
lab = normalize_hebrew_letter(letters[j])
row.append((lab, float(p[j])))
topk_list.append(row)
top1_letters.append(row[0][0] if row else "?")
crop_gallery = [(crops[i], f"{i+1:02d}: {top1_letters[i]}") for i in range(len(crops))]
return topk_list, top1_letters, crop_gallery
# -----------------------------------------------------------------------------
# mT5 source serialization
# -----------------------------------------------------------------------------
def _compute_box_normalizer_from_boxes(boxes: List[List[float]]) -> Tuple[float, float, float, float]:
xs = [b[0] for b in boxes] + [b[2] for b in boxes]
ys = [b[1] for b in boxes] + [b[3] for b in boxes]
minx = float(min(xs)) if xs else 0.0
maxx = float(max(xs)) if xs else 1.0
miny = float(min(ys)) if ys else 0.0
maxy = float(max(ys)) if ys else 1.0
scalex = max(1e-6, maxx - minx)
scaley = max(1e-6, maxy - miny)
return minx, miny, scalex, scaley
def build_source_text_mt5(
*,
pil: Image.Image,
boxes_xyxy: List[List[float]],
det_scores: List[float],
topk_list: List[List[Tuple[str, float]]],
coord_norm: str,
order_mode: str,
rtl: bool,
) -> str:
"""Serialize boxes+cls into the format mT5 was trained on.
Header:
[BOXES_CLS_V3]
n=.. coord_norm=boxes|det|none
Body:
01 x1=... y1=... x2=... y2=... w=... h=... xc=... yc=... score=... | cls א:0.123 ...
"""
coord_norm = (coord_norm or "boxes").strip().lower()
if coord_norm not in ("boxes", "det", "none"):
coord_norm = "boxes"
order_mode = (order_mode or "reading").strip().lower()
if order_mode not in ("reading", "detector"):
order_mode = "reading"
if not boxes_xyxy:
return "[BOXES_CLS_V3]\n" + "n=0 coord_norm=none"
if order_mode == "reading":
idxs = sort_boxes_reading_order(boxes_xyxy, rtl=rtl)
else:
idxs = list(range(len(boxes_xyxy)))
W, H = pil.size
if coord_norm == "boxes":
minx, miny, scalex, scaley = _compute_box_normalizer_from_boxes(boxes_xyxy)
elif coord_norm == "det":
minx, miny, scalex, scaley = 0.0, 0.0, float(W), float(H)
else:
minx, miny, scalex, scaley = 0.0, 0.0, 1.0, 1.0
lines = ["[BOXES_CLS_V3]", f"n={len(idxs)} coord_norm={coord_norm}"]
for j, i in enumerate(idxs, start=1):
x1, y1, x2, y2 = boxes_xyxy[i]
sc = det_scores[i] if i < len(det_scores) else 1.0
if coord_norm != "none":
x1 = (x1 - minx) / scalex
x2 = (x2 - minx) / scalex
y1 = (y1 - miny) / scaley
y2 = (y2 - miny) / scaley
xa, xb = (x1, x2) if x1 <= x2 else (x2, x1)
ya, yb = (y1, y2) if y1 <= y2 else (y2, y1)
w = max(1e-6, xb - xa)
h = max(1e-6, yb - ya)
xc = 0.5 * (xa + xb)
yc = 0.5 * (ya + yb)
cands = topk_list[i] if i < len(topk_list) else []
cls_str = " ".join([f"{ch}:{p:.3f}" for ch, p in cands]) if cands else "?"
lines.append(
f"{j:02d} x1={xa:.4f} y1={ya:.4f} x2={xb:.4f} y2={yb:.4f} "
f"w={w:.4f} h={h:.4f} xc={xc:.4f} yc={yc:.4f} score={sc:.3f} | cls {cls_str}"
)
return "\n".join(lines)
# -----------------------------------------------------------------------------
# Plotly overlay (hover tooltips on bboxes)
# -----------------------------------------------------------------------------
def _pil_to_data_uri(pil: Image.Image, fmt: str = "PNG") -> str:
buf = io.BytesIO()
pil.save(buf, format=fmt)
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
return f"data:image/{fmt.lower()};base64,{b64}"
def _hex_to_rgb(hex_color: str) -> Tuple[int, int, int]:
s = hex_color.lstrip("#")
return int(s[0:2], 16), int(s[2:4], 16), int(s[4:6], 16)
def make_bbox_figure(
pil: Image.Image,
boxes_xyxy: List[List[float]],
labels: Optional[List[str]] = None,
det_scores: Optional[List[float]] = None,
topk_list: Optional[List[List[Tuple[str, float]]]] = None,
height: int = IMG_HEIGHT,
) -> go.Figure:
W, H = pil.size
fig = go.Figure()
fig.add_layout_image(
dict(
source=_pil_to_data_uri(pil),
xref="x",
yref="y",
x=0,
y=0,
sizex=W,
sizey=H,
sizing="stretch",
layer="below",
)
)
palette = qualitative.Plotly
for i, (x1, y1, x2, y2) in enumerate(boxes_xyxy):
color_hex = palette[i % len(palette)]
r, g, b = _hex_to_rgb(color_hex)
lab = labels[i] if labels and i < len(labels) else f"{i+1:02d}"
sc = det_scores[i] if det_scores and i < len(det_scores) else None
topk_str = ""
if topk_list and i < len(topk_list) and topk_list[i]:
topk_str = ", ".join([f"{c}:{p:.3f}" for c, p in topk_list[i]])
hover_lines = [f"#{i+1:02d}", f"top1: {lab}"]
if sc is not None:
hover_lines.append(f"det_conf: {sc:.3f}")
if topk_str:
hover_lines.append(f"topk: {topk_str}")
xs = [x1, x2, x2, x1, x1]
ys = [y1, y1, y2, y2, y1]
fig.add_trace(
go.Scatter(
x=xs,
y=ys,
mode="lines",
fill="toself",
line=dict(width=3, color=f"rgb({r},{g},{b})"),
fillcolor=f"rgba({r},{g},{b},0.25)",
hoverinfo="text",
text="<br>".join(hover_lines),
showlegend=False,
)
)
fig.update_xaxes(visible=False, range=[0, W], constrain="domain")
fig.update_yaxes(visible=False, range=[H, 0], scaleanchor="x")
fig.update_layout(height=height, margin=dict(l=0, r=0, t=30, b=0), hovermode="closest")
return fig
# -----------------------------------------------------------------------------
# UI callbacks
# -----------------------------------------------------------------------------
def run_pipeline_tab(
pil: Optional[Image.Image],
conf: float,
iou: float,
max_det: int,
crop_pad: float,
topk_k: int,
rtl: bool,
output_mode: str,
he2en_kind: str,
coord_norm: str,
order_mode: str,
) -> Tuple[
Optional[Any],
List[Tuple[Image.Image, str]],
str,
str,
str,
str,
List[List[Tuple[str, float]]],
]:
try:
if pil is None:
return None, [], "", "", json.dumps({"error": "No input image"}, ensure_ascii=False, indent=2), "", []
pil = pil.convert("RGB")
L = ensure_loaded()
boxes, det_scores = yolo_predict(L.det, pil, conf, iou, max_det)
if not boxes:
fig = make_bbox_figure(pil, [], height=IMG_HEIGHT)
dbg = {"ts": now_ts(), "detector": L.det_name, "n_boxes": 0}
return fig, [], "No boxes detected.", "No boxes detected.", json.dumps(dbg, ensure_ascii=False, indent=2), "", []
topk_list, top1_letters, crop_gallery = classify_crops(
L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k
)
src = build_source_text_mt5(
pil=pil,
boxes_xyxy=boxes,
det_scores=det_scores,
topk_list=topk_list,
coord_norm=coord_norm,
order_mode=order_mode,
rtl=rtl,
)
heb = ""
eng = ""
if output_mode == "he":
heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src)
elif output_mode == "en_direct":
eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src)
else:
heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src)
if he2en_kind != "none":
eng = translate_he_to_en(heb, he2en_kind)
else:
eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src)
fig = make_bbox_figure(
pil,
boxes,
labels=top1_letters,
det_scores=det_scores,
topk_list=topk_list,
height=IMG_HEIGHT,
)
dbg = {
"ts": now_ts(),
"device": get_device(),
"detector": L.det_name,
"classifier": L.cls_name,
"n_boxes": len(boxes),
"coord_norm": coord_norm,
"order_mode": order_mode,
"rtl": bool(rtl),
"output_mode": output_mode,
"he2en_kind": he2en_kind,
}
return fig, crop_gallery, heb, eng, json.dumps(dbg, ensure_ascii=False, indent=2), src, topk_list
except Exception:
log.exception("run_pipeline_tab failed")
return None, [], "ERROR", "ERROR", traceback.format_exc(), "", []
def run_detector_tab(
pil: Optional[Image.Image], conf: float, iou: float, max_det: int
) -> Tuple[Optional[Any], str]:
try:
if pil is None:
return None, json.dumps({"error": "No input image"}, ensure_ascii=False, indent=2)
pil = pil.convert("RGB")
L = ensure_loaded()
boxes, scores = yolo_predict(L.det, pil, conf, iou, max_det)
fig = make_bbox_figure(pil, boxes, det_scores=scores, height=IMG_HEIGHT)
dbg = {
"ts": now_ts(),
"detector": L.det_name,
"n_boxes": len(boxes),
"boxes_xyxy": boxes,
"scores": scores,
}
return fig, json.dumps(dbg, ensure_ascii=False, indent=2)
except Exception:
log.exception("run_detector_tab failed")
return None, traceback.format_exc()
def run_classifier_tab(
pil: Optional[Image.Image],
conf: float,
iou: float,
max_det: int,
crop_pad: float,
topk_k: int,
) -> Tuple[List[Tuple[Image.Image, str]], str]:
try:
if pil is None:
return [], json.dumps({"error": "No input image"}, ensure_ascii=False, indent=2)
pil = pil.convert("RGB")
L = ensure_loaded()
boxes, det_scores = yolo_predict(L.det, pil, conf, iou, max_det)
if not boxes:
return [], json.dumps({"ts": now_ts(), "n_boxes": 0, "msg": "No boxes detected"}, ensure_ascii=False, indent=2)
topk_list, top1_letters, crop_gallery = classify_crops(
L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k
)
details: Dict[str, Any] = {}
for i, (row, box, sc) in enumerate(zip(topk_list, boxes, det_scores)):
details[f"Box_{i+1:02d}"] = {"xyxy": box, "det_score": float(sc), "topk": row}
dbg = {"ts": now_ts(), "classifier": L.cls_name, "n_boxes": len(boxes), "details": details}
return crop_gallery, json.dumps(dbg, ensure_ascii=False, indent=2)
except Exception:
log.exception("run_classifier_tab failed")
return [], traceback.format_exc()
def show_topk_popup(topk_list: List[List[Tuple[str, float]]], evt: gr.SelectData) -> str:
idx = getattr(evt, "index", None)
if idx is None or not topk_list or idx < 0 or idx >= len(topk_list):
return ""
rows = topk_list[idx]
lines = "\n".join([f"{lab:>2} {p:.4f}" for lab, p in rows])
return f"""
<div class="modal-backdrop" onclick="this.parentElement.innerHTML = ''">
<div class="modal-card">
<div class="modal-title">Top‑K for crop #{idx + 1}</div>
<pre class="modal-pre">{lines}</pre>
<div class="modal-hint">Click anywhere outside to close</div>
</div>
</div>
"""
def save_feedback(
heb_pred: str,
eng_pred: str,
heb_corr: str,
eng_corr: str,
notes: str,
) -> str:
try:
rec = {
"ts": now_ts(),
"heb_pred": heb_pred,
"eng_pred": eng_pred,
"heb_corr": heb_corr,
"eng_corr": eng_corr,
"notes": notes,
}
os.makedirs(os.path.dirname(FEEDBACK_PATH), exist_ok=True)
with open(FEEDBACK_PATH, "a", encoding="utf-8") as f:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
return f"Saved to {FEEDBACK_PATH}"
except Exception:
log.exception("save_feedback failed")
return traceback.format_exc()
# -----------------------------------------------------------------------------
# Examples helper
# -----------------------------------------------------------------------------
if EXAMPLES_DIR.exists():
gr.set_static_paths([str(EXAMPLES_DIR)])
def list_example_images(max_n: int = 24) -> List[List[str]]:
if not EXAMPLES_DIR.exists():
return []
paths = [p for p in sorted(EXAMPLES_DIR.iterdir()) if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}]
return [[str(p)] for p in paths[:max_n]]
# -----------------------------------------------------------------------------
# UI
# -----------------------------------------------------------------------------
CSS = """
#topk_modal .modal-backdrop {
position: fixed;
inset: 0;
background: rgba(0, 0, 0, 0.35);
display: flex;
align-items: center;
justify-content: center;
z-index: 9999;
}
#topk_modal .modal-card {
background: white;
border-radius: 14px;
padding: 18px 22px;
min-width: 280px;
max-width: 520px;
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.25);
}
#topk_modal .modal-title {
font-size: 18px;
font-weight: 700;
margin-bottom: 10px;
}
#topk_modal .modal-pre {
font-size: 16px;
line-height: 1.25;
margin: 0;
white-space: pre;
}
#topk_modal .modal-hint {
opacity: 0.6;
margin-top: 10px;
}
"""
with gr.Blocks(title="Paleo‑Hebrew Tablet Reader", css=CSS) as demo:
gr.Markdown("# Paleo‑Hebrew Epigraphy Pipeline")
topk_state = gr.State([])
modal = gr.HTML(value="", elem_id="topk_modal")
with gr.Row():
with gr.Column(scale=3):
inp = gr.Image(type="pil", label="Input Image", height=IMG_HEIGHT)
with gr.Accordion("Vision Settings", open=True):
conf = gr.Slider(0.05, 0.95, value=0.25, step=0.01, label="Box Confidence")
iou = gr.Slider(0.10, 0.90, value=0.45, step=0.01, label="IoU")
max_det = gr.Slider(1, 200, value=80, step=1, label="Max Detections")
crop_pad = gr.Slider(0.0, 0.8, value=0.20, step=0.01, label="Crop Pad")
topk_k = gr.Slider(1, 10, value=5, step=1, label="Classifier Top‑K")
with gr.Accordion("mT5 + Translation Settings", open=True):
rtl = gr.Checkbox(value=True, label="RTL order (right to left) [for reading-order mode]")
order_mode = gr.Radio(["reading", "detector"], value="reading", label="Box order")
coord_norm = gr.Radio(["boxes", "det", "none"], value="boxes", label="coord_norm (like training)")
output_mode = gr.Radio(["he", "en_direct", "he_then_en"], value="he_then_en", label="Output mode")
he2en_kind = gr.Dropdown(choices=TRANSLATOR_CHOICES, value="opus", label="Hebrew to English translator")
with gr.Accordion("How to run a VLM yourself (optional)", open=False):
gr.Markdown(
"""
If you want a VLM post‑OCR step (outside this Space), you can run it locally / on a GPU box.
High-level steps:
1) Build a `tools_text` block from YOLO+classifier: `[DETECTOR ...]` + `[CLASSIFIER topk]`.
2) Feed image + prompt + tools_text into your Qwen3‑VL LoRA model.
Skeleton (pseudo-code):
```python
from transformers import AutoProcessor, AutoModelForImageTextToText
proc = AutoProcessor.from_pretrained("mr3vial/paleo-hebrew-qwen3-vl-lora-post-ocr-processing")
model = AutoModelForImageTextToText.from_pretrained("mr3vial/paleo-hebrew-qwen3-vl-lora-post-ocr-processing")
prompt = "Transcribe the text on the tablet.\n\n" + tools_text
inputs = proc(text=prompt, images=[pil_image], return_tensors="pt")
out = model.generate(**inputs, max_new_tokens=64)
pred = proc.batch_decode(out, skip_special_tokens=True)[0]
```
"""
)
ex = list_example_images()
if ex:
gr.Markdown("### Examples")
gr.Examples(examples=ex, inputs=[inp], cache_examples=False)
with gr.Column(scale=7):
with gr.Tabs():
with gr.Tab("End-to-End Pipeline"):
run_pipe_btn = gr.Button("Run Full Pipeline", variant="primary")
out_fig_pipe = gr.Plot(label="Detections (hover tooltip)")
with gr.Row():
out_he_pipe = gr.Textbox(label="Hebrew", lines=2, interactive=False)
out_en_pipe = gr.Textbox(label="English", lines=2, interactive=False)
out_crops_pipe = gr.Gallery(columns=8, height=360, label="Letter crops")
with gr.Accordion("Debug", open=False):
out_dbg_pipe = gr.Code(label="Debug JSON", language="json")
out_mt5_src = gr.Textbox(label="mT5 source text (fed into mT5)", lines=10, interactive=False)
with gr.Tab("Detector"):
run_det_btn = gr.Button("Run Detector", variant="primary")
out_fig_det = gr.Plot(label="Detected Bounding Boxes")
out_json_det = gr.Code(label="Raw Detection Output", language="json")
with gr.Tab("Classifier"):
run_cls_btn = gr.Button("Run Classifier", variant="primary")
out_crops_cls = gr.Gallery(columns=8, height=360, label="Isolated Letter Crops")
out_json_cls = gr.Code(label="Top‑K per Box", language="json")
with gr.Tab("Feedback"):
gr.Markdown("Submit corrections to improve the dataset.")
heb_corr = gr.Textbox(label="Correct Hebrew", lines=2)
eng_corr = gr.Textbox(label="Correct English", lines=2)
notes = gr.Textbox(label="Notes", lines=2)
save_btn = gr.Button("Submit Feedback")
save_status = gr.Textbox(label="Status", interactive=False)
# Wiring
run_pipe_btn.click(
fn=run_pipeline_tab,
inputs=[inp, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind, coord_norm, order_mode],
outputs=[out_fig_pipe, out_crops_pipe, out_he_pipe, out_en_pipe, out_dbg_pipe, out_mt5_src, topk_state],
api_name=False,
)
out_crops_pipe.select(fn=show_topk_popup, inputs=[topk_state], outputs=[modal])
run_det_btn.click(
fn=run_detector_tab,
inputs=[inp, conf, iou, max_det],
outputs=[out_fig_det, out_json_det],
api_name=False,
)
run_cls_btn.click(
fn=run_classifier_tab,
inputs=[inp, conf, iou, max_det, crop_pad, topk_k],
outputs=[out_crops_cls, out_json_cls],
api_name=False,
)
save_btn.click(
fn=save_feedback,
inputs=[out_he_pipe, out_en_pipe, heb_corr, eng_corr, notes],
outputs=[save_status],
api_name=False,
)
demo.queue(max_size=32).launch(show_error=True)