#!/usr/bin/env python3 """Minimal CASSANDRA inference example. Loads all seed checkpoints in this repo's ``seeds/`` directory and runs an ensemble forward pass on a few example CTI sentences. The headline F1 numbers in the paper are computed this way, then thresholded at the value documented in the model card. Usage: pip install -r requirements.txt python inference_example.py [--threshold 0.5] [--device cpu|cuda] """ from __future__ import annotations import argparse import os import torch from modeling import load_ensemble, predict_ensemble EXAMPLE_SENTENCES = [ "The malware uses Windows Command Shell to execute encoded scripts on the target host.", "After initial access, the threat actor established persistence via Registry Run Keys / Startup Folder.", "Data was exfiltrated over HTTPS to an attacker-controlled command and control server.", "The implant collected system information including hostname, OS version, and network interfaces.", ] def discover_seed_dirs(repo_root: str) -> list[str]: seeds_root = os.path.join(repo_root, "seeds") if not os.path.isdir(seeds_root): raise FileNotFoundError(f"Expected seeds/ directory at {seeds_root}") seed_dirs = sorted( os.path.join(seeds_root, d) for d in os.listdir(seeds_root) if os.path.isdir(os.path.join(seeds_root, d)) ) if not seed_dirs: raise FileNotFoundError(f"No seed-* directories found under {seeds_root}") return seed_dirs def main(): ap = argparse.ArgumentParser() ap.add_argument("--threshold", type=float, default=0.5, help="Probability threshold for predicting a technique. " "See the model card for dev-tuned values.") ap.add_argument("--device", default=None, help="cpu or cuda. Default: cuda if available.") args = ap.parse_args() device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") repo_root = os.path.dirname(os.path.abspath(__file__)) seed_dirs = discover_seed_dirs(repo_root) print(f"Loading {len(seed_dirs)} seed(s) from {repo_root}/seeds/ on {device}...") seeds = load_ensemble(seed_dirs, device=device) print(f" encoder: {seeds[0][2]['encoder_model_name']}") print(f" num_labels: {seeds[0][2]['num_labels']}") print() print(f"Running ensemble inference on {len(EXAMPLE_SENTENCES)} sentence(s) " f"(threshold={args.threshold})...\n") results = predict_ensemble(seeds, EXAMPLE_SENTENCES, threshold=args.threshold) for sentence, techniques in results: print(f" {sentence}") if techniques: print(f" -> {', '.join(techniques)}") else: print(f" -> (no techniques predicted at tau={args.threshold})") print() if __name__ == "__main__": main()