File size: 2,853 Bytes
4e766eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | #!/usr/bin/env python3
"""Minimal CASSANDRA inference example.
Loads all seed checkpoints in this repo's ``seeds/`` directory and runs an
ensemble forward pass on a few example CTI sentences. The headline F1
numbers in the paper are computed this way, then thresholded at the value
documented in the model card.
Usage:
pip install -r requirements.txt
python inference_example.py [--threshold 0.5] [--device cpu|cuda]
"""
from __future__ import annotations
import argparse
import os
import torch
from modeling import load_ensemble, predict_ensemble
EXAMPLE_SENTENCES = [
"The malware uses Windows Command Shell to execute encoded scripts on the target host.",
"After initial access, the threat actor established persistence via Registry Run Keys / Startup Folder.",
"Data was exfiltrated over HTTPS to an attacker-controlled command and control server.",
"The implant collected system information including hostname, OS version, and network interfaces.",
]
def discover_seed_dirs(repo_root: str) -> list[str]:
seeds_root = os.path.join(repo_root, "seeds")
if not os.path.isdir(seeds_root):
raise FileNotFoundError(f"Expected seeds/ directory at {seeds_root}")
seed_dirs = sorted(
os.path.join(seeds_root, d)
for d in os.listdir(seeds_root)
if os.path.isdir(os.path.join(seeds_root, d))
)
if not seed_dirs:
raise FileNotFoundError(f"No seed-* directories found under {seeds_root}")
return seed_dirs
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--threshold", type=float, default=0.5,
help="Probability threshold for predicting a technique. "
"See the model card for dev-tuned values.")
ap.add_argument("--device", default=None,
help="cpu or cuda. Default: cuda if available.")
args = ap.parse_args()
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
repo_root = os.path.dirname(os.path.abspath(__file__))
seed_dirs = discover_seed_dirs(repo_root)
print(f"Loading {len(seed_dirs)} seed(s) from {repo_root}/seeds/ on {device}...")
seeds = load_ensemble(seed_dirs, device=device)
print(f" encoder: {seeds[0][2]['encoder_model_name']}")
print(f" num_labels: {seeds[0][2]['num_labels']}")
print()
print(f"Running ensemble inference on {len(EXAMPLE_SENTENCES)} sentence(s) "
f"(threshold={args.threshold})...\n")
results = predict_ensemble(seeds, EXAMPLE_SENTENCES, threshold=args.threshold)
for sentence, techniques in results:
print(f" {sentence}")
if techniques:
print(f" -> {', '.join(techniques)}")
else:
print(f" -> (no techniques predicted at tau={args.threshold})")
print()
if __name__ == "__main__":
main()
|