File size: 12,837 Bytes
9d29c62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2eb4e3
9d29c62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0524a1e
 
 
 
 
9d29c62
 
 
 
 
 
 
 
0524a1e
 
 
 
 
 
 
 
9d29c62
 
 
 
 
 
 
 
 
 
 
a162bd1
 
 
9d29c62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
# ocr_strip_engine.py - V303.2 (Adaptive Pipeline & Two-Pass Sniper)
import os
import io
import json
import re
import logging
import asyncio
from pathlib import Path
import numpy as np
from PIL import Image
import cv2
from utils.safe_json import safe_extract_json  # V1.0: Canonical JSON extractor

logger = logging.getLogger(__name__)

DEBUG_DIR = Path("/tmp/debug_ocr")
DEBUG_DIR.mkdir(parents=True, exist_ok=True)

# --- Constants for block detection ---
MIN_BLOCK_W      = 50
MIN_BLOCK_H      = 10
LARGE_BLOCK_H    = 100
ROW_MERGE_GAP    = 2

def _adaptive_preprocess(image_bytes: bytes) -> np.ndarray:
    """ 
    V303.2: Adaptive Preprocessing Pipeline
    Decides how aggressively to process based on image size.
    """
    np_img_raw = np.frombuffer(image_bytes, np.uint8)
    img_bgr = cv2.imdecode(np_img_raw, cv2.IMREAD_COLOR)
    file_size_kb = len(image_bytes) / 1024
    
    logger.info(f"📸 [OCR-ADAPTIVE] Input image size: {file_size_kb:.1f} KB")

    if file_size_kb < 500:
        # --- LOW-RES MODE (PC Screenshots / Snips) ---
        logger.info("🔧 [OCR-ADAPTIVE] Low-Res mode triggered: Upscaling and applying heavy morphology.")
        # 1. Upscale x2 to save thin pixels
        img_bgr = cv2.resize(img_bgr, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
        gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
        # 2. Strong CLAHE
        clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(8,8))
        cl1 = clahe.apply(gray)
        # 3. Morph Close (thicken lines)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
        processed = cv2.morphologyEx(cl1, cv2.MORPH_CLOSE, kernel)
    else:
        # --- HIGH-RES MODE (Phone Camera in Production) ---
        logger.info("📱 [OCR-ADAPTIVE] High-Res mode triggered: Mild enhancement only.")
        gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
        
        # V1.1: Pre-normalization (Contrast/Brightness Balance)
        alpha = 1.2 # Contrast
        beta = 10   # Brightness
        gray = cv2.convertScaleAbs(gray, alpha=alpha, beta=beta)
        
        # Mild CLAHE just to balance lighting, NO morphological distortion
        clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8,8))
        processed = clahe.apply(gray)

    return cv2.cvtColor(processed, cv2.COLOR_GRAY2BGR)


def _find_raw_blocks(np_bgr: np.ndarray) -> list[tuple]:
    gray  = cv2.cvtColor(np_bgr, cv2.COLOR_BGR2GRAY)
    blur  = cv2.GaussianBlur(gray, (7, 7), 0)
    thresh = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
    # Kernel optimized for both modes
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (60, 3))
    dilate = cv2.dilate(thresh, kernel, iterations=1)
    contours, _ = cv2.findContours(dilate, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    blocks = []
    for c in contours:
        x, y, w, h = cv2.boundingRect(c)
        if w >= MIN_BLOCK_W and h >= MIN_BLOCK_H:
            blocks.append((x, y, w, h))
    blocks.sort(key=lambda b: b[1])
    return blocks

def _filter_nested(blocks: list[tuple]) -> list[tuple]:
    filtered = []
    for i, (x1, y1, w1, h1) in enumerate(blocks):
        r1, b1 = x1 + w1, y1 + h1
        nested = False
        for j, (x2, y2, w2, h2) in enumerate(blocks):
            if i == j: continue
            r2, b2 = x2 + w2, y2 + h2
            if x1 >= x2 and y1 >= y2 and r1 <= r2 and b1 <= b2:
                nested = True
                break
        if not nested: filtered.append((x1, y1, w1, h1))
    return filtered

def _merge_same_row(blocks: list[tuple]) -> list[tuple]:
    if not blocks: return []
    merged = []
    cur_x, cur_y, cur_w, cur_h = blocks[0]
    for (x, y, w, h) in blocks[1:]:
        cur_b = cur_y + cur_h
        if cur_h >= LARGE_BLOCK_H or h >= LARGE_BLOCK_H:
            merged.append((cur_x, cur_y, cur_w, cur_h))
            cur_x, cur_y, cur_w, cur_h = x, y, w, h
            continue
        if y <= cur_b + ROW_MERGE_GAP:
            union_x, union_y = min(cur_x, x), min(cur_y, y)
            union_r, union_b = max(cur_x + cur_w, x + w), max(cur_y + cur_h, y + h)
            cur_x, cur_y, cur_w, cur_h = union_x, union_y, union_r - union_x, union_b - union_y
        else:
            merged.append((cur_x, cur_y, cur_w, cur_h))
            cur_x, cur_y, cur_w, cur_h = x, y, w, h
    merged.append((cur_x, cur_y, cur_w, cur_h))
    return merged

def _extract_blocks(np_bgr: np.ndarray) -> list[tuple]:
    raw = _find_raw_blocks(np_bgr)
    dedup = _filter_nested(raw)
    return _merge_same_row(dedup)

def get_best_sniper_roi(img):
    """
    V1.1: Math Structural Heatmap Prior.
    תעדוף אזורים עם צפיפות סמלים מתמטיים גבוהה.
    """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    ROI_HOMOGRAPHY_THRESHOLD = 50000  # Threshold for local correction (w*h)
    candidates = []
    
    for cnt in contours:
        x, y, w, h = cv2.boundingRect(cnt)
        if w < 20 or h < 10: continue
        
        # V1.1: Symbol Density Check (Heatmap Prior)
        roi_thresh = thresh[y:y+h, x:x+w]
        pixel_count = np.sum(roi_thresh == 255)
        density = pixel_count / (w * h)
        
        # Feature Clustering (Heatmap Weighting)
        # Higher score for small but high-density clusters (usually math symbols)
        heatmap_prior = density * (1.5 if density > 0.15 else 1.0)
        
        # Position Weighting (Top-heavy bias)
        position_weight = (1.0 / (1.0 + 0.005 * y))
        
        confidence_score = heatmap_prior * position_weight
        
        candidates.append({
            'confidence': confidence_score,
            'box': (x, y, w, h),
            'needs_local_homography': (w * h) > ROI_HOMOGRAPHY_THRESHOLD
        })

    if not candidates:
        logger.warning("⚠️ No valid candidates found. Fallback.")
        return img[:350, :], 0.0
    
    best = max(candidates, key=lambda c: c['confidence'])
    x, y, w, h = best['box']

    # V8.6.4: ROI Hardening — Safe Margins (Padding 50px)
    # Prevent cutting off minus signs from exponents (e.g., e^-x becoming e^x)
    SAFE_PADDING = 50
    y_start = max(0, y - SAFE_PADDING)
    y_end   = min(img.shape[0], y + h + SAFE_PADDING)
    
    logger.info(
        f"📐 [OCR-BBOX] Sniper ROI (V8.6.4) — x={x}, y={y}, w={w}, h={h} | "
        f"img=({img.shape[1]}x{img.shape[0]}) | "
        f"crop=[{y_start}:{y_end}, :] | "
        f"confidence={best['confidence']:.3f}"
    )

    # Conditional Local Homography Warning (Bit-log only for now)
    if best['needs_local_homography']:
        logger.info(f"📐 [V1.1] ROI ({w}x{h}) exceeds threshold. Local adjustment recommended.")
    
    return img[y_start:y_end, :], best['confidence']

def apply_conditional_homography(roi_img: np.ndarray) -> np.ndarray:
    """
    V1.1: Local Homography Warning.
    Only applies alignment if the ROI is large enough to warrant it.
    Small ROIs stay original to prevent distortion.
    """
    h, w = roi_img.shape[:2]
    area = h * w
    
    if area < 5000: # Small ROI - keep original (Prevent distortion)
        logger.info(f"📐 [V1.1] ROI too small ({area}) - skipping local homography.")
        return roi_img
        
    # Placeholder for actual homography warp
    # In production, this would involve findHomography + warpPerspective
    logger.info(f"📐 [V1.1] ROI large enough ({area}) - ready for local perspective correction.")
    return roi_img

async def transcribe(image_bytes: bytes, vision_model, debug_mode: bool = False) -> tuple[list[dict], float]:
    logger.info("🪡 [OCR-STRIP] V303.6 Production Lock Pipeline Starting...")
    
    # 1. Adaptive Preprocessing
    np_enhanced_bgr = _adaptive_preprocess(image_bytes)
    img_h, img_w = np_enhanced_bgr.shape[:2]
    
    # 2. Get Sniper ROI (V1.1)
    sniper_bgr, roi_confidence = get_best_sniper_roi(np_enhanced_bgr)
    
    # 2b. Apply Local Homography if needed (V1.1)
    # Note: Full homography implementation requires feature matching, 
    # for now we implement the threshold logic.
    sniper_bgr = apply_conditional_homography(sniper_bgr)
    
    sniper_image = Image.fromarray(cv2.cvtColor(sniper_bgr, cv2.COLOR_BGR2RGB))
    
    # 3. Reader Pass (Rest of the image)
    # Note: For V303.6 unified/single pass, we still use the full image for the reader or keep it split if preferred.
    # The user instruction implies a "Single Pass" logic for Strategy, but OCR can stay multi-pass as long as it's targeted.
    # We'll stick to a high-quality reader image of the original.
    pil_enhanced = Image.fromarray(cv2.cvtColor(np_enhanced_bgr, cv2.COLOR_BGR2RGB))
    reader_image = pil_enhanced # Full image for context
    
    if debug_mode:
        sniper_image.save(DEBUG_DIR / "ocr_pass1_sniper.jpg")
        reader_image.save(DEBUG_DIR / "ocr_pass2_reader.jpg")

    # 4. The Prompts
    sniper_prompt = (
        "Extract ONLY the main mathematical function defined in this image. "
        "It is usually preceded by words like 'נתונה הפונקציה'. "
        "CRITICAL: If any exponent or fraction bar appears small or ambiguous, zoom mentally and transcribe it explicitly using ^ notation. "
        "RETURN ONLY A JSON ARRAY: [{\"type\": \"math\", \"content\": \"...\"}]"
    )
    
    reader_prompt = (
        "Extract all Hebrew text and secondary mathematical content from this image. "
        "RETURN ONLY A JSON ARRAY: [{\"type\": \"text\"|\"math\", \"content\": \"...\"}]"
    )
    
    # 5. Concurrent Processing
    try:
        from google.generativeai.types import GenerationConfig
        gen_config = GenerationConfig(temperature=0.0, top_p=0.1, top_k=1)
        pass1_task = vision_model.generate_content_async([sniper_prompt, sniper_image], generation_config=gen_config)
        pass2_task = vision_model.generate_content_async([reader_prompt, reader_image], generation_config=gen_config)
        
        pass1_response, pass2_response = await asyncio.gather(pass1_task, pass2_task)
        
        blocks_pass1 = _parse_structured_json(pass1_response.text)
        blocks_pass2 = _parse_structured_json(pass2_response.text)
        
        # Merge, prioritizing pass 1 for the function definition
        final_blocks = blocks_pass1 + [b for b in blocks_pass2 if b not in blocks_pass1]
        
        # V303.7: Apply final OCR hotfixes
        for block in final_blocks:
            if block.get("type") == "text":
                block["content"] = finalize_ocr_text(block["content"])
        
        logger.info(f"✅ V303.6 Complete. Sniper: {len(blocks_pass1)}, Reader: {len(blocks_pass2)} (Confidence: {roi_confidence:.2f})")
        return final_blocks, roi_confidence
        
    except Exception as e:
        logger.exception("CRITICAL FLOW ERROR")
        logger.error(f"❌ OCR V303.6 FAILED: {e}")
        return [{"type": "text", "content": "שגיאת תקשורת בפענוח."}], 0.0

def finalize_ocr_text(text: str) -> str:
    """V1.1.2: Corrects common OCR misinterpretations in Hebrew context."""
    if not text: return ""
    text = text.replace("ציר ע", "ציר y")
    text = text.replace("ציר E", "ציר y") # Common misinterpretation (E looks like y in some fonts)
    text = text.replace("ציר ץ", "ציר y") # Another common one
    return text

def _parse_structured_json(raw_text: str) -> list[dict]:
    """V1.0: Uses canonical safe_extract_json (logs RAW, fail-closed)."""
    result = safe_extract_json(raw_text, caller="OCR", allow_array=True)
    if isinstance(result, list):
        # Flatten nested lists (LLM sometimes wraps array in array)
        flat = []
        for item in result:
            if isinstance(item, list):
                flat.extend(item)
            elif isinstance(item, dict):
                flat.append(item)
            elif isinstance(item, str) and item.strip():
                # V9.0.2 FIX: Handle strings by wrapping them in a text block
                flat.append({"type": "text", "content": item.strip()})
        return [p for p in flat if isinstance(p, dict)]
    if isinstance(result, dict) and not result.get("logic_error"):
        return [result]
    logger.error(f"[OCR] _parse_structured_json: parse failed for: {raw_text[:200]!r}")
    return []


def paginate_image(image_bytes, debug_mode=False):
    return [Image.open(io.BytesIO(image_bytes)).convert("RGB")] 
    
def flatten_to_text(structured: list[dict]) -> str:
    parts = []
    for item in structured:
        if item.get("type") == "math": parts.append(f"${item.get('content', '')}$")
        else: parts.append(item.get("content", ""))
    return " ".join(parts)