Fix load_quantized.py: TurboQuant-aware dequantization (was broken rowwise_symmetric loader)
Browse files- load_quantized.py +221 -49
load_quantized.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import json
|
|
|
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
import torch
|
|
@@ -8,29 +10,230 @@ from safetensors.torch import load_file
|
|
| 8 |
from transformers import AutoProcessor, AutoTokenizer
|
| 9 |
|
| 10 |
MANIFEST_FILENAME = "quant_manifest.json"
|
| 11 |
-
|
| 12 |
|
| 13 |
|
| 14 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
packed = packed.reshape(-1).to(torch.uint8).cpu()
|
| 16 |
-
if bit_width =
|
| 17 |
-
return packed[:total_values].contiguous()
|
| 18 |
values_per_byte = 8 // bit_width
|
| 19 |
mask = (1 << bit_width) - 1
|
| 20 |
packed_i32 = packed.to(torch.int32)
|
| 21 |
parts = []
|
| 22 |
for index in range(values_per_byte):
|
| 23 |
parts.append(((packed_i32 >> (index * bit_width)) & mask).to(torch.uint8))
|
| 24 |
-
return torch.stack(parts, dim=1).reshape(-1)[:total_values].contiguous()
|
| 25 |
|
| 26 |
|
| 27 |
-
def
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def _create_empty_model(repo_dir: Path, loader_kind: str):
|
| 32 |
from transformers import AutoConfig
|
| 33 |
-
|
| 34 |
config = AutoConfig.from_pretrained(repo_dir, trust_remote_code=True)
|
| 35 |
if loader_kind == "causal-lm":
|
| 36 |
from transformers import AutoModelForCausalLM
|
|
@@ -41,63 +244,31 @@ def _create_empty_model(repo_dir: Path, loader_kind: str):
|
|
| 41 |
raise ValueError(f"Unsupported loader kind: {loader_kind}")
|
| 42 |
|
| 43 |
|
| 44 |
-
def _load_model_from_state_dict(repo_dir: Path, loader_kind: str,
|
|
|
|
| 45 |
model = _create_empty_model(repo_dir, loader_kind)
|
| 46 |
incompatible = model.load_state_dict(state_dict, strict=False, assign=True)
|
| 47 |
if hasattr(model, "tie_weights"):
|
| 48 |
model.tie_weights()
|
| 49 |
-
|
| 50 |
allowed_missing = {"lm_head.weight"}
|
| 51 |
allowed_unexpected_prefixes = ("mtp.",)
|
| 52 |
-
disallowed_missing = sorted(
|
|
|
|
|
|
|
| 53 |
disallowed_unexpected = sorted(
|
| 54 |
-
key for key in incompatible.unexpected_keys
|
|
|
|
| 55 |
)
|
| 56 |
if disallowed_missing or disallowed_unexpected:
|
| 57 |
raise RuntimeError(
|
| 58 |
-
"Unexpected state_dict mismatch
|
| 59 |
f"missing={disallowed_missing}, unexpected={disallowed_unexpected}"
|
| 60 |
)
|
| 61 |
return model
|
| 62 |
|
| 63 |
|
| 64 |
-
def load_quantized_state_dict(repo_dir: str | Path):
|
| 65 |
-
repo_dir = Path(repo_dir)
|
| 66 |
-
manifest = json.loads((repo_dir / MANIFEST_FILENAME).read_text(encoding="utf-8"))
|
| 67 |
-
stored = load_file(str(repo_dir / WEIGHTS_FILENAME), device="cpu")
|
| 68 |
-
state_dict = {}
|
| 69 |
-
for name, spec in manifest["parameter_specs"].items():
|
| 70 |
-
prefix = spec["storage_prefix"]
|
| 71 |
-
if spec["quantized"]:
|
| 72 |
-
quantization_scheme = spec.get("quantization_scheme", manifest.get("quantization_scheme", "rowwise_symmetric"))
|
| 73 |
-
bit_width = int(spec["bit_width"])
|
| 74 |
-
group_size = int(spec["group_size"])
|
| 75 |
-
num_groups = int(spec["num_groups"])
|
| 76 |
-
total_group_values = int(spec.get("total_group_values", num_groups * group_size))
|
| 77 |
-
dtype = _name_to_dtype(spec["original_dtype"])
|
| 78 |
-
unpacked = _unpack_codes(stored[f"{prefix}__packed"], bit_width, total_group_values).to(torch.float32)
|
| 79 |
-
unpacked = unpacked.view(num_groups, group_size)
|
| 80 |
-
scales = stored[f"{prefix}__scales"].to(torch.float32).view(num_groups, 1)
|
| 81 |
-
if quantization_scheme != "rowwise_symmetric":
|
| 82 |
-
raise ValueError(f"Unsupported quantization scheme: {quantization_scheme}")
|
| 83 |
-
if bit_width == 1:
|
| 84 |
-
restored_groups = torch.where(unpacked > 0, torch.ones_like(unpacked), -torch.ones_like(unpacked)) * scales
|
| 85 |
-
else:
|
| 86 |
-
qmax = (1 << (bit_width - 1)) - 1
|
| 87 |
-
restored_groups = (unpacked - float(qmax)) * scales
|
| 88 |
-
row_count = int(spec["row_count"])
|
| 89 |
-
row_length = int(spec["row_length"])
|
| 90 |
-
groups_per_row = int(spec["groups_per_row"])
|
| 91 |
-
padded_row_length = groups_per_row * group_size
|
| 92 |
-
restored = restored_groups.view(row_count, groups_per_row, group_size).reshape(row_count, padded_row_length)
|
| 93 |
-
restored = restored[:, :row_length].contiguous().view(tuple(int(value) for value in spec["shape"]))
|
| 94 |
-
state_dict[name] = restored.to(dtype)
|
| 95 |
-
continue
|
| 96 |
-
state_dict[name] = stored[f"{prefix}__passthrough"].to(_name_to_dtype(spec["original_dtype"]))
|
| 97 |
-
return state_dict, manifest
|
| 98 |
-
|
| 99 |
-
|
| 100 |
def load_quantized_model(repo_dir: str | Path, device: str | torch.device = "cpu"):
|
|
|
|
| 101 |
repo_dir = Path(repo_dir)
|
| 102 |
state_dict, manifest = load_quantized_state_dict(repo_dir)
|
| 103 |
model = _load_model_from_state_dict(repo_dir, manifest["loader_kind"], state_dict)
|
|
@@ -107,6 +278,7 @@ def load_quantized_model(repo_dir: str | Path, device: str | torch.device = "cpu
|
|
| 107 |
|
| 108 |
|
| 109 |
def load_tokenizer(repo_dir: str | Path):
|
|
|
|
| 110 |
repo_dir = Path(repo_dir)
|
| 111 |
if (repo_dir / "preprocessor_config.json").exists():
|
| 112 |
processor = AutoProcessor.from_pretrained(repo_dir, trust_remote_code=True)
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import json
|
| 4 |
+
import math
|
| 5 |
+
from functools import lru_cache
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
import torch
|
|
|
|
| 10 |
from transformers import AutoProcessor, AutoTokenizer
|
| 11 |
|
| 12 |
MANIFEST_FILENAME = "quant_manifest.json"
|
| 13 |
+
TURBOQUANT_WEIGHTS_FILENAME = "turboquant_weights.safetensors"
|
| 14 |
|
| 15 |
|
| 16 |
+
def _name_to_dtype(name: str) -> torch.dtype:
|
| 17 |
+
return getattr(torch, name)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _unpack_indices(packed: torch.Tensor, bit_width: int, total_values: int) -> torch.Tensor:
|
| 21 |
packed = packed.reshape(-1).to(torch.uint8).cpu()
|
| 22 |
+
if bit_width >= 8:
|
| 23 |
+
return packed[:total_values].to(torch.int64).contiguous()
|
| 24 |
values_per_byte = 8 // bit_width
|
| 25 |
mask = (1 << bit_width) - 1
|
| 26 |
packed_i32 = packed.to(torch.int32)
|
| 27 |
parts = []
|
| 28 |
for index in range(values_per_byte):
|
| 29 |
parts.append(((packed_i32 >> (index * bit_width)) & mask).to(torch.uint8))
|
| 30 |
+
return torch.stack(parts, dim=1).reshape(-1)[:total_values].to(torch.int64).contiguous()
|
| 31 |
|
| 32 |
|
| 33 |
+
def _unpack_signs(packed: torch.Tensor, total_values: int) -> torch.Tensor:
|
| 34 |
+
bits = _unpack_indices(packed, 1, total_values)
|
| 35 |
+
return (bits.to(torch.int8) * 2 - 1)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@lru_cache(maxsize=None)
|
| 39 |
+
def _sphere_coordinate_density(dimension: int, grid_size: int):
|
| 40 |
+
"""Lemma 1: Beta distribution density on unit sphere coordinates.
|
| 41 |
+
|
| 42 |
+
f_X(x) = Gamma(d/2) / (sqrt(pi) * Gamma((d-1)/2)) * (1 - x^2)^((d-3)/2)
|
| 43 |
+
Reference: arXiv:2504.19874v1, Lemma 1
|
| 44 |
+
"""
|
| 45 |
+
eps = 1e-6
|
| 46 |
+
grid = torch.linspace(-1.0 + eps, 1.0 - eps, steps=grid_size, dtype=torch.float64)
|
| 47 |
+
exponent = 0.5 * float(dimension - 3)
|
| 48 |
+
log_const = (
|
| 49 |
+
torch.lgamma(torch.tensor(float(dimension) / 2.0, dtype=torch.float64))
|
| 50 |
+
- 0.5 * math.log(math.pi)
|
| 51 |
+
- torch.lgamma(torch.tensor(float(dimension - 1) / 2.0, dtype=torch.float64))
|
| 52 |
+
)
|
| 53 |
+
interior = torch.clamp(1.0 - grid.square(), min=1e-24)
|
| 54 |
+
density = torch.exp(log_const + exponent * torch.log(interior))
|
| 55 |
+
step = (2.0 - 2.0 * eps) / float(grid_size - 1)
|
| 56 |
+
weights = density * step
|
| 57 |
+
weights = weights / torch.sum(weights)
|
| 58 |
+
return grid, weights
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@lru_cache(maxsize=None)
|
| 62 |
+
def _compute_codebook(dimension: int, bit_width: int, grid_size: int = 8193, iterations: int = 96):
|
| 63 |
+
"""Lloyd-Max optimal scalar quantization on Beta distribution.
|
| 64 |
+
|
| 65 |
+
Solves Eq (4): argmin_{c_1,...,c_K} E[min_k |X - c_k|^2]
|
| 66 |
+
Reference: arXiv:2504.19874v1, Section 3.1, Eq (4)
|
| 67 |
+
"""
|
| 68 |
+
if bit_width == 0:
|
| 69 |
+
return torch.empty(0, dtype=torch.float64)
|
| 70 |
+
grid, weights = _sphere_coordinate_density(dimension, grid_size)
|
| 71 |
+
codebook_size = 1 << bit_width
|
| 72 |
+
cumulative = torch.cumsum(weights, dim=0)
|
| 73 |
+
cumulative = cumulative / cumulative[-1]
|
| 74 |
+
targets = (torch.arange(codebook_size, dtype=torch.float64) + 0.5) / float(codebook_size)
|
| 75 |
+
init_idx = torch.clamp(torch.searchsorted(cumulative, targets), max=grid.numel() - 1)
|
| 76 |
+
centroids = torch.sort(grid[init_idx]).values
|
| 77 |
+
for _ in range(iterations):
|
| 78 |
+
if centroids.numel() == 1:
|
| 79 |
+
break
|
| 80 |
+
boundaries = torch.empty(codebook_size + 1, dtype=torch.float64)
|
| 81 |
+
boundaries[0] = -1.0
|
| 82 |
+
boundaries[-1] = 1.0
|
| 83 |
+
boundaries[1:-1] = 0.5 * (centroids[:-1] + centroids[1:])
|
| 84 |
+
bucket_ids = torch.bucketize(grid, boundaries[1:-1])
|
| 85 |
+
updated = centroids.clone()
|
| 86 |
+
for bucket in range(codebook_size):
|
| 87 |
+
mask = bucket_ids == bucket
|
| 88 |
+
if not bool(torch.any(mask)):
|
| 89 |
+
continue
|
| 90 |
+
w = weights[mask]
|
| 91 |
+
ws = torch.sum(w)
|
| 92 |
+
if float(ws.item()) <= 0.0:
|
| 93 |
+
continue
|
| 94 |
+
updated[bucket] = torch.sum(grid[mask] * w) / ws
|
| 95 |
+
updated = torch.sort(updated).values
|
| 96 |
+
if float(torch.max(torch.abs(updated - centroids)).item()) < 1e-12:
|
| 97 |
+
break
|
| 98 |
+
centroids = updated
|
| 99 |
+
return centroids
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@lru_cache(maxsize=None)
|
| 103 |
+
def _cached_rotation_matrix(dimension: int, seed: int):
|
| 104 |
+
"""Random rotation via QR decomposition of Gaussian matrix.
|
| 105 |
+
|
| 106 |
+
Reference: arXiv:2504.19874v1, Algorithm 1 (step: generate random rotation Pi)
|
| 107 |
+
"""
|
| 108 |
+
gen = torch.Generator(device="cpu")
|
| 109 |
+
gen.manual_seed(seed)
|
| 110 |
+
gaussian = torch.randn((dimension, dimension), generator=gen, dtype=torch.float64)
|
| 111 |
+
q, r = torch.linalg.qr(gaussian, mode="reduced")
|
| 112 |
+
diag = torch.sign(torch.diag(r))
|
| 113 |
+
diag = torch.where(diag == 0, torch.ones_like(diag), diag)
|
| 114 |
+
q = q * diag.unsqueeze(0)
|
| 115 |
+
return q
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@lru_cache(maxsize=None)
|
| 119 |
+
def _cached_projection_matrix(dimension: int, seed: int):
|
| 120 |
+
"""Random Gaussian projection matrix for QJL.
|
| 121 |
+
|
| 122 |
+
Reference: arXiv:2504.19874v1, Definition 1 (S with iid N(0,1) entries)
|
| 123 |
+
"""
|
| 124 |
+
gen = torch.Generator(device="cpu")
|
| 125 |
+
gen.manual_seed(seed)
|
| 126 |
+
return torch.randn((dimension, dimension), generator=gen, dtype=torch.float64)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _dequantize_mse(indices: torch.Tensor, norms: torch.Tensor, dimension: int,
|
| 130 |
+
bit_width: int, seed: int, grid_size: int = 8193,
|
| 131 |
+
iterations: int = 96):
|
| 132 |
+
"""TurboQuant_MSE dequantization (Algorithm 1).
|
| 133 |
+
|
| 134 |
+
Dequant_mse(idx):
|
| 135 |
+
y_j = c_{idx_j} (codebook lookup)
|
| 136 |
+
x_hat = Pi^T * y (rotate back)
|
| 137 |
+
return x_hat * ||x||
|
| 138 |
+
|
| 139 |
+
Reference: arXiv:2504.19874v1, Algorithm 1
|
| 140 |
+
"""
|
| 141 |
+
row_count = int(indices.shape[0])
|
| 142 |
+
codebook = _compute_codebook(dimension, bit_width, grid_size, iterations).to(torch.float32)
|
| 143 |
+
rotation = _cached_rotation_matrix(dimension, seed).to(torch.float32)
|
| 144 |
+
rotated = codebook[indices]
|
| 145 |
+
reconstructed = rotated @ rotation
|
| 146 |
+
return reconstructed * norms.to(torch.float32).unsqueeze(-1)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _dequantize_prod(mse_indices: torch.Tensor, qjl_signs: torch.Tensor,
|
| 150 |
+
norms: torch.Tensor, residual_norms: torch.Tensor,
|
| 151 |
+
dimension: int, bit_width: int, seed: int,
|
| 152 |
+
grid_size: int = 8193, iterations: int = 96):
|
| 153 |
+
"""TurboQuant_prod dequantization (Algorithm 2).
|
| 154 |
+
|
| 155 |
+
Dequant_prod(idx, qjl, gamma):
|
| 156 |
+
x_mse = Dequant_mse(idx)
|
| 157 |
+
x_qjl = sqrt(pi/2)/d * S^T * qjl
|
| 158 |
+
x_hat = x_mse + gamma * x_qjl
|
| 159 |
+
return x_hat * ||x||
|
| 160 |
+
|
| 161 |
+
Reference: arXiv:2504.19874v1, Algorithm 2
|
| 162 |
+
"""
|
| 163 |
+
mse_bw = max(bit_width - 1, 0)
|
| 164 |
+
codebook = (_compute_codebook(dimension, mse_bw, grid_size, iterations).to(torch.float32)
|
| 165 |
+
if mse_bw > 0 else torch.zeros(1, dtype=torch.float32))
|
| 166 |
+
rotation = _cached_rotation_matrix(dimension, seed).to(torch.float32)
|
| 167 |
+
if mse_bw > 0:
|
| 168 |
+
mse_reconstructed = codebook[mse_indices] @ rotation
|
| 169 |
+
else:
|
| 170 |
+
mse_reconstructed = torch.zeros(mse_indices.shape[0], dimension, dtype=torch.float32)
|
| 171 |
+
projection = _cached_projection_matrix(dimension, seed + 1).to(torch.float32)
|
| 172 |
+
qjl_reconstructed = (math.sqrt(math.pi / 2.0) / float(dimension)) * (
|
| 173 |
+
qjl_signs.to(torch.float32) @ projection
|
| 174 |
+
)
|
| 175 |
+
reconstructed_unit = (mse_reconstructed
|
| 176 |
+
+ residual_norms.to(torch.float32).unsqueeze(-1) * qjl_reconstructed)
|
| 177 |
+
return reconstructed_unit * norms.to(torch.float32).unsqueeze(-1)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def load_quantized_state_dict(repo_dir: str | Path):
|
| 181 |
+
"""Load quantized state dict from TurboQuant checkpoint files."""
|
| 182 |
+
repo_dir = Path(repo_dir)
|
| 183 |
+
manifest = json.loads((repo_dir / MANIFEST_FILENAME).read_text(encoding="utf-8"))
|
| 184 |
+
stored = load_file(str(repo_dir / TURBOQUANT_WEIGHTS_FILENAME), device="cpu")
|
| 185 |
+
state_dict = {}
|
| 186 |
+
for name, spec in manifest["parameter_specs"].items():
|
| 187 |
+
prefix = spec["storage_prefix"]
|
| 188 |
+
if not spec["quantized"]:
|
| 189 |
+
state_dict[name] = stored[f"{prefix}__passthrough"].to(
|
| 190 |
+
_name_to_dtype(spec["original_dtype"])
|
| 191 |
+
)
|
| 192 |
+
continue
|
| 193 |
+
quantizer_type = spec["quantizer_type"]
|
| 194 |
+
dimension = int(spec["dimension"])
|
| 195 |
+
row_count = int(spec["row_count"])
|
| 196 |
+
bit_width = int(spec["bit_width"])
|
| 197 |
+
seed = int(spec["seed"])
|
| 198 |
+
grid_size = int(spec["grid_size"])
|
| 199 |
+
iterations = int(spec["iterations"])
|
| 200 |
+
original_shape = [int(v) for v in spec["shape"]]
|
| 201 |
+
original_dtype = _name_to_dtype(spec["original_dtype"])
|
| 202 |
+
if quantizer_type == "turboquant_mse":
|
| 203 |
+
total_idx = int(spec["total_index_values"])
|
| 204 |
+
indices = _unpack_indices(
|
| 205 |
+
stored[f"{prefix}__packed_indices"], bit_width, total_idx
|
| 206 |
+
).reshape(row_count, dimension)
|
| 207 |
+
norms = stored[f"{prefix}__norms"].to(torch.float32)
|
| 208 |
+
state_dict[name] = _dequantize_mse(
|
| 209 |
+
indices, norms, dimension, bit_width, seed, grid_size, iterations
|
| 210 |
+
).reshape(original_shape).to(original_dtype)
|
| 211 |
+
elif quantizer_type == "turboquant_prod":
|
| 212 |
+
mse_bw = max(bit_width - 1, 0)
|
| 213 |
+
total_idx = int(spec["total_index_values"])
|
| 214 |
+
total_signs = int(spec["total_sign_values"])
|
| 215 |
+
if mse_bw > 0:
|
| 216 |
+
mse_indices = _unpack_indices(
|
| 217 |
+
stored[f"{prefix}__packed_indices"], mse_bw, total_idx
|
| 218 |
+
).reshape(row_count, dimension)
|
| 219 |
+
else:
|
| 220 |
+
mse_indices = torch.zeros(row_count, dimension, dtype=torch.int64)
|
| 221 |
+
qjl_signs = _unpack_signs(
|
| 222 |
+
stored[f"{prefix}__packed_signs"], total_signs
|
| 223 |
+
).reshape(row_count, dimension)
|
| 224 |
+
norms = stored[f"{prefix}__norms"].to(torch.float32)
|
| 225 |
+
residual_norms = stored[f"{prefix}__residual_norms"].to(torch.float32)
|
| 226 |
+
state_dict[name] = _dequantize_prod(
|
| 227 |
+
mse_indices, qjl_signs, norms, residual_norms,
|
| 228 |
+
dimension, bit_width, seed, grid_size, iterations
|
| 229 |
+
).reshape(original_shape).to(original_dtype)
|
| 230 |
+
else:
|
| 231 |
+
raise ValueError(f"Unknown quantizer_type: {quantizer_type}")
|
| 232 |
+
return state_dict, manifest
|
| 233 |
|
| 234 |
|
| 235 |
def _create_empty_model(repo_dir: Path, loader_kind: str):
|
| 236 |
from transformers import AutoConfig
|
|
|
|
| 237 |
config = AutoConfig.from_pretrained(repo_dir, trust_remote_code=True)
|
| 238 |
if loader_kind == "causal-lm":
|
| 239 |
from transformers import AutoModelForCausalLM
|
|
|
|
| 244 |
raise ValueError(f"Unsupported loader kind: {loader_kind}")
|
| 245 |
|
| 246 |
|
| 247 |
+
def _load_model_from_state_dict(repo_dir: Path, loader_kind: str,
|
| 248 |
+
state_dict: dict[str, torch.Tensor]):
|
| 249 |
model = _create_empty_model(repo_dir, loader_kind)
|
| 250 |
incompatible = model.load_state_dict(state_dict, strict=False, assign=True)
|
| 251 |
if hasattr(model, "tie_weights"):
|
| 252 |
model.tie_weights()
|
|
|
|
| 253 |
allowed_missing = {"lm_head.weight"}
|
| 254 |
allowed_unexpected_prefixes = ("mtp.",)
|
| 255 |
+
disallowed_missing = sorted(
|
| 256 |
+
key for key in incompatible.missing_keys if key not in allowed_missing
|
| 257 |
+
)
|
| 258 |
disallowed_unexpected = sorted(
|
| 259 |
+
key for key in incompatible.unexpected_keys
|
| 260 |
+
if not key.startswith(allowed_unexpected_prefixes)
|
| 261 |
)
|
| 262 |
if disallowed_missing or disallowed_unexpected:
|
| 263 |
raise RuntimeError(
|
| 264 |
+
"Unexpected state_dict mismatch: "
|
| 265 |
f"missing={disallowed_missing}, unexpected={disallowed_unexpected}"
|
| 266 |
)
|
| 267 |
return model
|
| 268 |
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
def load_quantized_model(repo_dir: str | Path, device: str | torch.device = "cpu"):
|
| 271 |
+
"""Load the quantized model, dequantize weights, and return (model, manifest)."""
|
| 272 |
repo_dir = Path(repo_dir)
|
| 273 |
state_dict, manifest = load_quantized_state_dict(repo_dir)
|
| 274 |
model = _load_model_from_state_dict(repo_dir, manifest["loader_kind"], state_dict)
|
|
|
|
| 278 |
|
| 279 |
|
| 280 |
def load_tokenizer(repo_dir: str | Path):
|
| 281 |
+
"""Load the tokenizer from the repo directory."""
|
| 282 |
repo_dir = Path(repo_dir)
|
| 283 |
if (repo_dir / "preprocessor_config.json").exists():
|
| 284 |
processor = AutoProcessor.from_pretrained(repo_dir, trust_remote_code=True)
|