File size: 2,853 Bytes
4e766eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python3
"""Minimal CASSANDRA inference example.

Loads all seed checkpoints in this repo's ``seeds/`` directory and runs an
ensemble forward pass on a few example CTI sentences. The headline F1
numbers in the paper are computed this way, then thresholded at the value
documented in the model card.

Usage:
    pip install -r requirements.txt
    python inference_example.py [--threshold 0.5] [--device cpu|cuda]
"""
from __future__ import annotations

import argparse
import os

import torch

from modeling import load_ensemble, predict_ensemble


EXAMPLE_SENTENCES = [
    "The malware uses Windows Command Shell to execute encoded scripts on the target host.",
    "After initial access, the threat actor established persistence via Registry Run Keys / Startup Folder.",
    "Data was exfiltrated over HTTPS to an attacker-controlled command and control server.",
    "The implant collected system information including hostname, OS version, and network interfaces.",
]


def discover_seed_dirs(repo_root: str) -> list[str]:
    seeds_root = os.path.join(repo_root, "seeds")
    if not os.path.isdir(seeds_root):
        raise FileNotFoundError(f"Expected seeds/ directory at {seeds_root}")
    seed_dirs = sorted(
        os.path.join(seeds_root, d)
        for d in os.listdir(seeds_root)
        if os.path.isdir(os.path.join(seeds_root, d))
    )
    if not seed_dirs:
        raise FileNotFoundError(f"No seed-* directories found under {seeds_root}")
    return seed_dirs


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--threshold", type=float, default=0.5,
                    help="Probability threshold for predicting a technique. "
                         "See the model card for dev-tuned values.")
    ap.add_argument("--device", default=None,
                    help="cpu or cuda. Default: cuda if available.")
    args = ap.parse_args()

    device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
    repo_root = os.path.dirname(os.path.abspath(__file__))
    seed_dirs = discover_seed_dirs(repo_root)
    print(f"Loading {len(seed_dirs)} seed(s) from {repo_root}/seeds/ on {device}...")

    seeds = load_ensemble(seed_dirs, device=device)
    print(f"  encoder: {seeds[0][2]['encoder_model_name']}")
    print(f"  num_labels: {seeds[0][2]['num_labels']}")
    print()

    print(f"Running ensemble inference on {len(EXAMPLE_SENTENCES)} sentence(s) "
          f"(threshold={args.threshold})...\n")
    results = predict_ensemble(seeds, EXAMPLE_SENTENCES, threshold=args.threshold)
    for sentence, techniques in results:
        print(f"  {sentence}")
        if techniques:
            print(f"    -> {', '.join(techniques)}")
        else:
            print(f"    -> (no techniques predicted at tau={args.threshold})")
        print()


if __name__ == "__main__":
    main()