Text Generation
Transformers
Safetensors
English
Chinese
minimax_m2
minimax
nvfp4
4-bit precision
quantized
compressed-tensors
vllm
DGX-Spark
GB10
MoE
agentic
tool-use
code
conversational
custom_code
8-bit precision
Instructions to use saricles/MiniMax-M2.7-NVFP4-GB10-AC with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use saricles/MiniMax-M2.7-NVFP4-GB10-AC with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="saricles/MiniMax-M2.7-NVFP4-GB10-AC", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoTokenizer, AutoModelForMultimodalLM tokenizer = AutoTokenizer.from_pretrained("saricles/MiniMax-M2.7-NVFP4-GB10-AC", trust_remote_code=True) model = AutoModelForMultimodalLM.from_pretrained("saricles/MiniMax-M2.7-NVFP4-GB10-AC", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) outputs = model.generate(**inputs, max_new_tokens=40) print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:])) - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use saricles/MiniMax-M2.7-NVFP4-GB10-AC with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "saricles/MiniMax-M2.7-NVFP4-GB10-AC" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "saricles/MiniMax-M2.7-NVFP4-GB10-AC", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/saricles/MiniMax-M2.7-NVFP4-GB10-AC
- SGLang
How to use saricles/MiniMax-M2.7-NVFP4-GB10-AC with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "saricles/MiniMax-M2.7-NVFP4-GB10-AC" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "saricles/MiniMax-M2.7-NVFP4-GB10-AC", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "saricles/MiniMax-M2.7-NVFP4-GB10-AC" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "saricles/MiniMax-M2.7-NVFP4-GB10-AC", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use saricles/MiniMax-M2.7-NVFP4-GB10-AC with Docker Model Runner:
docker model run hf.co/saricles/MiniMax-M2.7-NVFP4-GB10-AC
| #!/usr/bin/env python3 | |
| """Protected NVFP4 quantization for MiniMax-M2.7 (agentic + coder calibration). | |
| Recipe used to produce `MiniMax-M2.7-NVFP4-GB10-AC`: a 7-dataset agentic + coder | |
| calibration (896 samples queued, 888 tokenized @ 49,152 max-seq) with per-sample | |
| OOM-defer protection and amax-only checkpointing. Designed to run on Hugging | |
| Face Jobs or any equivalent GPU host with access to the BF16 source model. | |
| Three phases dispatched by PHASE env var. The happy path is a single Phase A | |
| invocation on a GPU flavor large enough to hold the 230B BF16 model: | |
| PHASE=A — Calibration with per-sample OOM defer, checkpoint every N samples | |
| (amax-only format, ~60 MB each — NOT a full state_dict, which would | |
| be ~460 GB). At end, populate starved-expert amax from weights and | |
| write phase-a.result sentinel. On status=complete, export and | |
| publish inline while the model is still GPU-loaded. | |
| PHASE=B — Fallback only: resume from the latest good checkpoint on a | |
| larger-memory GPU flavor; process deferred samples and any | |
| remaining starved experts, save final state, inline export. | |
| PHASE=C — Manual recovery: restore from saved checkpoint and re-export if | |
| Phase A/B completed calibration but crashed during export. Requires | |
| a GPU flavor large enough to host the BF16 model. | |
| Key properties: | |
| - All calibration samples contribute (no skipping on OOM — deferred instead). | |
| - Amax-only checkpoints via modelopt_state + `_amax` buffers, not mto.save(). | |
| - Two-phase bucket commit: state-N.tar then state-N.ok marker with sha256, | |
| so a torn upload never masquerades as a valid checkpoint. | |
| - MaxCalibrator._calib_amax is SEEDED at restore so resume max-merges rather | |
| than overwrites prior-phase amax (the non-obvious gotcha that breaks naive | |
| resumes). | |
| - Starved-expert amax (quantizers that never received traffic) is populated | |
| via the enable_stats_collection(wq); wq(weight); finish_stats_collection | |
| pattern, producing the correct per-channel shape. | |
| - Wallclock watchdog thread lets Phase A exit cleanly before budget burn. | |
| Env vars: | |
| PHASE = A | B | C (required) | |
| INPUT_DIR = path to BF16 source model (default: /model) | |
| OUTPUT_DIR = export target path (default: /out) | |
| TARGET_REPO_ID = HF Hub model repo to publish to (required for Phase A | |
| inline export; Phase B/C also publish to it) | |
| BUCKET_REPO_ID = HF Hub dataset repo used as checkpoint workspace | |
| (required for all phases) | |
| BUCKET_PREFIX = path prefix inside the bucket repo (default: runs/ac/) | |
| NUM_CALIB_PER_DS = calibration samples per dataset (default: 128) | |
| MAX_SEQ = max sequence length, tokens (default: 49152 = 48K) | |
| CKPT_EVERY = checkpoint cadence in completed samples (default: 50) | |
| WALLCLOCK_BUDGET_S = soft deadline before Phase A starts deferring (default: | |
| 21600 = 6h; hard watchdog = budget + 1800s) | |
| STARVED_EXPERT_PCT_ABORT = abort if >this%% of quantizers ended starved | |
| (default: 1.0) | |
| The sanity gate in `_config_sanity_gate` is model-specific (MiniMax-M2.7: 256 | |
| experts × 62 layers × 125 shards) and will need adjustment for a different MoE | |
| architecture. Everything else is model-agnostic. | |
| """ | |
| from __future__ import annotations | |
| # Conv1D shim (transformers 4.57+ moved it out of modeling_utils). | |
| import transformers.modeling_utils # noqa: E402 | |
| if not hasattr(transformers.modeling_utils, "Conv1D"): | |
| from transformers.pytorch_utils import Conv1D as _Conv1D # noqa: E402 | |
| transformers.modeling_utils.Conv1D = _Conv1D | |
| import gc | |
| import hashlib | |
| import json | |
| import os | |
| import random | |
| import shutil | |
| import sys | |
| import tempfile | |
| import threading | |
| import time | |
| import traceback | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from huggingface_hub import HfApi | |
| import modelopt.torch.quantization as mtq | |
| import modelopt.torch.opt as mto | |
| from modelopt.torch.opt.conversion import modelopt_state, restore_from_modelopt_state | |
| from modelopt.torch.quantization.model_calib import ( | |
| enable_stats_collection, finish_stats_collection, | |
| ) | |
| from modelopt.torch.export import export_hf_checkpoint | |
| from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer | |
| # ---------- Config --------------------------------------------------------- | |
| PHASE = os.environ["PHASE"].upper() | |
| assert PHASE in {"A", "B", "C"}, f"PHASE must be A|B|C, got {PHASE!r}" | |
| INPUT_DIR = Path(os.environ.get("INPUT_DIR", "/model")) | |
| OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "/out")) | |
| BUCKET_REPO_ID = os.environ.get("BUCKET_REPO_ID", "") # e.g. "<your-username>/quant-lab" | |
| BUCKET_PREFIX = os.environ.get("BUCKET_PREFIX", "runs/ac/").rstrip("/") + "/" | |
| TARGET_REPO_ID = os.environ.get("TARGET_REPO_ID", "") # e.g. "<your-username>/MiniMax-M2.7-NVFP4-AC" | |
| # Required-env validation: fail fast with a clear message rather than deep in | |
| # an HfApi call. All phases need the bucket; Phase A also needs the target for | |
| # its inline export on success. | |
| if not BUCKET_REPO_ID: | |
| raise RuntimeError( | |
| "BUCKET_REPO_ID env var is required (HF Hub dataset repo used for " | |
| "checkpoint workspace). Set it to e.g. '<your-username>/quant-lab'." | |
| ) | |
| if PHASE == "A" and not TARGET_REPO_ID: | |
| raise RuntimeError( | |
| "TARGET_REPO_ID env var is required for Phase A (inline export target). " | |
| "Set it to e.g. '<your-username>/MiniMax-M2.7-NVFP4-AC'." | |
| ) | |
| NUM_CALIB_PER_DS = int(os.environ.get("NUM_CALIB_PER_DS", "128")) | |
| MAX_SEQ = int(os.environ.get("MAX_SEQ", "49152")) | |
| CKPT_EVERY = int(os.environ.get("CKPT_EVERY", "50")) | |
| WALLCLOCK_BUDGET_S = int(os.environ.get("WALLCLOCK_BUDGET_S", "21600")) | |
| # 7-dataset agentic + coder calibration mix. Code + tool + math + science + | |
| # SWE-agent trajectories for the target workloads, plus a general-chat anchor | |
| # (ultrachat) to keep plain conversational activations in-range. | |
| # | |
| # Schemas handled by _extract_text: messages (list OR JSON string), trajectory, | |
| # instruction+output, query+answers. | |
| CALIB_DATASETS = [ | |
| ("theblackcat102/evol-codealpaca-v1", {"split": "train"}), | |
| ("Salesforce/xlam-function-calling-60k", {"split": "train"}), | |
| ("open-r1/Mixture-of-Thoughts", {"name": "code", "split": "train"}), | |
| ("open-r1/Mixture-of-Thoughts", {"name": "math", "split": "train"}), | |
| ("open-r1/Mixture-of-Thoughts", {"name": "science", "split": "train"}), | |
| ("SWE-bench/SWE-smith-trajectories", {"split": "tool"}), | |
| ("HuggingFaceH4/ultrachat_200k", {"split": "train_sft"}), | |
| ] | |
| DEVICE = "cuda" | |
| STARVED_EXPERT_PCT_ABORT = float(os.environ.get("STARVED_EXPERT_PCT_ABORT", "1.0")) | |
| def _log(msg: str) -> None: | |
| ts = time.strftime("%H:%M:%S") | |
| print(f"[{ts}] [phase-{PHASE}] {msg}", flush=True) | |
| def _free_gpu() -> None: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| # ---------- Wallclock watchdog --------------------------------------------- | |
| class WallclockWatchdog: | |
| """Monotonic-clock watchdog. Main loop reads .expired between samples. | |
| Belt-and-suspenders: a second thread triggers `os._exit(43)` if budget + | |
| hard_cap_margin_s is exceeded, in case the main loop's expiry check path | |
| is broken (e.g., loop hangs in a forward pass). This is the budget-safety | |
| net against an otherwise $400 worst-case container-timeout burn. | |
| """ | |
| def __init__(self, budget_s: int, name: str = "phase-a", | |
| hard_cap_margin_s: int = 1800): | |
| self._budget_s = budget_s | |
| self._hard_cap_s = budget_s + hard_cap_margin_s | |
| self._start = time.monotonic() | |
| self._stop = threading.Event() | |
| self._expired = threading.Event() | |
| self._name = name | |
| self._soft_thread = threading.Thread(target=self._run_soft, daemon=True) | |
| self._hard_thread = threading.Thread(target=self._run_hard, daemon=True) | |
| self._soft_thread.start() | |
| self._hard_thread.start() | |
| def _run_soft(self): | |
| while not self._stop.wait(10): | |
| if time.monotonic() - self._start > self._budget_s: | |
| self._expired.set() | |
| return | |
| def _run_hard(self): | |
| while not self._stop.wait(30): | |
| if time.monotonic() - self._start > self._hard_cap_s: | |
| # Main loop failed to exit on soft cap — kill the process to | |
| # bound the HF Jobs bill. Phase A ckpts are durable; Phase B | |
| # can pick up from the latest good one. Exit code 45 to | |
| # distinguish hard-kill from graceful wallclock-cap (43). | |
| print(f"[watchdog] HARD CAP {self._hard_cap_s}s exceeded — os._exit(45)", | |
| flush=True) | |
| os._exit(45) | |
| def expired(self) -> bool: | |
| return self._expired.is_set() | |
| def elapsed_s(self) -> int: | |
| return int(time.monotonic() - self._start) | |
| def shutdown(self): | |
| self._stop.set() | |
| # ---------- Amax-only checkpoint I/O --------------------------------------- | |
| def _amax_buffers_cpu(model: torch.nn.Module) -> dict[str, torch.Tensor]: | |
| # `.cpu()` from a non-CPU device already copies; `.clone()` would be a | |
| # redundant second copy of thousands of small buffers. | |
| return { | |
| n: b.detach().cpu() | |
| for n, b in model.named_buffers() | |
| if n.endswith("._amax") and b is not None | |
| } | |
| _CACHED_MODELOPT_STATE: dict | None = None | |
| def _get_cached_modelopt_state(model: torch.nn.Module) -> dict: | |
| """modelopt_state(model) walks the full module tree — cache after first call | |
| since architecture is immutable post-install. | |
| """ | |
| global _CACHED_MODELOPT_STATE | |
| if _CACHED_MODELOPT_STATE is None: | |
| _CACHED_MODELOPT_STATE = modelopt_state(model) | |
| return _CACHED_MODELOPT_STATE | |
| def _save_ckpt_local(model: torch.nn.Module, path: Path, samples_done: int, | |
| deferred: list, starved: list | None = None, | |
| phase: str = PHASE, | |
| deferred_batches: list | None = None) -> int: | |
| """Write amax-only + bookkeeping to `path`. Returns bytes written. | |
| `deferred_batches` (optional): list of (idx, tokenized_tensor) pairs for the | |
| deferred samples, so Phase B can replay the EXACT same tensors without | |
| rebuilding the calibration set (avoids non-determinism if HF datasets drift). | |
| """ | |
| state = { | |
| "modelopt_state": _get_cached_modelopt_state(model), | |
| "amax_buffers": _amax_buffers_cpu(model), | |
| "samples_completed": samples_done, | |
| "deferred_samples": deferred, | |
| "deferred_batches": deferred_batches, | |
| "starved_experts": starved, | |
| "phase": phase, | |
| } | |
| torch.save(state, path) | |
| return path.stat().st_size | |
| def _sha256_of(path: Path) -> str: | |
| h = hashlib.sha256() | |
| with open(path, "rb") as f: | |
| for chunk in iter(lambda: f.read(1 << 20), b""): | |
| h.update(chunk) | |
| return h.hexdigest() | |
| def _bucket_commit_ckpt(api: HfApi, local_path: Path, bucket_path: str, | |
| sha256_hex: str) -> None: | |
| """Two-phase commit: upload state-N.tar, then upload state-N.ok marker. | |
| If the .ok upload fails after 3 retries, DELETE the orphan tar so | |
| _find_latest_good_ckpt won't resurrect a torn state. | |
| """ | |
| api.upload_file( | |
| path_or_fileobj=str(local_path), | |
| path_in_repo=bucket_path, | |
| repo_id=BUCKET_REPO_ID, | |
| repo_type="dataset", | |
| commit_message=f"ckpt upload {bucket_path}", | |
| ) | |
| ok_path = bucket_path.rsplit(".tar", 1)[0] + ".ok" | |
| marker = json.dumps({ | |
| "sha256": sha256_hex, | |
| "timestamp": time.time(), | |
| "size": local_path.stat().st_size, | |
| }).encode() | |
| last_err: Exception | None = None | |
| for attempt in range(3): | |
| try: | |
| api.upload_file( | |
| path_or_fileobj=marker, | |
| path_in_repo=ok_path, | |
| repo_id=BUCKET_REPO_ID, | |
| repo_type="dataset", | |
| commit_message=f"ckpt marker {ok_path}", | |
| ) | |
| return | |
| except Exception as e: | |
| last_err = e | |
| _log(f"ckpt .ok upload attempt {attempt + 1}/3 failed: {e}; retrying") | |
| time.sleep(2 ** attempt) | |
| # Marker terminally failed — delete the orphan tar to avoid torn-state confusion | |
| _log(f"ckpt .ok TERMINAL FAIL; deleting orphan {bucket_path}") | |
| try: | |
| api.delete_file( | |
| path_in_repo=bucket_path, | |
| repo_id=BUCKET_REPO_ID, | |
| repo_type="dataset", | |
| commit_message=f"orphan cleanup after .ok failure on {bucket_path}", | |
| ) | |
| except Exception as e: | |
| _log(f"orphan delete also failed: {e}") | |
| raise AssertionError(f"ckpt .ok upload failed after retries: {last_err}") | |
| def _find_latest_good_ckpt(api: HfApi) -> tuple[int, str] | None: | |
| """List bucket, return (N, path) of highest state-N where BOTH tar and .ok | |
| are present and the .ok's sha256 matches the tar's sha256. | |
| """ | |
| files = api.list_repo_files(repo_id=BUCKET_REPO_ID, repo_type="dataset") | |
| tars: dict[int, str] = {} | |
| oks: dict[int, str] = {} | |
| for f in files: | |
| if not f.startswith(BUCKET_PREFIX): | |
| continue | |
| base = f.rsplit("/", 1)[-1] | |
| if base.startswith("state-") and base.endswith(".tar"): | |
| try: | |
| n = int(base[len("state-"):-len(".tar")]) | |
| tars[n] = f | |
| except ValueError: | |
| continue | |
| elif base.startswith("state-") and base.endswith(".ok"): | |
| try: | |
| n = int(base[len("state-"):-len(".ok")]) | |
| oks[n] = f | |
| except ValueError: | |
| continue | |
| candidates = sorted(set(tars) & set(oks), reverse=True) | |
| for n in candidates: | |
| return n, tars[n] | |
| return None | |
| def _download_ckpt(api: HfApi, bucket_path: str, local_path: Path) -> None: | |
| """Download a specific file from the workspace bucket.""" | |
| downloaded = api.hf_hub_download( | |
| repo_id=BUCKET_REPO_ID, | |
| repo_type="dataset", | |
| filename=bucket_path, | |
| local_dir=str(local_path.parent), | |
| ) | |
| shutil.move(downloaded, local_path) | |
| def _restore_ckpt_into_model(model: torch.nn.Module, ckpt_path: Path) -> dict: | |
| """Restore modelopt state + amax buffers + seed calibrator _calib_amax. | |
| CRITICAL device handling: | |
| - `torch.load(map_location="cpu")` puts the saved amax on CPU | |
| - After `restore_from_modelopt_state`, quantizers are inserted but have NO | |
| _amax buffer yet (it's lazily registered on first assignment) | |
| - `tq.amax = val` path-1 of the setter (tensor_quantizer.py:224-225) calls | |
| `self.register_buffer("_amax", value.clone().detach())` — registering the | |
| buffer on value's device, which is CPU. The model's other params are on | |
| CUDA (via device_map="auto"). First forward → device mismatch. | |
| - Fix: determine each quantizer's target device from its PARENT module's | |
| parameters (already on the correct GPU) and move `val` there before the | |
| setter call. Then the setter registers the buffer on the right device. | |
| Also seeds `_calibrator._calib_amax` on the same target device so | |
| post-restore calibration max-merges rather than overwrites the restored amax. | |
| """ | |
| state = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| restore_from_modelopt_state(model, state["modelopt_state"]) | |
| missing: list[str] = [] | |
| for name, val in state["amax_buffers"].items(): | |
| assert name.endswith("._amax"), f"unexpected buffer name: {name}" | |
| q_path = name[: -len("._amax")] | |
| try: | |
| tq = model.get_submodule(q_path) | |
| except AttributeError: | |
| missing.append(name) | |
| continue | |
| current = getattr(tq, "_amax", None) | |
| if current is not None and current.shape != val.shape: | |
| raise AssertionError( | |
| f"shape mismatch at {name}: saved={val.shape} current={current.shape}" | |
| ) | |
| # Resolve target device from parent module (device_map="auto" shards | |
| # layers across GPUs). Parent's own params are already on the right GPU. | |
| parent_path = q_path.rsplit(".", 1)[0] if "." in q_path else "" | |
| parent = model.get_submodule(parent_path) if parent_path else model | |
| target_device = next( | |
| (p.device for p in parent.parameters()), | |
| next((b.device for b in parent.buffers()), torch.device("cpu")), | |
| ) | |
| val_on_device = val.detach().to(target_device) | |
| tq.amax = val_on_device | |
| calibrator = getattr(tq, "_calibrator", None) | |
| if calibrator is not None and hasattr(calibrator, "_calib_amax"): | |
| calibrator._calib_amax = val_on_device.clone() | |
| if missing: | |
| raise AssertionError( | |
| f"Restore missed {len(missing)}/{len(state['amax_buffers'])} amax buffers. " | |
| f"Sample: {missing[:3]}" | |
| ) | |
| return state | |
| # ---------- Calibration set construction ----------------------------------- | |
| def _extract_content_text(content) -> str: | |
| """Flatten a message `content` field into plain text. | |
| Handles: string, list-of-dicts (OpenAI/Anthropic multimodal + tool-use | |
| blocks), list-of-strings, nested content, None. Ported from the proven | |
| extractor in quantize-nvfp4-gb10-AC.py (the SWE-smith tool-use messages | |
| carry text inside `{"type": "text", "text": "..."}` blocks, and content | |
| can nest one more level via `{"content": [...]}`). | |
| """ | |
| if content is None: | |
| return "" | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, list): | |
| parts: list[str] = [] | |
| for block in content: | |
| if isinstance(block, str): | |
| parts.append(block) | |
| elif isinstance(block, dict): | |
| if isinstance(block.get("text"), str): | |
| parts.append(block["text"]) | |
| elif "content" in block: | |
| sub = _extract_content_text(block["content"]) | |
| if sub: | |
| parts.append(sub) | |
| return "\n".join(parts) | |
| return str(content) | |
| def _extract_text(row: dict) -> str | None: | |
| """Extract one calibration text per the proven 7 datasets' schemas: | |
| messages (ultrachat, Mixture-of-Thoughts × 3, SWE-smith — list OR JSON str) | |
| trajectory (defensive — earlier SWE-smith variants stored here) | |
| instruction+output (evol-codealpaca) | |
| query+answers (xlam-function-calling) | |
| """ | |
| if "messages" in row: | |
| msgs = row["messages"] | |
| if isinstance(msgs, str): | |
| try: | |
| msgs = json.loads(msgs) | |
| except Exception: | |
| msgs = None | |
| if isinstance(msgs, list): | |
| parts = [] | |
| for m in msgs: | |
| if isinstance(m, dict): | |
| sub = _extract_content_text(m.get("content")) | |
| if sub and sub.strip(): | |
| parts.append(sub) | |
| if parts: | |
| return "\n".join(parts) | |
| if "trajectory" in row: | |
| traj = row["trajectory"] | |
| if isinstance(traj, str): | |
| try: | |
| traj = json.loads(traj) | |
| except Exception: | |
| traj = None | |
| if isinstance(traj, list): | |
| parts = [] | |
| for m in traj: | |
| if isinstance(m, dict): | |
| sub = _extract_content_text(m.get("content")) | |
| if sub and sub.strip(): | |
| parts.append(sub) | |
| if parts: | |
| return "\n".join(parts) | |
| if "instruction" in row: | |
| out = row.get("output", "") | |
| instr = row.get("instruction", "") | |
| text = (instr + "\n" + str(out)) if out else str(instr) | |
| if text.strip(): | |
| return text | |
| if "query" in row: | |
| ans = row.get("answers", row.get("answer", "")) | |
| q = row.get("query", "") | |
| text = (str(q) + "\n" + str(ans)) if ans else str(q) | |
| if text.strip(): | |
| return text | |
| return None | |
| def _probe_calibration_datasets() -> None: | |
| """Pre-model-load schema probe. | |
| Loads 1 row from each of the 7 CALIB_DATASETS and verifies that | |
| `_extract_text` returns non-empty. Aborts Phase A/B BEFORE the 5-10min | |
| model download if any dataset's schema has drifted. Total probe cost | |
| on HF Jobs a100x8: ~$0.05 (vs ~$3 if we fail post-model-load). | |
| Error message includes the failing dataset's first-row keys so a fix | |
| can be made without re-probing manually. | |
| """ | |
| from datasets import load_dataset | |
| _log(f"probing {len(CALIB_DATASETS)} calibration datasets (schema check)") | |
| failures: list[str] = [] | |
| for name, kwargs in CALIB_DATASETS: | |
| # `split="x[:1]"` slicing is a sized-dataset feature — incompatible | |
| # with streaming. Guard here so a future streaming entry fails with a | |
| # clear message instead of a cryptic datasets error. | |
| assert not kwargs.get("streaming"), ( | |
| f"probe does not support streaming=True ({name}); remove streaming or " | |
| f"skip the probe for this entry" | |
| ) | |
| probe_kwargs = dict(kwargs) | |
| probe_kwargs["split"] = f"{kwargs['split']}[:1]" | |
| try: | |
| ds = load_dataset(name, **probe_kwargs) | |
| except Exception as e: | |
| failures.append( | |
| f"{name} (kwargs={kwargs}): load failed — {type(e).__name__}: {e}" | |
| ) | |
| continue | |
| if len(ds) == 0: | |
| failures.append(f"{name} (kwargs={kwargs}): empty split") | |
| continue | |
| row = ds[0] | |
| row_keys = list(row.keys())[:10] | |
| text = _extract_text(row) | |
| if not text or not text.strip(): | |
| failures.append( | |
| f"{name} (kwargs={kwargs}): _extract_text returned empty; " | |
| f"first-row keys={row_keys}" | |
| ) | |
| continue | |
| _log(f" probe OK: {name} ({len(text)} chars, keys={row_keys[:6]})") | |
| if failures: | |
| msg = "calibration dataset probe FAILED:\n " + "\n ".join(failures) | |
| _log(msg) | |
| raise RuntimeError(msg) | |
| _log(f"probe OK for all {len(CALIB_DATASETS)} datasets") | |
| def _load_calibration_texts(tokenizer) -> list[str]: | |
| """Load NUM_CALIB_PER_DS texts per dataset across CALIB_DATASETS. Abort on | |
| any dataset that yields <50% of requested texts (matches original STRICT | |
| coverage invariant). | |
| Uses `split[:NUM_CALIB_PER_DS]` slicing so load_dataset only pulls the rows | |
| we actually need — relevant for tulu-3-sft-mixture (939K rows) and | |
| ultrachat_200k (200K rows). | |
| """ | |
| from datasets import load_dataset | |
| texts: list[str] = [] | |
| failed: list[str] = [] | |
| for name, kwargs in CALIB_DATASETS: | |
| _log(f"loading {name}") | |
| load_kwargs = dict(kwargs) | |
| load_kwargs["split"] = f"{kwargs['split']}[:{NUM_CALIB_PER_DS}]" | |
| try: | |
| ds = load_dataset(name, **load_kwargs) | |
| except Exception as e: | |
| _log(f" FAILED to load {name}: {e}") | |
| failed.append(f"{name}: {e}") | |
| continue | |
| first_keys = list(ds[0].keys())[:10] if len(ds) else [] | |
| loaded_this_ds = 0 | |
| for row in ds: | |
| text = _extract_text(row) | |
| if text and len(text) > 32: | |
| texts.append(text) | |
| loaded_this_ds += 1 | |
| _log(f" extracted {loaded_this_ds} texts from {name}") | |
| if loaded_this_ds < NUM_CALIB_PER_DS // 2: | |
| failed.append( | |
| f"{name} (only {loaded_this_ds}/{NUM_CALIB_PER_DS}); first_keys={first_keys}" | |
| ) | |
| assert not failed, f"calibration coverage failure: {failed}" | |
| _log(f"total calibration texts: {len(texts)}") | |
| return texts | |
| def _tokenize_and_sort(tokenizer, texts: list[str]) -> list[torch.Tensor]: | |
| """Tokenize each text, filter too-short, sort ascending by length. | |
| Sorting shortest-first keeps early samples cheap so we hit the first | |
| CKPT_EVERY boundary fast (checkpoint exists before big samples risk OOM). | |
| """ | |
| tokenized = [] | |
| for t in texts: | |
| ids = tokenizer(t, return_tensors="pt", truncation=True, max_length=MAX_SEQ).input_ids | |
| if ids.shape[1] < 32: | |
| continue | |
| tokenized.append(ids) | |
| tokenized.sort(key=lambda x: x.shape[1]) | |
| return tokenized | |
| # ---------- Quantizer setup + calibration ---------------------------------- | |
| def _assert_config_is_pure_max(cfg: dict) -> None: | |
| """ISC-1/2/3: refuse anything other than NVFP4_DEFAULT_CFG + algorithm=max. | |
| Identity check catches variants like NVFP4_KV_CFG that would also pass the | |
| algorithm=max check but change the quantization behavior. | |
| """ | |
| assert cfg is mtq.NVFP4_DEFAULT_CFG, ( | |
| f"must use the exact NVFP4_DEFAULT_CFG object, got {cfg!r}" | |
| ) | |
| algo = cfg.get("algorithm") | |
| assert algo in (None, "max"), ( | |
| f"NVFP4 config must use algorithm=max/None for split-sample safety; got {algo!r}" | |
| ) | |
| def _install_quantizers_with_dummy(model, dummy_batches): | |
| """Install quantizers via mtq.quantize with a small dummy forward_loop. | |
| Matches original's proven pattern. Ignore list is applied post-quantize. | |
| """ | |
| cfg = mtq.NVFP4_DEFAULT_CFG | |
| _assert_config_is_pure_max(cfg) | |
| def forward_loop(m): | |
| m.eval() | |
| with torch.no_grad(): | |
| for ids in dummy_batches: | |
| m(input_ids=ids.to(next(m.parameters()).device)) | |
| return mtq.quantize(model, cfg, forward_loop) | |
| def _apply_ignore_list(model) -> None: | |
| """Disable quantizers on the GB10 ignore list (lm_head + router gate + | |
| first/last layers + embeds). Matches original quantize-nvfp4-gb10-AC.py. | |
| Note: we deliberately do NOT clear `_amax` on disabled quantizers. Stale | |
| amax on a disabled quantizer is inert — the quantize/dequantize code path | |
| short-circuits on `_disabled=True`. modelopt's state save/restore preserves | |
| the disable flag (verified in 0.29.0 tensor_quantizer.py:1075 / 1127), | |
| so disabled quantizers stay disabled through Phase A/B/C transitions. | |
| """ | |
| mtq.disable_quantizer(model, "lm_head*") | |
| mtq.disable_quantizer(model, "*block_sparse_moe.gate") | |
| mtq.disable_quantizer(model, "*embed_tokens*") | |
| n_layers = getattr(model.config, "num_hidden_layers", None) | |
| assert n_layers is not None, "cannot determine num_hidden_layers" | |
| mtq.disable_quantizer(model, "*model.layers.0.*") | |
| mtq.disable_quantizer(model, f"*model.layers.{n_layers - 1}.*") | |
| n_disabled = sum( | |
| 1 for _, tq in model.named_modules() | |
| if isinstance(tq, TensorQuantizer) and not getattr(tq, "is_enabled", True) | |
| ) | |
| _log(f"ignore list applied (last layer={n_layers - 1}); {n_disabled} quantizers disabled") | |
| def _populate_starved_experts(model) -> list[str]: | |
| """Rescue quantizers with None/zero amax by routing weight through the | |
| calibration path. Produces correct per-channel amax shape (matches original | |
| Phase 2.5 pattern). | |
| """ | |
| rescued: list[str] = [] | |
| for name, module in model.named_modules(): | |
| wq = getattr(module, "weight_quantizer", None) | |
| if wq is None: | |
| continue | |
| if not getattr(wq, "is_enabled", True): | |
| continue | |
| amax = getattr(wq, "amax", None) | |
| if amax is not None and not torch.all(amax == 0): | |
| continue | |
| with torch.no_grad(): | |
| if hasattr(wq, "reset_amax"): | |
| wq.reset_amax() | |
| enable_stats_collection(wq) | |
| wq(module.weight) | |
| finish_stats_collection(wq) | |
| assert wq.amax is not None, f"starved rescue populate failed for {name}" | |
| rescued.append(name) | |
| return rescued | |
| def _count_total_quantizers(model) -> int: | |
| return sum(1 for _, m in model.named_modules() if isinstance(m, TensorQuantizer)) | |
| # ---------- Phase A: calibrate + ckpt + defer ------------------------------ | |
| def _config_sanity_gate() -> None: | |
| """Verify /model is actually MiniMax-M2.7 230B (256 experts × 62 layers × 125 shards). | |
| Catches wrong-mount before burning calibration compute. | |
| """ | |
| cfg_path = INPUT_DIR / "config.json" | |
| assert cfg_path.exists(), f"{cfg_path} not found — is /model mounted?" | |
| cfg = json.loads(cfg_path.read_text()) | |
| assert cfg.get("architectures") == ["MiniMaxM2ForCausalLM"], ( | |
| f"arch mismatch: {cfg.get('architectures')} (expected MiniMaxM2ForCausalLM)" | |
| ) | |
| assert cfg.get("num_local_experts") == 256, ( | |
| f"expert count mismatch: {cfg.get('num_local_experts')} (expected 256 for M2.7)" | |
| ) | |
| assert cfg.get("num_hidden_layers") == 62, ( | |
| f"layer count mismatch: {cfg.get('num_hidden_layers')} (expected 62 for M2.7)" | |
| ) | |
| shards = [f for f in os.listdir(INPUT_DIR) | |
| if f.startswith("model-") and f.endswith(".safetensors")] | |
| assert len(shards) == 125, f"shard count mismatch: {len(shards)} (expected 125)" | |
| _log(f"config sanity gate OK: MiniMax-M2.7 230B, 256 experts × 62 layers × 125 shards") | |
| def _phase_a_main() -> int: | |
| """Returns sentinel status string via exit: 0=complete, 42=deferred, | |
| 43=wallclock-cap, 44=starved-experts-excessive.""" | |
| _log(f"starting Phase A: input={INPUT_DIR} bucket={BUCKET_REPO_ID} ckpt_every={CKPT_EVERY}") | |
| _log(f"wallclock budget = {WALLCLOCK_BUDGET_S}s, max_seq={MAX_SEQ}") | |
| _config_sanity_gate() | |
| # Schema probe BEFORE model download — fail in ~60s on dataset drift | |
| # instead of ~10min into a100x8 compute. Cost of a false-pass probe is | |
| # still lower than one post-model-load surprise. | |
| _probe_calibration_datasets() | |
| watchdog = WallclockWatchdog(WALLCLOCK_BUDGET_S, name="phase-a") | |
| api = HfApi() | |
| tokenizer = AutoTokenizer.from_pretrained(str(INPUT_DIR), trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| str(INPUT_DIR), torch_dtype=torch.bfloat16, device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| # Calibration set | |
| texts = _load_calibration_texts(tokenizer) | |
| batches = _tokenize_and_sort(tokenizer, texts) | |
| _log(f"tokenized {len(batches)} samples, shortest={batches[0].shape[1]}, longest={batches[-1].shape[1]}") | |
| # Install quantizers with a 1-sample dummy (the real calibration happens | |
| # in the manual loop below, which we control for checkpointing). | |
| dummy = batches[:1] | |
| model = _install_quantizers_with_dummy(model, dummy) | |
| _apply_ignore_list(model) | |
| total_q = _count_total_quantizers(model) | |
| _log(f"{total_q} TensorQuantizer modules installed post-ignore-list") | |
| # Real calibration loop: per-sample try/except, checkpoint every N | |
| deferred: list[dict] = [] | |
| deferred_batches_payload: list[dict] = [] # persist the exact tensors so Phase B is deterministic | |
| completed = 0 | |
| consecutive_oom = 0 | |
| ckpt_counter = 0 | |
| last_ckpt_at_completed = 0 | |
| enable_stats_collection(model) | |
| try: | |
| for i, ids in enumerate(batches): | |
| if watchdog.expired: | |
| _log(f"WALLCLOCK EXPIRED at sample {i}/{len(batches)} — deferring remainder") | |
| for j in range(i, len(batches)): | |
| deferred.append({"idx": j, "reason": "wallclock-cap"}) | |
| deferred_batches_payload.append({"idx": j, "input_ids": batches[j].detach().cpu()}) | |
| break | |
| try: | |
| ids_dev = ids.to(next(model.parameters()).device) | |
| with torch.no_grad(): | |
| model(input_ids=ids_dev) | |
| completed += 1 | |
| consecutive_oom = 0 | |
| except torch.cuda.OutOfMemoryError: | |
| _log(f"OOM at sample {i} (len={ids.shape[1]}) — deferring") | |
| deferred.append({"idx": i, "reason": "OOM", "len": int(ids.shape[1])}) | |
| deferred_batches_payload.append({"idx": i, "input_ids": ids.detach().cpu()}) | |
| _free_gpu() | |
| # On FIRST OOM of a new streak, opportunistically checkpoint so | |
| # progress up to now is saved in case the next sample also OOMs. | |
| if consecutive_oom == 0 and completed > last_ckpt_at_completed: | |
| ckpt_counter += 1 | |
| _checkpoint_commit(api, model, completed, deferred, ckpt_counter, | |
| deferred_batches_payload=deferred_batches_payload) | |
| last_ckpt_at_completed = completed | |
| consecutive_oom += 1 | |
| if consecutive_oom >= 2: | |
| _log("2 consecutive OOMs — ABORTING Phase A (saving state)") | |
| break | |
| if (completed > 0) and (completed % CKPT_EVERY == 0) and completed != last_ckpt_at_completed: | |
| ckpt_counter += 1 | |
| _checkpoint_commit(api, model, completed, deferred, ckpt_counter, | |
| deferred_batches_payload=deferred_batches_payload) | |
| last_ckpt_at_completed = completed | |
| if (i + 1) % 25 == 0: | |
| _log(f"progress: {i+1}/{len(batches)} (completed={completed}, deferred={len(deferred)})") | |
| finally: | |
| finish_stats_collection(model) | |
| watchdog.shutdown() | |
| # Starved-expert rescue | |
| _log("populating starved experts from weights") | |
| rescued = _populate_starved_experts(model) | |
| starved_pct = 100.0 * len(rescued) / max(total_q, 1) | |
| _log(f"starved rescue: {len(rescued)}/{total_q} ({starved_pct:.2f}%)") | |
| # Final checkpoint — carries both the latest calibration state AND the | |
| # rescue results. Bumps counter so Phase B sees the newest N. | |
| ckpt_counter += 1 | |
| _checkpoint_commit(api, model, completed, deferred, ckpt_counter, | |
| starved=rescued, | |
| deferred_batches_payload=deferred_batches_payload) | |
| # Determine status. Precedence: excessive-starved > wallclock > deferred > complete. | |
| status = "complete" | |
| if deferred: | |
| status = "wallclock-cap" if any(d["reason"] == "wallclock-cap" for d in deferred) else "deferred" | |
| if starved_pct > STARVED_EXPERT_PCT_ABORT: | |
| status = "excessive-starved" | |
| sentinel = { | |
| "status": status, | |
| "deferred_count": len(deferred), | |
| "starved_count": len(rescued), | |
| "total_quantizers": total_q, | |
| "starved_pct": starved_pct, | |
| "last_ckpt": ckpt_counter, | |
| "elapsed_s": watchdog.elapsed_s(), | |
| "completed": completed, | |
| } | |
| # INLINE EXPORT: if status=complete, export + publish RIGHT HERE while the | |
| # model is loaded and the container has GPUs. This avoids needing a | |
| # separate cpu-basic Phase C job (which couldn't fit the 460GB model | |
| # anyway). If export fails, we still have the calibration ckpts in the | |
| # bucket and can manually recover via PHASE=C on a GPU flavor. | |
| if status == "complete": | |
| try: | |
| _log("Phase A status=complete — exporting + publishing inline") | |
| info = _finalize_and_export(api, model, tokenizer, origin_phase="A") | |
| sentinel.update(info) | |
| sentinel["status"] = "complete-published" | |
| status = "complete-published" | |
| except Exception as e: | |
| _log(f"inline export FAILED: {type(e).__name__}: {e}") | |
| traceback.print_exc() | |
| sentinel["export_error"] = f"{type(e).__name__}: {e}" | |
| # Keep status=complete so launcher/operator can kick off PHASE=C | |
| # recovery on a fresh GPU flavor. | |
| _write_sentinel(api, "phase-a.result", sentinel) | |
| _log(f"Phase A DONE — status={sentinel['status']} deferred={len(deferred)} starved={len(rescued)}") | |
| if status == "excessive-starved": | |
| return 44 | |
| if status == "wallclock-cap": | |
| return 43 | |
| if status == "deferred": | |
| return 42 | |
| return 0 | |
| def _checkpoint_commit(api: HfApi, model, completed: int, deferred: list, | |
| ckpt_counter: int, starved: list | None = None, | |
| deferred_batches_payload: list | None = None) -> None: | |
| """Save amax-only ckpt locally, upload tar + .ok marker.""" | |
| with tempfile.NamedTemporaryFile(suffix=".tar", delete=False) as tf: | |
| local = Path(tf.name) | |
| try: | |
| size = _save_ckpt_local(model, local, completed, deferred, | |
| starved=starved, | |
| deferred_batches=deferred_batches_payload) | |
| sha = _sha256_of(local) | |
| bucket_tar = f"{BUCKET_PREFIX}state-{ckpt_counter}.tar" | |
| _bucket_commit_ckpt(api, local, bucket_tar, sha) | |
| _log(f"ckpt {ckpt_counter} committed ({size // 1024} KB, sha={sha[:10]}, deferred={len(deferred)})") | |
| finally: | |
| local.unlink(missing_ok=True) | |
| def _write_sentinel(api: HfApi, name: str, payload: dict) -> None: | |
| """Upload a tiny JSON sentinel — written AFTER all ckpt uploads confirmed.""" | |
| data = json.dumps(payload, indent=2).encode() | |
| api.upload_file( | |
| path_or_fileobj=data, | |
| path_in_repo=f"{BUCKET_PREFIX}{name}", | |
| repo_id=BUCKET_REPO_ID, | |
| repo_type="dataset", | |
| commit_message=f"sentinel {name}", | |
| ) | |
| # ---------- Phase B: restore + process deferred ---------------------------- | |
| def _phase_b_main() -> int: | |
| """Restore Phase A's latest good ckpt, process deferred samples + any | |
| starved experts, save final state for Phase C export. | |
| """ | |
| _log("starting Phase B: h200x8 resume") | |
| api = HfApi() | |
| latest = _find_latest_good_ckpt(api) | |
| assert latest is not None, "Phase B found no valid ckpt in bucket — aborting" | |
| n, bucket_path = latest | |
| _log(f"resuming from ckpt N={n} at {bucket_path}") | |
| tokenizer = AutoTokenizer.from_pretrained(str(INPUT_DIR), trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| str(INPUT_DIR), torch_dtype=torch.bfloat16, device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| with tempfile.TemporaryDirectory() as td: | |
| local = Path(td) / "ckpt.tar" | |
| _download_ckpt(api, bucket_path, local) | |
| state = _restore_ckpt_into_model(model, local) | |
| deferred = state.get("deferred_samples", []) | |
| deferred_batches = state.get("deferred_batches") or [] | |
| _log(f"restored state: completed={state.get('samples_completed', 0)}, deferred={len(deferred)}") | |
| # Prefer replay from the exact tokenized tensors saved in Phase A's ckpt. | |
| # Only fall back to dataset rebuild if the ckpt is from an older version | |
| # that didn't persist batches. | |
| if deferred_batches: | |
| _log(f"using {len(deferred_batches)} deferred batches from ckpt (deterministic replay)") | |
| replay = [(d["idx"], d["input_ids"]) for d in deferred_batches] | |
| else: | |
| _log("ckpt has no deferred_batches — rebuilding calibration set (non-deterministic risk)") | |
| # Legacy path: Phase A may have run days earlier with different | |
| # CALIB_DATASETS or schemas. Re-probe so drift fails in ~60s instead | |
| # of wasting the h200x8 slot we already paid to allocate. | |
| _probe_calibration_datasets() | |
| texts = _load_calibration_texts(tokenizer) | |
| batches = _tokenize_and_sort(tokenizer, texts) | |
| replay = [(d["idx"], batches[d["idx"]]) for d in deferred if d["idx"] < len(batches)] | |
| # Process deferred samples | |
| enable_stats_collection(model) | |
| processed = 0 | |
| still_deferred: list[dict] = [] | |
| try: | |
| for idx, ids in replay: | |
| try: | |
| ids_dev = ids.to(next(model.parameters()).device) | |
| with torch.no_grad(): | |
| model(input_ids=ids_dev) | |
| processed += 1 | |
| except torch.cuda.OutOfMemoryError: | |
| _log(f"Phase B STILL OOM on sample {idx} — escalate manually") | |
| still_deferred.append({"idx": idx, "reason": "h200-OOM"}) | |
| _free_gpu() | |
| finally: | |
| finish_stats_collection(model) | |
| # Any remaining starved experts get one more rescue pass | |
| rescued = _populate_starved_experts(model) | |
| _log(f"Phase B: processed {processed}/{len(deferred)} deferred, rescued {len(rescued)} starved") | |
| # Save final state to bucket as state-FINAL.tar | |
| with tempfile.NamedTemporaryFile(suffix=".tar", delete=False) as tf: | |
| local = Path(tf.name) | |
| try: | |
| _save_ckpt_local(model, local, state.get("samples_completed", 0) + processed, | |
| still_deferred, starved=rescued, phase="B") | |
| sha = _sha256_of(local) | |
| _bucket_commit_ckpt(api, local, f"{BUCKET_PREFIX}state-FINAL.tar", sha) | |
| finally: | |
| local.unlink(missing_ok=True) | |
| sentinel = { | |
| "status": "complete" if not still_deferred else "partial", | |
| "processed": processed, | |
| "still_deferred": len(still_deferred), | |
| "final_starved_rescued": len(rescued), | |
| } | |
| # Inline export + publish while the h200x8 still has the model loaded. | |
| try: | |
| _log("Phase B calibration complete — exporting + publishing inline") | |
| info = _finalize_and_export(api, model, tokenizer, origin_phase="B") | |
| sentinel.update(info) | |
| sentinel["status"] = "complete-published" | |
| except Exception as e: | |
| _log(f"Phase B inline export FAILED: {type(e).__name__}: {e}") | |
| traceback.print_exc() | |
| sentinel["export_error"] = f"{type(e).__name__}: {e}" | |
| _write_sentinel(api, "phase-b.result", sentinel) | |
| _log(f"Phase B DONE — status={sentinel['status']}") | |
| return 0 | |
| # ---------- Phase C: restore + export + upload ----------------------------- | |
| def _cleanup_target_repo_shards(api: HfApi) -> None: | |
| """Before publishing new NVFP4 shards, delete any stale *.safetensors files | |
| from the target repo. `upload_folder` does NOT delete pre-existing files — | |
| it only overwrites same-named paths. If the new quant has fewer shards than | |
| a prior upload, old shards remain and break loading. | |
| """ | |
| try: | |
| existing = api.list_repo_files(repo_id=TARGET_REPO_ID, repo_type="model") | |
| except Exception as e: | |
| _log(f"target repo list failed (new repo?): {e}") | |
| return | |
| stale = [f for f in existing if f.endswith(".safetensors") or f == "model.safetensors.index.json"] | |
| if not stale: | |
| return | |
| _log(f"deleting {len(stale)} stale files from target repo before publish") | |
| from huggingface_hub import CommitOperationDelete | |
| ops = [CommitOperationDelete(path_in_repo=f) for f in stale] | |
| api.create_commit( | |
| repo_id=TARGET_REPO_ID, | |
| repo_type="model", | |
| operations=ops, | |
| commit_message=f"cleanup {len(stale)} stale shards before new publish", | |
| ) | |
| def _finalize_and_export(api: HfApi, model, tokenizer, origin_phase: str) -> dict: | |
| """Export NVFP4 checkpoint + publish to target repo + bucket backup. | |
| Called inline from Phase A (on status=complete) or Phase B (always) so the | |
| already-loaded GPU model is reused. Avoids reloading a 460GB model in a | |
| separate cpu-basic job (which can't fit it anyway). | |
| """ | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| _log(f"exporting to {OUTPUT_DIR}") | |
| with torch.inference_mode(): | |
| export_hf_checkpoint(model, export_dir=str(OUTPUT_DIR)) | |
| shards = list(OUTPUT_DIR.glob("*.safetensors")) | |
| assert shards, "export produced no safetensors shards" | |
| exp_gb = sum(p.stat().st_size for p in OUTPUT_DIR.rglob("*") if p.is_file()) / 1e9 | |
| _log(f"export: {len(shards)} shards, {exp_gb:.1f} GB") | |
| tokenizer.save_pretrained(str(OUTPUT_DIR)) | |
| for f in os.listdir(INPUT_DIR): | |
| if f.startswith(("modeling_", "configuration_", "tokenization_")) and f.endswith(".py"): | |
| shutil.copy2(INPUT_DIR / f, OUTPUT_DIR / f) | |
| if f == "chat_template.jinja": | |
| shutil.copy2(INPUT_DIR / f, OUTPUT_DIR / f) | |
| api.create_repo(TARGET_REPO_ID, private=False, exist_ok=True) | |
| _cleanup_target_repo_shards(api) | |
| _log(f"uploading to target: {TARGET_REPO_ID}") | |
| api.upload_folder( | |
| folder_path=str(OUTPUT_DIR), | |
| repo_id=TARGET_REPO_ID, | |
| commit_message=f"v3 AC NVFP4 (protected rerun, origin={origin_phase})", | |
| ) | |
| _log("target upload DONE") | |
| api.upload_folder( | |
| folder_path=str(OUTPUT_DIR), | |
| repo_id=BUCKET_REPO_ID, | |
| repo_type="dataset", | |
| path_in_repo=f"{BUCKET_PREFIX}export/", | |
| commit_message=f"export backup from phase-{origin_phase}", | |
| ) | |
| return {"shards": len(shards), "size_gb": exp_gb, "origin_phase": origin_phase} | |
| def _phase_c_main() -> int: | |
| """MANUAL RECOVERY ONLY — not auto-submitted by the launcher. | |
| If Phase A or Phase B crashed after calibration but before export, this | |
| phase reloads the model + restores the latest good ckpt + exports. Requires | |
| a GPU flavor (a100-large minimum) because the 230B BF16 model won't fit in | |
| cpu-basic RAM. | |
| """ | |
| _log("starting Phase C: MANUAL export-only recovery") | |
| api = HfApi() | |
| files = api.list_repo_files(repo_id=BUCKET_REPO_ID, repo_type="dataset") | |
| final = f"{BUCKET_PREFIX}state-FINAL.tar" | |
| if final in files and f"{BUCKET_PREFIX}state-FINAL.ok" in files: | |
| ckpt_bucket_path = final | |
| _log("using state-FINAL (Phase B was run)") | |
| else: | |
| latest = _find_latest_good_ckpt(api) | |
| assert latest is not None, "no valid ckpt for Phase C" | |
| _, ckpt_bucket_path = latest | |
| _log(f"using Phase A latest ckpt: {ckpt_bucket_path}") | |
| tokenizer = AutoTokenizer.from_pretrained(str(INPUT_DIR), trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| str(INPUT_DIR), torch_dtype=torch.bfloat16, device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| with tempfile.TemporaryDirectory() as td: | |
| local = Path(td) / "ckpt.tar" | |
| _download_ckpt(api, ckpt_bucket_path, local) | |
| _restore_ckpt_into_model(model, local) | |
| rescued = _populate_starved_experts(model) | |
| if rescued: | |
| _log(f"Phase C extra rescue: {len(rescued)} quantizers") | |
| info = _finalize_and_export(api, model, tokenizer, origin_phase="C-recovery") | |
| _write_sentinel(api, "phase-c.result", {"status": "complete", **info}) | |
| _log("Phase C DONE") | |
| return 0 | |
| # ---------- Entry ---------------------------------------------------------- | |
| def main() -> int: | |
| _log(f"quantize-ac-protected.py phase={PHASE}") | |
| _log(f"modelopt={getattr(__import__('modelopt'), '__version__', '?')}") | |
| if PHASE == "A": | |
| return _phase_a_main() | |
| if PHASE == "B": | |
| return _phase_b_main() | |
| return _phase_c_main() | |
| def _write_crash_sentinel(exc_type: str, exc_msg: str, tb: str) -> None: | |
| """Always-attempt sentinel write on uncaught exception so the launcher | |
| doesn't spin for 21h waiting for a sentinel that'll never arrive. | |
| """ | |
| try: | |
| api = HfApi() | |
| payload = { | |
| "status": "crashed", | |
| "phase": PHASE, | |
| "error_type": exc_type, | |
| "error_message": exc_msg[:2000], | |
| "traceback": tb[:4000], | |
| } | |
| _write_sentinel(api, f"phase-{PHASE.lower()}.result", payload) | |
| _log(f"crash sentinel written for phase-{PHASE.lower()}.result") | |
| except Exception as inner: | |
| _log(f"failed to write crash sentinel: {type(inner).__name__}: {inner}") | |
| if __name__ == "__main__": | |
| try: | |
| sys.exit(main()) | |
| except Exception as e: | |
| _log(f"FATAL: {type(e).__name__}: {e}") | |
| tb = traceback.format_exc() | |
| traceback.print_exc() | |
| _write_crash_sentinel(type(e).__name__, str(e), tb) | |
| sys.exit(2) | |