IAT-ONNX / export_iat_onnx.py
kpezhgorski's picture
Upload folder using huggingface_hub
b7b07be verified
#!/usr/bin/env python3
"""
IAT (Illumination Adaptive Transformer) β†’ ONNX Export Script
Monkey-patches ONNX-incompatible patterns in IAT source, exports all 3
checkpoints (exposure, lol_v1, lol_v2), and verifies each numerically
at multiple resolutions.
Patches applied:
1. IAT.apply_color: tensordot β†’ matmul (ONNX-friendly)
2. IAT.forward: Python for-loop over batch β†’ vectorized bmm + broadcast pow
3. Aff_channel.forward: tensordot β†’ matmul (fallback β€” needed for tracing)
"""
import argparse
import sys
import os
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Add IAT source to path, fix Python 3.12+ compatibility
# ---------------------------------------------------------------------------
IAT_ROOT = Path(__file__).parent / "iat" / "IAT_enhance"
sys.path.insert(0, str(IAT_ROOT))
# IAT's global_net.py has `import imp` which was removed in Python 3.12.
# It's unused, so we provide a dummy module before importing IAT.
import importlib
if not importlib.util.find_spec("imp"):
import types
sys.modules["imp"] = types.ModuleType("imp")
from model.IAT_main import IAT # noqa: E402
from model.blocks import Aff_channel # noqa: E402
# ===========================================================================
# Monkey-patches
# ===========================================================================
def _patched_apply_color(self, image, ccm):
"""Replace tensordot with matmul for ONNX compatibility.
Original: torch.tensordot(image, ccm, dims=[[-1], [-1]])
which computes image @ ccm.T (contract last dim of both)
Replacement: torch.matmul(image, ccm.T)
"""
shape = image.shape
image = image.view(-1, 3)
image = torch.matmul(image, ccm.permute(1, 0))
image = image.view(shape)
return torch.clamp(image, 1e-8, 1.0)
def _patched_forward(self, img_low):
"""Vectorized forward β€” no Python for-loop over batch dimension.
Original:
img_high = torch.stack([self.apply_color(img_high[i,:,:,:], color[i,:,:])
**gamma[i,:] for i in range(b)], dim=0)
Replacement:
1. bmm for batched color matrix multiply
2. broadcast pow for gamma
"""
mul, add = self.local_net(img_low)
img_high = (img_low.mul(mul)).add(add)
if not self.with_global:
return mul, add, img_high
gamma, color = self.global_net(img_low)
# img_high: (B, C, H, W) β†’ (B, H, W, C) β†’ (B, H*W, C)
b, c, h, w = img_high.shape
img_high = img_high.permute(0, 2, 3, 1).reshape(b, h * w, c)
# Batched color matrix: (B, H*W, 3) @ (B, 3, 3) β†’ (B, H*W, 3)
# color is (B, 3, 3), we need img @ color^T for each batch element
color_t = color.permute(0, 2, 1) # (B, 3, 3)
img_high = torch.bmm(img_high, color_t)
img_high = torch.clamp(img_high, 1e-8, 1.0)
# Reshape back to (B, H, W, C) for broadcast pow
img_high = img_high.view(b, h, w, c)
# gamma is (B, 1) β€” reshape to (B, 1, 1, 1) for broadcast
gamma_broadcast = gamma.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1, 1)
img_high = img_high ** gamma_broadcast
# (B, H, W, C) β†’ (B, C, H, W)
img_high = img_high.permute(0, 3, 1, 2)
return mul, add, img_high
def _patched_aff_channel_forward(self, x):
"""Replace tensordot with matmul in Aff_channel for ONNX compatibility.
Original: torch.tensordot(x, self.color, dims=[[-1], [-1]])
Replacement: torch.matmul(x, self.color.T)
"""
if self.channel_first:
x1 = torch.matmul(x, self.color.permute(1, 0))
x2 = x1 * self.alpha + self.beta
else:
x1 = x * self.alpha + self.beta
x2 = torch.matmul(x1, self.color.permute(1, 0))
return x2
# ===========================================================================
# Fallback patches (not needed for current export, documented for reference)
# ===========================================================================
# --- Fallback: query_Attention expand ---
# If export fails on expand in global attention (global_net.py):
#
# def _patched_query_attention_forward(self, x):
# B, N, C = x.shape
# # Original: self.q.expand(B, -1, -1) -- can fail with dynamic batch
# # Fix: use repeat which traces cleanly
# q = self.q.repeat(B, 1, 1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# ... rest of forward unchanged ...
#
# from model.global_net import query_Attention
# query_Attention.forward = _patched_query_attention_forward
# --- Fallback: gamma power operator ---
# If ** operator traces incorrectly for broadcast shapes:
#
# Replace in _patched_forward:
# img_high = torch.pow(torch.clamp(img_high, 1e-8), gamma)
# With:
# img_high = torch.exp(torch.log(torch.clamp(img_high, 1e-8)) * gamma)
# ===========================================================================
# Checkpoint configurations
# ===========================================================================
CHECKPOINTS = {
"exposure": {
"path": IAT_ROOT / "best_Epoch_exposure.pth",
"model_kwargs": {"type": "exp"},
"description": "Exposure correction",
},
"lol_v1": {
"path": IAT_ROOT / "best_Epoch_lol_v1.pth",
"model_kwargs": {"type": "lol"},
"description": "LOL v1 low-light enhancement",
},
"lol_v2": {
"path": IAT_ROOT / "best_Epoch_lol.pth",
"model_kwargs": {"type": "lol"},
"description": "LOL v2 low-light enhancement",
},
}
VERIFICATION_RESOLUTIONS = [
(256, 256),
(512, 512),
(768, 1024), # H, W β€” non-square
]
# ===========================================================================
# Apply patches
# ===========================================================================
def apply_patches():
"""Monkey-patch IAT classes at runtime. Does not modify source files."""
# Patch 1 & 2: IAT.apply_color and IAT.forward
IAT.apply_color = _patched_apply_color
IAT.forward = _patched_forward
# Patch 3 (fallback β€” needed for Aff_channel tensordot tracing):
Aff_channel.forward = _patched_aff_channel_forward
print("[PATCH] IAT.apply_color: tensordot -> matmul")
print("[PATCH] IAT.forward: for-loop -> vectorized bmm + broadcast pow")
print("[PATCH] Aff_channel.forward: tensordot -> matmul")
# ===========================================================================
# Export
# ===========================================================================
def load_model(name: str) -> nn.Module:
"""Load an IAT model from checkpoint."""
cfg = CHECKPOINTS[name]
model = IAT(in_dim=3, with_global=True, **cfg["model_kwargs"])
state_dict = torch.load(str(cfg["path"]), map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
model.train(False)
return model
def export_onnx(model: nn.Module, output_path: Path, opset: int) -> None:
"""Export a single IAT model to ONNX."""
dummy_input = torch.randn(1, 3, 256, 256)
torch.onnx.export(
model,
(dummy_input,),
str(output_path),
opset_version=opset,
input_names=["input"],
output_names=["mul", "add", "enhanced"],
dynamic_axes={
"input": {0: "batch", 2: "height", 3: "width"},
"mul": {0: "batch", 2: "height", 3: "width"},
"add": {0: "batch", 2: "height", 3: "width"},
"enhanced": {0: "batch", 2: "height", 3: "width"},
},
)
def verify_onnx(model: nn.Module, onnx_path: Path) -> bool:
"""Numerical verification of ONNX vs PyTorch at multiple resolutions."""
import onnxruntime as ort
session = ort.InferenceSession(
str(onnx_path),
providers=["CPUExecutionProvider"],
)
all_ok = True
for h, w in VERIFICATION_RESOLUTIONS:
dummy = torch.randn(1, 3, h, w)
# PyTorch reference
with torch.no_grad():
pt_mul, pt_add, pt_enhanced = model(dummy)
# ONNX Runtime
ort_inputs = {"input": dummy.numpy()}
ort_mul, ort_add, ort_enhanced = session.run(None, ort_inputs)
# Compare enhanced output (the one that matters most)
for name, pt_out, ort_out in [
("mul", pt_mul, ort_mul),
("add", pt_add, ort_add),
("enhanced", pt_enhanced, ort_enhanced),
]:
max_diff = np.max(np.abs(pt_out.numpy() - ort_out))
if max_diff < 1e-5:
status = "OK"
symbol = "+"
elif max_diff < 1e-3:
status = "WARN"
symbol = "~"
else:
status = "FAIL"
symbol = "X"
print(f" [{symbol}] {h}x{w} {name:10s} max_diff={max_diff:.2e} [{status}]")
if max_diff >= 1e-3:
print(f" FAIL: max abs diff {max_diff:.6f} >= 1e-3")
all_ok = False
return all_ok
# ===========================================================================
# Main
# ===========================================================================
def main():
parser = argparse.ArgumentParser(description="Export IAT checkpoints to ONNX")
parser.add_argument(
"--checkpoints",
type=str,
default="all",
choices=["all", "exposure", "lol_v1", "lol_v2"],
help="Which checkpoint(s) to export",
)
parser.add_argument(
"--output-dir",
type=str,
default=str(Path(__file__).parent / "outputs"),
help="Directory for exported ONNX files",
)
parser.add_argument(
"--opset",
type=int,
default=17,
help="ONNX opset version",
)
args = parser.parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Determine which checkpoints to export
if args.checkpoints == "all":
names = list(CHECKPOINTS.keys())
else:
names = [args.checkpoints]
# Apply monkey-patches
print("=" * 60)
print("Applying ONNX-compatibility patches...")
print("=" * 60)
apply_patches()
print()
results = {}
for name in names:
cfg = CHECKPOINTS[name]
onnx_path = output_dir / f"iat_{name}.onnx"
print("=" * 60)
print(f"Exporting: {name} ({cfg['description']})")
print(f" Checkpoint: {cfg['path']}")
print(f" Output: {onnx_path}")
print("=" * 60)
# Check checkpoint exists
if not cfg["path"].exists():
print(f" SKIP: checkpoint not found at {cfg['path']}")
results[name] = "SKIP"
continue
# Load
t0 = time.time()
model = load_model(name)
print(f" Loaded model in {time.time() - t0:.2f}s")
# Export
t0 = time.time()
export_onnx(model, onnx_path, args.opset)
export_time = time.time() - t0
file_size_mb = onnx_path.stat().st_size / (1024 * 1024)
print(f" Exported in {export_time:.2f}s ({file_size_mb:.1f} MB)")
# Verify
print(f" Verifying at {len(VERIFICATION_RESOLUTIONS)} resolutions...")
ok = verify_onnx(model, onnx_path)
results[name] = "PASS" if ok else "FAIL"
print()
# Summary
print("=" * 60)
print("SUMMARY")
print("=" * 60)
all_pass = True
for name, status in results.items():
if status == "PASS":
symbol = "+"
elif status == "SKIP":
symbol = "-"
else:
symbol = "X"
print(f" [{symbol}] {name}: {status}")
if status == "FAIL":
all_pass = False
if not all_pass:
print("\nSome exports FAILED numerical verification!")
sys.exit(1)
else:
print("\nAll exports passed!")
sys.exit(0)
if __name__ == "__main__":
main()