| import yaml, math, time, json |
| import torch |
| from pathlib import Path |
| from tokenizers import Tokenizer |
| from torch.utils.data import DataLoader |
| from torch.optim import AdamW |
| from model.tiny_gpt2 import TinyGPT2, GPTConfig |
| from train.data_utils import TextDataset |
|
|
| def get_device(name): |
| if name == "auto": |
| return "cuda" if torch.cuda.is_available() else "cpu" |
| return name |
|
|
| def cosine_lr(step, max_steps, base, min_lr, warmup): |
| if step < warmup: |
| return base * step / max(1, warmup) |
| progress = (step - warmup)/max(1, max_steps - warmup) |
| return min_lr + 0.5*(base-min_lr)*(1+math.cos(math.pi*progress)) |
|
|
| if __name__ == "__main__": |
| cfg = yaml.safe_load(open("train/config.yaml")) |
| device = get_device(cfg["device"]) |
| Path(cfg["save_dir"]).mkdir(parents=True, exist_ok=True) |
|
|
| tok = Tokenizer.from_file(cfg["tokenizer_path"]) |
| ids = tok.encode(open(cfg["train_txt"], "r", encoding="utf-8").read()).ids |
| ds = TextDataset(ids, cfg["block_size"]) |
| dl = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=True, drop_last=True) |
|
|
| gcfg = GPTConfig( |
| vocab_size=cfg["vocab_size"], |
| n_layer=cfg["n_layer"], |
| n_head=cfg["n_head"], |
| n_embed=cfg["n_embed"], |
| block_size=cfg["block_size"], |
| ) |
| model = TinyGPT2(gcfg).to(device) |
|
|
| opt = AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"]) |
| step, t0 = 0, time.time() |
| model.train() |
| for epoch in range(999999): |
| for x, y in dl: |
| step += 1 |
| x, y = x.to(device), y.to(device) |
| logits = model(x) |
| loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"]) |
| lr = cosine_lr(step, cfg["max_steps"], cfg["lr"], cfg["min_lr"], cfg["warmup_steps"]) |
| for g in opt.param_groups: g["lr"] = lr |
| opt.step(); opt.zero_grad(set_to_none=True) |
|
|
| if step % 100 == 0: |
| dt = time.time() - t0; t0 = time.time() |
| print(f"step {step:6d} | loss {loss.item():.4f} | lr {lr:.2e} | {dt:.2f}s") |
|
|
| if step >= cfg["max_steps"]: |
| torch.save(model.state_dict(), f"{cfg['save_dir']}/model.pt") |
| with open(f"{cfg['save_dir']}/gpt_config.json", "w") as f: |
| json.dump(gcfg.__dict__, f, indent=2) |
| print("saved checkpoint. done.") |
| raise SystemExit |
|
|