lew96123 commited on
Commit
919819e
·
verified ·
1 Parent(s): d78669f

Fix load_quantized.py: TurboQuant-aware dequantization (was broken rowwise_symmetric loader)

Browse files
Files changed (1) hide show
  1. load_quantized.py +221 -49
load_quantized.py CHANGED
@@ -1,6 +1,8 @@
1
  from __future__ import annotations
2
 
3
  import json
 
 
4
  from pathlib import Path
5
 
6
  import torch
@@ -8,29 +10,230 @@ from safetensors.torch import load_file
8
  from transformers import AutoProcessor, AutoTokenizer
9
 
10
  MANIFEST_FILENAME = "quant_manifest.json"
11
- WEIGHTS_FILENAME = "packed_weights.safetensors"
12
 
13
 
14
- def _unpack_codes(packed: torch.Tensor, bit_width: int, total_values: int) -> torch.Tensor:
 
 
 
 
15
  packed = packed.reshape(-1).to(torch.uint8).cpu()
16
- if bit_width == 8:
17
- return packed[:total_values].contiguous()
18
  values_per_byte = 8 // bit_width
19
  mask = (1 << bit_width) - 1
20
  packed_i32 = packed.to(torch.int32)
21
  parts = []
22
  for index in range(values_per_byte):
23
  parts.append(((packed_i32 >> (index * bit_width)) & mask).to(torch.uint8))
24
- return torch.stack(parts, dim=1).reshape(-1)[:total_values].contiguous()
25
 
26
 
27
- def _name_to_dtype(name: str) -> torch.dtype:
28
- return getattr(torch, name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  def _create_empty_model(repo_dir: Path, loader_kind: str):
32
  from transformers import AutoConfig
33
-
34
  config = AutoConfig.from_pretrained(repo_dir, trust_remote_code=True)
35
  if loader_kind == "causal-lm":
36
  from transformers import AutoModelForCausalLM
@@ -41,63 +244,31 @@ def _create_empty_model(repo_dir: Path, loader_kind: str):
41
  raise ValueError(f"Unsupported loader kind: {loader_kind}")
42
 
43
 
44
- def _load_model_from_state_dict(repo_dir: Path, loader_kind: str, state_dict: dict[str, torch.Tensor]):
 
45
  model = _create_empty_model(repo_dir, loader_kind)
46
  incompatible = model.load_state_dict(state_dict, strict=False, assign=True)
47
  if hasattr(model, "tie_weights"):
48
  model.tie_weights()
49
-
50
  allowed_missing = {"lm_head.weight"}
51
  allowed_unexpected_prefixes = ("mtp.",)
52
- disallowed_missing = sorted(key for key in incompatible.missing_keys if key not in allowed_missing)
 
 
53
  disallowed_unexpected = sorted(
54
- key for key in incompatible.unexpected_keys if not key.startswith(allowed_unexpected_prefixes)
 
55
  )
56
  if disallowed_missing or disallowed_unexpected:
57
  raise RuntimeError(
58
- "Unexpected state_dict mismatch while loading quantized model: "
59
  f"missing={disallowed_missing}, unexpected={disallowed_unexpected}"
60
  )
61
  return model
62
 
63
 
64
- def load_quantized_state_dict(repo_dir: str | Path):
65
- repo_dir = Path(repo_dir)
66
- manifest = json.loads((repo_dir / MANIFEST_FILENAME).read_text(encoding="utf-8"))
67
- stored = load_file(str(repo_dir / WEIGHTS_FILENAME), device="cpu")
68
- state_dict = {}
69
- for name, spec in manifest["parameter_specs"].items():
70
- prefix = spec["storage_prefix"]
71
- if spec["quantized"]:
72
- quantization_scheme = spec.get("quantization_scheme", manifest.get("quantization_scheme", "rowwise_symmetric"))
73
- bit_width = int(spec["bit_width"])
74
- group_size = int(spec["group_size"])
75
- num_groups = int(spec["num_groups"])
76
- total_group_values = int(spec.get("total_group_values", num_groups * group_size))
77
- dtype = _name_to_dtype(spec["original_dtype"])
78
- unpacked = _unpack_codes(stored[f"{prefix}__packed"], bit_width, total_group_values).to(torch.float32)
79
- unpacked = unpacked.view(num_groups, group_size)
80
- scales = stored[f"{prefix}__scales"].to(torch.float32).view(num_groups, 1)
81
- if quantization_scheme != "rowwise_symmetric":
82
- raise ValueError(f"Unsupported quantization scheme: {quantization_scheme}")
83
- if bit_width == 1:
84
- restored_groups = torch.where(unpacked > 0, torch.ones_like(unpacked), -torch.ones_like(unpacked)) * scales
85
- else:
86
- qmax = (1 << (bit_width - 1)) - 1
87
- restored_groups = (unpacked - float(qmax)) * scales
88
- row_count = int(spec["row_count"])
89
- row_length = int(spec["row_length"])
90
- groups_per_row = int(spec["groups_per_row"])
91
- padded_row_length = groups_per_row * group_size
92
- restored = restored_groups.view(row_count, groups_per_row, group_size).reshape(row_count, padded_row_length)
93
- restored = restored[:, :row_length].contiguous().view(tuple(int(value) for value in spec["shape"]))
94
- state_dict[name] = restored.to(dtype)
95
- continue
96
- state_dict[name] = stored[f"{prefix}__passthrough"].to(_name_to_dtype(spec["original_dtype"]))
97
- return state_dict, manifest
98
-
99
-
100
  def load_quantized_model(repo_dir: str | Path, device: str | torch.device = "cpu"):
 
101
  repo_dir = Path(repo_dir)
102
  state_dict, manifest = load_quantized_state_dict(repo_dir)
103
  model = _load_model_from_state_dict(repo_dir, manifest["loader_kind"], state_dict)
@@ -107,6 +278,7 @@ def load_quantized_model(repo_dir: str | Path, device: str | torch.device = "cpu
107
 
108
 
109
  def load_tokenizer(repo_dir: str | Path):
 
110
  repo_dir = Path(repo_dir)
111
  if (repo_dir / "preprocessor_config.json").exists():
112
  processor = AutoProcessor.from_pretrained(repo_dir, trust_remote_code=True)
 
1
  from __future__ import annotations
2
 
3
  import json
4
+ import math
5
+ from functools import lru_cache
6
  from pathlib import Path
7
 
8
  import torch
 
10
  from transformers import AutoProcessor, AutoTokenizer
11
 
12
  MANIFEST_FILENAME = "quant_manifest.json"
13
+ TURBOQUANT_WEIGHTS_FILENAME = "turboquant_weights.safetensors"
14
 
15
 
16
+ def _name_to_dtype(name: str) -> torch.dtype:
17
+ return getattr(torch, name)
18
+
19
+
20
+ def _unpack_indices(packed: torch.Tensor, bit_width: int, total_values: int) -> torch.Tensor:
21
  packed = packed.reshape(-1).to(torch.uint8).cpu()
22
+ if bit_width >= 8:
23
+ return packed[:total_values].to(torch.int64).contiguous()
24
  values_per_byte = 8 // bit_width
25
  mask = (1 << bit_width) - 1
26
  packed_i32 = packed.to(torch.int32)
27
  parts = []
28
  for index in range(values_per_byte):
29
  parts.append(((packed_i32 >> (index * bit_width)) & mask).to(torch.uint8))
30
+ return torch.stack(parts, dim=1).reshape(-1)[:total_values].to(torch.int64).contiguous()
31
 
32
 
33
+ def _unpack_signs(packed: torch.Tensor, total_values: int) -> torch.Tensor:
34
+ bits = _unpack_indices(packed, 1, total_values)
35
+ return (bits.to(torch.int8) * 2 - 1)
36
+
37
+
38
+ @lru_cache(maxsize=None)
39
+ def _sphere_coordinate_density(dimension: int, grid_size: int):
40
+ """Lemma 1: Beta distribution density on unit sphere coordinates.
41
+
42
+ f_X(x) = Gamma(d/2) / (sqrt(pi) * Gamma((d-1)/2)) * (1 - x^2)^((d-3)/2)
43
+ Reference: arXiv:2504.19874v1, Lemma 1
44
+ """
45
+ eps = 1e-6
46
+ grid = torch.linspace(-1.0 + eps, 1.0 - eps, steps=grid_size, dtype=torch.float64)
47
+ exponent = 0.5 * float(dimension - 3)
48
+ log_const = (
49
+ torch.lgamma(torch.tensor(float(dimension) / 2.0, dtype=torch.float64))
50
+ - 0.5 * math.log(math.pi)
51
+ - torch.lgamma(torch.tensor(float(dimension - 1) / 2.0, dtype=torch.float64))
52
+ )
53
+ interior = torch.clamp(1.0 - grid.square(), min=1e-24)
54
+ density = torch.exp(log_const + exponent * torch.log(interior))
55
+ step = (2.0 - 2.0 * eps) / float(grid_size - 1)
56
+ weights = density * step
57
+ weights = weights / torch.sum(weights)
58
+ return grid, weights
59
+
60
+
61
+ @lru_cache(maxsize=None)
62
+ def _compute_codebook(dimension: int, bit_width: int, grid_size: int = 8193, iterations: int = 96):
63
+ """Lloyd-Max optimal scalar quantization on Beta distribution.
64
+
65
+ Solves Eq (4): argmin_{c_1,...,c_K} E[min_k |X - c_k|^2]
66
+ Reference: arXiv:2504.19874v1, Section 3.1, Eq (4)
67
+ """
68
+ if bit_width == 0:
69
+ return torch.empty(0, dtype=torch.float64)
70
+ grid, weights = _sphere_coordinate_density(dimension, grid_size)
71
+ codebook_size = 1 << bit_width
72
+ cumulative = torch.cumsum(weights, dim=0)
73
+ cumulative = cumulative / cumulative[-1]
74
+ targets = (torch.arange(codebook_size, dtype=torch.float64) + 0.5) / float(codebook_size)
75
+ init_idx = torch.clamp(torch.searchsorted(cumulative, targets), max=grid.numel() - 1)
76
+ centroids = torch.sort(grid[init_idx]).values
77
+ for _ in range(iterations):
78
+ if centroids.numel() == 1:
79
+ break
80
+ boundaries = torch.empty(codebook_size + 1, dtype=torch.float64)
81
+ boundaries[0] = -1.0
82
+ boundaries[-1] = 1.0
83
+ boundaries[1:-1] = 0.5 * (centroids[:-1] + centroids[1:])
84
+ bucket_ids = torch.bucketize(grid, boundaries[1:-1])
85
+ updated = centroids.clone()
86
+ for bucket in range(codebook_size):
87
+ mask = bucket_ids == bucket
88
+ if not bool(torch.any(mask)):
89
+ continue
90
+ w = weights[mask]
91
+ ws = torch.sum(w)
92
+ if float(ws.item()) <= 0.0:
93
+ continue
94
+ updated[bucket] = torch.sum(grid[mask] * w) / ws
95
+ updated = torch.sort(updated).values
96
+ if float(torch.max(torch.abs(updated - centroids)).item()) < 1e-12:
97
+ break
98
+ centroids = updated
99
+ return centroids
100
+
101
+
102
+ @lru_cache(maxsize=None)
103
+ def _cached_rotation_matrix(dimension: int, seed: int):
104
+ """Random rotation via QR decomposition of Gaussian matrix.
105
+
106
+ Reference: arXiv:2504.19874v1, Algorithm 1 (step: generate random rotation Pi)
107
+ """
108
+ gen = torch.Generator(device="cpu")
109
+ gen.manual_seed(seed)
110
+ gaussian = torch.randn((dimension, dimension), generator=gen, dtype=torch.float64)
111
+ q, r = torch.linalg.qr(gaussian, mode="reduced")
112
+ diag = torch.sign(torch.diag(r))
113
+ diag = torch.where(diag == 0, torch.ones_like(diag), diag)
114
+ q = q * diag.unsqueeze(0)
115
+ return q
116
+
117
+
118
+ @lru_cache(maxsize=None)
119
+ def _cached_projection_matrix(dimension: int, seed: int):
120
+ """Random Gaussian projection matrix for QJL.
121
+
122
+ Reference: arXiv:2504.19874v1, Definition 1 (S with iid N(0,1) entries)
123
+ """
124
+ gen = torch.Generator(device="cpu")
125
+ gen.manual_seed(seed)
126
+ return torch.randn((dimension, dimension), generator=gen, dtype=torch.float64)
127
+
128
+
129
+ def _dequantize_mse(indices: torch.Tensor, norms: torch.Tensor, dimension: int,
130
+ bit_width: int, seed: int, grid_size: int = 8193,
131
+ iterations: int = 96):
132
+ """TurboQuant_MSE dequantization (Algorithm 1).
133
+
134
+ Dequant_mse(idx):
135
+ y_j = c_{idx_j} (codebook lookup)
136
+ x_hat = Pi^T * y (rotate back)
137
+ return x_hat * ||x||
138
+
139
+ Reference: arXiv:2504.19874v1, Algorithm 1
140
+ """
141
+ row_count = int(indices.shape[0])
142
+ codebook = _compute_codebook(dimension, bit_width, grid_size, iterations).to(torch.float32)
143
+ rotation = _cached_rotation_matrix(dimension, seed).to(torch.float32)
144
+ rotated = codebook[indices]
145
+ reconstructed = rotated @ rotation
146
+ return reconstructed * norms.to(torch.float32).unsqueeze(-1)
147
+
148
+
149
+ def _dequantize_prod(mse_indices: torch.Tensor, qjl_signs: torch.Tensor,
150
+ norms: torch.Tensor, residual_norms: torch.Tensor,
151
+ dimension: int, bit_width: int, seed: int,
152
+ grid_size: int = 8193, iterations: int = 96):
153
+ """TurboQuant_prod dequantization (Algorithm 2).
154
+
155
+ Dequant_prod(idx, qjl, gamma):
156
+ x_mse = Dequant_mse(idx)
157
+ x_qjl = sqrt(pi/2)/d * S^T * qjl
158
+ x_hat = x_mse + gamma * x_qjl
159
+ return x_hat * ||x||
160
+
161
+ Reference: arXiv:2504.19874v1, Algorithm 2
162
+ """
163
+ mse_bw = max(bit_width - 1, 0)
164
+ codebook = (_compute_codebook(dimension, mse_bw, grid_size, iterations).to(torch.float32)
165
+ if mse_bw > 0 else torch.zeros(1, dtype=torch.float32))
166
+ rotation = _cached_rotation_matrix(dimension, seed).to(torch.float32)
167
+ if mse_bw > 0:
168
+ mse_reconstructed = codebook[mse_indices] @ rotation
169
+ else:
170
+ mse_reconstructed = torch.zeros(mse_indices.shape[0], dimension, dtype=torch.float32)
171
+ projection = _cached_projection_matrix(dimension, seed + 1).to(torch.float32)
172
+ qjl_reconstructed = (math.sqrt(math.pi / 2.0) / float(dimension)) * (
173
+ qjl_signs.to(torch.float32) @ projection
174
+ )
175
+ reconstructed_unit = (mse_reconstructed
176
+ + residual_norms.to(torch.float32).unsqueeze(-1) * qjl_reconstructed)
177
+ return reconstructed_unit * norms.to(torch.float32).unsqueeze(-1)
178
+
179
+
180
+ def load_quantized_state_dict(repo_dir: str | Path):
181
+ """Load quantized state dict from TurboQuant checkpoint files."""
182
+ repo_dir = Path(repo_dir)
183
+ manifest = json.loads((repo_dir / MANIFEST_FILENAME).read_text(encoding="utf-8"))
184
+ stored = load_file(str(repo_dir / TURBOQUANT_WEIGHTS_FILENAME), device="cpu")
185
+ state_dict = {}
186
+ for name, spec in manifest["parameter_specs"].items():
187
+ prefix = spec["storage_prefix"]
188
+ if not spec["quantized"]:
189
+ state_dict[name] = stored[f"{prefix}__passthrough"].to(
190
+ _name_to_dtype(spec["original_dtype"])
191
+ )
192
+ continue
193
+ quantizer_type = spec["quantizer_type"]
194
+ dimension = int(spec["dimension"])
195
+ row_count = int(spec["row_count"])
196
+ bit_width = int(spec["bit_width"])
197
+ seed = int(spec["seed"])
198
+ grid_size = int(spec["grid_size"])
199
+ iterations = int(spec["iterations"])
200
+ original_shape = [int(v) for v in spec["shape"]]
201
+ original_dtype = _name_to_dtype(spec["original_dtype"])
202
+ if quantizer_type == "turboquant_mse":
203
+ total_idx = int(spec["total_index_values"])
204
+ indices = _unpack_indices(
205
+ stored[f"{prefix}__packed_indices"], bit_width, total_idx
206
+ ).reshape(row_count, dimension)
207
+ norms = stored[f"{prefix}__norms"].to(torch.float32)
208
+ state_dict[name] = _dequantize_mse(
209
+ indices, norms, dimension, bit_width, seed, grid_size, iterations
210
+ ).reshape(original_shape).to(original_dtype)
211
+ elif quantizer_type == "turboquant_prod":
212
+ mse_bw = max(bit_width - 1, 0)
213
+ total_idx = int(spec["total_index_values"])
214
+ total_signs = int(spec["total_sign_values"])
215
+ if mse_bw > 0:
216
+ mse_indices = _unpack_indices(
217
+ stored[f"{prefix}__packed_indices"], mse_bw, total_idx
218
+ ).reshape(row_count, dimension)
219
+ else:
220
+ mse_indices = torch.zeros(row_count, dimension, dtype=torch.int64)
221
+ qjl_signs = _unpack_signs(
222
+ stored[f"{prefix}__packed_signs"], total_signs
223
+ ).reshape(row_count, dimension)
224
+ norms = stored[f"{prefix}__norms"].to(torch.float32)
225
+ residual_norms = stored[f"{prefix}__residual_norms"].to(torch.float32)
226
+ state_dict[name] = _dequantize_prod(
227
+ mse_indices, qjl_signs, norms, residual_norms,
228
+ dimension, bit_width, seed, grid_size, iterations
229
+ ).reshape(original_shape).to(original_dtype)
230
+ else:
231
+ raise ValueError(f"Unknown quantizer_type: {quantizer_type}")
232
+ return state_dict, manifest
233
 
234
 
235
  def _create_empty_model(repo_dir: Path, loader_kind: str):
236
  from transformers import AutoConfig
 
237
  config = AutoConfig.from_pretrained(repo_dir, trust_remote_code=True)
238
  if loader_kind == "causal-lm":
239
  from transformers import AutoModelForCausalLM
 
244
  raise ValueError(f"Unsupported loader kind: {loader_kind}")
245
 
246
 
247
+ def _load_model_from_state_dict(repo_dir: Path, loader_kind: str,
248
+ state_dict: dict[str, torch.Tensor]):
249
  model = _create_empty_model(repo_dir, loader_kind)
250
  incompatible = model.load_state_dict(state_dict, strict=False, assign=True)
251
  if hasattr(model, "tie_weights"):
252
  model.tie_weights()
 
253
  allowed_missing = {"lm_head.weight"}
254
  allowed_unexpected_prefixes = ("mtp.",)
255
+ disallowed_missing = sorted(
256
+ key for key in incompatible.missing_keys if key not in allowed_missing
257
+ )
258
  disallowed_unexpected = sorted(
259
+ key for key in incompatible.unexpected_keys
260
+ if not key.startswith(allowed_unexpected_prefixes)
261
  )
262
  if disallowed_missing or disallowed_unexpected:
263
  raise RuntimeError(
264
+ "Unexpected state_dict mismatch: "
265
  f"missing={disallowed_missing}, unexpected={disallowed_unexpected}"
266
  )
267
  return model
268
 
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  def load_quantized_model(repo_dir: str | Path, device: str | torch.device = "cpu"):
271
+ """Load the quantized model, dequantize weights, and return (model, manifest)."""
272
  repo_dir = Path(repo_dir)
273
  state_dict, manifest = load_quantized_state_dict(repo_dir)
274
  model = _load_model_from_state_dict(repo_dir, manifest["loader_kind"], state_dict)
 
278
 
279
 
280
  def load_tokenizer(repo_dir: str | Path):
281
+ """Load the tokenizer from the repo directory."""
282
  repo_dir = Path(repo_dir)
283
  if (repo_dir / "preprocessor_config.json").exists():
284
  processor = AutoProcessor.from_pretrained(repo_dir, trust_remote_code=True)