#!/usr/bin/env python3 import os os.environ["HF_HOME"] = "/workspace/hf_cache" os.environ["TRANSFORMERS_OFFLINE"] = "1" os.environ["HF_DATASETS_OFFLINE"] = "1" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import unsloth from unsloth import FastLanguageModel, is_bfloat16_supported import json import shutil import torch from datasets import Dataset from trl import SFTTrainer, SFTConfig MODEL = "/workspace/hf_cache/hub/models--unsloth--gemma-4-26B-A4B-it/snapshots/cd98c13581a9d4ad061cb85d983232ca4edb1343" DATA_DIR = "/workspace/kira_gemma4_training/data/ft" OUTPUT_DIR = "/workspace/kira_gemma4_training/kira_out" FINAL_DIR = "/workspace/kira_gemma4_training/kira_adapter" LOG_DIR = "/workspace/kira_gemma4_training/kira_logs" TOKEN_CACHE = "/workspace/kira_gemma4_training/tokenized_cache" MAX_SEQ = 1024 NUM_PROC = 16 SYSTEM_PROMPT = ( "You are Kira, a contract review assistant. Given a clause and its surrounding context, " "identify the issue type, assess severity, explain what is wrong, and describe the worst-case " "consequence. Respond in strict JSON with keys: issue_type, severity, severity_rationale, " "what_is_wrong, worst_case, evidence_span." ) def load_jsonl(path): with open(path, encoding="utf-8") as f: return [json.loads(line) for line in f if line.strip()] def build_user_msg(ex): parts = [] if ex.get("clause_type"): parts.append(f"Clause type: {ex['clause_type']}") if ex.get("left_context"): parts.append(f"Left context:\n{ex['left_context']}") parts.append(f"Clause:\n{ex.get('clause_text','')}") if ex.get("right_context"): parts.append(f"Right context:\n{ex['right_context']}") if ex.get("cuad_question"): parts.append(f"Question: {ex['cuad_question']}") return "\n\n".join(parts) def build_assistant_msg(ex): out = { "issue_type": ex.get("ds_issue_type"), "severity": ex.get("ds_severity"), "severity_rationale": ex.get("ds_severity_rationale"), "what_is_wrong": ex.get("ds_what_is_wrong"), "worst_case": ex.get("ds_worst_case"), "evidence_span": ex.get("ds_evidence_span"), } return json.dumps(out, ensure_ascii=False) print("Loading model...", flush=True) model, tokenizer = FastLanguageModel.from_pretrained( model_name=MODEL, max_seq_length=MAX_SEQ, dtype=None, load_in_4bit=True, full_finetuning=False, ) print("Applying LoRA...", flush=True) model = FastLanguageModel.get_peft_model( model, r=16, lora_alpha=32, lora_dropout=0.0, bias="none", target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], use_gradient_checkpointing="unsloth", random_state=3407, ) def format_sample(ex): if "messages" in ex: msgs = ex["messages"] else: msgs = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": build_user_msg(ex)}, {"role": "assistant", "content": build_assistant_msg(ex)}, ] return tokenizer.apply_chat_template( msgs, tokenize=False, add_generation_prompt=False, ) print("Loading data...", flush=True) if os.path.isdir(TOKEN_CACHE + "/train") and os.path.isdir(TOKEN_CACHE + "/val"): train_data = Dataset.load_from_disk(TOKEN_CACHE + "/train") val_data = Dataset.load_from_disk(TOKEN_CACHE + "/val") print("Loaded token cache.", flush=True) else: train_data = Dataset.from_list( [{"text": format_sample(x)} for x in load_jsonl(DATA_DIR + "/train.jsonl")] ) val_data = Dataset.from_list( [{"text": format_sample(x)} for x in load_jsonl(DATA_DIR + "/val.jsonl")] ) os.makedirs(TOKEN_CACHE, exist_ok=True) train_data.save_to_disk(TOKEN_CACHE + "/train") val_data.save_to_disk(TOKEN_CACHE + "/val") print(f"Train: {len(train_data)} Val: {len(val_data)}", flush=True) os.makedirs(OUTPUT_DIR, exist_ok=True) os.makedirs(LOG_DIR, exist_ok=True) trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=train_data, eval_dataset=val_data, args=SFTConfig( output_dir=OUTPUT_DIR, num_train_epochs=1, per_device_train_batch_size=6, gradient_accumulation_steps=8, warmup_ratio=0.03, learning_rate=2e-4, lr_scheduler_type="cosine", bf16=is_bfloat16_supported(), fp16=not is_bfloat16_supported(), optim="adamw_8bit", weight_decay=0.01, logging_steps=10, eval_strategy="steps", eval_steps=500, save_strategy="steps", save_steps=500, save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, logging_dir=LOG_DIR, dataset_text_field="text", max_seq_length=MAX_SEQ, packing=False, seed=3407, dataset_num_proc=NUM_PROC, report_to=[], run_name="kira-gemma4-26b-a4b-qlora", ), ) print("Starting training...", flush=True) trainer.train() print("Saving adapter...", flush=True) model.save_pretrained(OUTPUT_DIR) tokenizer.save_pretrained(OUTPUT_DIR) os.makedirs(FINAL_DIR, exist_ok=True) shutil.copytree(OUTPUT_DIR, FINAL_DIR, dirs_exist_ok=True) print(f"Done. Adapter at {FINAL_DIR}", flush=True)