Commit ·
8e6eafa
0
Parent(s):
- .gitattributes +35 -0
- MANIFEST.in +4 -0
- README.md +97 -0
- config.json +44 -0
- gemma3_tinystories/__init__.py +3 -0
- gemma3_tinystories/modeling_gemma3.py +229 -0
- pytorch_model.bin +3 -0
- setup.py +23 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
MANIFEST.in
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include *.json
|
| 2 |
+
include README.md
|
| 3 |
+
include pytorch_model.bin
|
| 4 |
+
recursive-include gemma3_hf *.py
|
README.md
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language: en
|
| 4 |
+
tags:
|
| 5 |
+
- gemma3
|
| 6 |
+
- slm
|
| 7 |
+
- tinystories
|
| 8 |
+
model_type: gemma3
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Gemma3 270M: Small Language Model Implementation from Scratch
|
| 12 |
+
|
| 13 |
+
This repository contains a **PyTorch implementation** of the **Google's Gemma3 model** with **270 million parameters**, trained from scratch on the **TinyStories** dataset.
|
| 14 |
+
|
| 15 |
+
[Github Repo Link](https://github.com/ShubhamWaghmare11/Gemma3_270M_Scratch)
|
| 16 |
+
|
| 17 |
+
## About
|
| 18 |
+
|
| 19 |
+
This model is based on Google DeepMind's Gemma3 architecture and was built from the ground up to explore training dynamics, architecture design, and generation quality of small LLMs. It includes advanced components such as:
|
| 20 |
+
|
| 21 |
+
- Sliding Window Attention (512-token window)
|
| 22 |
+
|
| 23 |
+
- Rotary Positional Embeddings (RoPE)
|
| 24 |
+
|
| 25 |
+
- RMSNorm for stable training
|
| 26 |
+
|
| 27 |
+
- Grouped Key-Value Attention (1 KV group)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
## Model Architecture
|
| 31 |
+
|
| 32 |
+
- Parameters: 270M total (170M embedding + 100M transformer)
|
| 33 |
+
- Layers: 18 transformer blocks
|
| 34 |
+
- Attention Heads: 4 Query Heads, 1 KV Group
|
| 35 |
+
- Hidden Dimension: 2048
|
| 36 |
+
- Embedding Dimension: 640
|
| 37 |
+
- Head Dimension: 256
|
| 38 |
+
- Vocabulary Size: 50,257 (GPT-2 tokenizer)
|
| 39 |
+
- Context Length: 32,768 tokens (trained with 128 block size)
|
| 40 |
+
- Sliding Window: 512 tokens
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
## Training Details
|
| 45 |
+
|
| 46 |
+
- Dataset: [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) by Roneneldan
|
| 47 |
+
|
| 48 |
+
- Steps: 150,000 steps (not epochs)
|
| 49 |
+
|
| 50 |
+
- Batch Size: 32
|
| 51 |
+
|
| 52 |
+
- Loss Function: Cross-Entropy
|
| 53 |
+
|
| 54 |
+
- Optimizer: AdamW
|
| 55 |
+
|
| 56 |
+
- LR Scheduler: Linear Warmup + Cosine Decay
|
| 57 |
+
|
| 58 |
+
- Hardware: Single NVIDIA A100 GPU
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
## Requirements
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
pip install git+https://huggingface.co/Shubhamw11/Gemma-270M-TinyStories
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## How to use
|
| 72 |
+
|
| 73 |
+
You can now load and use the model from the `gemma3_tinystories` library:
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
from gemma3_tinystories import HFGemma3Model, Gemma3Config
|
| 77 |
+
import tiktoken
|
| 78 |
+
import torch
|
| 79 |
+
|
| 80 |
+
config = Gemma3Config.from_pretrained("Shubhamw11/Gemma-270M-TinyStories")
|
| 81 |
+
model = HFGemma3Model.from_pretrained("Shubhamw11/Gemma-270M-TinyStories", config=config).model
|
| 82 |
+
tokenizer = tiktoken.get_encoding("gpt2")
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
# Generate text
|
| 86 |
+
```python
|
| 87 |
+
|
| 88 |
+
#define the device
|
| 89 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 90 |
+
|
| 91 |
+
input_text = "Once upon a time, there was a little"
|
| 92 |
+
context = torch.tensor(tokenizer.encode(input_text), dtype=torch.long).unsqueeze(0).to(device)
|
| 93 |
+
model.to(device)
|
| 94 |
+
response = model.generate(context, max_new_tokens=200, temperature=1.1, top_k=5)
|
| 95 |
+
|
| 96 |
+
print(tokenizer.decode(response.squeeze().tolist()))
|
| 97 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"context_length": 32768,
|
| 3 |
+
"dtype": "bfloat16",
|
| 4 |
+
"emb_dim": 640,
|
| 5 |
+
"finetuning_task": null,
|
| 6 |
+
"head_dim": 256,
|
| 7 |
+
"hidden_dim": 2048,
|
| 8 |
+
"layer_types": [
|
| 9 |
+
"sliding_attention",
|
| 10 |
+
"sliding_attention",
|
| 11 |
+
"sliding_attention",
|
| 12 |
+
"sliding_attention",
|
| 13 |
+
"sliding_attention",
|
| 14 |
+
"full_attention",
|
| 15 |
+
"sliding_attention",
|
| 16 |
+
"sliding_attention",
|
| 17 |
+
"sliding_attention",
|
| 18 |
+
"sliding_attention",
|
| 19 |
+
"sliding_attention",
|
| 20 |
+
"full_attention",
|
| 21 |
+
"sliding_attention",
|
| 22 |
+
"sliding_attention",
|
| 23 |
+
"sliding_attention",
|
| 24 |
+
"sliding_attention",
|
| 25 |
+
"sliding_attention",
|
| 26 |
+
"full_attention"
|
| 27 |
+
],
|
| 28 |
+
"n_heads": 4,
|
| 29 |
+
"n_kv_groups": 1,
|
| 30 |
+
"n_layers": 18,
|
| 31 |
+
"num_labels": 2,
|
| 32 |
+
"output_attentions": false,
|
| 33 |
+
"output_hidden_states": false,
|
| 34 |
+
"output_past": true,
|
| 35 |
+
"pruned_heads": {},
|
| 36 |
+
"qk_norm": true,
|
| 37 |
+
"query_pre_attn_scalar": 256,
|
| 38 |
+
"rope_base": 1000000.0,
|
| 39 |
+
"rope_local_base": 10000.0,
|
| 40 |
+
"sliding_window": 512,
|
| 41 |
+
"torchscript": false,
|
| 42 |
+
"use_bfloat16": false,
|
| 43 |
+
"vocab_size": 50257
|
| 44 |
+
}
|
gemma3_tinystories/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .modeling_gemma3 import HFGemma3DPONegative, Gemma3Config
|
| 2 |
+
|
| 3 |
+
__all__ = ["HFGemma3DPONegative", "Gemma3Config"]
|
gemma3_tinystories/modeling_gemma3.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
# ------------------- Config -------------------
|
| 7 |
+
class Gemma3Config(PretrainedConfig):
|
| 8 |
+
model_type = "gemma3"
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
vocab_size=50257,
|
| 13 |
+
emb_dim=512,
|
| 14 |
+
n_layers=12,
|
| 15 |
+
n_heads=8,
|
| 16 |
+
n_kv_groups=2,
|
| 17 |
+
head_dim=None,
|
| 18 |
+
context_length=1024,
|
| 19 |
+
sliding_window=64,
|
| 20 |
+
rope_base=10000,
|
| 21 |
+
rope_local_base=10000,
|
| 22 |
+
layer_types=None,
|
| 23 |
+
qk_norm=False,
|
| 24 |
+
query_pre_attn_scalar=None,
|
| 25 |
+
dtype="float32",
|
| 26 |
+
hidden_dim=2048,
|
| 27 |
+
**kwargs
|
| 28 |
+
):
|
| 29 |
+
super().__init__(**kwargs)
|
| 30 |
+
self.vocab_size = vocab_size
|
| 31 |
+
self.emb_dim = emb_dim
|
| 32 |
+
self.n_layers = n_layers
|
| 33 |
+
self.n_heads = n_heads
|
| 34 |
+
self.n_kv_groups = n_kv_groups
|
| 35 |
+
self.head_dim = head_dim
|
| 36 |
+
self.context_length = context_length
|
| 37 |
+
self.sliding_window = sliding_window
|
| 38 |
+
self.rope_base = rope_base
|
| 39 |
+
self.rope_local_base = rope_local_base
|
| 40 |
+
self.layer_types = layer_types or ["sliding_attention"] * n_layers
|
| 41 |
+
self.qk_norm = qk_norm
|
| 42 |
+
self.query_pre_attn_scalar = query_pre_attn_scalar
|
| 43 |
+
self.dtype = dtype
|
| 44 |
+
self.hidden_dim = hidden_dim
|
| 45 |
+
|
| 46 |
+
# ------------------- Utilities -------------------
|
| 47 |
+
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
|
| 48 |
+
assert head_dim % 2 == 0, "Embedding dimension must be even"
|
| 49 |
+
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype) / head_dim))
|
| 50 |
+
positions = torch.arange(context_length, dtype=dtype)
|
| 51 |
+
angles = positions[:, None] * inv_freq[None, :]
|
| 52 |
+
angles = torch.cat([angles, angles], dim=1)
|
| 53 |
+
return torch.cos(angles), torch.sin(angles)
|
| 54 |
+
|
| 55 |
+
def apply_rope(x, cos, sin):
|
| 56 |
+
batch_size, num_heads, seq_len, head_dim = x.shape
|
| 57 |
+
x1, x2 = x[..., :head_dim//2], x[..., head_dim//2:]
|
| 58 |
+
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
|
| 59 |
+
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
|
| 60 |
+
rotated = torch.cat((-x2, x1), dim=-1)
|
| 61 |
+
return ((x * cos) + (rotated * sin)).to(dtype=x.dtype)
|
| 62 |
+
|
| 63 |
+
class RMSNorm(nn.Module):
|
| 64 |
+
def __init__(self, emb_dim, eps=1e-6, bias=False):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.eps = eps
|
| 67 |
+
self.scale = nn.Parameter(torch.zeros(emb_dim))
|
| 68 |
+
self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
input_dtype = x.dtype
|
| 72 |
+
x_f = x.float()
|
| 73 |
+
var = x_f.pow(2).mean(dim=-1, keepdim=True)
|
| 74 |
+
x_norm = x_f * torch.rsqrt(var + self.eps)
|
| 75 |
+
out = x_norm * (1.0 + self.scale.float())
|
| 76 |
+
if self.shift is not None:
|
| 77 |
+
out = out + self.shift.float()
|
| 78 |
+
return out.to(input_dtype)
|
| 79 |
+
|
| 80 |
+
# ------------------- Core Layers -------------------
|
| 81 |
+
class GroupedQueryAttention(nn.Module):
|
| 82 |
+
def __init__(self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False,
|
| 83 |
+
query_pre_attn_scalar=None, dtype=torch.float32):
|
| 84 |
+
super().__init__()
|
| 85 |
+
if isinstance(dtype, str):
|
| 86 |
+
dtype = getattr(torch, dtype)
|
| 87 |
+
self.num_heads = num_heads
|
| 88 |
+
self.num_kv_groups = num_kv_groups
|
| 89 |
+
self.group_size = num_heads // num_kv_groups
|
| 90 |
+
if head_dim is None:
|
| 91 |
+
head_dim = d_in // num_heads
|
| 92 |
+
self.head_dim = head_dim
|
| 93 |
+
self.d_out = num_heads * head_dim
|
| 94 |
+
|
| 95 |
+
self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)
|
| 96 |
+
self.W_key = nn.Linear(d_in, num_kv_groups*head_dim, bias=False, dtype=dtype)
|
| 97 |
+
self.W_value = nn.Linear(d_in, num_kv_groups*head_dim, bias=False, dtype=dtype)
|
| 98 |
+
self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)
|
| 99 |
+
|
| 100 |
+
self.q_norm = RMSNorm(head_dim) if qk_norm else None
|
| 101 |
+
self.k_norm = RMSNorm(head_dim) if qk_norm else None
|
| 102 |
+
self.scaling = (query_pre_attn_scalar**-0.5 if query_pre_attn_scalar else head_dim**-0.5)
|
| 103 |
+
|
| 104 |
+
def forward(self, x, mask, cos, sin):
|
| 105 |
+
b, num_tokens, _ = x.shape
|
| 106 |
+
q = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1,2)
|
| 107 |
+
k = self.W_key(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1,2)
|
| 108 |
+
v = self.W_value(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1,2)
|
| 109 |
+
if self.q_norm: q = self.q_norm(q)
|
| 110 |
+
if self.k_norm: k = self.k_norm(k)
|
| 111 |
+
q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
|
| 112 |
+
k, v = k.repeat_interleave(self.group_size, dim=1), v.repeat_interleave(self.group_size, dim=1)
|
| 113 |
+
q = q * self.scaling
|
| 114 |
+
attn_scores = q @ k.transpose(2,3)
|
| 115 |
+
attn_scores = attn_scores.masked_fill(mask, float('-inf'))
|
| 116 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
| 117 |
+
context = (attn_weights @ v).transpose(1,2).reshape(b, num_tokens, self.d_out)
|
| 118 |
+
return self.out_proj(context)
|
| 119 |
+
|
| 120 |
+
class FeedForward(nn.Module):
|
| 121 |
+
def __init__(self, cfg):
|
| 122 |
+
dtype = cfg["dtype"]
|
| 123 |
+
if isinstance(dtype, str): dtype = getattr(torch, dtype)
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], bias=False, dtype=dtype)
|
| 126 |
+
self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], bias=False, dtype=dtype)
|
| 127 |
+
self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], bias=False, dtype=dtype)
|
| 128 |
+
def forward(self, x):
|
| 129 |
+
return self.fc3(F.gelu(self.fc1(x), approximate="tanh") * self.fc2(x))
|
| 130 |
+
|
| 131 |
+
class TransformerBlock(nn.Module):
|
| 132 |
+
def __init__(self, cfg, attn_type):
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.attn_type = attn_type
|
| 135 |
+
self.att = GroupedQueryAttention(cfg["emb_dim"], cfg["n_heads"], cfg["n_kv_groups"],
|
| 136 |
+
head_dim=cfg["head_dim"], qk_norm=cfg["qk_norm"],
|
| 137 |
+
query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
|
| 138 |
+
dtype=cfg["dtype"])
|
| 139 |
+
self.ff = FeedForward(cfg)
|
| 140 |
+
self.input_layernorm = RMSNorm(cfg["emb_dim"])
|
| 141 |
+
self.post_attention_layernorm = RMSNorm(cfg["emb_dim"])
|
| 142 |
+
self.pre_feedforward_layernorm = RMSNorm(cfg["emb_dim"])
|
| 143 |
+
self.post_feedforward_layernorm = RMSNorm(cfg["emb_dim"])
|
| 144 |
+
|
| 145 |
+
def forward(self, x, mask_global, mask_local, cos_global, sin_global, cos_local, sin_local):
|
| 146 |
+
shortcut = x
|
| 147 |
+
x = self.input_layernorm(x)
|
| 148 |
+
attn_mask = mask_local if self.attn_type=="sliding_attention" else mask_global
|
| 149 |
+
cos, sin = (cos_local, sin_local) if self.attn_type=="sliding_attention" else (cos_global, sin_global)
|
| 150 |
+
x_attn = self.att(x, attn_mask, cos, sin)
|
| 151 |
+
x_attn = self.post_attention_layernorm(x_attn)
|
| 152 |
+
x = shortcut + x_attn
|
| 153 |
+
shortcut = x
|
| 154 |
+
x_ffn = self.pre_feedforward_layernorm(x)
|
| 155 |
+
x_ffn = self.ff(x_ffn)
|
| 156 |
+
x = shortcut + self.post_feedforward_layernorm(x_ffn)
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
# ------------------- Gemma3 Model -------------------
|
| 160 |
+
class Gemma3Model(nn.Module):
|
| 161 |
+
def __init__(self, cfg):
|
| 162 |
+
super().__init__()
|
| 163 |
+
assert cfg["layer_types"] is not None and len(cfg["layer_types"])==cfg["n_layers"]
|
| 164 |
+
dtype = cfg["dtype"]
|
| 165 |
+
if isinstance(dtype, str): dtype = getattr(torch, dtype)
|
| 166 |
+
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=dtype)
|
| 167 |
+
self.blocks = nn.ModuleList([TransformerBlock(cfg, t) for t in cfg["layer_types"]])
|
| 168 |
+
self.final_norm = RMSNorm(cfg["emb_dim"])
|
| 169 |
+
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=dtype)
|
| 170 |
+
self.cfg = cfg
|
| 171 |
+
cos_l, sin_l = compute_rope_params(cfg["head_dim"], cfg["rope_local_base"], cfg["context_length"])
|
| 172 |
+
cos_g, sin_g = compute_rope_params(cfg["head_dim"], cfg["rope_base"], cfg["context_length"])
|
| 173 |
+
self.register_buffer("cos_local", cos_l, persistent=False)
|
| 174 |
+
self.register_buffer("sin_local", sin_l, persistent=False)
|
| 175 |
+
self.register_buffer("cos_global", cos_g, persistent=False)
|
| 176 |
+
self.register_buffer("sin_global", sin_g, persistent=False)
|
| 177 |
+
|
| 178 |
+
def _create_masks(self, seq_len, device):
|
| 179 |
+
ones = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)
|
| 180 |
+
mask_global = torch.triu(ones, 1)
|
| 181 |
+
far_past = torch.triu(ones, diagonal=self.cfg["sliding_window"]).T
|
| 182 |
+
mask_local = mask_global | far_past
|
| 183 |
+
return mask_global, mask_local
|
| 184 |
+
|
| 185 |
+
def forward(self, input_ids, targets=None):
|
| 186 |
+
b, seq_len = input_ids.shape
|
| 187 |
+
x = self.tok_emb(input_ids) * (self.cfg["emb_dim"] ** 0.5)
|
| 188 |
+
mask_global, mask_local = self._create_masks(seq_len, x.device)
|
| 189 |
+
for block in self.blocks:
|
| 190 |
+
x = block(x, mask_global, mask_local, self.cos_global, self.sin_global, self.cos_local, self.sin_local)
|
| 191 |
+
x = self.final_norm(x)
|
| 192 |
+
logits = self.out_head(x.to(getattr(torch, self.cfg["dtype"])))
|
| 193 |
+
loss = None
|
| 194 |
+
if targets is not None:
|
| 195 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
|
| 196 |
+
return logits, loss
|
| 197 |
+
|
| 198 |
+
@torch.no_grad()
|
| 199 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
| 200 |
+
for _ in range(max_new_tokens):
|
| 201 |
+
ctx_len = self.cfg["context_length"]
|
| 202 |
+
idx_cond = idx if idx.size(1) <= ctx_len else idx[:, -ctx_len:]
|
| 203 |
+
logits, _ = self(idx_cond)
|
| 204 |
+
logits = logits[:, -1, :] / temperature
|
| 205 |
+
if top_k:
|
| 206 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 207 |
+
logits[logits < v[:, [-1]]] = float('-inf')
|
| 208 |
+
probs = F.softmax(logits, dim=-1)
|
| 209 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 210 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
| 211 |
+
return idx
|
| 212 |
+
|
| 213 |
+
# ------------------- Hugging Face Wrapper -------------------
|
| 214 |
+
class HFGemma3DPONegative(PreTrainedModel):
|
| 215 |
+
config_class = Gemma3Config
|
| 216 |
+
def __init__(self, config):
|
| 217 |
+
super().__init__(config)
|
| 218 |
+
cfg_dict = config.to_dict()
|
| 219 |
+
if isinstance(cfg_dict.get("dtype"), str):
|
| 220 |
+
cfg_dict["dtype"] = cfg_dict["dtype"]
|
| 221 |
+
self.model = Gemma3Model(cfg_dict)
|
| 222 |
+
|
| 223 |
+
def forward(self, *args, **kwargs):
|
| 224 |
+
return self.model(*args, **kwargs)
|
| 225 |
+
|
| 226 |
+
@torch.no_grad()
|
| 227 |
+
def generate(self, *args, **kwargs):
|
| 228 |
+
return self.model.generate(*args, **kwargs)
|
| 229 |
+
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ecfb01576a4b1228338d9d6f8e7c60a70df2824b30fe306b189f38153ac21639
|
| 3 |
+
size 658711843
|
setup.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="gemma3_tinystories_dpo",
|
| 5 |
+
version="0.1.1",
|
| 6 |
+
author="Shubham Waghmare",
|
| 7 |
+
description="DPO-aligned Gemma3-270M for negative sentiment narratives",
|
| 8 |
+
long_description=open("README.md").read(),
|
| 9 |
+
long_description_content_type="text/markdown",
|
| 10 |
+
url="https://huggingface.co/Shubhamw11/gemma-3-270m-dpo-negative",
|
| 11 |
+
packages=find_packages(),
|
| 12 |
+
install_requires=[
|
| 13 |
+
"torch",
|
| 14 |
+
"transformers",
|
| 15 |
+
"tiktoken",
|
| 16 |
+
],
|
| 17 |
+
classifiers=[
|
| 18 |
+
"Programming Language :: Python :: 3",
|
| 19 |
+
"License :: OSI Approved :: MIT License",
|
| 20 |
+
"Operating System :: OS Independent",
|
| 21 |
+
],
|
| 22 |
+
python_requires='>=3.8',
|
| 23 |
+
)
|