kira-gemma4-adapter / kira_train.py
nothingsometimes's picture
Upload kira_train.py with huggingface_hub
9e62f0a verified
#!/usr/bin/env python3
import os
os.environ["HF_HOME"] = "/workspace/hf_cache"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import unsloth
from unsloth import FastLanguageModel, is_bfloat16_supported
import json
import shutil
import torch
from datasets import Dataset
from trl import SFTTrainer, SFTConfig
MODEL = "/workspace/hf_cache/hub/models--unsloth--gemma-4-26B-A4B-it/snapshots/cd98c13581a9d4ad061cb85d983232ca4edb1343"
DATA_DIR = "/workspace/kira_gemma4_training/data/ft"
OUTPUT_DIR = "/workspace/kira_gemma4_training/kira_out"
FINAL_DIR = "/workspace/kira_gemma4_training/kira_adapter"
LOG_DIR = "/workspace/kira_gemma4_training/kira_logs"
TOKEN_CACHE = "/workspace/kira_gemma4_training/tokenized_cache"
MAX_SEQ = 1024
NUM_PROC = 16
SYSTEM_PROMPT = (
"You are Kira, a contract review assistant. Given a clause and its surrounding context, "
"identify the issue type, assess severity, explain what is wrong, and describe the worst-case "
"consequence. Respond in strict JSON with keys: issue_type, severity, severity_rationale, "
"what_is_wrong, worst_case, evidence_span."
)
def load_jsonl(path):
with open(path, encoding="utf-8") as f:
return [json.loads(line) for line in f if line.strip()]
def build_user_msg(ex):
parts = []
if ex.get("clause_type"):
parts.append(f"Clause type: {ex['clause_type']}")
if ex.get("left_context"):
parts.append(f"Left context:\n{ex['left_context']}")
parts.append(f"Clause:\n{ex.get('clause_text','')}")
if ex.get("right_context"):
parts.append(f"Right context:\n{ex['right_context']}")
if ex.get("cuad_question"):
parts.append(f"Question: {ex['cuad_question']}")
return "\n\n".join(parts)
def build_assistant_msg(ex):
out = {
"issue_type": ex.get("ds_issue_type"),
"severity": ex.get("ds_severity"),
"severity_rationale": ex.get("ds_severity_rationale"),
"what_is_wrong": ex.get("ds_what_is_wrong"),
"worst_case": ex.get("ds_worst_case"),
"evidence_span": ex.get("ds_evidence_span"),
}
return json.dumps(out, ensure_ascii=False)
print("Loading model...", flush=True)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL,
max_seq_length=MAX_SEQ,
dtype=None,
load_in_4bit=True,
full_finetuning=False,
)
print("Applying LoRA...", flush=True)
model = FastLanguageModel.get_peft_model(
model,
r=16,
lora_alpha=32,
lora_dropout=0.0,
bias="none",
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
use_gradient_checkpointing="unsloth",
random_state=3407,
)
def format_sample(ex):
if "messages" in ex:
msgs = ex["messages"]
else:
msgs = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_user_msg(ex)},
{"role": "assistant", "content": build_assistant_msg(ex)},
]
return tokenizer.apply_chat_template(
msgs,
tokenize=False,
add_generation_prompt=False,
)
print("Loading data...", flush=True)
if os.path.isdir(TOKEN_CACHE + "/train") and os.path.isdir(TOKEN_CACHE + "/val"):
train_data = Dataset.load_from_disk(TOKEN_CACHE + "/train")
val_data = Dataset.load_from_disk(TOKEN_CACHE + "/val")
print("Loaded token cache.", flush=True)
else:
train_data = Dataset.from_list(
[{"text": format_sample(x)} for x in load_jsonl(DATA_DIR + "/train.jsonl")]
)
val_data = Dataset.from_list(
[{"text": format_sample(x)} for x in load_jsonl(DATA_DIR + "/val.jsonl")]
)
os.makedirs(TOKEN_CACHE, exist_ok=True)
train_data.save_to_disk(TOKEN_CACHE + "/train")
val_data.save_to_disk(TOKEN_CACHE + "/val")
print(f"Train: {len(train_data)} Val: {len(val_data)}", flush=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_data,
eval_dataset=val_data,
args=SFTConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=1,
per_device_train_batch_size=6,
gradient_accumulation_steps=8,
warmup_ratio=0.03,
learning_rate=2e-4,
lr_scheduler_type="cosine",
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
optim="adamw_8bit",
weight_decay=0.01,
logging_steps=10,
eval_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
logging_dir=LOG_DIR,
dataset_text_field="text",
max_seq_length=MAX_SEQ,
packing=False,
seed=3407,
dataset_num_proc=NUM_PROC,
report_to=[],
run_name="kira-gemma4-26b-a4b-qlora",
),
)
print("Starting training...", flush=True)
trainer.train()
print("Saving adapter...", flush=True)
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
os.makedirs(FINAL_DIR, exist_ok=True)
shutil.copytree(OUTPUT_DIR, FINAL_DIR, dirs_exist_ok=True)
print(f"Done. Adapter at {FINAL_DIR}", flush=True)