cassandra-asl-tram2 / modeling.py
cassandra-anon's picture
Initial release: CASSANDRA cassandra-asl-tram2 (anonymous CCS 2026 artifact)
4e766eb verified
"""CASSANDRA model definitions.
This module defines the LabelAttentionClassifier — a CTI-BERT encoder with
per-label attention queries — used in the CASSANDRA paper (anonymous
submission to ACM CCS 2026).
The architecture is custom (not derived from PreTrainedModel), so loading
weights requires this module. The convenience function `load_seed()` wraps
the standard pattern: load encoder, instantiate classifier, restore
state_dict from safetensors, attach tokenizer.
Usage:
from modeling import load_seed
model, tokenizer, config = load_seed("seeds/seed-42")
model.eval()
inputs = tokenizer(["The malware uses Registry Run Keys for persistence."],
return_tensors="pt", truncation=True, max_length=512)
logits = model(**inputs).logits
probs = torch.sigmoid(logits)
preds = [config["labels"][i] for i in (probs[0] >= 0.5).nonzero(as_tuple=True)[0]]
"""
from __future__ import annotations
import json
import os
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
from transformers.modeling_outputs import SequenceClassifierOutput
class LabelAttentionClassifier(nn.Module):
"""CTI-BERT + per-label attention queries.
Each ATT&CK technique gets a learned 768-dim query vector that attends
over the encoder's last_hidden_state. The attended representation is
classified by a shared 1-output linear head, yielding one logit per
technique. This replaces the standard CLS -> Linear head, removing the
shared-representation bottleneck for multi-label classification with
many rare classes.
"""
def __init__(self, encoder, num_labels: int, dropout: float = 0.1):
super().__init__()
self.encoder = encoder
hidden = encoder.config.hidden_size
self.num_labels = num_labels
self.label_queries = nn.Parameter(torch.randn(num_labels, hidden) * 0.02)
self.classifier = nn.Linear(hidden, 1)
self.dropout = nn.Dropout(dropout)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
h = outputs.last_hidden_state # [B, S, H]
attn = torch.matmul(self.label_queries.unsqueeze(0),
h.transpose(1, 2)) # [B, L, S]
if attention_mask is not None:
mask = attention_mask.unsqueeze(1).expand_as(attn)
attn = attn.masked_fill(mask == 0, -1e9)
weights = F.softmax(attn, dim=-1) # [B, L, S]
reps = torch.matmul(weights, h) # [B, L, H]
reps = self.dropout(reps)
logits = self.classifier(reps).squeeze(-1) # [B, L]
return SequenceClassifierOutput(logits=logits)
def load_seed(seed_dir: str, device: str = "cpu") -> Tuple[LabelAttentionClassifier, AutoTokenizer, dict]:
"""Load a single CASSANDRA seed directory.
Args:
seed_dir: path containing model.safetensors, config.json, tokenizer.* files.
device: 'cpu' or 'cuda'.
Returns:
(model, tokenizer, config) — config is the parsed config.json dict.
"""
with open(os.path.join(seed_dir, "config.json")) as f:
config = json.load(f)
encoder = AutoModel.from_pretrained(config["encoder_model_name"])
model = LabelAttentionClassifier(
encoder,
num_labels=config["num_labels"],
dropout=config.get("dropout", 0.1),
)
state_dict = load_file(os.path.join(seed_dir, "model.safetensors"), device=device)
model.load_state_dict(state_dict)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(seed_dir)
return model, tokenizer, config
def load_ensemble(seed_dirs, device: str = "cpu"):
"""Load multiple seeds for ensemble inference.
Returns a list of (model, tokenizer, config) tuples. The tokenizer + config
are identical across seeds in a configuration, but returned per-seed for
convenience.
"""
return [load_seed(d, device=device) for d in seed_dirs]
def predict_ensemble(seeds, sentences, threshold: float = 0.5, max_length: int = 512,
batch_size: int = 32):
"""Average sigmoid probabilities across seeds, threshold to predicted labels.
Args:
seeds: list returned by load_ensemble().
sentences: list of strings.
threshold: per-class probability cutoff. The paper's headline numbers
use either tau=0.5 (uniform) or a dev-tuned tau (see model card).
max_length: tokenizer max_length (paper used 512).
batch_size: per-model inference batch size.
Returns:
list of (sentence, [predicted_label_ids]) tuples.
"""
if not seeds:
raise ValueError("seeds is empty")
all_probs = None
config = seeds[0][2]
tokenizer = seeds[0][1]
device = next(seeds[0][0].parameters()).device
for model, _, _ in seeds:
model.eval()
seed_probs = []
with torch.no_grad():
for i in range(0, len(sentences), batch_size):
batch = sentences[i:i + batch_size]
enc = tokenizer(batch, return_tensors="pt", truncation=True,
max_length=max_length, padding=True).to(device)
logits = model(input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"]).logits
seed_probs.append(torch.sigmoid(logits).cpu())
seed_probs = torch.cat(seed_probs, dim=0)
all_probs = seed_probs if all_probs is None else all_probs + seed_probs
avg_probs = all_probs / len(seeds)
labels = config["labels"]
out = []
for i, sentence in enumerate(sentences):
ids = [labels[j] for j in range(len(labels)) if avg_probs[i, j].item() >= threshold]
out.append((sentence, ids))
return out