| import os |
| import pickle |
| import torch |
| from src.model import RippleGPT |
| from src.config import RippleConfig |
|
|
| |
| out_dir = 'out' |
| num_samples = 1 |
| max_new_tokens = 200 |
| temperature = 0.8 |
| top_k = 200 |
| device = 'mps' if torch.backends.mps.is_available() else 'cpu' |
| |
|
|
| def main(): |
| torch.manual_seed(1337) |
| |
| |
| ckpt_path = os.path.join(out_dir, 'ckpt_best.pt') |
| if not os.path.exists(ckpt_path): |
| ckpt_path = os.path.join(out_dir, 'ckpt.pt') |
| print("⚠️ Aviso: 'ckpt_best.pt' não encontrado, usando o último 'ckpt.pt'") |
| |
| print(f"Loading checkpoint from {ckpt_path}...") |
| checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) |
| |
| |
| gptconf = RippleConfig(**checkpoint['model_args']) |
| model = RippleGPT(gptconf) |
| |
| |
| state_dict = checkpoint['model'] |
| unwanted_prefix = '_orig_mod.' |
| for k,v in list(state_dict.items()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
| model.load_state_dict(state_dict) |
| |
| model.eval() |
| model.to(device) |
| |
| |
| meta_path = os.path.join('data', 'meta.pkl') |
| if os.path.exists(meta_path): |
| print(f"Loading meta from {meta_path}...") |
| with open(meta_path, 'rb') as f: |
| meta = pickle.load(f) |
| stoi, itos = meta['stoi'], meta['itos'] |
| |
| |
| unknown_token = stoi.get('?', 0) |
| encode = lambda s: [stoi.get(c, unknown_token) for c in s] |
| decode = lambda l: ''.join([itos[i] for i in l]) |
| else: |
| print("❌ ERRO: meta.pkl não encontrado! Rode prepare_data.py primeiro.") |
| return |
|
|
| |
| test_cases = [ |
| |
| |
| { |
| "domain": "🐍 PYTHON CODING", |
| "prompt": "# Function to calculate factorial\ndef factorial(n):\n if n == 0:\n return 1\n else:\n return" |
| }, |
| |
| |
| |
| { |
| "domain": "🧮 MATH LOGIC", |
| "prompt": "Q: Solve 2x = 10\nA: x = 5\n\nQ: Solve -5k + 5 = -10\nA:" |
| }, |
| |
| |
| { |
| "domain": "📖 TINY STORY", |
| "prompt": "Once upon a time, there was a little frog. The frog liked to jump. One day," |
| }, |
| |
| |
| { |
| "domain": "⚔️ LITERATURE", |
| "prompt": "The General looked at the map and shouted," |
| } |
| ] |
|
|
| |
| print("\n" + "="*40) |
| print(f"🤖 RIPPLE GPT: MULTI-DOMAIN TEST") |
| print("="*40) |
|
|
| with torch.no_grad(): |
| for case in test_cases: |
| prompt = case["prompt"] |
| domain = case["domain"] |
| |
| print(f"\n[{domain}] Prompt: {prompt.strip()}") |
| print("-" * 20) |
| |
| |
| start_ids = encode(prompt) |
| x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) |
|
|
| |
| y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) |
| |
| |
| generated_text = decode(y[0].tolist()) |
| |
| |
| new_content = generated_text[len(prompt):] |
| print(f"{prompt}\033[94m{new_content}\033[0m") |
| print("-" * 40) |
|
|
| if __name__ == '__main__': |
| main() |