| |
| |
| |
|
|
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| LLAMA32_CONFIG_1B = { |
| "vocab_size": 128_256, |
| "context_length": 131_072, |
| "emb_dim": 2048, |
| "n_heads": 32, |
| "n_layers": 16, |
| "hidden_dim": 8192, |
| "n_kv_groups": 8, |
| "rope_base": 500_000.0, |
| "dtype": torch.bfloat16, |
| "rope_freq": { |
| "factor": 32.0, |
| "low_freq_factor": 1.0, |
| "high_freq_factor": 4.0, |
| "original_context_length": 8192, |
| } |
| } |
|
|
| LLAMA32_CONFIG_3B = { |
| "vocab_size": 128_256, |
| "context_length": 131_072, |
| "emb_dim": 3072, |
| "n_heads": 24, |
| "n_layers": 28, |
| "hidden_dim": 8192, |
| "n_kv_groups": 8, |
| "rope_base": 500_000.0, |
| "dtype": torch.bfloat16, |
| "rope_freq": { |
| "factor": 32.0, |
| "low_freq_factor": 1.0, |
| "high_freq_factor": 4.0, |
| "original_context_length": 8192, |
| } |
| } |
|
|
|
|
| class Llama3Model(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
|
|
| |
| self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"]) |
|
|
| self.trf_blocks = nn.ModuleList( |
| [TransformerBlock(cfg) for _ in range(cfg["n_layers"])] |
| ) |
|
|
| self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) |
| self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) |
|
|
| |
| cos, sin = compute_rope_params( |
| head_dim=cfg["emb_dim"] // cfg["n_heads"], |
| theta_base=cfg["rope_base"], |
| context_length=cfg["context_length"], |
| freq_config=cfg["rope_freq"] |
| ) |
| self.register_buffer("cos", cos, persistent=False) |
| self.register_buffer("sin", sin, persistent=False) |
| self.cfg = cfg |
|
|
| def forward(self, in_idx): |
| tok_embeds = self.tok_emb(in_idx) |
| x = tok_embeds |
|
|
| num_tokens = x.shape[1] |
| mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1) |
|
|
| for block in self.trf_blocks: |
| x = block(x, mask, self.cos, self.sin) |
| x = self.final_norm(x) |
| logits = self.out_head(x.to(self.cfg["dtype"])) |
| return logits |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.att = GroupedQueryAttention( |
| d_in=cfg["emb_dim"], |
| d_out=cfg["emb_dim"], |
| num_heads=cfg["n_heads"], |
| num_kv_groups=cfg["n_kv_groups"], |
| dtype=cfg["dtype"] |
| ) |
| self.ff = FeedForward(cfg) |
| self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) |
| self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) |
|
|
| def forward(self, x, mask, cos, sin): |
| |
| shortcut = x |
| x = self.norm1(x) |
| x = self.att(x, mask, cos, sin) |
| x = x + shortcut |
|
|
| |
| shortcut = x |
| x = self.norm2(x) |
| x = self.ff(x) |
| x = x + shortcut |
|
|
| return x |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False) |
| self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False) |
| self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False) |
|
|
| def forward(self, x): |
| x_fc1 = self.fc1(x) |
| x_fc2 = self.fc2(x) |
| x = nn.functional.silu(x_fc1) * x_fc2 |
| return self.fc3(x) |
|
|
|
|
| class GroupedQueryAttention(nn.Module): |
| def __init__( |
| self, d_in, d_out, num_heads, num_kv_groups, dtype=None |
| ): |
| super().__init__() |
| assert d_out % num_heads == 0, "d_out must be divisible by num_heads" |
| assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups" |
|
|
| self.d_out = d_out |
| self.num_heads = num_heads |
| self.head_dim = d_out // num_heads |
|
|
| self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) |
| self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) |
| self.num_kv_groups = num_kv_groups |
| self.group_size = num_heads // num_kv_groups |
|
|
| self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype) |
| self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) |
|
|
| def forward(self, x, mask, cos, sin): |
| b, num_tokens, d_in = x.shape |
|
|
| queries = self.W_query(x) |
| keys = self.W_key(x) |
| values = self.W_value(x) |
|
|
| |
| queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) |
| keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim) |
| values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim) |
|
|
| |
| keys = keys.transpose(1, 2) |
| values = values.transpose(1, 2) |
| queries = queries.transpose(1, 2) |
|
|
| |
| keys = apply_rope(keys, cos, sin) |
| queries = apply_rope(queries, cos, sin) |
|
|
| |
| |
| keys = keys.repeat_interleave(self.group_size, dim=1) |
| values = values.repeat_interleave(self.group_size, dim=1) |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| attn_scores = queries @ keys.transpose(2, 3) |
|
|
| |
| attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf) |
|
|
| attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) |
| assert keys.shape[-1] == self.head_dim |
|
|
| |
| context_vec = (attn_weights @ values).transpose(1, 2) |
|
|
| |
| context_vec = context_vec.reshape(b, num_tokens, self.d_out) |
| context_vec = self.out_proj(context_vec) |
|
|
| return context_vec |
|
|
|
|
| def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32): |
| assert head_dim % 2 == 0, "Embedding dimension must be even" |
|
|
| |
| inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim)) |
|
|
| |
| if freq_config is not None: |
| low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"] |
| high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"] |
|
|
| wavelen = 2 * torch.pi / inv_freq |
|
|
| inv_freq_llama = torch.where( |
| wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq |
| ) |
|
|
| smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / ( |
| freq_config["high_freq_factor"] - freq_config["low_freq_factor"] |
| ) |
|
|
| smoothed_inv_freq = ( |
| (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq |
| ) |
|
|
| is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen) |
| inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) |
| inv_freq = inv_freq_llama |
|
|
| |
| positions = torch.arange(context_length, dtype=dtype) |
|
|
| |
| angles = positions[:, None] * inv_freq[None, :] |
|
|
| |
| angles = torch.cat([angles, angles], dim=1) |
|
|
| |
| cos = torch.cos(angles) |
| sin = torch.sin(angles) |
|
|
| return cos, sin |
|
|
|
|
| def apply_rope(x, cos, sin): |
| |
| batch_size, num_heads, seq_len, head_dim = x.shape |
| assert head_dim % 2 == 0, "Head dimension must be even" |
|
|
| |
| x1 = x[..., : head_dim // 2] |
| x2 = x[..., head_dim // 2:] |
|
|
| |
| cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) |
| sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) |
|
|
| |
| rotated = torch.cat((-x2, x1), dim=-1) |
| x_rotated = (x * cos) + (rotated * sin) |
|
|
| |
| return x_rotated.to(dtype=x.dtype) |
|
|
|
|
| def text_to_token_ids(text, tokenizer): |
| encoded = tokenizer.encode(text) |
| encoded_tensor = torch.tensor(encoded).unsqueeze(0) |
| return encoded_tensor |
|
|
|
|
| def token_ids_to_text(token_ids, tokenizer): |
| flat = token_ids.squeeze(0) |
| return tokenizer.decode(flat.tolist()) |
|
|
|
|
| def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None): |
|
|
| |
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -context_size:] |
| with torch.no_grad(): |
| logits = model(idx_cond) |
| logits = logits[:, -1, :] |
|
|
| |
| if top_k is not None: |
| |
| top_logits, _ = torch.topk(logits, top_k) |
| min_val = top_logits[:, -1] |
| logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits) |
|
|
| |
| if temperature > 0.0: |
| logits = logits / temperature |
|
|
| |
| probs = torch.softmax(logits, dim=-1) |
|
|
| |
| idx_next = torch.multinomial(probs, num_samples=1) |
|
|
| |
| else: |
| idx_next = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
| if idx_next == eos_id: |
| break |
|
|
| |
| idx = torch.cat((idx, idx_next), dim=1) |
|
|
| return idx |
|
|