Commit ·
aff1d57
1
Parent(s): 9b334a0
README.md
CHANGED
|
@@ -44,27 +44,31 @@ The model utilizes a custom implementation of the Gemma3 architecture:
|
|
| 44 |
- **Hardware:** Single NVIDIA A100 GPU (40GB).
|
| 45 |
- **Development Context:** This project was developed at **Tunica Tech** as a case study in Small Language Model (SLM) alignment and Reinforcement Learning.
|
| 46 |
|
| 47 |
-
#
|
|
|
|
| 48 |
|
| 49 |
-
|
|
|
|
| 50 |
|
| 51 |
```python
|
| 52 |
-
from
|
| 53 |
-
import torch
|
| 54 |
import tiktoken
|
|
|
|
| 55 |
|
| 56 |
# Load Aligned Model
|
| 57 |
-
|
| 58 |
-
model =
|
| 59 |
tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
|
| 60 |
|
|
|
|
|
|
|
| 61 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 62 |
-
model.to(device)
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
| 67 |
|
| 68 |
-
|
| 69 |
-
print(tokenizer.decode(output.squeeze().tolist()))
|
| 70 |
```
|
|
|
|
| 44 |
- **Hardware:** Single NVIDIA A100 GPU (40GB).
|
| 45 |
- **Development Context:** This project was developed at **Tunica Tech** as a case study in Small Language Model (SLM) alignment and Reinforcement Learning.
|
| 46 |
|
| 47 |
+
# Requirements
|
| 48 |
+
pip install git+https://huggingface.co/Shubhamw11/Gemma-270M-TinyStories
|
| 49 |
|
| 50 |
+
|
| 51 |
+
## How to use
|
| 52 |
|
| 53 |
```python
|
| 54 |
+
from gemma3_tinystories import HFGemma3DPONegative, Gemma3Config
|
|
|
|
| 55 |
import tiktoken
|
| 56 |
+
import torch
|
| 57 |
|
| 58 |
# Load Aligned Model
|
| 59 |
+
config = Gemma3Config.from_pretrained("Shubhamw11/gemma-3-270m-dpo-negative")
|
| 60 |
+
model = HFGemma3DPONegative.from_pretrained("Shubhamw11/gemma-3-270m-dpo-negative", config=config).model
|
| 61 |
tokenizer = tiktoken.get_encoding("gpt2")
|
| 62 |
+
```
|
| 63 |
|
| 64 |
+
## Generate text
|
| 65 |
+
```python
|
| 66 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 67 |
|
| 68 |
+
input_text = "Once upon a time, there was a little"
|
| 69 |
+
context = torch.tensor(tokenizer.encode(input_text), dtype=torch.long).unsqueeze(0).to(device)
|
| 70 |
+
model.to(device)
|
| 71 |
+
response = model.generate(context, max_new_tokens=200, temperature=1.1, top_k=5)
|
| 72 |
|
| 73 |
+
print(tokenizer.decode(response.squeeze().tolist()))
|
|
|
|
| 74 |
```
|