MiniMax-M2.7-NVFP4-GB10-AC / quantize-ac-protected.py
saricles's picture
Add sanitized quantization recipe
9b860c0 verified
#!/usr/bin/env python3
"""Protected NVFP4 quantization for MiniMax-M2.7 (agentic + coder calibration).
Recipe used to produce `MiniMax-M2.7-NVFP4-GB10-AC`: a 7-dataset agentic + coder
calibration (896 samples queued, 888 tokenized @ 49,152 max-seq) with per-sample
OOM-defer protection and amax-only checkpointing. Designed to run on Hugging
Face Jobs or any equivalent GPU host with access to the BF16 source model.
Three phases dispatched by PHASE env var. The happy path is a single Phase A
invocation on a GPU flavor large enough to hold the 230B BF16 model:
PHASE=A — Calibration with per-sample OOM defer, checkpoint every N samples
(amax-only format, ~60 MB each — NOT a full state_dict, which would
be ~460 GB). At end, populate starved-expert amax from weights and
write phase-a.result sentinel. On status=complete, export and
publish inline while the model is still GPU-loaded.
PHASE=B — Fallback only: resume from the latest good checkpoint on a
larger-memory GPU flavor; process deferred samples and any
remaining starved experts, save final state, inline export.
PHASE=C — Manual recovery: restore from saved checkpoint and re-export if
Phase A/B completed calibration but crashed during export. Requires
a GPU flavor large enough to host the BF16 model.
Key properties:
- All calibration samples contribute (no skipping on OOM — deferred instead).
- Amax-only checkpoints via modelopt_state + `_amax` buffers, not mto.save().
- Two-phase bucket commit: state-N.tar then state-N.ok marker with sha256,
so a torn upload never masquerades as a valid checkpoint.
- MaxCalibrator._calib_amax is SEEDED at restore so resume max-merges rather
than overwrites prior-phase amax (the non-obvious gotcha that breaks naive
resumes).
- Starved-expert amax (quantizers that never received traffic) is populated
via the enable_stats_collection(wq); wq(weight); finish_stats_collection
pattern, producing the correct per-channel shape.
- Wallclock watchdog thread lets Phase A exit cleanly before budget burn.
Env vars:
PHASE = A | B | C (required)
INPUT_DIR = path to BF16 source model (default: /model)
OUTPUT_DIR = export target path (default: /out)
TARGET_REPO_ID = HF Hub model repo to publish to (required for Phase A
inline export; Phase B/C also publish to it)
BUCKET_REPO_ID = HF Hub dataset repo used as checkpoint workspace
(required for all phases)
BUCKET_PREFIX = path prefix inside the bucket repo (default: runs/ac/)
NUM_CALIB_PER_DS = calibration samples per dataset (default: 128)
MAX_SEQ = max sequence length, tokens (default: 49152 = 48K)
CKPT_EVERY = checkpoint cadence in completed samples (default: 50)
WALLCLOCK_BUDGET_S = soft deadline before Phase A starts deferring (default:
21600 = 6h; hard watchdog = budget + 1800s)
STARVED_EXPERT_PCT_ABORT = abort if >this%% of quantizers ended starved
(default: 1.0)
The sanity gate in `_config_sanity_gate` is model-specific (MiniMax-M2.7: 256
experts × 62 layers × 125 shards) and will need adjustment for a different MoE
architecture. Everything else is model-agnostic.
"""
from __future__ import annotations
# Conv1D shim (transformers 4.57+ moved it out of modeling_utils).
import transformers.modeling_utils # noqa: E402
if not hasattr(transformers.modeling_utils, "Conv1D"):
from transformers.pytorch_utils import Conv1D as _Conv1D # noqa: E402
transformers.modeling_utils.Conv1D = _Conv1D
import gc
import hashlib
import json
import os
import random
import shutil
import sys
import tempfile
import threading
import time
import traceback
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import HfApi
import modelopt.torch.quantization as mtq
import modelopt.torch.opt as mto
from modelopt.torch.opt.conversion import modelopt_state, restore_from_modelopt_state
from modelopt.torch.quantization.model_calib import (
enable_stats_collection, finish_stats_collection,
)
from modelopt.torch.export import export_hf_checkpoint
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
# ---------- Config ---------------------------------------------------------
PHASE = os.environ["PHASE"].upper()
assert PHASE in {"A", "B", "C"}, f"PHASE must be A|B|C, got {PHASE!r}"
INPUT_DIR = Path(os.environ.get("INPUT_DIR", "/model"))
OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "/out"))
BUCKET_REPO_ID = os.environ.get("BUCKET_REPO_ID", "") # e.g. "<your-username>/quant-lab"
BUCKET_PREFIX = os.environ.get("BUCKET_PREFIX", "runs/ac/").rstrip("/") + "/"
TARGET_REPO_ID = os.environ.get("TARGET_REPO_ID", "") # e.g. "<your-username>/MiniMax-M2.7-NVFP4-AC"
# Required-env validation: fail fast with a clear message rather than deep in
# an HfApi call. All phases need the bucket; Phase A also needs the target for
# its inline export on success.
if not BUCKET_REPO_ID:
raise RuntimeError(
"BUCKET_REPO_ID env var is required (HF Hub dataset repo used for "
"checkpoint workspace). Set it to e.g. '<your-username>/quant-lab'."
)
if PHASE == "A" and not TARGET_REPO_ID:
raise RuntimeError(
"TARGET_REPO_ID env var is required for Phase A (inline export target). "
"Set it to e.g. '<your-username>/MiniMax-M2.7-NVFP4-AC'."
)
NUM_CALIB_PER_DS = int(os.environ.get("NUM_CALIB_PER_DS", "128"))
MAX_SEQ = int(os.environ.get("MAX_SEQ", "49152"))
CKPT_EVERY = int(os.environ.get("CKPT_EVERY", "50"))
WALLCLOCK_BUDGET_S = int(os.environ.get("WALLCLOCK_BUDGET_S", "21600"))
# 7-dataset agentic + coder calibration mix. Code + tool + math + science +
# SWE-agent trajectories for the target workloads, plus a general-chat anchor
# (ultrachat) to keep plain conversational activations in-range.
#
# Schemas handled by _extract_text: messages (list OR JSON string), trajectory,
# instruction+output, query+answers.
CALIB_DATASETS = [
("theblackcat102/evol-codealpaca-v1", {"split": "train"}),
("Salesforce/xlam-function-calling-60k", {"split": "train"}),
("open-r1/Mixture-of-Thoughts", {"name": "code", "split": "train"}),
("open-r1/Mixture-of-Thoughts", {"name": "math", "split": "train"}),
("open-r1/Mixture-of-Thoughts", {"name": "science", "split": "train"}),
("SWE-bench/SWE-smith-trajectories", {"split": "tool"}),
("HuggingFaceH4/ultrachat_200k", {"split": "train_sft"}),
]
DEVICE = "cuda"
STARVED_EXPERT_PCT_ABORT = float(os.environ.get("STARVED_EXPERT_PCT_ABORT", "1.0"))
def _log(msg: str) -> None:
ts = time.strftime("%H:%M:%S")
print(f"[{ts}] [phase-{PHASE}] {msg}", flush=True)
def _free_gpu() -> None:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# ---------- Wallclock watchdog ---------------------------------------------
class WallclockWatchdog:
"""Monotonic-clock watchdog. Main loop reads .expired between samples.
Belt-and-suspenders: a second thread triggers `os._exit(43)` if budget +
hard_cap_margin_s is exceeded, in case the main loop's expiry check path
is broken (e.g., loop hangs in a forward pass). This is the budget-safety
net against an otherwise $400 worst-case container-timeout burn.
"""
def __init__(self, budget_s: int, name: str = "phase-a",
hard_cap_margin_s: int = 1800):
self._budget_s = budget_s
self._hard_cap_s = budget_s + hard_cap_margin_s
self._start = time.monotonic()
self._stop = threading.Event()
self._expired = threading.Event()
self._name = name
self._soft_thread = threading.Thread(target=self._run_soft, daemon=True)
self._hard_thread = threading.Thread(target=self._run_hard, daemon=True)
self._soft_thread.start()
self._hard_thread.start()
def _run_soft(self):
while not self._stop.wait(10):
if time.monotonic() - self._start > self._budget_s:
self._expired.set()
return
def _run_hard(self):
while not self._stop.wait(30):
if time.monotonic() - self._start > self._hard_cap_s:
# Main loop failed to exit on soft cap — kill the process to
# bound the HF Jobs bill. Phase A ckpts are durable; Phase B
# can pick up from the latest good one. Exit code 45 to
# distinguish hard-kill from graceful wallclock-cap (43).
print(f"[watchdog] HARD CAP {self._hard_cap_s}s exceeded — os._exit(45)",
flush=True)
os._exit(45)
@property
def expired(self) -> bool:
return self._expired.is_set()
def elapsed_s(self) -> int:
return int(time.monotonic() - self._start)
def shutdown(self):
self._stop.set()
# ---------- Amax-only checkpoint I/O ---------------------------------------
def _amax_buffers_cpu(model: torch.nn.Module) -> dict[str, torch.Tensor]:
# `.cpu()` from a non-CPU device already copies; `.clone()` would be a
# redundant second copy of thousands of small buffers.
return {
n: b.detach().cpu()
for n, b in model.named_buffers()
if n.endswith("._amax") and b is not None
}
_CACHED_MODELOPT_STATE: dict | None = None
def _get_cached_modelopt_state(model: torch.nn.Module) -> dict:
"""modelopt_state(model) walks the full module tree — cache after first call
since architecture is immutable post-install.
"""
global _CACHED_MODELOPT_STATE
if _CACHED_MODELOPT_STATE is None:
_CACHED_MODELOPT_STATE = modelopt_state(model)
return _CACHED_MODELOPT_STATE
def _save_ckpt_local(model: torch.nn.Module, path: Path, samples_done: int,
deferred: list, starved: list | None = None,
phase: str = PHASE,
deferred_batches: list | None = None) -> int:
"""Write amax-only + bookkeeping to `path`. Returns bytes written.
`deferred_batches` (optional): list of (idx, tokenized_tensor) pairs for the
deferred samples, so Phase B can replay the EXACT same tensors without
rebuilding the calibration set (avoids non-determinism if HF datasets drift).
"""
state = {
"modelopt_state": _get_cached_modelopt_state(model),
"amax_buffers": _amax_buffers_cpu(model),
"samples_completed": samples_done,
"deferred_samples": deferred,
"deferred_batches": deferred_batches,
"starved_experts": starved,
"phase": phase,
}
torch.save(state, path)
return path.stat().st_size
def _sha256_of(path: Path) -> str:
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
h.update(chunk)
return h.hexdigest()
def _bucket_commit_ckpt(api: HfApi, local_path: Path, bucket_path: str,
sha256_hex: str) -> None:
"""Two-phase commit: upload state-N.tar, then upload state-N.ok marker.
If the .ok upload fails after 3 retries, DELETE the orphan tar so
_find_latest_good_ckpt won't resurrect a torn state.
"""
api.upload_file(
path_or_fileobj=str(local_path),
path_in_repo=bucket_path,
repo_id=BUCKET_REPO_ID,
repo_type="dataset",
commit_message=f"ckpt upload {bucket_path}",
)
ok_path = bucket_path.rsplit(".tar", 1)[0] + ".ok"
marker = json.dumps({
"sha256": sha256_hex,
"timestamp": time.time(),
"size": local_path.stat().st_size,
}).encode()
last_err: Exception | None = None
for attempt in range(3):
try:
api.upload_file(
path_or_fileobj=marker,
path_in_repo=ok_path,
repo_id=BUCKET_REPO_ID,
repo_type="dataset",
commit_message=f"ckpt marker {ok_path}",
)
return
except Exception as e:
last_err = e
_log(f"ckpt .ok upload attempt {attempt + 1}/3 failed: {e}; retrying")
time.sleep(2 ** attempt)
# Marker terminally failed — delete the orphan tar to avoid torn-state confusion
_log(f"ckpt .ok TERMINAL FAIL; deleting orphan {bucket_path}")
try:
api.delete_file(
path_in_repo=bucket_path,
repo_id=BUCKET_REPO_ID,
repo_type="dataset",
commit_message=f"orphan cleanup after .ok failure on {bucket_path}",
)
except Exception as e:
_log(f"orphan delete also failed: {e}")
raise AssertionError(f"ckpt .ok upload failed after retries: {last_err}")
def _find_latest_good_ckpt(api: HfApi) -> tuple[int, str] | None:
"""List bucket, return (N, path) of highest state-N where BOTH tar and .ok
are present and the .ok's sha256 matches the tar's sha256.
"""
files = api.list_repo_files(repo_id=BUCKET_REPO_ID, repo_type="dataset")
tars: dict[int, str] = {}
oks: dict[int, str] = {}
for f in files:
if not f.startswith(BUCKET_PREFIX):
continue
base = f.rsplit("/", 1)[-1]
if base.startswith("state-") and base.endswith(".tar"):
try:
n = int(base[len("state-"):-len(".tar")])
tars[n] = f
except ValueError:
continue
elif base.startswith("state-") and base.endswith(".ok"):
try:
n = int(base[len("state-"):-len(".ok")])
oks[n] = f
except ValueError:
continue
candidates = sorted(set(tars) & set(oks), reverse=True)
for n in candidates:
return n, tars[n]
return None
def _download_ckpt(api: HfApi, bucket_path: str, local_path: Path) -> None:
"""Download a specific file from the workspace bucket."""
downloaded = api.hf_hub_download(
repo_id=BUCKET_REPO_ID,
repo_type="dataset",
filename=bucket_path,
local_dir=str(local_path.parent),
)
shutil.move(downloaded, local_path)
def _restore_ckpt_into_model(model: torch.nn.Module, ckpt_path: Path) -> dict:
"""Restore modelopt state + amax buffers + seed calibrator _calib_amax.
CRITICAL device handling:
- `torch.load(map_location="cpu")` puts the saved amax on CPU
- After `restore_from_modelopt_state`, quantizers are inserted but have NO
_amax buffer yet (it's lazily registered on first assignment)
- `tq.amax = val` path-1 of the setter (tensor_quantizer.py:224-225) calls
`self.register_buffer("_amax", value.clone().detach())` — registering the
buffer on value's device, which is CPU. The model's other params are on
CUDA (via device_map="auto"). First forward → device mismatch.
- Fix: determine each quantizer's target device from its PARENT module's
parameters (already on the correct GPU) and move `val` there before the
setter call. Then the setter registers the buffer on the right device.
Also seeds `_calibrator._calib_amax` on the same target device so
post-restore calibration max-merges rather than overwrites the restored amax.
"""
state = torch.load(ckpt_path, map_location="cpu", weights_only=False)
restore_from_modelopt_state(model, state["modelopt_state"])
missing: list[str] = []
for name, val in state["amax_buffers"].items():
assert name.endswith("._amax"), f"unexpected buffer name: {name}"
q_path = name[: -len("._amax")]
try:
tq = model.get_submodule(q_path)
except AttributeError:
missing.append(name)
continue
current = getattr(tq, "_amax", None)
if current is not None and current.shape != val.shape:
raise AssertionError(
f"shape mismatch at {name}: saved={val.shape} current={current.shape}"
)
# Resolve target device from parent module (device_map="auto" shards
# layers across GPUs). Parent's own params are already on the right GPU.
parent_path = q_path.rsplit(".", 1)[0] if "." in q_path else ""
parent = model.get_submodule(parent_path) if parent_path else model
target_device = next(
(p.device for p in parent.parameters()),
next((b.device for b in parent.buffers()), torch.device("cpu")),
)
val_on_device = val.detach().to(target_device)
tq.amax = val_on_device
calibrator = getattr(tq, "_calibrator", None)
if calibrator is not None and hasattr(calibrator, "_calib_amax"):
calibrator._calib_amax = val_on_device.clone()
if missing:
raise AssertionError(
f"Restore missed {len(missing)}/{len(state['amax_buffers'])} amax buffers. "
f"Sample: {missing[:3]}"
)
return state
# ---------- Calibration set construction -----------------------------------
def _extract_content_text(content) -> str:
"""Flatten a message `content` field into plain text.
Handles: string, list-of-dicts (OpenAI/Anthropic multimodal + tool-use
blocks), list-of-strings, nested content, None. Ported from the proven
extractor in quantize-nvfp4-gb10-AC.py (the SWE-smith tool-use messages
carry text inside `{"type": "text", "text": "..."}` blocks, and content
can nest one more level via `{"content": [...]}`).
"""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict):
if isinstance(block.get("text"), str):
parts.append(block["text"])
elif "content" in block:
sub = _extract_content_text(block["content"])
if sub:
parts.append(sub)
return "\n".join(parts)
return str(content)
def _extract_text(row: dict) -> str | None:
"""Extract one calibration text per the proven 7 datasets' schemas:
messages (ultrachat, Mixture-of-Thoughts × 3, SWE-smith — list OR JSON str)
trajectory (defensive — earlier SWE-smith variants stored here)
instruction+output (evol-codealpaca)
query+answers (xlam-function-calling)
"""
if "messages" in row:
msgs = row["messages"]
if isinstance(msgs, str):
try:
msgs = json.loads(msgs)
except Exception:
msgs = None
if isinstance(msgs, list):
parts = []
for m in msgs:
if isinstance(m, dict):
sub = _extract_content_text(m.get("content"))
if sub and sub.strip():
parts.append(sub)
if parts:
return "\n".join(parts)
if "trajectory" in row:
traj = row["trajectory"]
if isinstance(traj, str):
try:
traj = json.loads(traj)
except Exception:
traj = None
if isinstance(traj, list):
parts = []
for m in traj:
if isinstance(m, dict):
sub = _extract_content_text(m.get("content"))
if sub and sub.strip():
parts.append(sub)
if parts:
return "\n".join(parts)
if "instruction" in row:
out = row.get("output", "")
instr = row.get("instruction", "")
text = (instr + "\n" + str(out)) if out else str(instr)
if text.strip():
return text
if "query" in row:
ans = row.get("answers", row.get("answer", ""))
q = row.get("query", "")
text = (str(q) + "\n" + str(ans)) if ans else str(q)
if text.strip():
return text
return None
def _probe_calibration_datasets() -> None:
"""Pre-model-load schema probe.
Loads 1 row from each of the 7 CALIB_DATASETS and verifies that
`_extract_text` returns non-empty. Aborts Phase A/B BEFORE the 5-10min
model download if any dataset's schema has drifted. Total probe cost
on HF Jobs a100x8: ~$0.05 (vs ~$3 if we fail post-model-load).
Error message includes the failing dataset's first-row keys so a fix
can be made without re-probing manually.
"""
from datasets import load_dataset
_log(f"probing {len(CALIB_DATASETS)} calibration datasets (schema check)")
failures: list[str] = []
for name, kwargs in CALIB_DATASETS:
# `split="x[:1]"` slicing is a sized-dataset feature — incompatible
# with streaming. Guard here so a future streaming entry fails with a
# clear message instead of a cryptic datasets error.
assert not kwargs.get("streaming"), (
f"probe does not support streaming=True ({name}); remove streaming or "
f"skip the probe for this entry"
)
probe_kwargs = dict(kwargs)
probe_kwargs["split"] = f"{kwargs['split']}[:1]"
try:
ds = load_dataset(name, **probe_kwargs)
except Exception as e:
failures.append(
f"{name} (kwargs={kwargs}): load failed — {type(e).__name__}: {e}"
)
continue
if len(ds) == 0:
failures.append(f"{name} (kwargs={kwargs}): empty split")
continue
row = ds[0]
row_keys = list(row.keys())[:10]
text = _extract_text(row)
if not text or not text.strip():
failures.append(
f"{name} (kwargs={kwargs}): _extract_text returned empty; "
f"first-row keys={row_keys}"
)
continue
_log(f" probe OK: {name} ({len(text)} chars, keys={row_keys[:6]})")
if failures:
msg = "calibration dataset probe FAILED:\n " + "\n ".join(failures)
_log(msg)
raise RuntimeError(msg)
_log(f"probe OK for all {len(CALIB_DATASETS)} datasets")
def _load_calibration_texts(tokenizer) -> list[str]:
"""Load NUM_CALIB_PER_DS texts per dataset across CALIB_DATASETS. Abort on
any dataset that yields <50% of requested texts (matches original STRICT
coverage invariant).
Uses `split[:NUM_CALIB_PER_DS]` slicing so load_dataset only pulls the rows
we actually need — relevant for tulu-3-sft-mixture (939K rows) and
ultrachat_200k (200K rows).
"""
from datasets import load_dataset
texts: list[str] = []
failed: list[str] = []
for name, kwargs in CALIB_DATASETS:
_log(f"loading {name}")
load_kwargs = dict(kwargs)
load_kwargs["split"] = f"{kwargs['split']}[:{NUM_CALIB_PER_DS}]"
try:
ds = load_dataset(name, **load_kwargs)
except Exception as e:
_log(f" FAILED to load {name}: {e}")
failed.append(f"{name}: {e}")
continue
first_keys = list(ds[0].keys())[:10] if len(ds) else []
loaded_this_ds = 0
for row in ds:
text = _extract_text(row)
if text and len(text) > 32:
texts.append(text)
loaded_this_ds += 1
_log(f" extracted {loaded_this_ds} texts from {name}")
if loaded_this_ds < NUM_CALIB_PER_DS // 2:
failed.append(
f"{name} (only {loaded_this_ds}/{NUM_CALIB_PER_DS}); first_keys={first_keys}"
)
assert not failed, f"calibration coverage failure: {failed}"
_log(f"total calibration texts: {len(texts)}")
return texts
def _tokenize_and_sort(tokenizer, texts: list[str]) -> list[torch.Tensor]:
"""Tokenize each text, filter too-short, sort ascending by length.
Sorting shortest-first keeps early samples cheap so we hit the first
CKPT_EVERY boundary fast (checkpoint exists before big samples risk OOM).
"""
tokenized = []
for t in texts:
ids = tokenizer(t, return_tensors="pt", truncation=True, max_length=MAX_SEQ).input_ids
if ids.shape[1] < 32:
continue
tokenized.append(ids)
tokenized.sort(key=lambda x: x.shape[1])
return tokenized
# ---------- Quantizer setup + calibration ----------------------------------
def _assert_config_is_pure_max(cfg: dict) -> None:
"""ISC-1/2/3: refuse anything other than NVFP4_DEFAULT_CFG + algorithm=max.
Identity check catches variants like NVFP4_KV_CFG that would also pass the
algorithm=max check but change the quantization behavior.
"""
assert cfg is mtq.NVFP4_DEFAULT_CFG, (
f"must use the exact NVFP4_DEFAULT_CFG object, got {cfg!r}"
)
algo = cfg.get("algorithm")
assert algo in (None, "max"), (
f"NVFP4 config must use algorithm=max/None for split-sample safety; got {algo!r}"
)
def _install_quantizers_with_dummy(model, dummy_batches):
"""Install quantizers via mtq.quantize with a small dummy forward_loop.
Matches original's proven pattern. Ignore list is applied post-quantize.
"""
cfg = mtq.NVFP4_DEFAULT_CFG
_assert_config_is_pure_max(cfg)
def forward_loop(m):
m.eval()
with torch.no_grad():
for ids in dummy_batches:
m(input_ids=ids.to(next(m.parameters()).device))
return mtq.quantize(model, cfg, forward_loop)
def _apply_ignore_list(model) -> None:
"""Disable quantizers on the GB10 ignore list (lm_head + router gate +
first/last layers + embeds). Matches original quantize-nvfp4-gb10-AC.py.
Note: we deliberately do NOT clear `_amax` on disabled quantizers. Stale
amax on a disabled quantizer is inert — the quantize/dequantize code path
short-circuits on `_disabled=True`. modelopt's state save/restore preserves
the disable flag (verified in 0.29.0 tensor_quantizer.py:1075 / 1127),
so disabled quantizers stay disabled through Phase A/B/C transitions.
"""
mtq.disable_quantizer(model, "lm_head*")
mtq.disable_quantizer(model, "*block_sparse_moe.gate")
mtq.disable_quantizer(model, "*embed_tokens*")
n_layers = getattr(model.config, "num_hidden_layers", None)
assert n_layers is not None, "cannot determine num_hidden_layers"
mtq.disable_quantizer(model, "*model.layers.0.*")
mtq.disable_quantizer(model, f"*model.layers.{n_layers - 1}.*")
n_disabled = sum(
1 for _, tq in model.named_modules()
if isinstance(tq, TensorQuantizer) and not getattr(tq, "is_enabled", True)
)
_log(f"ignore list applied (last layer={n_layers - 1}); {n_disabled} quantizers disabled")
def _populate_starved_experts(model) -> list[str]:
"""Rescue quantizers with None/zero amax by routing weight through the
calibration path. Produces correct per-channel amax shape (matches original
Phase 2.5 pattern).
"""
rescued: list[str] = []
for name, module in model.named_modules():
wq = getattr(module, "weight_quantizer", None)
if wq is None:
continue
if not getattr(wq, "is_enabled", True):
continue
amax = getattr(wq, "amax", None)
if amax is not None and not torch.all(amax == 0):
continue
with torch.no_grad():
if hasattr(wq, "reset_amax"):
wq.reset_amax()
enable_stats_collection(wq)
wq(module.weight)
finish_stats_collection(wq)
assert wq.amax is not None, f"starved rescue populate failed for {name}"
rescued.append(name)
return rescued
def _count_total_quantizers(model) -> int:
return sum(1 for _, m in model.named_modules() if isinstance(m, TensorQuantizer))
# ---------- Phase A: calibrate + ckpt + defer ------------------------------
def _config_sanity_gate() -> None:
"""Verify /model is actually MiniMax-M2.7 230B (256 experts × 62 layers × 125 shards).
Catches wrong-mount before burning calibration compute.
"""
cfg_path = INPUT_DIR / "config.json"
assert cfg_path.exists(), f"{cfg_path} not found — is /model mounted?"
cfg = json.loads(cfg_path.read_text())
assert cfg.get("architectures") == ["MiniMaxM2ForCausalLM"], (
f"arch mismatch: {cfg.get('architectures')} (expected MiniMaxM2ForCausalLM)"
)
assert cfg.get("num_local_experts") == 256, (
f"expert count mismatch: {cfg.get('num_local_experts')} (expected 256 for M2.7)"
)
assert cfg.get("num_hidden_layers") == 62, (
f"layer count mismatch: {cfg.get('num_hidden_layers')} (expected 62 for M2.7)"
)
shards = [f for f in os.listdir(INPUT_DIR)
if f.startswith("model-") and f.endswith(".safetensors")]
assert len(shards) == 125, f"shard count mismatch: {len(shards)} (expected 125)"
_log(f"config sanity gate OK: MiniMax-M2.7 230B, 256 experts × 62 layers × 125 shards")
def _phase_a_main() -> int:
"""Returns sentinel status string via exit: 0=complete, 42=deferred,
43=wallclock-cap, 44=starved-experts-excessive."""
_log(f"starting Phase A: input={INPUT_DIR} bucket={BUCKET_REPO_ID} ckpt_every={CKPT_EVERY}")
_log(f"wallclock budget = {WALLCLOCK_BUDGET_S}s, max_seq={MAX_SEQ}")
_config_sanity_gate()
# Schema probe BEFORE model download — fail in ~60s on dataset drift
# instead of ~10min into a100x8 compute. Cost of a false-pass probe is
# still lower than one post-model-load surprise.
_probe_calibration_datasets()
watchdog = WallclockWatchdog(WALLCLOCK_BUDGET_S, name="phase-a")
api = HfApi()
tokenizer = AutoTokenizer.from_pretrained(str(INPUT_DIR), trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
str(INPUT_DIR), torch_dtype=torch.bfloat16, device_map="auto",
trust_remote_code=True,
)
# Calibration set
texts = _load_calibration_texts(tokenizer)
batches = _tokenize_and_sort(tokenizer, texts)
_log(f"tokenized {len(batches)} samples, shortest={batches[0].shape[1]}, longest={batches[-1].shape[1]}")
# Install quantizers with a 1-sample dummy (the real calibration happens
# in the manual loop below, which we control for checkpointing).
dummy = batches[:1]
model = _install_quantizers_with_dummy(model, dummy)
_apply_ignore_list(model)
total_q = _count_total_quantizers(model)
_log(f"{total_q} TensorQuantizer modules installed post-ignore-list")
# Real calibration loop: per-sample try/except, checkpoint every N
deferred: list[dict] = []
deferred_batches_payload: list[dict] = [] # persist the exact tensors so Phase B is deterministic
completed = 0
consecutive_oom = 0
ckpt_counter = 0
last_ckpt_at_completed = 0
enable_stats_collection(model)
try:
for i, ids in enumerate(batches):
if watchdog.expired:
_log(f"WALLCLOCK EXPIRED at sample {i}/{len(batches)} — deferring remainder")
for j in range(i, len(batches)):
deferred.append({"idx": j, "reason": "wallclock-cap"})
deferred_batches_payload.append({"idx": j, "input_ids": batches[j].detach().cpu()})
break
try:
ids_dev = ids.to(next(model.parameters()).device)
with torch.no_grad():
model(input_ids=ids_dev)
completed += 1
consecutive_oom = 0
except torch.cuda.OutOfMemoryError:
_log(f"OOM at sample {i} (len={ids.shape[1]}) — deferring")
deferred.append({"idx": i, "reason": "OOM", "len": int(ids.shape[1])})
deferred_batches_payload.append({"idx": i, "input_ids": ids.detach().cpu()})
_free_gpu()
# On FIRST OOM of a new streak, opportunistically checkpoint so
# progress up to now is saved in case the next sample also OOMs.
if consecutive_oom == 0 and completed > last_ckpt_at_completed:
ckpt_counter += 1
_checkpoint_commit(api, model, completed, deferred, ckpt_counter,
deferred_batches_payload=deferred_batches_payload)
last_ckpt_at_completed = completed
consecutive_oom += 1
if consecutive_oom >= 2:
_log("2 consecutive OOMs — ABORTING Phase A (saving state)")
break
if (completed > 0) and (completed % CKPT_EVERY == 0) and completed != last_ckpt_at_completed:
ckpt_counter += 1
_checkpoint_commit(api, model, completed, deferred, ckpt_counter,
deferred_batches_payload=deferred_batches_payload)
last_ckpt_at_completed = completed
if (i + 1) % 25 == 0:
_log(f"progress: {i+1}/{len(batches)} (completed={completed}, deferred={len(deferred)})")
finally:
finish_stats_collection(model)
watchdog.shutdown()
# Starved-expert rescue
_log("populating starved experts from weights")
rescued = _populate_starved_experts(model)
starved_pct = 100.0 * len(rescued) / max(total_q, 1)
_log(f"starved rescue: {len(rescued)}/{total_q} ({starved_pct:.2f}%)")
# Final checkpoint — carries both the latest calibration state AND the
# rescue results. Bumps counter so Phase B sees the newest N.
ckpt_counter += 1
_checkpoint_commit(api, model, completed, deferred, ckpt_counter,
starved=rescued,
deferred_batches_payload=deferred_batches_payload)
# Determine status. Precedence: excessive-starved > wallclock > deferred > complete.
status = "complete"
if deferred:
status = "wallclock-cap" if any(d["reason"] == "wallclock-cap" for d in deferred) else "deferred"
if starved_pct > STARVED_EXPERT_PCT_ABORT:
status = "excessive-starved"
sentinel = {
"status": status,
"deferred_count": len(deferred),
"starved_count": len(rescued),
"total_quantizers": total_q,
"starved_pct": starved_pct,
"last_ckpt": ckpt_counter,
"elapsed_s": watchdog.elapsed_s(),
"completed": completed,
}
# INLINE EXPORT: if status=complete, export + publish RIGHT HERE while the
# model is loaded and the container has GPUs. This avoids needing a
# separate cpu-basic Phase C job (which couldn't fit the 460GB model
# anyway). If export fails, we still have the calibration ckpts in the
# bucket and can manually recover via PHASE=C on a GPU flavor.
if status == "complete":
try:
_log("Phase A status=complete — exporting + publishing inline")
info = _finalize_and_export(api, model, tokenizer, origin_phase="A")
sentinel.update(info)
sentinel["status"] = "complete-published"
status = "complete-published"
except Exception as e:
_log(f"inline export FAILED: {type(e).__name__}: {e}")
traceback.print_exc()
sentinel["export_error"] = f"{type(e).__name__}: {e}"
# Keep status=complete so launcher/operator can kick off PHASE=C
# recovery on a fresh GPU flavor.
_write_sentinel(api, "phase-a.result", sentinel)
_log(f"Phase A DONE — status={sentinel['status']} deferred={len(deferred)} starved={len(rescued)}")
if status == "excessive-starved":
return 44
if status == "wallclock-cap":
return 43
if status == "deferred":
return 42
return 0
def _checkpoint_commit(api: HfApi, model, completed: int, deferred: list,
ckpt_counter: int, starved: list | None = None,
deferred_batches_payload: list | None = None) -> None:
"""Save amax-only ckpt locally, upload tar + .ok marker."""
with tempfile.NamedTemporaryFile(suffix=".tar", delete=False) as tf:
local = Path(tf.name)
try:
size = _save_ckpt_local(model, local, completed, deferred,
starved=starved,
deferred_batches=deferred_batches_payload)
sha = _sha256_of(local)
bucket_tar = f"{BUCKET_PREFIX}state-{ckpt_counter}.tar"
_bucket_commit_ckpt(api, local, bucket_tar, sha)
_log(f"ckpt {ckpt_counter} committed ({size // 1024} KB, sha={sha[:10]}, deferred={len(deferred)})")
finally:
local.unlink(missing_ok=True)
def _write_sentinel(api: HfApi, name: str, payload: dict) -> None:
"""Upload a tiny JSON sentinel — written AFTER all ckpt uploads confirmed."""
data = json.dumps(payload, indent=2).encode()
api.upload_file(
path_or_fileobj=data,
path_in_repo=f"{BUCKET_PREFIX}{name}",
repo_id=BUCKET_REPO_ID,
repo_type="dataset",
commit_message=f"sentinel {name}",
)
# ---------- Phase B: restore + process deferred ----------------------------
def _phase_b_main() -> int:
"""Restore Phase A's latest good ckpt, process deferred samples + any
starved experts, save final state for Phase C export.
"""
_log("starting Phase B: h200x8 resume")
api = HfApi()
latest = _find_latest_good_ckpt(api)
assert latest is not None, "Phase B found no valid ckpt in bucket — aborting"
n, bucket_path = latest
_log(f"resuming from ckpt N={n} at {bucket_path}")
tokenizer = AutoTokenizer.from_pretrained(str(INPUT_DIR), trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
str(INPUT_DIR), torch_dtype=torch.bfloat16, device_map="auto",
trust_remote_code=True,
)
with tempfile.TemporaryDirectory() as td:
local = Path(td) / "ckpt.tar"
_download_ckpt(api, bucket_path, local)
state = _restore_ckpt_into_model(model, local)
deferred = state.get("deferred_samples", [])
deferred_batches = state.get("deferred_batches") or []
_log(f"restored state: completed={state.get('samples_completed', 0)}, deferred={len(deferred)}")
# Prefer replay from the exact tokenized tensors saved in Phase A's ckpt.
# Only fall back to dataset rebuild if the ckpt is from an older version
# that didn't persist batches.
if deferred_batches:
_log(f"using {len(deferred_batches)} deferred batches from ckpt (deterministic replay)")
replay = [(d["idx"], d["input_ids"]) for d in deferred_batches]
else:
_log("ckpt has no deferred_batches — rebuilding calibration set (non-deterministic risk)")
# Legacy path: Phase A may have run days earlier with different
# CALIB_DATASETS or schemas. Re-probe so drift fails in ~60s instead
# of wasting the h200x8 slot we already paid to allocate.
_probe_calibration_datasets()
texts = _load_calibration_texts(tokenizer)
batches = _tokenize_and_sort(tokenizer, texts)
replay = [(d["idx"], batches[d["idx"]]) for d in deferred if d["idx"] < len(batches)]
# Process deferred samples
enable_stats_collection(model)
processed = 0
still_deferred: list[dict] = []
try:
for idx, ids in replay:
try:
ids_dev = ids.to(next(model.parameters()).device)
with torch.no_grad():
model(input_ids=ids_dev)
processed += 1
except torch.cuda.OutOfMemoryError:
_log(f"Phase B STILL OOM on sample {idx} — escalate manually")
still_deferred.append({"idx": idx, "reason": "h200-OOM"})
_free_gpu()
finally:
finish_stats_collection(model)
# Any remaining starved experts get one more rescue pass
rescued = _populate_starved_experts(model)
_log(f"Phase B: processed {processed}/{len(deferred)} deferred, rescued {len(rescued)} starved")
# Save final state to bucket as state-FINAL.tar
with tempfile.NamedTemporaryFile(suffix=".tar", delete=False) as tf:
local = Path(tf.name)
try:
_save_ckpt_local(model, local, state.get("samples_completed", 0) + processed,
still_deferred, starved=rescued, phase="B")
sha = _sha256_of(local)
_bucket_commit_ckpt(api, local, f"{BUCKET_PREFIX}state-FINAL.tar", sha)
finally:
local.unlink(missing_ok=True)
sentinel = {
"status": "complete" if not still_deferred else "partial",
"processed": processed,
"still_deferred": len(still_deferred),
"final_starved_rescued": len(rescued),
}
# Inline export + publish while the h200x8 still has the model loaded.
try:
_log("Phase B calibration complete — exporting + publishing inline")
info = _finalize_and_export(api, model, tokenizer, origin_phase="B")
sentinel.update(info)
sentinel["status"] = "complete-published"
except Exception as e:
_log(f"Phase B inline export FAILED: {type(e).__name__}: {e}")
traceback.print_exc()
sentinel["export_error"] = f"{type(e).__name__}: {e}"
_write_sentinel(api, "phase-b.result", sentinel)
_log(f"Phase B DONE — status={sentinel['status']}")
return 0
# ---------- Phase C: restore + export + upload -----------------------------
def _cleanup_target_repo_shards(api: HfApi) -> None:
"""Before publishing new NVFP4 shards, delete any stale *.safetensors files
from the target repo. `upload_folder` does NOT delete pre-existing files —
it only overwrites same-named paths. If the new quant has fewer shards than
a prior upload, old shards remain and break loading.
"""
try:
existing = api.list_repo_files(repo_id=TARGET_REPO_ID, repo_type="model")
except Exception as e:
_log(f"target repo list failed (new repo?): {e}")
return
stale = [f for f in existing if f.endswith(".safetensors") or f == "model.safetensors.index.json"]
if not stale:
return
_log(f"deleting {len(stale)} stale files from target repo before publish")
from huggingface_hub import CommitOperationDelete
ops = [CommitOperationDelete(path_in_repo=f) for f in stale]
api.create_commit(
repo_id=TARGET_REPO_ID,
repo_type="model",
operations=ops,
commit_message=f"cleanup {len(stale)} stale shards before new publish",
)
def _finalize_and_export(api: HfApi, model, tokenizer, origin_phase: str) -> dict:
"""Export NVFP4 checkpoint + publish to target repo + bucket backup.
Called inline from Phase A (on status=complete) or Phase B (always) so the
already-loaded GPU model is reused. Avoids reloading a 460GB model in a
separate cpu-basic job (which can't fit it anyway).
"""
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
_log(f"exporting to {OUTPUT_DIR}")
with torch.inference_mode():
export_hf_checkpoint(model, export_dir=str(OUTPUT_DIR))
shards = list(OUTPUT_DIR.glob("*.safetensors"))
assert shards, "export produced no safetensors shards"
exp_gb = sum(p.stat().st_size for p in OUTPUT_DIR.rglob("*") if p.is_file()) / 1e9
_log(f"export: {len(shards)} shards, {exp_gb:.1f} GB")
tokenizer.save_pretrained(str(OUTPUT_DIR))
for f in os.listdir(INPUT_DIR):
if f.startswith(("modeling_", "configuration_", "tokenization_")) and f.endswith(".py"):
shutil.copy2(INPUT_DIR / f, OUTPUT_DIR / f)
if f == "chat_template.jinja":
shutil.copy2(INPUT_DIR / f, OUTPUT_DIR / f)
api.create_repo(TARGET_REPO_ID, private=False, exist_ok=True)
_cleanup_target_repo_shards(api)
_log(f"uploading to target: {TARGET_REPO_ID}")
api.upload_folder(
folder_path=str(OUTPUT_DIR),
repo_id=TARGET_REPO_ID,
commit_message=f"v3 AC NVFP4 (protected rerun, origin={origin_phase})",
)
_log("target upload DONE")
api.upload_folder(
folder_path=str(OUTPUT_DIR),
repo_id=BUCKET_REPO_ID,
repo_type="dataset",
path_in_repo=f"{BUCKET_PREFIX}export/",
commit_message=f"export backup from phase-{origin_phase}",
)
return {"shards": len(shards), "size_gb": exp_gb, "origin_phase": origin_phase}
def _phase_c_main() -> int:
"""MANUAL RECOVERY ONLY — not auto-submitted by the launcher.
If Phase A or Phase B crashed after calibration but before export, this
phase reloads the model + restores the latest good ckpt + exports. Requires
a GPU flavor (a100-large minimum) because the 230B BF16 model won't fit in
cpu-basic RAM.
"""
_log("starting Phase C: MANUAL export-only recovery")
api = HfApi()
files = api.list_repo_files(repo_id=BUCKET_REPO_ID, repo_type="dataset")
final = f"{BUCKET_PREFIX}state-FINAL.tar"
if final in files and f"{BUCKET_PREFIX}state-FINAL.ok" in files:
ckpt_bucket_path = final
_log("using state-FINAL (Phase B was run)")
else:
latest = _find_latest_good_ckpt(api)
assert latest is not None, "no valid ckpt for Phase C"
_, ckpt_bucket_path = latest
_log(f"using Phase A latest ckpt: {ckpt_bucket_path}")
tokenizer = AutoTokenizer.from_pretrained(str(INPUT_DIR), trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
str(INPUT_DIR), torch_dtype=torch.bfloat16, device_map="auto",
trust_remote_code=True,
)
with tempfile.TemporaryDirectory() as td:
local = Path(td) / "ckpt.tar"
_download_ckpt(api, ckpt_bucket_path, local)
_restore_ckpt_into_model(model, local)
rescued = _populate_starved_experts(model)
if rescued:
_log(f"Phase C extra rescue: {len(rescued)} quantizers")
info = _finalize_and_export(api, model, tokenizer, origin_phase="C-recovery")
_write_sentinel(api, "phase-c.result", {"status": "complete", **info})
_log("Phase C DONE")
return 0
# ---------- Entry ----------------------------------------------------------
def main() -> int:
_log(f"quantize-ac-protected.py phase={PHASE}")
_log(f"modelopt={getattr(__import__('modelopt'), '__version__', '?')}")
if PHASE == "A":
return _phase_a_main()
if PHASE == "B":
return _phase_b_main()
return _phase_c_main()
def _write_crash_sentinel(exc_type: str, exc_msg: str, tb: str) -> None:
"""Always-attempt sentinel write on uncaught exception so the launcher
doesn't spin for 21h waiting for a sentinel that'll never arrive.
"""
try:
api = HfApi()
payload = {
"status": "crashed",
"phase": PHASE,
"error_type": exc_type,
"error_message": exc_msg[:2000],
"traceback": tb[:4000],
}
_write_sentinel(api, f"phase-{PHASE.lower()}.result", payload)
_log(f"crash sentinel written for phase-{PHASE.lower()}.result")
except Exception as inner:
_log(f"failed to write crash sentinel: {type(inner).__name__}: {inner}")
if __name__ == "__main__":
try:
sys.exit(main())
except Exception as e:
_log(f"FATAL: {type(e).__name__}: {e}")
tb = traceback.format_exc()
traceback.print_exc()
_write_crash_sentinel(type(e).__name__, str(e), tb)
sys.exit(2)