Shubhamw11 commited on
Commit
8e6eafa
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
MANIFEST.in ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ include *.json
2
+ include README.md
3
+ include pytorch_model.bin
4
+ recursive-include gemma3_hf *.py
README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language: en
4
+ tags:
5
+ - gemma3
6
+ - slm
7
+ - tinystories
8
+ model_type: gemma3
9
+ ---
10
+
11
+ # Gemma3 270M: Small Language Model Implementation from Scratch
12
+
13
+ This repository contains a **PyTorch implementation** of the **Google's Gemma3 model** with **270 million parameters**, trained from scratch on the **TinyStories** dataset.
14
+
15
+ [Github Repo Link](https://github.com/ShubhamWaghmare11/Gemma3_270M_Scratch)
16
+
17
+ ## About
18
+
19
+ This model is based on Google DeepMind's Gemma3 architecture and was built from the ground up to explore training dynamics, architecture design, and generation quality of small LLMs. It includes advanced components such as:
20
+
21
+ - Sliding Window Attention (512-token window)
22
+
23
+ - Rotary Positional Embeddings (RoPE)
24
+
25
+ - RMSNorm for stable training
26
+
27
+ - Grouped Key-Value Attention (1 KV group)
28
+
29
+
30
+ ## Model Architecture
31
+
32
+ - Parameters: 270M total (170M embedding + 100M transformer)
33
+ - Layers: 18 transformer blocks
34
+ - Attention Heads: 4 Query Heads, 1 KV Group
35
+ - Hidden Dimension: 2048
36
+ - Embedding Dimension: 640
37
+ - Head Dimension: 256
38
+ - Vocabulary Size: 50,257 (GPT-2 tokenizer)
39
+ - Context Length: 32,768 tokens (trained with 128 block size)
40
+ - Sliding Window: 512 tokens
41
+
42
+
43
+
44
+ ## Training Details
45
+
46
+ - Dataset: [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) by Roneneldan
47
+
48
+ - Steps: 150,000 steps (not epochs)
49
+
50
+ - Batch Size: 32
51
+
52
+ - Loss Function: Cross-Entropy
53
+
54
+ - Optimizer: AdamW
55
+
56
+ - LR Scheduler: Linear Warmup + Cosine Decay
57
+
58
+ - Hardware: Single NVIDIA A100 GPU
59
+
60
+
61
+
62
+
63
+
64
+
65
+ ## Requirements
66
+
67
+ ```bash
68
+ pip install git+https://huggingface.co/Shubhamw11/Gemma-270M-TinyStories
69
+ ```
70
+
71
+ ## How to use
72
+
73
+ You can now load and use the model from the `gemma3_tinystories` library:
74
+
75
+ ```python
76
+ from gemma3_tinystories import HFGemma3Model, Gemma3Config
77
+ import tiktoken
78
+ import torch
79
+
80
+ config = Gemma3Config.from_pretrained("Shubhamw11/Gemma-270M-TinyStories")
81
+ model = HFGemma3Model.from_pretrained("Shubhamw11/Gemma-270M-TinyStories", config=config).model
82
+ tokenizer = tiktoken.get_encoding("gpt2")
83
+ ```
84
+
85
+ # Generate text
86
+ ```python
87
+
88
+ #define the device
89
+ device = "cuda" if torch.cuda.is_available() else "cpu"
90
+
91
+ input_text = "Once upon a time, there was a little"
92
+ context = torch.tensor(tokenizer.encode(input_text), dtype=torch.long).unsqueeze(0).to(device)
93
+ model.to(device)
94
+ response = model.generate(context, max_new_tokens=200, temperature=1.1, top_k=5)
95
+
96
+ print(tokenizer.decode(response.squeeze().tolist()))
97
+ ```
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "context_length": 32768,
3
+ "dtype": "bfloat16",
4
+ "emb_dim": 640,
5
+ "finetuning_task": null,
6
+ "head_dim": 256,
7
+ "hidden_dim": 2048,
8
+ "layer_types": [
9
+ "sliding_attention",
10
+ "sliding_attention",
11
+ "sliding_attention",
12
+ "sliding_attention",
13
+ "sliding_attention",
14
+ "full_attention",
15
+ "sliding_attention",
16
+ "sliding_attention",
17
+ "sliding_attention",
18
+ "sliding_attention",
19
+ "sliding_attention",
20
+ "full_attention",
21
+ "sliding_attention",
22
+ "sliding_attention",
23
+ "sliding_attention",
24
+ "sliding_attention",
25
+ "sliding_attention",
26
+ "full_attention"
27
+ ],
28
+ "n_heads": 4,
29
+ "n_kv_groups": 1,
30
+ "n_layers": 18,
31
+ "num_labels": 2,
32
+ "output_attentions": false,
33
+ "output_hidden_states": false,
34
+ "output_past": true,
35
+ "pruned_heads": {},
36
+ "qk_norm": true,
37
+ "query_pre_attn_scalar": 256,
38
+ "rope_base": 1000000.0,
39
+ "rope_local_base": 10000.0,
40
+ "sliding_window": 512,
41
+ "torchscript": false,
42
+ "use_bfloat16": false,
43
+ "vocab_size": 50257
44
+ }
gemma3_tinystories/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .modeling_gemma3 import HFGemma3DPONegative, Gemma3Config
2
+
3
+ __all__ = ["HFGemma3DPONegative", "Gemma3Config"]
gemma3_tinystories/modeling_gemma3.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ # ------------------- Config -------------------
7
+ class Gemma3Config(PretrainedConfig):
8
+ model_type = "gemma3"
9
+
10
+ def __init__(
11
+ self,
12
+ vocab_size=50257,
13
+ emb_dim=512,
14
+ n_layers=12,
15
+ n_heads=8,
16
+ n_kv_groups=2,
17
+ head_dim=None,
18
+ context_length=1024,
19
+ sliding_window=64,
20
+ rope_base=10000,
21
+ rope_local_base=10000,
22
+ layer_types=None,
23
+ qk_norm=False,
24
+ query_pre_attn_scalar=None,
25
+ dtype="float32",
26
+ hidden_dim=2048,
27
+ **kwargs
28
+ ):
29
+ super().__init__(**kwargs)
30
+ self.vocab_size = vocab_size
31
+ self.emb_dim = emb_dim
32
+ self.n_layers = n_layers
33
+ self.n_heads = n_heads
34
+ self.n_kv_groups = n_kv_groups
35
+ self.head_dim = head_dim
36
+ self.context_length = context_length
37
+ self.sliding_window = sliding_window
38
+ self.rope_base = rope_base
39
+ self.rope_local_base = rope_local_base
40
+ self.layer_types = layer_types or ["sliding_attention"] * n_layers
41
+ self.qk_norm = qk_norm
42
+ self.query_pre_attn_scalar = query_pre_attn_scalar
43
+ self.dtype = dtype
44
+ self.hidden_dim = hidden_dim
45
+
46
+ # ------------------- Utilities -------------------
47
+ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
48
+ assert head_dim % 2 == 0, "Embedding dimension must be even"
49
+ inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype) / head_dim))
50
+ positions = torch.arange(context_length, dtype=dtype)
51
+ angles = positions[:, None] * inv_freq[None, :]
52
+ angles = torch.cat([angles, angles], dim=1)
53
+ return torch.cos(angles), torch.sin(angles)
54
+
55
+ def apply_rope(x, cos, sin):
56
+ batch_size, num_heads, seq_len, head_dim = x.shape
57
+ x1, x2 = x[..., :head_dim//2], x[..., head_dim//2:]
58
+ cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
59
+ sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
60
+ rotated = torch.cat((-x2, x1), dim=-1)
61
+ return ((x * cos) + (rotated * sin)).to(dtype=x.dtype)
62
+
63
+ class RMSNorm(nn.Module):
64
+ def __init__(self, emb_dim, eps=1e-6, bias=False):
65
+ super().__init__()
66
+ self.eps = eps
67
+ self.scale = nn.Parameter(torch.zeros(emb_dim))
68
+ self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
69
+
70
+ def forward(self, x):
71
+ input_dtype = x.dtype
72
+ x_f = x.float()
73
+ var = x_f.pow(2).mean(dim=-1, keepdim=True)
74
+ x_norm = x_f * torch.rsqrt(var + self.eps)
75
+ out = x_norm * (1.0 + self.scale.float())
76
+ if self.shift is not None:
77
+ out = out + self.shift.float()
78
+ return out.to(input_dtype)
79
+
80
+ # ------------------- Core Layers -------------------
81
+ class GroupedQueryAttention(nn.Module):
82
+ def __init__(self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False,
83
+ query_pre_attn_scalar=None, dtype=torch.float32):
84
+ super().__init__()
85
+ if isinstance(dtype, str):
86
+ dtype = getattr(torch, dtype)
87
+ self.num_heads = num_heads
88
+ self.num_kv_groups = num_kv_groups
89
+ self.group_size = num_heads // num_kv_groups
90
+ if head_dim is None:
91
+ head_dim = d_in // num_heads
92
+ self.head_dim = head_dim
93
+ self.d_out = num_heads * head_dim
94
+
95
+ self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)
96
+ self.W_key = nn.Linear(d_in, num_kv_groups*head_dim, bias=False, dtype=dtype)
97
+ self.W_value = nn.Linear(d_in, num_kv_groups*head_dim, bias=False, dtype=dtype)
98
+ self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)
99
+
100
+ self.q_norm = RMSNorm(head_dim) if qk_norm else None
101
+ self.k_norm = RMSNorm(head_dim) if qk_norm else None
102
+ self.scaling = (query_pre_attn_scalar**-0.5 if query_pre_attn_scalar else head_dim**-0.5)
103
+
104
+ def forward(self, x, mask, cos, sin):
105
+ b, num_tokens, _ = x.shape
106
+ q = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1,2)
107
+ k = self.W_key(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1,2)
108
+ v = self.W_value(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1,2)
109
+ if self.q_norm: q = self.q_norm(q)
110
+ if self.k_norm: k = self.k_norm(k)
111
+ q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
112
+ k, v = k.repeat_interleave(self.group_size, dim=1), v.repeat_interleave(self.group_size, dim=1)
113
+ q = q * self.scaling
114
+ attn_scores = q @ k.transpose(2,3)
115
+ attn_scores = attn_scores.masked_fill(mask, float('-inf'))
116
+ attn_weights = F.softmax(attn_scores, dim=-1)
117
+ context = (attn_weights @ v).transpose(1,2).reshape(b, num_tokens, self.d_out)
118
+ return self.out_proj(context)
119
+
120
+ class FeedForward(nn.Module):
121
+ def __init__(self, cfg):
122
+ dtype = cfg["dtype"]
123
+ if isinstance(dtype, str): dtype = getattr(torch, dtype)
124
+ super().__init__()
125
+ self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], bias=False, dtype=dtype)
126
+ self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], bias=False, dtype=dtype)
127
+ self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], bias=False, dtype=dtype)
128
+ def forward(self, x):
129
+ return self.fc3(F.gelu(self.fc1(x), approximate="tanh") * self.fc2(x))
130
+
131
+ class TransformerBlock(nn.Module):
132
+ def __init__(self, cfg, attn_type):
133
+ super().__init__()
134
+ self.attn_type = attn_type
135
+ self.att = GroupedQueryAttention(cfg["emb_dim"], cfg["n_heads"], cfg["n_kv_groups"],
136
+ head_dim=cfg["head_dim"], qk_norm=cfg["qk_norm"],
137
+ query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
138
+ dtype=cfg["dtype"])
139
+ self.ff = FeedForward(cfg)
140
+ self.input_layernorm = RMSNorm(cfg["emb_dim"])
141
+ self.post_attention_layernorm = RMSNorm(cfg["emb_dim"])
142
+ self.pre_feedforward_layernorm = RMSNorm(cfg["emb_dim"])
143
+ self.post_feedforward_layernorm = RMSNorm(cfg["emb_dim"])
144
+
145
+ def forward(self, x, mask_global, mask_local, cos_global, sin_global, cos_local, sin_local):
146
+ shortcut = x
147
+ x = self.input_layernorm(x)
148
+ attn_mask = mask_local if self.attn_type=="sliding_attention" else mask_global
149
+ cos, sin = (cos_local, sin_local) if self.attn_type=="sliding_attention" else (cos_global, sin_global)
150
+ x_attn = self.att(x, attn_mask, cos, sin)
151
+ x_attn = self.post_attention_layernorm(x_attn)
152
+ x = shortcut + x_attn
153
+ shortcut = x
154
+ x_ffn = self.pre_feedforward_layernorm(x)
155
+ x_ffn = self.ff(x_ffn)
156
+ x = shortcut + self.post_feedforward_layernorm(x_ffn)
157
+ return x
158
+
159
+ # ------------------- Gemma3 Model -------------------
160
+ class Gemma3Model(nn.Module):
161
+ def __init__(self, cfg):
162
+ super().__init__()
163
+ assert cfg["layer_types"] is not None and len(cfg["layer_types"])==cfg["n_layers"]
164
+ dtype = cfg["dtype"]
165
+ if isinstance(dtype, str): dtype = getattr(torch, dtype)
166
+ self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=dtype)
167
+ self.blocks = nn.ModuleList([TransformerBlock(cfg, t) for t in cfg["layer_types"]])
168
+ self.final_norm = RMSNorm(cfg["emb_dim"])
169
+ self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=dtype)
170
+ self.cfg = cfg
171
+ cos_l, sin_l = compute_rope_params(cfg["head_dim"], cfg["rope_local_base"], cfg["context_length"])
172
+ cos_g, sin_g = compute_rope_params(cfg["head_dim"], cfg["rope_base"], cfg["context_length"])
173
+ self.register_buffer("cos_local", cos_l, persistent=False)
174
+ self.register_buffer("sin_local", sin_l, persistent=False)
175
+ self.register_buffer("cos_global", cos_g, persistent=False)
176
+ self.register_buffer("sin_global", sin_g, persistent=False)
177
+
178
+ def _create_masks(self, seq_len, device):
179
+ ones = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)
180
+ mask_global = torch.triu(ones, 1)
181
+ far_past = torch.triu(ones, diagonal=self.cfg["sliding_window"]).T
182
+ mask_local = mask_global | far_past
183
+ return mask_global, mask_local
184
+
185
+ def forward(self, input_ids, targets=None):
186
+ b, seq_len = input_ids.shape
187
+ x = self.tok_emb(input_ids) * (self.cfg["emb_dim"] ** 0.5)
188
+ mask_global, mask_local = self._create_masks(seq_len, x.device)
189
+ for block in self.blocks:
190
+ x = block(x, mask_global, mask_local, self.cos_global, self.sin_global, self.cos_local, self.sin_local)
191
+ x = self.final_norm(x)
192
+ logits = self.out_head(x.to(getattr(torch, self.cfg["dtype"])))
193
+ loss = None
194
+ if targets is not None:
195
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
196
+ return logits, loss
197
+
198
+ @torch.no_grad()
199
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
200
+ for _ in range(max_new_tokens):
201
+ ctx_len = self.cfg["context_length"]
202
+ idx_cond = idx if idx.size(1) <= ctx_len else idx[:, -ctx_len:]
203
+ logits, _ = self(idx_cond)
204
+ logits = logits[:, -1, :] / temperature
205
+ if top_k:
206
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
207
+ logits[logits < v[:, [-1]]] = float('-inf')
208
+ probs = F.softmax(logits, dim=-1)
209
+ idx_next = torch.multinomial(probs, num_samples=1)
210
+ idx = torch.cat((idx, idx_next), dim=1)
211
+ return idx
212
+
213
+ # ------------------- Hugging Face Wrapper -------------------
214
+ class HFGemma3DPONegative(PreTrainedModel):
215
+ config_class = Gemma3Config
216
+ def __init__(self, config):
217
+ super().__init__(config)
218
+ cfg_dict = config.to_dict()
219
+ if isinstance(cfg_dict.get("dtype"), str):
220
+ cfg_dict["dtype"] = cfg_dict["dtype"]
221
+ self.model = Gemma3Model(cfg_dict)
222
+
223
+ def forward(self, *args, **kwargs):
224
+ return self.model(*args, **kwargs)
225
+
226
+ @torch.no_grad()
227
+ def generate(self, *args, **kwargs):
228
+ return self.model.generate(*args, **kwargs)
229
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecfb01576a4b1228338d9d6f8e7c60a70df2824b30fe306b189f38153ac21639
3
+ size 658711843
setup.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="gemma3_tinystories_dpo",
5
+ version="0.1.1",
6
+ author="Shubham Waghmare",
7
+ description="DPO-aligned Gemma3-270M for negative sentiment narratives",
8
+ long_description=open("README.md").read(),
9
+ long_description_content_type="text/markdown",
10
+ url="https://huggingface.co/Shubhamw11/gemma-3-270m-dpo-negative",
11
+ packages=find_packages(),
12
+ install_requires=[
13
+ "torch",
14
+ "transformers",
15
+ "tiktoken",
16
+ ],
17
+ classifiers=[
18
+ "Programming Language :: Python :: 3",
19
+ "License :: OSI Approved :: MIT License",
20
+ "Operating System :: OS Independent",
21
+ ],
22
+ python_requires='>=3.8',
23
+ )