| |
| """ |
| IAT (Illumination Adaptive Transformer) β ONNX Export Script |
| |
| Monkey-patches ONNX-incompatible patterns in IAT source, exports all 3 |
| checkpoints (exposure, lol_v1, lol_v2), and verifies each numerically |
| at multiple resolutions. |
| |
| Patches applied: |
| 1. IAT.apply_color: tensordot β matmul (ONNX-friendly) |
| 2. IAT.forward: Python for-loop over batch β vectorized bmm + broadcast pow |
| 3. Aff_channel.forward: tensordot β matmul (fallback β needed for tracing) |
| """ |
|
|
| import argparse |
| import sys |
| import os |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| |
| |
| |
| IAT_ROOT = Path(__file__).parent / "iat" / "IAT_enhance" |
| sys.path.insert(0, str(IAT_ROOT)) |
|
|
| |
| |
| import importlib |
| if not importlib.util.find_spec("imp"): |
| import types |
| sys.modules["imp"] = types.ModuleType("imp") |
|
|
| from model.IAT_main import IAT |
| from model.blocks import Aff_channel |
|
|
|
|
| |
| |
| |
|
|
| def _patched_apply_color(self, image, ccm): |
| """Replace tensordot with matmul for ONNX compatibility. |
| |
| Original: torch.tensordot(image, ccm, dims=[[-1], [-1]]) |
| which computes image @ ccm.T (contract last dim of both) |
| Replacement: torch.matmul(image, ccm.T) |
| """ |
| shape = image.shape |
| image = image.view(-1, 3) |
| image = torch.matmul(image, ccm.permute(1, 0)) |
| image = image.view(shape) |
| return torch.clamp(image, 1e-8, 1.0) |
|
|
|
|
| def _patched_forward(self, img_low): |
| """Vectorized forward β no Python for-loop over batch dimension. |
| |
| Original: |
| img_high = torch.stack([self.apply_color(img_high[i,:,:,:], color[i,:,:]) |
| **gamma[i,:] for i in range(b)], dim=0) |
| |
| Replacement: |
| 1. bmm for batched color matrix multiply |
| 2. broadcast pow for gamma |
| """ |
| mul, add = self.local_net(img_low) |
| img_high = (img_low.mul(mul)).add(add) |
|
|
| if not self.with_global: |
| return mul, add, img_high |
|
|
| gamma, color = self.global_net(img_low) |
| |
| b, c, h, w = img_high.shape |
| img_high = img_high.permute(0, 2, 3, 1).reshape(b, h * w, c) |
|
|
| |
| |
| color_t = color.permute(0, 2, 1) |
| img_high = torch.bmm(img_high, color_t) |
| img_high = torch.clamp(img_high, 1e-8, 1.0) |
|
|
| |
| img_high = img_high.view(b, h, w, c) |
|
|
| |
| gamma_broadcast = gamma.unsqueeze(-1).unsqueeze(-1) |
| img_high = img_high ** gamma_broadcast |
|
|
| |
| img_high = img_high.permute(0, 3, 1, 2) |
| return mul, add, img_high |
|
|
|
|
| def _patched_aff_channel_forward(self, x): |
| """Replace tensordot with matmul in Aff_channel for ONNX compatibility. |
| |
| Original: torch.tensordot(x, self.color, dims=[[-1], [-1]]) |
| Replacement: torch.matmul(x, self.color.T) |
| """ |
| if self.channel_first: |
| x1 = torch.matmul(x, self.color.permute(1, 0)) |
| x2 = x1 * self.alpha + self.beta |
| else: |
| x1 = x * self.alpha + self.beta |
| x2 = torch.matmul(x1, self.color.permute(1, 0)) |
| return x2 |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
|
|
| CHECKPOINTS = { |
| "exposure": { |
| "path": IAT_ROOT / "best_Epoch_exposure.pth", |
| "model_kwargs": {"type": "exp"}, |
| "description": "Exposure correction", |
| }, |
| "lol_v1": { |
| "path": IAT_ROOT / "best_Epoch_lol_v1.pth", |
| "model_kwargs": {"type": "lol"}, |
| "description": "LOL v1 low-light enhancement", |
| }, |
| "lol_v2": { |
| "path": IAT_ROOT / "best_Epoch_lol.pth", |
| "model_kwargs": {"type": "lol"}, |
| "description": "LOL v2 low-light enhancement", |
| }, |
| } |
|
|
| VERIFICATION_RESOLUTIONS = [ |
| (256, 256), |
| (512, 512), |
| (768, 1024), |
| ] |
|
|
|
|
| |
| |
| |
|
|
| def apply_patches(): |
| """Monkey-patch IAT classes at runtime. Does not modify source files.""" |
| |
| IAT.apply_color = _patched_apply_color |
| IAT.forward = _patched_forward |
|
|
| |
| Aff_channel.forward = _patched_aff_channel_forward |
|
|
| print("[PATCH] IAT.apply_color: tensordot -> matmul") |
| print("[PATCH] IAT.forward: for-loop -> vectorized bmm + broadcast pow") |
| print("[PATCH] Aff_channel.forward: tensordot -> matmul") |
|
|
|
|
| |
| |
| |
|
|
| def load_model(name: str) -> nn.Module: |
| """Load an IAT model from checkpoint.""" |
| cfg = CHECKPOINTS[name] |
| model = IAT(in_dim=3, with_global=True, **cfg["model_kwargs"]) |
| state_dict = torch.load(str(cfg["path"]), map_location="cpu", weights_only=True) |
| model.load_state_dict(state_dict) |
| model.train(False) |
| return model |
|
|
|
|
| def export_onnx(model: nn.Module, output_path: Path, opset: int) -> None: |
| """Export a single IAT model to ONNX.""" |
| dummy_input = torch.randn(1, 3, 256, 256) |
|
|
| torch.onnx.export( |
| model, |
| (dummy_input,), |
| str(output_path), |
| opset_version=opset, |
| input_names=["input"], |
| output_names=["mul", "add", "enhanced"], |
| dynamic_axes={ |
| "input": {0: "batch", 2: "height", 3: "width"}, |
| "mul": {0: "batch", 2: "height", 3: "width"}, |
| "add": {0: "batch", 2: "height", 3: "width"}, |
| "enhanced": {0: "batch", 2: "height", 3: "width"}, |
| }, |
| ) |
|
|
|
|
| def verify_onnx(model: nn.Module, onnx_path: Path) -> bool: |
| """Numerical verification of ONNX vs PyTorch at multiple resolutions.""" |
| import onnxruntime as ort |
|
|
| session = ort.InferenceSession( |
| str(onnx_path), |
| providers=["CPUExecutionProvider"], |
| ) |
|
|
| all_ok = True |
| for h, w in VERIFICATION_RESOLUTIONS: |
| dummy = torch.randn(1, 3, h, w) |
|
|
| |
| with torch.no_grad(): |
| pt_mul, pt_add, pt_enhanced = model(dummy) |
|
|
| |
| ort_inputs = {"input": dummy.numpy()} |
| ort_mul, ort_add, ort_enhanced = session.run(None, ort_inputs) |
|
|
| |
| for name, pt_out, ort_out in [ |
| ("mul", pt_mul, ort_mul), |
| ("add", pt_add, ort_add), |
| ("enhanced", pt_enhanced, ort_enhanced), |
| ]: |
| max_diff = np.max(np.abs(pt_out.numpy() - ort_out)) |
| if max_diff < 1e-5: |
| status = "OK" |
| symbol = "+" |
| elif max_diff < 1e-3: |
| status = "WARN" |
| symbol = "~" |
| else: |
| status = "FAIL" |
| symbol = "X" |
|
|
| print(f" [{symbol}] {h}x{w} {name:10s} max_diff={max_diff:.2e} [{status}]") |
|
|
| if max_diff >= 1e-3: |
| print(f" FAIL: max abs diff {max_diff:.6f} >= 1e-3") |
| all_ok = False |
|
|
| return all_ok |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Export IAT checkpoints to ONNX") |
| parser.add_argument( |
| "--checkpoints", |
| type=str, |
| default="all", |
| choices=["all", "exposure", "lol_v1", "lol_v2"], |
| help="Which checkpoint(s) to export", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default=str(Path(__file__).parent / "outputs"), |
| help="Directory for exported ONNX files", |
| ) |
| parser.add_argument( |
| "--opset", |
| type=int, |
| default=17, |
| help="ONNX opset version", |
| ) |
| args = parser.parse_args() |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| if args.checkpoints == "all": |
| names = list(CHECKPOINTS.keys()) |
| else: |
| names = [args.checkpoints] |
|
|
| |
| print("=" * 60) |
| print("Applying ONNX-compatibility patches...") |
| print("=" * 60) |
| apply_patches() |
| print() |
|
|
| results = {} |
| for name in names: |
| cfg = CHECKPOINTS[name] |
| onnx_path = output_dir / f"iat_{name}.onnx" |
|
|
| print("=" * 60) |
| print(f"Exporting: {name} ({cfg['description']})") |
| print(f" Checkpoint: {cfg['path']}") |
| print(f" Output: {onnx_path}") |
| print("=" * 60) |
|
|
| |
| if not cfg["path"].exists(): |
| print(f" SKIP: checkpoint not found at {cfg['path']}") |
| results[name] = "SKIP" |
| continue |
|
|
| |
| t0 = time.time() |
| model = load_model(name) |
| print(f" Loaded model in {time.time() - t0:.2f}s") |
|
|
| |
| t0 = time.time() |
| export_onnx(model, onnx_path, args.opset) |
| export_time = time.time() - t0 |
| file_size_mb = onnx_path.stat().st_size / (1024 * 1024) |
| print(f" Exported in {export_time:.2f}s ({file_size_mb:.1f} MB)") |
|
|
| |
| print(f" Verifying at {len(VERIFICATION_RESOLUTIONS)} resolutions...") |
| ok = verify_onnx(model, onnx_path) |
| results[name] = "PASS" if ok else "FAIL" |
| print() |
|
|
| |
| print("=" * 60) |
| print("SUMMARY") |
| print("=" * 60) |
| all_pass = True |
| for name, status in results.items(): |
| if status == "PASS": |
| symbol = "+" |
| elif status == "SKIP": |
| symbol = "-" |
| else: |
| symbol = "X" |
| print(f" [{symbol}] {name}: {status}") |
| if status == "FAIL": |
| all_pass = False |
|
|
| if not all_pass: |
| print("\nSome exports FAILED numerical verification!") |
| sys.exit(1) |
| else: |
| print("\nAll exports passed!") |
| sys.exit(0) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|