| |
| """ |
| Mixed-FP8 safetensors converter for Hunyuan-Foley checkpoints. |
| |
| - Converts selected .weight tensors to FP8 storage (E5M2 by default on pre-Hopper). |
| - Keeps math in FP16/BF16; this is a storage-only change in the file. |
| - Honors existing FP8 tensors in the input unless --recode-fp8 is set. |
| - Skips norms, biases, visual_proj.*, final_layer.* by design. |
| - Optional --aggressive converts modulation linears too. |
| |
| USAGE (simple): |
| python convert_fp8.py in.safetensors [out.safetensors] # out is optional |
| |
| USAGE (flags): |
| python convert_fp8.py in.safetensors out.safetensors --fp8 auto --aggressive |
| |
| Notes: |
| - “auto” picks FP8_E5M2 on SM < 90 (e.g., 3090), else FP8_E4M3FN. |
| - You can force a format: --fp8 e5m2 | e4m3fn |
| - Dry run: add --dry to print what would change without writing. |
| """ |
|
|
| import argparse |
| import re |
| from typing import Dict, Tuple |
| from pathlib import Path |
|
|
| import torch |
| from safetensors.torch import load_file, save_file |
|
|
|
|
| |
|
|
| |
| _DENY_SUBSTRINGS = ( |
| ".bias", ".norm", "q_norm.", "k_norm.", |
| "final_layer.", "visual_proj.", |
| ) |
|
|
| |
| _ALLOW_PATTERNS = tuple(re.compile(p) for p in ( |
| |
| r"^single_blocks\.\d+\.linear1\.weight$", |
| r"^single_blocks\.\d+\.linear2\.w[123]\.weight$", |
| r"^single_blocks\.\d+\.linear_qkv\.weight$", |
| r"^single_blocks\.\d+\.modulation\.linear\.weight$", |
|
|
| |
| r"^triple_blocks\.\d+\.audio_mlp\.fc[12]\.weight$", |
| r"^triple_blocks\.\d+\.v_cond_mlp\.fc[12]\.weight$", |
|
|
| |
| r"^triple_blocks\.\d+\.(audio_self_attn_qkv|v_cond_attn_qkv|text_cross_kv)\.weight$", |
| r"^triple_blocks\.\d+\.(audio_self_proj|v_cond_self_proj)\.weight$", |
|
|
| |
| |
|
|
| |
| r"^triple_blocks\.\d+\.(audio_mod|v_cond_mod)\.linear\.weight$", |
| )) |
|
|
|
|
| |
|
|
| def default_out_path(in_path: str, tgt_dtype: torch.dtype) -> str: |
| """<in>_fp8_<e5m2|e4m3fn>.safetensors (idempotent if already suffixed).""" |
| suffix = "e5m2" if tgt_dtype == torch.float8_e5m2 else "e4m3fn" |
| p = Path(in_path) |
| stem = re.sub(r"_fp8_e(5m2|4m3fn)$", "", p.stem) |
| ext = p.suffix or ".safetensors" |
| return str(p.with_name(f"{stem}_fp8_{suffix}{ext}")) |
|
|
|
|
| def pick_fp8_dtype(fp8_mode: str) -> torch.dtype: |
| """Pick target FP8 dtype.""" |
| m = fp8_mode.lower() |
| if m == "e5m2": |
| return torch.float8_e5m2 |
| if m == "e4m3fn": |
| return torch.float8_e4m3fn |
| |
| try: |
| major, _ = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0) |
| except Exception: |
| major = 0 |
| return torch.float8_e5m2 if major < 9 else torch.float8_e4m3fn |
|
|
|
|
| def bytes_of(t: torch.Tensor) -> int: |
| """Size in bytes (FP8=1 byte/elt).""" |
| if t.dtype in (torch.float8_e5m2, torch.float8_e4m3fn): |
| return t.numel() * 1 |
| return t.numel() * t.element_size() |
|
|
|
|
| def human_gb(nbytes: int) -> float: |
| return nbytes / (1024 ** 3) |
|
|
|
|
| def _is_denied(name: str) -> bool: |
| return any(tok in name for tok in _DENY_SUBSTRINGS) |
|
|
|
|
| def should_convert_to_fp8(name: str, aggressive: bool) -> bool: |
| """Match names for conversion, with modulation linears gated by --aggressive.""" |
| if not name.endswith(".weight"): |
| return False |
| if _is_denied(name): |
| return False |
|
|
| for pat in _ALLOW_PATTERNS: |
| if pat.search(name): |
| |
| if ( |
| ".modulation.linear.weight" in name |
| or ".audio_mod.linear.weight" in name |
| or ".v_cond_mod.linear.weight" in name |
| ): |
| return aggressive |
| return True |
| return False |
|
|
|
|
| |
|
|
| def convert_state_dict( |
| sd: Dict[str, torch.Tensor], |
| fp8_mode: str = "auto", |
| aggressive: bool = False, |
| recode_fp8: bool = False, |
| ) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]: |
| """ |
| Convert selected weights to FP8 storage according to the policy. |
| Honors existing FP8 unless recode_fp8=True. |
| Returns (new_sd, stats) with byte counts. |
| """ |
| tgt_dtype = pick_fp8_dtype(fp8_mode) |
| out: Dict[str, torch.Tensor] = {} |
| stats = { |
| "total_before": 0, |
| "total_after": 0, |
| "converted_count": 0, |
| "kept_fp8_count": 0, |
| "skipped_count": 0, |
| } |
|
|
| for name, tensor in sd.items(): |
| before = bytes_of(tensor) |
| stats["total_before"] += before |
|
|
| |
| if tensor.dtype in (torch.float8_e5m2, torch.float8_e4m3fn): |
| if recode_fp8: |
| out[name] = tensor.to(dtype=tgt_dtype) |
| stats["converted_count"] += 1 |
| else: |
| out[name] = tensor |
| stats["kept_fp8_count"] += 1 |
| stats["total_after"] += bytes_of(out[name]) |
| continue |
|
|
| |
| if should_convert_to_fp8(name, aggressive): |
| out[name] = tensor.to(dtype=tgt_dtype) |
| stats["converted_count"] += 1 |
| else: |
| out[name] = tensor |
| stats["skipped_count"] += 1 |
|
|
| stats["total_after"] += bytes_of(out[name]) |
|
|
| return out, stats |
|
|
|
|
| |
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description="Convert selected weights in a safetensors file to FP8 storage.") |
| p.add_argument("in_path", help="Input .safetensors") |
| p.add_argument("out_path", nargs="?", help="Output .safetensors (optional)") |
| p.add_argument("--fp8", choices=["auto", "e5m2", "e4m3fn"], default="auto", |
| help='Target FP8 storage dtype: "auto" (default), "e5m2", or "e4m3fn"') |
| p.add_argument("--aggressive", action="store_true", |
| help="Also convert modulation linears (audio_mod/v_cond_mod + single modulation.linear).") |
| p.add_argument("--recode-fp8", action="store_true", |
| help="Re-encode existing FP8 tensors to the chosen target dtype.") |
| p.add_argument("--dry", action="store_true", |
| help="Dry run: report only; do not write output file.") |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| print(f"[load] {args.in_path}") |
| sd = load_file(args.in_path) |
|
|
| tgt = pick_fp8_dtype(args.fp8) |
| if not args.out_path: |
| args.out_path = default_out_path(args.in_path, tgt) |
| print(f"[auto-out] {args.out_path}") |
|
|
| print(f"[policy] fp8_mode={args.fp8} -> {str(tgt).replace('torch.','')}, " |
| f"aggressive={args.aggressive}, recode_fp8={args.recode_fp8}") |
|
|
| new_sd, stats = convert_state_dict( |
| sd, |
| fp8_mode=args.fp8, |
| aggressive=args.aggressive, |
| recode_fp8=args.recode_fp8, |
| ) |
|
|
| saved = stats["total_before"] - stats["total_after"] |
| print(f"[stats] tensors: {len(sd)}") |
| print(f"[stats] converted: {stats['converted_count']} | kept_fp8: {stats['kept_fp8_count']} " |
| f"| skipped: {stats['skipped_count']}") |
| print(f"[bytes] before={human_gb(stats['total_before']):.3f} GiB | " |
| f"after={human_gb(stats['total_after']):.3f} GiB | saved={human_gb(saved):.3f} GiB") |
|
|
| if args.dry: |
| print("[dry] no file written") |
| return |
|
|
| print(f"[save] {args.out_path}") |
| save_file(new_sd, args.out_path) |
| print("[done]") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|