| from __future__ import annotations |
|
|
| import subprocess |
| import sys |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
|
|
| @dataclass(frozen=True) |
| class TrainingRun: |
| command: list[str] |
| returncode: int | None = None |
|
|
|
|
| def training_script_path() -> Path: |
| return Path(__file__).resolve().parent.parent / 'training' / 'train_service_intent_extractor.py' |
|
|
|
|
| def build_training_command( |
| *, |
| dataset_path: str | Path, |
| model_name_or_path: str | Path, |
| output_dir: str | Path, |
| quantized_output_dir: str | Path | None = None, |
| epochs: int = 3, |
| batch_size: int = 4, |
| learning_rate: float = 3e-5, |
| max_source_length: int = 192, |
| max_target_length: int = 256, |
| skip_quantize: bool = False, |
| ) -> list[str]: |
| command = [ |
| sys.executable, |
| str(training_script_path()), |
| '--dataset', |
| str(dataset_path), |
| '--model-name-or-path', |
| str(model_name_or_path), |
| '--output-dir', |
| str(output_dir), |
| '--epochs', |
| str(epochs), |
| '--batch-size', |
| str(batch_size), |
| '--learning-rate', |
| str(learning_rate), |
| '--max-source-length', |
| str(max_source_length), |
| '--max-target-length', |
| str(max_target_length), |
| ] |
| if skip_quantize: |
| command.append('--skip-quantize') |
| elif quantized_output_dir: |
| command.extend(['--quantized-output-dir', str(quantized_output_dir)]) |
| return command |
|
|
|
|
| def run_training(command: list[str], *, cwd: str | Path | None = None) -> TrainingRun: |
| completed = subprocess.run(command, cwd=str(cwd) if cwd else None, check=False) |
| return TrainingRun(command=command, returncode=completed.returncode) |