#!/usr/bin/env python3 """ IAT (Illumination Adaptive Transformer) → ONNX Export Script Monkey-patches ONNX-incompatible patterns in IAT source, exports all 3 checkpoints (exposure, lol_v1, lol_v2), and verifies each numerically at multiple resolutions. Patches applied: 1. IAT.apply_color: tensordot → matmul (ONNX-friendly) 2. IAT.forward: Python for-loop over batch → vectorized bmm + broadcast pow 3. Aff_channel.forward: tensordot → matmul (fallback — needed for tracing) """ import argparse import sys import os import time from pathlib import Path import numpy as np import torch import torch.nn as nn # --------------------------------------------------------------------------- # Add IAT source to path, fix Python 3.12+ compatibility # --------------------------------------------------------------------------- IAT_ROOT = Path(__file__).parent / "iat" / "IAT_enhance" sys.path.insert(0, str(IAT_ROOT)) # IAT's global_net.py has `import imp` which was removed in Python 3.12. # It's unused, so we provide a dummy module before importing IAT. import importlib if not importlib.util.find_spec("imp"): import types sys.modules["imp"] = types.ModuleType("imp") from model.IAT_main import IAT # noqa: E402 from model.blocks import Aff_channel # noqa: E402 # =========================================================================== # Monkey-patches # =========================================================================== def _patched_apply_color(self, image, ccm): """Replace tensordot with matmul for ONNX compatibility. Original: torch.tensordot(image, ccm, dims=[[-1], [-1]]) which computes image @ ccm.T (contract last dim of both) Replacement: torch.matmul(image, ccm.T) """ shape = image.shape image = image.view(-1, 3) image = torch.matmul(image, ccm.permute(1, 0)) image = image.view(shape) return torch.clamp(image, 1e-8, 1.0) def _patched_forward(self, img_low): """Vectorized forward — no Python for-loop over batch dimension. Original: img_high = torch.stack([self.apply_color(img_high[i,:,:,:], color[i,:,:]) **gamma[i,:] for i in range(b)], dim=0) Replacement: 1. bmm for batched color matrix multiply 2. broadcast pow for gamma """ mul, add = self.local_net(img_low) img_high = (img_low.mul(mul)).add(add) if not self.with_global: return mul, add, img_high gamma, color = self.global_net(img_low) # img_high: (B, C, H, W) → (B, H, W, C) → (B, H*W, C) b, c, h, w = img_high.shape img_high = img_high.permute(0, 2, 3, 1).reshape(b, h * w, c) # Batched color matrix: (B, H*W, 3) @ (B, 3, 3) → (B, H*W, 3) # color is (B, 3, 3), we need img @ color^T for each batch element color_t = color.permute(0, 2, 1) # (B, 3, 3) img_high = torch.bmm(img_high, color_t) img_high = torch.clamp(img_high, 1e-8, 1.0) # Reshape back to (B, H, W, C) for broadcast pow img_high = img_high.view(b, h, w, c) # gamma is (B, 1) — reshape to (B, 1, 1, 1) for broadcast gamma_broadcast = gamma.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1, 1) img_high = img_high ** gamma_broadcast # (B, H, W, C) → (B, C, H, W) img_high = img_high.permute(0, 3, 1, 2) return mul, add, img_high def _patched_aff_channel_forward(self, x): """Replace tensordot with matmul in Aff_channel for ONNX compatibility. Original: torch.tensordot(x, self.color, dims=[[-1], [-1]]) Replacement: torch.matmul(x, self.color.T) """ if self.channel_first: x1 = torch.matmul(x, self.color.permute(1, 0)) x2 = x1 * self.alpha + self.beta else: x1 = x * self.alpha + self.beta x2 = torch.matmul(x1, self.color.permute(1, 0)) return x2 # =========================================================================== # Fallback patches (not needed for current export, documented for reference) # =========================================================================== # --- Fallback: query_Attention expand --- # If export fails on expand in global attention (global_net.py): # # def _patched_query_attention_forward(self, x): # B, N, C = x.shape # # Original: self.q.expand(B, -1, -1) -- can fail with dynamic batch # # Fix: use repeat which traces cleanly # q = self.q.repeat(B, 1, 1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # ... rest of forward unchanged ... # # from model.global_net import query_Attention # query_Attention.forward = _patched_query_attention_forward # --- Fallback: gamma power operator --- # If ** operator traces incorrectly for broadcast shapes: # # Replace in _patched_forward: # img_high = torch.pow(torch.clamp(img_high, 1e-8), gamma) # With: # img_high = torch.exp(torch.log(torch.clamp(img_high, 1e-8)) * gamma) # =========================================================================== # Checkpoint configurations # =========================================================================== CHECKPOINTS = { "exposure": { "path": IAT_ROOT / "best_Epoch_exposure.pth", "model_kwargs": {"type": "exp"}, "description": "Exposure correction", }, "lol_v1": { "path": IAT_ROOT / "best_Epoch_lol_v1.pth", "model_kwargs": {"type": "lol"}, "description": "LOL v1 low-light enhancement", }, "lol_v2": { "path": IAT_ROOT / "best_Epoch_lol.pth", "model_kwargs": {"type": "lol"}, "description": "LOL v2 low-light enhancement", }, } VERIFICATION_RESOLUTIONS = [ (256, 256), (512, 512), (768, 1024), # H, W — non-square ] # =========================================================================== # Apply patches # =========================================================================== def apply_patches(): """Monkey-patch IAT classes at runtime. Does not modify source files.""" # Patch 1 & 2: IAT.apply_color and IAT.forward IAT.apply_color = _patched_apply_color IAT.forward = _patched_forward # Patch 3 (fallback — needed for Aff_channel tensordot tracing): Aff_channel.forward = _patched_aff_channel_forward print("[PATCH] IAT.apply_color: tensordot -> matmul") print("[PATCH] IAT.forward: for-loop -> vectorized bmm + broadcast pow") print("[PATCH] Aff_channel.forward: tensordot -> matmul") # =========================================================================== # Export # =========================================================================== def load_model(name: str) -> nn.Module: """Load an IAT model from checkpoint.""" cfg = CHECKPOINTS[name] model = IAT(in_dim=3, with_global=True, **cfg["model_kwargs"]) state_dict = torch.load(str(cfg["path"]), map_location="cpu", weights_only=True) model.load_state_dict(state_dict) model.train(False) return model def export_onnx(model: nn.Module, output_path: Path, opset: int) -> None: """Export a single IAT model to ONNX.""" dummy_input = torch.randn(1, 3, 256, 256) torch.onnx.export( model, (dummy_input,), str(output_path), opset_version=opset, input_names=["input"], output_names=["mul", "add", "enhanced"], dynamic_axes={ "input": {0: "batch", 2: "height", 3: "width"}, "mul": {0: "batch", 2: "height", 3: "width"}, "add": {0: "batch", 2: "height", 3: "width"}, "enhanced": {0: "batch", 2: "height", 3: "width"}, }, ) def verify_onnx(model: nn.Module, onnx_path: Path) -> bool: """Numerical verification of ONNX vs PyTorch at multiple resolutions.""" import onnxruntime as ort session = ort.InferenceSession( str(onnx_path), providers=["CPUExecutionProvider"], ) all_ok = True for h, w in VERIFICATION_RESOLUTIONS: dummy = torch.randn(1, 3, h, w) # PyTorch reference with torch.no_grad(): pt_mul, pt_add, pt_enhanced = model(dummy) # ONNX Runtime ort_inputs = {"input": dummy.numpy()} ort_mul, ort_add, ort_enhanced = session.run(None, ort_inputs) # Compare enhanced output (the one that matters most) for name, pt_out, ort_out in [ ("mul", pt_mul, ort_mul), ("add", pt_add, ort_add), ("enhanced", pt_enhanced, ort_enhanced), ]: max_diff = np.max(np.abs(pt_out.numpy() - ort_out)) if max_diff < 1e-5: status = "OK" symbol = "+" elif max_diff < 1e-3: status = "WARN" symbol = "~" else: status = "FAIL" symbol = "X" print(f" [{symbol}] {h}x{w} {name:10s} max_diff={max_diff:.2e} [{status}]") if max_diff >= 1e-3: print(f" FAIL: max abs diff {max_diff:.6f} >= 1e-3") all_ok = False return all_ok # =========================================================================== # Main # =========================================================================== def main(): parser = argparse.ArgumentParser(description="Export IAT checkpoints to ONNX") parser.add_argument( "--checkpoints", type=str, default="all", choices=["all", "exposure", "lol_v1", "lol_v2"], help="Which checkpoint(s) to export", ) parser.add_argument( "--output-dir", type=str, default=str(Path(__file__).parent / "outputs"), help="Directory for exported ONNX files", ) parser.add_argument( "--opset", type=int, default=17, help="ONNX opset version", ) args = parser.parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Determine which checkpoints to export if args.checkpoints == "all": names = list(CHECKPOINTS.keys()) else: names = [args.checkpoints] # Apply monkey-patches print("=" * 60) print("Applying ONNX-compatibility patches...") print("=" * 60) apply_patches() print() results = {} for name in names: cfg = CHECKPOINTS[name] onnx_path = output_dir / f"iat_{name}.onnx" print("=" * 60) print(f"Exporting: {name} ({cfg['description']})") print(f" Checkpoint: {cfg['path']}") print(f" Output: {onnx_path}") print("=" * 60) # Check checkpoint exists if not cfg["path"].exists(): print(f" SKIP: checkpoint not found at {cfg['path']}") results[name] = "SKIP" continue # Load t0 = time.time() model = load_model(name) print(f" Loaded model in {time.time() - t0:.2f}s") # Export t0 = time.time() export_onnx(model, onnx_path, args.opset) export_time = time.time() - t0 file_size_mb = onnx_path.stat().st_size / (1024 * 1024) print(f" Exported in {export_time:.2f}s ({file_size_mb:.1f} MB)") # Verify print(f" Verifying at {len(VERIFICATION_RESOLUTIONS)} resolutions...") ok = verify_onnx(model, onnx_path) results[name] = "PASS" if ok else "FAIL" print() # Summary print("=" * 60) print("SUMMARY") print("=" * 60) all_pass = True for name, status in results.items(): if status == "PASS": symbol = "+" elif status == "SKIP": symbol = "-" else: symbol = "X" print(f" [{symbol}] {name}: {status}") if status == "FAIL": all_pass = False if not all_pass: print("\nSome exports FAILED numerical verification!") sys.exit(1) else: print("\nAll exports passed!") sys.exit(0) if __name__ == "__main__": main()