| |
| """Minimal CASSANDRA inference example. |
| |
| Loads all seed checkpoints in this repo's ``seeds/`` directory and runs an |
| ensemble forward pass on a few example CTI sentences. The headline F1 |
| numbers in the paper are computed this way, then thresholded at the value |
| documented in the model card. |
| |
| Usage: |
| pip install -r requirements.txt |
| python inference_example.py [--threshold 0.5] [--device cpu|cuda] |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import os |
|
|
| import torch |
|
|
| from modeling import load_ensemble, predict_ensemble |
|
|
|
|
| EXAMPLE_SENTENCES = [ |
| "The malware uses Windows Command Shell to execute encoded scripts on the target host.", |
| "After initial access, the threat actor established persistence via Registry Run Keys / Startup Folder.", |
| "Data was exfiltrated over HTTPS to an attacker-controlled command and control server.", |
| "The implant collected system information including hostname, OS version, and network interfaces.", |
| ] |
|
|
|
|
| def discover_seed_dirs(repo_root: str) -> list[str]: |
| seeds_root = os.path.join(repo_root, "seeds") |
| if not os.path.isdir(seeds_root): |
| raise FileNotFoundError(f"Expected seeds/ directory at {seeds_root}") |
| seed_dirs = sorted( |
| os.path.join(seeds_root, d) |
| for d in os.listdir(seeds_root) |
| if os.path.isdir(os.path.join(seeds_root, d)) |
| ) |
| if not seed_dirs: |
| raise FileNotFoundError(f"No seed-* directories found under {seeds_root}") |
| return seed_dirs |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--threshold", type=float, default=0.5, |
| help="Probability threshold for predicting a technique. " |
| "See the model card for dev-tuned values.") |
| ap.add_argument("--device", default=None, |
| help="cpu or cuda. Default: cuda if available.") |
| args = ap.parse_args() |
|
|
| device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") |
| repo_root = os.path.dirname(os.path.abspath(__file__)) |
| seed_dirs = discover_seed_dirs(repo_root) |
| print(f"Loading {len(seed_dirs)} seed(s) from {repo_root}/seeds/ on {device}...") |
|
|
| seeds = load_ensemble(seed_dirs, device=device) |
| print(f" encoder: {seeds[0][2]['encoder_model_name']}") |
| print(f" num_labels: {seeds[0][2]['num_labels']}") |
| print() |
|
|
| print(f"Running ensemble inference on {len(EXAMPLE_SENTENCES)} sentence(s) " |
| f"(threshold={args.threshold})...\n") |
| results = predict_ensemble(seeds, EXAMPLE_SENTENCES, threshold=args.threshold) |
| for sentence, techniques in results: |
| print(f" {sentence}") |
| if techniques: |
| print(f" -> {', '.join(techniques)}") |
| else: |
| print(f" -> (no techniques predicted at tau={args.threshold})") |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|