BuddyMath / ocr_strip_engine.py
dotandru's picture
V9.0.2: CRITICAL Fix - Data Anchor Failure via Robust OCR Payload Flattener
a162bd1
raw
history blame
12.8 kB
# ocr_strip_engine.py - V303.2 (Adaptive Pipeline & Two-Pass Sniper)
import os
import io
import json
import re
import logging
import asyncio
from pathlib import Path
import numpy as np
from PIL import Image
import cv2
from utils.safe_json import safe_extract_json # V1.0: Canonical JSON extractor
logger = logging.getLogger(__name__)
DEBUG_DIR = Path("/tmp/debug_ocr")
DEBUG_DIR.mkdir(parents=True, exist_ok=True)
# --- Constants for block detection ---
MIN_BLOCK_W = 50
MIN_BLOCK_H = 10
LARGE_BLOCK_H = 100
ROW_MERGE_GAP = 2
def _adaptive_preprocess(image_bytes: bytes) -> np.ndarray:
"""
V303.2: Adaptive Preprocessing Pipeline
Decides how aggressively to process based on image size.
"""
np_img_raw = np.frombuffer(image_bytes, np.uint8)
img_bgr = cv2.imdecode(np_img_raw, cv2.IMREAD_COLOR)
file_size_kb = len(image_bytes) / 1024
logger.info(f"📸 [OCR-ADAPTIVE] Input image size: {file_size_kb:.1f} KB")
if file_size_kb < 500:
# --- LOW-RES MODE (PC Screenshots / Snips) ---
logger.info("🔧 [OCR-ADAPTIVE] Low-Res mode triggered: Upscaling and applying heavy morphology.")
# 1. Upscale x2 to save thin pixels
img_bgr = cv2.resize(img_bgr, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# 2. Strong CLAHE
clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(8,8))
cl1 = clahe.apply(gray)
# 3. Morph Close (thicken lines)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
processed = cv2.morphologyEx(cl1, cv2.MORPH_CLOSE, kernel)
else:
# --- HIGH-RES MODE (Phone Camera in Production) ---
logger.info("📱 [OCR-ADAPTIVE] High-Res mode triggered: Mild enhancement only.")
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# V1.1: Pre-normalization (Contrast/Brightness Balance)
alpha = 1.2 # Contrast
beta = 10 # Brightness
gray = cv2.convertScaleAbs(gray, alpha=alpha, beta=beta)
# Mild CLAHE just to balance lighting, NO morphological distortion
clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8,8))
processed = clahe.apply(gray)
return cv2.cvtColor(processed, cv2.COLOR_GRAY2BGR)
def _find_raw_blocks(np_bgr: np.ndarray) -> list[tuple]:
gray = cv2.cvtColor(np_bgr, cv2.COLOR_BGR2GRAY)
blur = cv2.GaussianBlur(gray, (7, 7), 0)
thresh = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
# Kernel optimized for both modes
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (60, 3))
dilate = cv2.dilate(thresh, kernel, iterations=1)
contours, _ = cv2.findContours(dilate, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
blocks = []
for c in contours:
x, y, w, h = cv2.boundingRect(c)
if w >= MIN_BLOCK_W and h >= MIN_BLOCK_H:
blocks.append((x, y, w, h))
blocks.sort(key=lambda b: b[1])
return blocks
def _filter_nested(blocks: list[tuple]) -> list[tuple]:
filtered = []
for i, (x1, y1, w1, h1) in enumerate(blocks):
r1, b1 = x1 + w1, y1 + h1
nested = False
for j, (x2, y2, w2, h2) in enumerate(blocks):
if i == j: continue
r2, b2 = x2 + w2, y2 + h2
if x1 >= x2 and y1 >= y2 and r1 <= r2 and b1 <= b2:
nested = True
break
if not nested: filtered.append((x1, y1, w1, h1))
return filtered
def _merge_same_row(blocks: list[tuple]) -> list[tuple]:
if not blocks: return []
merged = []
cur_x, cur_y, cur_w, cur_h = blocks[0]
for (x, y, w, h) in blocks[1:]:
cur_b = cur_y + cur_h
if cur_h >= LARGE_BLOCK_H or h >= LARGE_BLOCK_H:
merged.append((cur_x, cur_y, cur_w, cur_h))
cur_x, cur_y, cur_w, cur_h = x, y, w, h
continue
if y <= cur_b + ROW_MERGE_GAP:
union_x, union_y = min(cur_x, x), min(cur_y, y)
union_r, union_b = max(cur_x + cur_w, x + w), max(cur_y + cur_h, y + h)
cur_x, cur_y, cur_w, cur_h = union_x, union_y, union_r - union_x, union_b - union_y
else:
merged.append((cur_x, cur_y, cur_w, cur_h))
cur_x, cur_y, cur_w, cur_h = x, y, w, h
merged.append((cur_x, cur_y, cur_w, cur_h))
return merged
def _extract_blocks(np_bgr: np.ndarray) -> list[tuple]:
raw = _find_raw_blocks(np_bgr)
dedup = _filter_nested(raw)
return _merge_same_row(dedup)
def get_best_sniper_roi(img):
"""
V1.1: Math Structural Heatmap Prior.
תעדוף אזורים עם צפיפות סמלים מתמטיים גבוהה.
"""
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
ROI_HOMOGRAPHY_THRESHOLD = 50000 # Threshold for local correction (w*h)
candidates = []
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
if w < 20 or h < 10: continue
# V1.1: Symbol Density Check (Heatmap Prior)
roi_thresh = thresh[y:y+h, x:x+w]
pixel_count = np.sum(roi_thresh == 255)
density = pixel_count / (w * h)
# Feature Clustering (Heatmap Weighting)
# Higher score for small but high-density clusters (usually math symbols)
heatmap_prior = density * (1.5 if density > 0.15 else 1.0)
# Position Weighting (Top-heavy bias)
position_weight = (1.0 / (1.0 + 0.005 * y))
confidence_score = heatmap_prior * position_weight
candidates.append({
'confidence': confidence_score,
'box': (x, y, w, h),
'needs_local_homography': (w * h) > ROI_HOMOGRAPHY_THRESHOLD
})
if not candidates:
logger.warning("⚠️ No valid candidates found. Fallback.")
return img[:350, :], 0.0
best = max(candidates, key=lambda c: c['confidence'])
x, y, w, h = best['box']
# V8.6.4: ROI Hardening — Safe Margins (Padding 50px)
# Prevent cutting off minus signs from exponents (e.g., e^-x becoming e^x)
SAFE_PADDING = 50
y_start = max(0, y - SAFE_PADDING)
y_end = min(img.shape[0], y + h + SAFE_PADDING)
logger.info(
f"📐 [OCR-BBOX] Sniper ROI (V8.6.4) — x={x}, y={y}, w={w}, h={h} | "
f"img=({img.shape[1]}x{img.shape[0]}) | "
f"crop=[{y_start}:{y_end}, :] | "
f"confidence={best['confidence']:.3f}"
)
# Conditional Local Homography Warning (Bit-log only for now)
if best['needs_local_homography']:
logger.info(f"📐 [V1.1] ROI ({w}x{h}) exceeds threshold. Local adjustment recommended.")
return img[y_start:y_end, :], best['confidence']
def apply_conditional_homography(roi_img: np.ndarray) -> np.ndarray:
"""
V1.1: Local Homography Warning.
Only applies alignment if the ROI is large enough to warrant it.
Small ROIs stay original to prevent distortion.
"""
h, w = roi_img.shape[:2]
area = h * w
if area < 5000: # Small ROI - keep original (Prevent distortion)
logger.info(f"📐 [V1.1] ROI too small ({area}) - skipping local homography.")
return roi_img
# Placeholder for actual homography warp
# In production, this would involve findHomography + warpPerspective
logger.info(f"📐 [V1.1] ROI large enough ({area}) - ready for local perspective correction.")
return roi_img
async def transcribe(image_bytes: bytes, vision_model, debug_mode: bool = False) -> tuple[list[dict], float]:
logger.info("🪡 [OCR-STRIP] V303.6 Production Lock Pipeline Starting...")
# 1. Adaptive Preprocessing
np_enhanced_bgr = _adaptive_preprocess(image_bytes)
img_h, img_w = np_enhanced_bgr.shape[:2]
# 2. Get Sniper ROI (V1.1)
sniper_bgr, roi_confidence = get_best_sniper_roi(np_enhanced_bgr)
# 2b. Apply Local Homography if needed (V1.1)
# Note: Full homography implementation requires feature matching,
# for now we implement the threshold logic.
sniper_bgr = apply_conditional_homography(sniper_bgr)
sniper_image = Image.fromarray(cv2.cvtColor(sniper_bgr, cv2.COLOR_BGR2RGB))
# 3. Reader Pass (Rest of the image)
# Note: For V303.6 unified/single pass, we still use the full image for the reader or keep it split if preferred.
# The user instruction implies a "Single Pass" logic for Strategy, but OCR can stay multi-pass as long as it's targeted.
# We'll stick to a high-quality reader image of the original.
pil_enhanced = Image.fromarray(cv2.cvtColor(np_enhanced_bgr, cv2.COLOR_BGR2RGB))
reader_image = pil_enhanced # Full image for context
if debug_mode:
sniper_image.save(DEBUG_DIR / "ocr_pass1_sniper.jpg")
reader_image.save(DEBUG_DIR / "ocr_pass2_reader.jpg")
# 4. The Prompts
sniper_prompt = (
"Extract ONLY the main mathematical function defined in this image. "
"It is usually preceded by words like 'נתונה הפונקציה'. "
"CRITICAL: If any exponent or fraction bar appears small or ambiguous, zoom mentally and transcribe it explicitly using ^ notation. "
"RETURN ONLY A JSON ARRAY: [{\"type\": \"math\", \"content\": \"...\"}]"
)
reader_prompt = (
"Extract all Hebrew text and secondary mathematical content from this image. "
"RETURN ONLY A JSON ARRAY: [{\"type\": \"text\"|\"math\", \"content\": \"...\"}]"
)
# 5. Concurrent Processing
try:
from google.generativeai.types import GenerationConfig
gen_config = GenerationConfig(temperature=0.0, top_p=0.1, top_k=1)
pass1_task = vision_model.generate_content_async([sniper_prompt, sniper_image], generation_config=gen_config)
pass2_task = vision_model.generate_content_async([reader_prompt, reader_image], generation_config=gen_config)
pass1_response, pass2_response = await asyncio.gather(pass1_task, pass2_task)
blocks_pass1 = _parse_structured_json(pass1_response.text)
blocks_pass2 = _parse_structured_json(pass2_response.text)
# Merge, prioritizing pass 1 for the function definition
final_blocks = blocks_pass1 + [b for b in blocks_pass2 if b not in blocks_pass1]
# V303.7: Apply final OCR hotfixes
for block in final_blocks:
if block.get("type") == "text":
block["content"] = finalize_ocr_text(block["content"])
logger.info(f"✅ V303.6 Complete. Sniper: {len(blocks_pass1)}, Reader: {len(blocks_pass2)} (Confidence: {roi_confidence:.2f})")
return final_blocks, roi_confidence
except Exception as e:
logger.exception("CRITICAL FLOW ERROR")
logger.error(f"❌ OCR V303.6 FAILED: {e}")
return [{"type": "text", "content": "שגיאת תקשורת בפענוח."}], 0.0
def finalize_ocr_text(text: str) -> str:
"""V1.1.2: Corrects common OCR misinterpretations in Hebrew context."""
if not text: return ""
text = text.replace("ציר ע", "ציר y")
text = text.replace("ציר E", "ציר y") # Common misinterpretation (E looks like y in some fonts)
text = text.replace("ציר ץ", "ציר y") # Another common one
return text
def _parse_structured_json(raw_text: str) -> list[dict]:
"""V1.0: Uses canonical safe_extract_json (logs RAW, fail-closed)."""
result = safe_extract_json(raw_text, caller="OCR", allow_array=True)
if isinstance(result, list):
# Flatten nested lists (LLM sometimes wraps array in array)
flat = []
for item in result:
if isinstance(item, list):
flat.extend(item)
elif isinstance(item, dict):
flat.append(item)
elif isinstance(item, str) and item.strip():
# V9.0.2 FIX: Handle strings by wrapping them in a text block
flat.append({"type": "text", "content": item.strip()})
return [p for p in flat if isinstance(p, dict)]
if isinstance(result, dict) and not result.get("logic_error"):
return [result]
logger.error(f"[OCR] _parse_structured_json: parse failed for: {raw_text[:200]!r}")
return []
def paginate_image(image_bytes, debug_mode=False):
return [Image.open(io.BytesIO(image_bytes)).convert("RGB")]
def flatten_to_text(structured: list[dict]) -> str:
parts = []
for item in structured:
if item.get("type") == "math": parts.append(f"${item.get('content', '')}$")
else: parts.append(item.get("content", ""))
return " ".join(parts)