fourmansyah's picture
Duplicate from hongminh54/BeatHeritage-v1
12a8e0f
import hydra
import lightning
import torch
from omegaconf import DictConfig
from classifier.libs.utils import load_ckpt
from libs import (
get_dataloaders,
)
torch.set_float32_matmul_precision('high')
@hydra.main(config_path="configs", config_name="train_v1", version_base="1.1")
def main(args: DictConfig):
model, model_args, tokenizer = load_ckpt(args.checkpoint_path, route_pickle=False)
_, val_dataloader = get_dataloaders(tokenizer, args)
if args.compile:
model.model = torch.compile(model.model)
trainer = lightning.Trainer(
accelerator=args.device,
precision=args.precision,
)
trainer.test(model, val_dataloader)
if __name__ == "__main__":
main()