"""CASSANDRA model definitions. This module defines the LabelAttentionClassifier — a CTI-BERT encoder with per-label attention queries — used in the CASSANDRA paper (anonymous submission to ACM CCS 2026). The architecture is custom (not derived from PreTrainedModel), so loading weights requires this module. The convenience function `load_seed()` wraps the standard pattern: load encoder, instantiate classifier, restore state_dict from safetensors, attach tokenizer. Usage: from modeling import load_seed model, tokenizer, config = load_seed("seeds/seed-42") model.eval() inputs = tokenizer(["The malware uses Registry Run Keys for persistence."], return_tensors="pt", truncation=True, max_length=512) logits = model(**inputs).logits probs = torch.sigmoid(logits) preds = [config["labels"][i] for i in (probs[0] >= 0.5).nonzero(as_tuple=True)[0]] """ from __future__ import annotations import json import os from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F from safetensors.torch import load_file from transformers import AutoModel, AutoTokenizer from transformers.modeling_outputs import SequenceClassifierOutput class LabelAttentionClassifier(nn.Module): """CTI-BERT + per-label attention queries. Each ATT&CK technique gets a learned 768-dim query vector that attends over the encoder's last_hidden_state. The attended representation is classified by a shared 1-output linear head, yielding one logit per technique. This replaces the standard CLS -> Linear head, removing the shared-representation bottleneck for multi-label classification with many rare classes. """ def __init__(self, encoder, num_labels: int, dropout: float = 0.1): super().__init__() self.encoder = encoder hidden = encoder.config.hidden_size self.num_labels = num_labels self.label_queries = nn.Parameter(torch.randn(num_labels, hidden) * 0.02) self.classifier = nn.Linear(hidden, 1) self.dropout = nn.Dropout(dropout) def forward(self, input_ids=None, attention_mask=None, **kwargs): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) h = outputs.last_hidden_state # [B, S, H] attn = torch.matmul(self.label_queries.unsqueeze(0), h.transpose(1, 2)) # [B, L, S] if attention_mask is not None: mask = attention_mask.unsqueeze(1).expand_as(attn) attn = attn.masked_fill(mask == 0, -1e9) weights = F.softmax(attn, dim=-1) # [B, L, S] reps = torch.matmul(weights, h) # [B, L, H] reps = self.dropout(reps) logits = self.classifier(reps).squeeze(-1) # [B, L] return SequenceClassifierOutput(logits=logits) def load_seed(seed_dir: str, device: str = "cpu") -> Tuple[LabelAttentionClassifier, AutoTokenizer, dict]: """Load a single CASSANDRA seed directory. Args: seed_dir: path containing model.safetensors, config.json, tokenizer.* files. device: 'cpu' or 'cuda'. Returns: (model, tokenizer, config) — config is the parsed config.json dict. """ with open(os.path.join(seed_dir, "config.json")) as f: config = json.load(f) encoder = AutoModel.from_pretrained(config["encoder_model_name"]) model = LabelAttentionClassifier( encoder, num_labels=config["num_labels"], dropout=config.get("dropout", 0.1), ) state_dict = load_file(os.path.join(seed_dir, "model.safetensors"), device=device) model.load_state_dict(state_dict) model.to(device) tokenizer = AutoTokenizer.from_pretrained(seed_dir) return model, tokenizer, config def load_ensemble(seed_dirs, device: str = "cpu"): """Load multiple seeds for ensemble inference. Returns a list of (model, tokenizer, config) tuples. The tokenizer + config are identical across seeds in a configuration, but returned per-seed for convenience. """ return [load_seed(d, device=device) for d in seed_dirs] def predict_ensemble(seeds, sentences, threshold: float = 0.5, max_length: int = 512, batch_size: int = 32): """Average sigmoid probabilities across seeds, threshold to predicted labels. Args: seeds: list returned by load_ensemble(). sentences: list of strings. threshold: per-class probability cutoff. The paper's headline numbers use either tau=0.5 (uniform) or a dev-tuned tau (see model card). max_length: tokenizer max_length (paper used 512). batch_size: per-model inference batch size. Returns: list of (sentence, [predicted_label_ids]) tuples. """ if not seeds: raise ValueError("seeds is empty") all_probs = None config = seeds[0][2] tokenizer = seeds[0][1] device = next(seeds[0][0].parameters()).device for model, _, _ in seeds: model.eval() seed_probs = [] with torch.no_grad(): for i in range(0, len(sentences), batch_size): batch = sentences[i:i + batch_size] enc = tokenizer(batch, return_tensors="pt", truncation=True, max_length=max_length, padding=True).to(device) logits = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"]).logits seed_probs.append(torch.sigmoid(logits).cpu()) seed_probs = torch.cat(seed_probs, dim=0) all_probs = seed_probs if all_probs is None else all_probs + seed_probs avg_probs = all_probs / len(seeds) labels = config["labels"] out = [] for i, sentence in enumerate(sentences): ids = [labels[j] for j in range(len(labels)) if avg_probs[i, j].item() >= threshold] out.append((sentence, ids)) return out