LisaMegaWatts commited on
Commit
7759445
Β·
verified Β·
1 Parent(s): c8115ca

Upload juliaslm_svd_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. juliaslm_svd_model.py +267 -0
juliaslm_svd_model.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """JuliaSLM-compressed-svd β€” SVD-compressed inference model.
2
+
3
+ LLaMA-style decoder with SVD-factored weight matrices. Each linear layer
4
+ stores low-rank factors A (out, rank) and B (rank, in) instead of the full
5
+ weight matrix, reducing parameter count while preserving model quality.
6
+
7
+ Architecture: MHA (4 heads), RMSNorm, SwiGLU, RoPE, weight-tied output.
8
+ Base config: d_model=256, n_layers=6, n_heads=4, head_dim=64, ctx=256,
9
+ vocab=2000, SVD-90 compression (~4.81M params).
10
+ """
11
+ import math
12
+ from dataclasses import dataclass, field
13
+ from typing import Optional
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+
20
+ # ═══════════════════════════════════════════════════════════════════
21
+ # Configuration
22
+ # ═══════════════════════════════════════════════════════════════════
23
+
24
+
25
+ @dataclass
26
+ class SVDConfig:
27
+ d_model: int = 256
28
+ n_layers: int = 6
29
+ n_heads: int = 4
30
+ head_dim: int = 64
31
+ ffn_inner: int = 640
32
+ context_length: int = 256
33
+ vocab_size: int = 2000
34
+ weight_tying: bool = True
35
+ rope_base: float = 10000.0
36
+ # Per-layer SVD ranks: list of dicts with keys wq, wk, wv, wo, w1, v, w2
37
+ layer_ranks: list = field(default_factory=list)
38
+
39
+ @staticmethod
40
+ def from_checkpoint(state_dict: dict) -> "SVDConfig":
41
+ """Build config by inspecting checkpoint tensor shapes."""
42
+ vocab_size, d_model = state_dict["tok_emb.weight"].shape
43
+ ctx_len = state_dict["rope.cos_cache"].shape[0]
44
+ head_dim = state_dict["rope.cos_cache"].shape[1] * 2 # cos_cache is half
45
+ n_heads = d_model // head_dim
46
+ ffn_inner = state_dict["blocks.0.ffn.w1.A"].shape[0]
47
+
48
+ n_layers = max(
49
+ int(k.split(".")[1])
50
+ for k in state_dict
51
+ if k.startswith("blocks.")
52
+ ) + 1
53
+
54
+ layer_ranks = []
55
+ for i in range(n_layers):
56
+ ranks = {}
57
+ for name in ("wq", "wk", "wv", "wo"):
58
+ ranks[name] = state_dict[f"blocks.{i}.attn.{name}.A"].shape[1]
59
+ for name in ("w1", "v", "w2"):
60
+ ranks[name] = state_dict[f"blocks.{i}.ffn.{name}.A"].shape[1]
61
+ layer_ranks.append(ranks)
62
+
63
+ return SVDConfig(
64
+ d_model=d_model,
65
+ n_layers=n_layers,
66
+ n_heads=n_heads,
67
+ head_dim=head_dim,
68
+ ffn_inner=ffn_inner,
69
+ context_length=ctx_len,
70
+ vocab_size=vocab_size,
71
+ layer_ranks=layer_ranks,
72
+ )
73
+
74
+
75
+ # ═══════════════════════════════════════════════════════════════════
76
+ # Building blocks
77
+ # ═══════════════════════════════════════════════════════════════════
78
+
79
+
80
+ class SVDLinear(nn.Module):
81
+ """Linear layer stored as low-rank A @ B factorization.
82
+
83
+ Forward: x @ B^T @ A^T (equivalent to x @ (A @ B)^T = x @ W^T)
84
+ where W β‰ˆ A @ B with A: (out, rank), B: (rank, in).
85
+ """
86
+
87
+ def __init__(self, out_features: int, rank: int, in_features: int):
88
+ super().__init__()
89
+ self.A = nn.Parameter(torch.empty(out_features, rank))
90
+ self.B = nn.Parameter(torch.empty(rank, in_features))
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ return F.linear(F.linear(x, self.B), self.A)
94
+
95
+
96
+ class RMSNorm(nn.Module):
97
+ def __init__(self, dim: int, eps: float = 1e-6):
98
+ super().__init__()
99
+ self.weight = nn.Parameter(torch.ones(dim))
100
+ self.eps = eps
101
+
102
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
103
+ rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
104
+ return x / rms * self.weight
105
+
106
+
107
+ class RotaryEmbedding(nn.Module):
108
+ def __init__(self, dim: int, max_seq_len: int = 256, base: float = 10000.0):
109
+ super().__init__()
110
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
111
+ positions = torch.arange(max_seq_len).float()
112
+ angles = torch.outer(positions, freqs)
113
+ self.register_buffer("cos_cache", angles.cos())
114
+ self.register_buffer("sin_cache", angles.sin())
115
+
116
+ def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
117
+ seq_len = x.size(2)
118
+ half = x.size(-1) // 2
119
+ x1, x2 = x[..., :half], x[..., half:]
120
+ cos = self.cos_cache[start_pos:start_pos + seq_len, :half].unsqueeze(0).unsqueeze(0)
121
+ sin = self.sin_cache[start_pos:start_pos + seq_len, :half].unsqueeze(0).unsqueeze(0)
122
+ return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
123
+
124
+
125
+ class SVDSwiGLU(nn.Module):
126
+ """SwiGLU FFN with SVD-compressed linear layers."""
127
+
128
+ def __init__(self, d_model: int, inner_dim: int, ranks: dict):
129
+ super().__init__()
130
+ self.w1 = SVDLinear(inner_dim, ranks["w1"], d_model)
131
+ self.v = SVDLinear(inner_dim, ranks["v"], d_model)
132
+ self.w2 = SVDLinear(d_model, ranks["w2"], inner_dim)
133
+
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ return self.w2(F.silu(self.w1(x)) * self.v(x))
136
+
137
+
138
+ class SVDCausalAttention(nn.Module):
139
+ """Multi-head attention with SVD-compressed projections and KV cache."""
140
+
141
+ def __init__(self, d_model: int, n_heads: int, head_dim: int, ranks: dict):
142
+ super().__init__()
143
+ self.n_heads = n_heads
144
+ self.head_dim = head_dim
145
+ self.scale = 1.0 / math.sqrt(head_dim)
146
+
147
+ self.wq = SVDLinear(n_heads * head_dim, ranks["wq"], d_model)
148
+ self.wk = SVDLinear(n_heads * head_dim, ranks["wk"], d_model)
149
+ self.wv = SVDLinear(n_heads * head_dim, ranks["wv"], d_model)
150
+ self.wo = SVDLinear(d_model, ranks["wo"], n_heads * head_dim)
151
+
152
+ def forward(
153
+ self,
154
+ x: torch.Tensor,
155
+ rope: RotaryEmbedding,
156
+ mask: Optional[torch.Tensor],
157
+ kv_cache: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
158
+ start_pos: int = 0,
159
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
160
+ B, T, _ = x.shape
161
+ H, HD = self.n_heads, self.head_dim
162
+
163
+ q = self.wq(x).view(B, T, H, HD).transpose(1, 2)
164
+ k = self.wk(x).view(B, T, H, HD).transpose(1, 2)
165
+ v = self.wv(x).view(B, T, H, HD).transpose(1, 2)
166
+
167
+ q = rope(q, start_pos)
168
+ k = rope(k, start_pos)
169
+
170
+ if kv_cache is not None:
171
+ prev_k, prev_v = kv_cache
172
+ k = torch.cat([prev_k, k], dim=2)
173
+ v = torch.cat([prev_v, v], dim=2)
174
+ new_cache = (k, v)
175
+
176
+ attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
177
+ if mask is not None:
178
+ attn = attn + mask
179
+ attn = F.softmax(attn, dim=-1)
180
+ out = torch.matmul(attn, v)
181
+
182
+ out = out.transpose(1, 2).contiguous().view(B, T, H * HD)
183
+ return self.wo(out), new_cache
184
+
185
+
186
+ # ═══════════════════════════════════════════════════════════════════
187
+ # Transformer block and model
188
+ # ═══════════════════════════════════════════════════════════════════
189
+
190
+
191
+ class SVDTransformerBlock(nn.Module):
192
+ def __init__(self, config: SVDConfig, layer_idx: int):
193
+ super().__init__()
194
+ ranks = config.layer_ranks[layer_idx]
195
+ self.ln1 = RMSNorm(config.d_model)
196
+ self.attn = SVDCausalAttention(
197
+ config.d_model, config.n_heads, config.head_dim, ranks
198
+ )
199
+ self.ln2 = RMSNorm(config.d_model)
200
+ self.ffn = SVDSwiGLU(config.d_model, config.ffn_inner, ranks)
201
+
202
+ def forward(
203
+ self,
204
+ x: torch.Tensor,
205
+ rope: RotaryEmbedding,
206
+ mask: Optional[torch.Tensor],
207
+ kv_cache: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
208
+ start_pos: int = 0,
209
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
210
+ attn_out, new_cache = self.attn(self.ln1(x), rope, mask, kv_cache, start_pos)
211
+ x = x + attn_out
212
+ x = x + self.ffn(self.ln2(x))
213
+ return x, new_cache
214
+
215
+
216
+ class JuliaSLM_SVD(nn.Module):
217
+ def __init__(self, config: SVDConfig):
218
+ super().__init__()
219
+ self.config = config
220
+ self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
221
+ self.rope = RotaryEmbedding(config.head_dim, config.context_length, config.rope_base)
222
+ self.blocks = nn.ModuleList(
223
+ [SVDTransformerBlock(config, i) for i in range(config.n_layers)]
224
+ )
225
+ self.ln_f = RMSNorm(config.d_model)
226
+
227
+ causal = torch.triu(
228
+ torch.full((config.context_length, config.context_length), float("-inf")),
229
+ diagonal=1,
230
+ )
231
+ self.register_buffer("causal_mask", causal)
232
+
233
+ def forward(
234
+ self,
235
+ input_ids: torch.Tensor,
236
+ kv_caches: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None,
237
+ ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
238
+ """Forward pass with optional KV cache.
239
+
240
+ Without cache (prefill): processes full sequence with causal mask.
241
+ With cache (decode): processes only new token(s), O(1) per token.
242
+ """
243
+ B, T = input_ids.shape
244
+ x = self.tok_emb(input_ids)
245
+
246
+ if kv_caches is not None:
247
+ start_pos = kv_caches[0][0].size(2)
248
+ mask = None
249
+ else:
250
+ start_pos = 0
251
+ mask = self.causal_mask[:T, :T].to(dtype=x.dtype)
252
+ kv_caches = [None] * len(self.blocks)
253
+
254
+ new_caches = []
255
+ for block, cache in zip(self.blocks, kv_caches):
256
+ x, new_cache = block(x, self.rope, mask, cache, start_pos)
257
+ new_caches.append(new_cache)
258
+
259
+ x = self.ln_f(x)
260
+ # Weight-tied output projection
261
+ logits = F.linear(x, self.tok_emb.weight)
262
+
263
+ return logits, new_caches
264
+
265
+ @property
266
+ def num_parameters(self) -> int:
267
+ return sum(p.numel() for p in self.parameters())