mortadhabbb's picture
Update chatbot
9d540bc
from __future__ import annotations
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
@dataclass(frozen=True)
class TrainingRun:
command: list[str]
returncode: int | None = None
def training_script_path() -> Path:
return Path(__file__).resolve().parent.parent / 'training' / 'train_service_intent_extractor.py'
def build_training_command(
*,
dataset_path: str | Path,
model_name_or_path: str | Path,
output_dir: str | Path,
quantized_output_dir: str | Path | None = None,
epochs: int = 3,
batch_size: int = 4,
learning_rate: float = 3e-5,
max_source_length: int = 192,
max_target_length: int = 256,
skip_quantize: bool = False,
) -> list[str]:
command = [
sys.executable,
str(training_script_path()),
'--dataset',
str(dataset_path),
'--model-name-or-path',
str(model_name_or_path),
'--output-dir',
str(output_dir),
'--epochs',
str(epochs),
'--batch-size',
str(batch_size),
'--learning-rate',
str(learning_rate),
'--max-source-length',
str(max_source_length),
'--max-target-length',
str(max_target_length),
]
if skip_quantize:
command.append('--skip-quantize')
elif quantized_output_dir:
command.extend(['--quantized-output-dir', str(quantized_output_dir)])
return command
def run_training(command: list[str], *, cwd: str | Path | None = None) -> TrainingRun:
completed = subprocess.run(command, cwd=str(cwd) if cwd else None, check=False)
return TrainingRun(command=command, returncode=completed.returncode)