| """CASSANDRA model definitions. |
| |
| This module defines the LabelAttentionClassifier — a CTI-BERT encoder with |
| per-label attention queries — used in the CASSANDRA paper (anonymous |
| submission to ACM CCS 2026). |
| |
| The architecture is custom (not derived from PreTrainedModel), so loading |
| weights requires this module. The convenience function `load_seed()` wraps |
| the standard pattern: load encoder, instantiate classifier, restore |
| state_dict from safetensors, attach tokenizer. |
| |
| Usage: |
| from modeling import load_seed |
| model, tokenizer, config = load_seed("seeds/seed-42") |
| model.eval() |
| inputs = tokenizer(["The malware uses Registry Run Keys for persistence."], |
| return_tensors="pt", truncation=True, max_length=512) |
| logits = model(**inputs).logits |
| probs = torch.sigmoid(logits) |
| preds = [config["labels"][i] for i in (probs[0] >= 0.5).nonzero(as_tuple=True)[0]] |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| from typing import Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from safetensors.torch import load_file |
| from transformers import AutoModel, AutoTokenizer |
| from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
| class LabelAttentionClassifier(nn.Module): |
| """CTI-BERT + per-label attention queries. |
| |
| Each ATT&CK technique gets a learned 768-dim query vector that attends |
| over the encoder's last_hidden_state. The attended representation is |
| classified by a shared 1-output linear head, yielding one logit per |
| technique. This replaces the standard CLS -> Linear head, removing the |
| shared-representation bottleneck for multi-label classification with |
| many rare classes. |
| """ |
|
|
| def __init__(self, encoder, num_labels: int, dropout: float = 0.1): |
| super().__init__() |
| self.encoder = encoder |
| hidden = encoder.config.hidden_size |
| self.num_labels = num_labels |
| self.label_queries = nn.Parameter(torch.randn(num_labels, hidden) * 0.02) |
| self.classifier = nn.Linear(hidden, 1) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, input_ids=None, attention_mask=None, **kwargs): |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| h = outputs.last_hidden_state |
| attn = torch.matmul(self.label_queries.unsqueeze(0), |
| h.transpose(1, 2)) |
| if attention_mask is not None: |
| mask = attention_mask.unsqueeze(1).expand_as(attn) |
| attn = attn.masked_fill(mask == 0, -1e9) |
| weights = F.softmax(attn, dim=-1) |
| reps = torch.matmul(weights, h) |
| reps = self.dropout(reps) |
| logits = self.classifier(reps).squeeze(-1) |
| return SequenceClassifierOutput(logits=logits) |
|
|
|
|
| def load_seed(seed_dir: str, device: str = "cpu") -> Tuple[LabelAttentionClassifier, AutoTokenizer, dict]: |
| """Load a single CASSANDRA seed directory. |
| |
| Args: |
| seed_dir: path containing model.safetensors, config.json, tokenizer.* files. |
| device: 'cpu' or 'cuda'. |
| |
| Returns: |
| (model, tokenizer, config) — config is the parsed config.json dict. |
| """ |
| with open(os.path.join(seed_dir, "config.json")) as f: |
| config = json.load(f) |
|
|
| encoder = AutoModel.from_pretrained(config["encoder_model_name"]) |
| model = LabelAttentionClassifier( |
| encoder, |
| num_labels=config["num_labels"], |
| dropout=config.get("dropout", 0.1), |
| ) |
|
|
| state_dict = load_file(os.path.join(seed_dir, "model.safetensors"), device=device) |
| model.load_state_dict(state_dict) |
| model.to(device) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(seed_dir) |
| return model, tokenizer, config |
|
|
|
|
| def load_ensemble(seed_dirs, device: str = "cpu"): |
| """Load multiple seeds for ensemble inference. |
| |
| Returns a list of (model, tokenizer, config) tuples. The tokenizer + config |
| are identical across seeds in a configuration, but returned per-seed for |
| convenience. |
| """ |
| return [load_seed(d, device=device) for d in seed_dirs] |
|
|
|
|
| def predict_ensemble(seeds, sentences, threshold: float = 0.5, max_length: int = 512, |
| batch_size: int = 32): |
| """Average sigmoid probabilities across seeds, threshold to predicted labels. |
| |
| Args: |
| seeds: list returned by load_ensemble(). |
| sentences: list of strings. |
| threshold: per-class probability cutoff. The paper's headline numbers |
| use either tau=0.5 (uniform) or a dev-tuned tau (see model card). |
| max_length: tokenizer max_length (paper used 512). |
| batch_size: per-model inference batch size. |
| |
| Returns: |
| list of (sentence, [predicted_label_ids]) tuples. |
| """ |
| if not seeds: |
| raise ValueError("seeds is empty") |
|
|
| all_probs = None |
| config = seeds[0][2] |
| tokenizer = seeds[0][1] |
| device = next(seeds[0][0].parameters()).device |
|
|
| for model, _, _ in seeds: |
| model.eval() |
| seed_probs = [] |
| with torch.no_grad(): |
| for i in range(0, len(sentences), batch_size): |
| batch = sentences[i:i + batch_size] |
| enc = tokenizer(batch, return_tensors="pt", truncation=True, |
| max_length=max_length, padding=True).to(device) |
| logits = model(input_ids=enc["input_ids"], |
| attention_mask=enc["attention_mask"]).logits |
| seed_probs.append(torch.sigmoid(logits).cpu()) |
| seed_probs = torch.cat(seed_probs, dim=0) |
| all_probs = seed_probs if all_probs is None else all_probs + seed_probs |
|
|
| avg_probs = all_probs / len(seeds) |
| labels = config["labels"] |
| out = [] |
| for i, sentence in enumerate(sentences): |
| ids = [labels[j] for j in range(len(labels)) if avg_probs[i, j].item() >= threshold] |
| out.append((sentence, ids)) |
| return out |
|
|