Niiaz Bureev
added draft of demo
8816141
raw
history blame
20.3 kB
import os
import gc
import json
import time
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from PIL import Image, ImageDraw
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# =============================================================================
# Configuration via env vars (edit in Space Settings -> Variables if needed)
# =============================================================================
DET_REPO_ID = os.getenv("DET_REPO_ID", "mr3vial/paleo-hebrew-yolo")
DET_FILENAME = os.getenv("DET_FILENAME", "best.onnx") # recommended
DET_FALLBACK_PT = os.getenv("DET_FALLBACK_PT", "best.pt")
CLS_REPO_ID = os.getenv("CLS_REPO_ID", "mr3vial/paleo-hebrew-convnext")
CLS_WEIGHTS = os.getenv("CLS_WEIGHTS", "model.safetensors")
CLS_CLASSES = os.getenv("CLS_CLASSES", "classes.txt")
CLS_MODEL_NAME = os.getenv("CLS_MODEL_NAME", "convnext_large")
MT5_BOX2HE_REPO = os.getenv("MT5_BOX2HE_REPO", "mr3vial/paleo-hebrew-mt5-post-ocr-processing")
MT5_BOX2EN_REPO = os.getenv("MT5_BOX2EN_REPO", "mr3vial/paleo-hebrew-mt5-translate")
FEEDBACK_PATH = os.getenv("FEEDBACK_PATH", "/tmp/feedback.jsonl")
EXAMPLES_DIR = Path(__file__).parent / "examples"
# =============================================================================
# Helpers
# =============================================================================
def now_ts() -> str:
return time.strftime("%Y-%m-%d %H:%M:%S")
def normalize_spaces(s: str) -> str:
return " ".join((s or "").strip().split())
def get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
def safe_int(x: float, default: int = 0) -> int:
try:
return int(round(float(x)))
except Exception:
return default
def sort_boxes_reading_order(boxes_xyxy: List[List[float]], rtl: bool = True) -> List[int]:
"""
Rough reading order:
- sort by y-center, then within a line by x-center (rtl: desc, ltr: asc).
"""
if not boxes_xyxy:
return []
centers = []
heights = []
for i, (x1, y1, x2, y2) in enumerate(boxes_xyxy):
xc = 0.5 * (x1 + x2)
yc = 0.5 * (y1 + y2)
centers.append((i, xc, yc))
heights.append(max(1.0, y2 - y1))
line_tol = float(np.median(heights)) * 0.6
centers_sorted = sorted(centers, key=lambda t: t[2])
lines: List[List[Tuple[int, float, float]]] = []
for item in centers_sorted:
if not lines:
lines.append([item])
continue
_, _, yc = item
prev_line = lines[-1]
prev_y = float(np.mean([p[2] for p in prev_line]))
if abs(yc - prev_y) <= line_tol:
prev_line.append(item)
else:
lines.append([item])
idxs: List[int] = []
for line in lines:
line_sorted = sorted(line, key=lambda t: t[1], reverse=rtl)
idxs.extend([i for i, _, _ in line_sorted])
return idxs
# =============================================================================
# Load models (Detector, Classifier, mT5)
# =============================================================================
@torch.no_grad()
def load_detector():
from ultralytics import YOLO
try:
det_path = hf_hub_download(repo_id=DET_REPO_ID, filename=DET_FILENAME)
return YOLO(det_path), f"{DET_REPO_ID}/{DET_FILENAME}"
except Exception:
pt_path = hf_hub_download(repo_id=DET_REPO_ID, filename=DET_FALLBACK_PT)
return YOLO(pt_path), f"{DET_REPO_ID}/{DET_FALLBACK_PT} (PT fallback)"
@torch.no_grad()
def load_classifier():
import timm
from safetensors.torch import load_file as safetensors_load
device = get_device()
classes_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_CLASSES)
with open(classes_path, "r", encoding="utf-8") as f:
raw = [ln.strip() for ln in f if ln.strip()]
letters: List[str] = []
for ln in raw:
if "_" in ln:
letters.append(ln.split("_")[-1])
else:
letters.append(ln)
weights_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_WEIGHTS)
model = timm.create_model(CLS_MODEL_NAME, pretrained=False, num_classes=len(letters))
state = safetensors_load(weights_path)
new_state = {}
for k, v in state.items():
nk = k
for pref in ("module.", "model."):
if nk.startswith(pref):
nk = nk[len(pref):]
new_state[nk] = v
model.load_state_dict(new_state, strict=False)
model.eval().to(device)
return model, letters, f"{CLS_REPO_ID}/{CLS_WEIGHTS}"
def preprocess_crop_imagenet(pil: Image.Image) -> torch.Tensor:
arr = np.asarray(pil.convert("RGB"), dtype=np.float32) / 255.0
t = torch.from_numpy(arr).permute(2, 0, 1).contiguous()
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
return (t - mean) / std
@torch.no_grad()
def load_mt5(repo_id: str):
device = get_device()
# Memory optimization: Load in bfloat16 to fit multiple models in Space RAM
dtype = torch.bfloat16 if (device == "cpu" or torch.cuda.is_bf16_supported()) else torch.float16
tok = AutoTokenizer.from_pretrained(repo_id, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(
repo_id,
torch_dtype=dtype,
).to(device).eval()
return tok, model
@torch.no_grad()
def mt5_generate(tok, model, text: str, max_new_tokens: int = 128) -> str:
device = get_device()
inp = tok([text], return_tensors="pt", truncation=True, max_length=2048).to(device)
out = model.generate(**inp, max_new_tokens=max_new_tokens, num_beams=4, do_sample=False)
return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
# =============================================================================
# Optional Hebrew->English translators (user-selectable)
# =============================================================================
TRANSLATOR_CHOICES = [
("None (use direct box→en)", "none"),
("OPUS-MT he→en (Helsinki-NLP/opus-mt-tc-big-he-en)", "opus"),
("NLLB-200 distilled 600M (he→en) [CC-BY-NC]", "nllb"),
("M2M100 418M (he→en)", "m2m100"),
]
_TRANSLATOR_CACHE: Dict[str, Tuple[Any, Any, str]] = {} # kind -> (tok, model, kind)
@torch.no_grad()
def get_he2en_translator(kind: str):
global _TRANSLATOR_CACHE
if kind in _TRANSLATOR_CACHE:
return _TRANSLATOR_CACHE[kind]
# Memory Safety: Clear old translators from memory before loading a new one
_TRANSLATOR_CACHE.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
device = get_device()
if kind == "opus":
from transformers import MarianMTModel, MarianTokenizer
repo = "Helsinki-NLP/opus-mt-tc-big-he-en"
tok = MarianTokenizer.from_pretrained(repo)
model = MarianMTModel.from_pretrained(repo).to(device).eval()
elif kind == "nllb":
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
repo = "facebook/nllb-200-distilled-600M"
tok = AutoTokenizer.from_pretrained(repo, src_lang="heb_Hebr")
model = AutoModelForSeq2SeqLM.from_pretrained(repo).to(device).eval()
elif kind == "m2m100":
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
repo = "facebook/m2m100_418M"
tok = M2M100Tokenizer.from_pretrained(repo)
tok.src_lang = "he"
model = M2M100ForConditionalGeneration.from_pretrained(repo).to(device).eval()
else:
_TRANSLATOR_CACHE[kind] = (None, None, "none")
return _TRANSLATOR_CACHE[kind]
_TRANSLATOR_CACHE[kind] = (tok, model, kind)
return _TRANSLATOR_CACHE[kind]
@torch.no_grad()
def translate_he_to_en(text_he: str, kind: str) -> str:
if kind == "none":
return ""
tok, model, kind = get_he2en_translator(kind)
device = get_device()
if kind == "opus":
batch = tok([text_he], return_tensors="pt", padding=True, truncation=True).to(device)
out = model.generate(**batch, max_new_tokens=128, num_beams=4)
return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
if kind == "nllb":
batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device)
out = model.generate(**batch, forced_bos_token_id=tok.lang_code_to_id["eng_Latn"], max_new_tokens=128, num_beams=4)
return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
if kind == "m2m100":
batch = tok([text_he], return_tensors="pt", truncation=True, max_length=512).to(device)
out = model.generate(**batch, forced_bos_token_id=tok.get_lang_id("en"), max_new_tokens=128, num_beams=4)
return normalize_spaces(tok.batch_decode(out, skip_special_tokens=True)[0])
return ""
# =============================================================================
# Pipeline Components
# =============================================================================
@dataclass
class Loaded:
det: Any
det_name: str
cls: Any
cls_letters: List[str]
cls_name: str
mt5_he_tok: Any
mt5_he: Any
mt5_en_tok: Any
mt5_en: Any
LOADED: Optional[Loaded] = None
def ensure_loaded() -> Loaded:
global LOADED
if LOADED is not None:
return LOADED
det, det_name = load_detector()
cls, letters, cls_name = load_classifier()
he_tok, he_model = load_mt5(MT5_BOX2HE_REPO)
en_tok, en_model = load_mt5(MT5_BOX2EN_REPO)
LOADED = Loaded(
det=det, det_name=det_name,
cls=cls, cls_letters=letters, cls_name=cls_name,
mt5_he_tok=he_tok, mt5_he=he_model,
mt5_en_tok=en_tok, mt5_en=en_model,
)
return LOADED
def yolo_predict(det_model, pil: Image.Image, conf: float, iou: float, max_det: int):
# Direct inference, no tempfile needed
res = det_model.predict(
source=pil,
conf=float(conf),
iou=float(iou),
max_det=int(max_det),
verbose=False,
)
r0 = res[0]
if r0.boxes is None or len(r0.boxes) == 0:
return [], []
xyxy = r0.boxes.xyxy.detach().float().cpu().numpy().tolist()
scores = r0.boxes.conf.detach().float().cpu().numpy().tolist()
return xyxy, scores
@torch.no_grad()
def classify_crops(cls_model, letters: List[str], pil: Image.Image, boxes_xyxy, pad: float, topk: int):
device = get_device()
W, H = pil.size
crops = []
for (x1, y1, x2, y2) in boxes_xyxy:
w = max(1.0, x2 - x1)
h = max(1.0, y2 - y1)
p = int(round(max(w, h) * float(pad)))
a, b = max(0, safe_int(x1) - p), max(0, safe_int(y1) - p)
c, d = min(W, safe_int(x2) + p), min(H, safe_int(y2) + p)
crops.append(pil.crop((a, b, c, d)).resize((224, 224), Image.BICUBIC))
if not crops:
return [], [], []
batch = torch.stack([preprocess_crop_imagenet(c) for c in crops], dim=0).to(device)
logits = cls_model(batch)
probs = torch.softmax(logits, dim=-1).detach().float().cpu().numpy()
topk_list, top1_letters = [], []
for i in range(len(crops)):
p = probs[i]
idx = np.argsort(-p)[: max(1, int(topk))]
row = [(letters[j], float(p[j])) for j in idx]
topk_list.append(row)
top1_letters.append(row[0][0] if row else "?")
crop_gallery = [(crops[i], f"{i+1:02d}: {top1_letters[i]}") for i in range(len(crops))]
return topk_list, top1_letters, crop_gallery
def build_source_text(boxes_xyxy: List[List[float]], scores: List[float], topk: List[List[Tuple[str, float]]], rtl: bool) -> str:
idxs = sort_boxes_reading_order(boxes_xyxy, rtl=rtl)
lines = ["[BOXES_CLS_V3]", f"n={len(idxs)} coord_norm=none"]
for j, i in enumerate(idxs, start=1):
x1, y1, x2, y2 = boxes_xyxy[i]
sc = scores[i] if i < len(scores) else 1.0
cands = topk[i] if i < len(topk) else []
cls_str = " ".join([f"{ch}:{p:.3f}" for ch, p in cands]) if cands else "?"
lines.append(f"{j:02d} x1={x1:.1f} y1={y1:.1f} x2={x2:.1f} y2={y2:.1f} score={sc:.3f} | cls {cls_str}")
return "\n".join(lines)
def make_annotated(pil: Image.Image, boxes_xyxy, top1_letters, scores):
img = pil.copy()
draw = ImageDraw.Draw(img)
for i, bb in enumerate(boxes_xyxy):
x1, y1, x2, y2 = [safe_int(v) for v in bb]
lab = top1_letters[i] if i < len(top1_letters) else "?"
sc = scores[i] if i < len(scores) else 0.0
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
text_label = f"{i+1:02d}" + (f" {lab}" if lab != "?" else "")
draw.text((x1, y1), text_label, fill="red")
return img
# =============================================================================
# Tab Action Handlers
# =============================================================================
def run_pipeline_tab(pil, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind):
if pil is None: return None, [], "", ""
L = ensure_loaded()
boxes, det_scores = yolo_predict(L.det, pil, conf, iou, max_det)
if not boxes: return pil, [], "No boxes detected.", "No boxes detected."
topk_list, top1_letters, crop_gallery = classify_crops(L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k)
src = build_source_text(boxes, det_scores, topk_list, rtl=rtl)
heb = ""
eng = ""
if output_mode == "he":
heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src)
elif output_mode == "en_direct":
eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src)
else: # he_then_en
heb = mt5_generate(L.mt5_he_tok, L.mt5_he, src)
if he2en_kind != "none":
eng = translate_he_to_en(heb, he2en_kind)
else:
eng = mt5_generate(L.mt5_en_tok, L.mt5_en, src)
annotated = make_annotated(pil, boxes, top1_letters, det_scores)
return annotated, crop_gallery, heb, eng
def run_detector_tab(pil, conf, iou, max_det):
if pil is None: return None, "{}"
L = ensure_loaded()
boxes, scores = yolo_predict(L.det, pil, conf, iou, max_det)
annotated = make_annotated(pil, boxes, [], scores)
debug = {"num_boxes": len(boxes), "boxes_xyxy": boxes, "confidences": scores}
return annotated, json.dumps(debug, indent=2)
def run_classifier_tab(pil, conf, iou, max_det, crop_pad, topk_k):
if pil is None: return [], "{}"
L = ensure_loaded()
boxes, _ = yolo_predict(L.det, pil, conf, iou, max_det)
if not boxes: return [], "No boxes found to classify."
topk_list, top1_letters, crop_gallery = classify_crops(L.cls, L.cls_letters, pil, boxes, pad=crop_pad, topk=topk_k)
details = {}
for i, (row, box) in enumerate(zip(topk_list, boxes)):
details[f"Box_{i+1:02d}"] = {
"top_predictions": dict(row),
"coordinates": box
}
return crop_gallery, json.dumps(details, ensure_ascii=False, indent=2)
def save_feedback(heb_pred, eng_pred, heb_corr, eng_corr, notes):
rec = {
"ts": now_ts(),
"heb_pred": heb_pred, "eng_pred": eng_pred,
"heb_corr": heb_corr, "eng_corr": eng_corr,
"notes": notes,
}
with open(FEEDBACK_PATH, "a", encoding="utf-8") as f:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
return f"Saved to {FEEDBACK_PATH}"
# =============================================================================
# UI
# =============================================================================
if EXAMPLES_DIR.exists():
gr.set_static_paths([str(EXAMPLES_DIR)])
def list_example_images(max_n: int = 24):
if not EXAMPLES_DIR.exists(): return []
paths = [p for p in sorted(EXAMPLES_DIR.iterdir()) if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}]
return [[str(p)] for p in paths[:max_n]]
with gr.Blocks(title="Paleo-Hebrew Tablet Reader") as demo:
gr.Markdown("# Paleo-Hebrew Epigraphy Pipeline")
with gr.Row():
# LEFT COLUMN: Global Inputs & Settings
with gr.Column(scale=3):
inp = gr.Image(type="pil", label="Input Image")
with gr.Accordion("Vision Settings", open=True):
conf = gr.Slider(0.05, 0.95, value=0.25, step=0.01, label="YOLO conf")
iou = gr.Slider(0.10, 0.90, value=0.45, step=0.01, label="YOLO IoU")
max_det = gr.Slider(1, 200, value=80, step=1, label="Max Detections")
crop_pad = gr.Slider(0.0, 0.8, value=0.20, step=0.01, label="Crop Pad")
topk_k = gr.Slider(1, 10, value=5, step=1, label="Classifier Top-K")
with gr.Accordion("Translation Settings", open=True):
rtl = gr.Checkbox(value=True, label="RTL order (right→left)")
output_mode = gr.Radio(["he", "en_direct", "he_then_en"], value="he_then_en", label="Output mode")
he2en_kind = gr.Dropdown(choices=TRANSLATOR_CHOICES, value="opus", label="Hebrew → English translator")
ex = list_example_images()
if ex:
gr.Examples(examples=ex, inputs=[inp], cache_examples=False)
# RIGHT COLUMN: Tabs for different pipeline stages
with gr.Column(scale=7):
with gr.Tabs():
# TAB 1: Main E2E Pipeline
with gr.Tab("End-to-End Pipeline"):
run_pipe_btn = gr.Button("Run Full Pipeline", variant="primary")
out_annot_pipe = gr.Image(type="pil", label="Detections (bbox + top1)")
with gr.Row():
out_he_pipe = gr.Textbox(label="Hebrew Output", lines=3, interactive=False)
out_en_pipe = gr.Textbox(label="English Output", lines=3, interactive=False)
out_crops_pipe = gr.Gallery(label="Letter crops").style(grid=6, height="auto") if hasattr(gr.Gallery, 'style') else gr.Gallery(label="Letter crops")
# TAB 2: Detector Only
with gr.Tab("Detector (YOLO)"):
run_det_btn = gr.Button("Run Detector", variant="primary")
out_annot_det = gr.Image(type="pil", label="Detected Bounding Boxes")
out_json_det = gr.Code(label="Raw Detection Output", language="json")
# TAB 3: Classifier Only
with gr.Tab("Classifier (ConvNeXt)"):
run_cls_btn = gr.Button("Run Classifier", variant="primary")
out_crops_cls = gr.Gallery(label="Isolated Letter Crops").style(grid=6, height="auto") if hasattr(gr.Gallery, 'style') else gr.Gallery(label="Isolated Letter Crops")
out_json_cls = gr.Code(label="Top-K Probabilities per Box", language="json")
# TAB 4: Feedback
with gr.Tab("Feedback"):
gr.Markdown("Submit corrections to improve the dataset.")
heb_corr = gr.Textbox(label="Correct Hebrew", lines=2)
eng_corr = gr.Textbox(label="Correct English", lines=2)
notes = gr.Textbox(label="Notes (e.g., missed boxes, ambiguous glyphs)", lines=2)
save_btn = gr.Button("Submit Feedback")
save_status = gr.Textbox(label="Status", interactive=False)
# --- Button Clicks ---
run_pipe_btn.click(
fn=run_pipeline_tab,
inputs=[inp, conf, iou, max_det, crop_pad, topk_k, rtl, output_mode, he2en_kind],
outputs=[out_annot_pipe, out_crops_pipe, out_he_pipe, out_en_pipe],
api_name=False
)
run_det_btn.click(
fn=run_detector_tab,
inputs=[inp, conf, iou, max_det],
outputs=[out_annot_det, out_json_det],
api_name=False
)
run_cls_btn.click(
fn=run_classifier_tab,
inputs=[inp, conf, iou, max_det, crop_pad, topk_k],
outputs=[out_crops_cls, out_json_cls],
api_name=False
)
save_btn.click(
fn=save_feedback,
inputs=[out_he_pipe, out_en_pipe, heb_corr, eng_corr, notes],
outputs=[save_status],
api_name=False
)
demo.queue(max_size=32).launch()