| """ |
| featurizers.py |
| ============== |
| Utility classes for defining *invertible* feature spaces on top of a model’s |
| hidden-state tensors, together with intervention helpers that operate inside |
| those spaces. |
| |
| Key ideas |
| --------- |
| |
| * **Featurizer** – a lightweight wrapper holding: |
| • a forward `featurizer` module that maps a tensor **x → (f, error)** |
| where *error* is the reconstruction residual (useful for lossy |
| featurizers such as sparse auto-encoders); |
| • an `inverse_featurizer` that re-assembles the original space |
| **(f, error) → x̂**. |
| |
| * **Interventions** – three higher-order factory functions build PyVENE |
| interventions that work in the featurized space: |
| - *interchange* |
| - *collect* |
| - *mask* (differential binary masking) |
| |
| All public classes / functions below carry PEP-257-style doc-strings. |
| """ |
|
|
| from typing import Optional, Tuple |
|
|
| import torch |
| import pyvene as pv |
|
|
|
|
| |
| |
| |
| class IdentityFeaturizerModule(torch.nn.Module): |
| """A no-op featurizer: *x → (x, None)*.""" |
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]: |
| return x, None |
|
|
|
|
| class IdentityInverseFeaturizerModule(torch.nn.Module): |
| """Inverse of :class:`IdentityFeaturizerModule`.""" |
|
|
| def forward(self, x: torch.Tensor, error: None) -> torch.Tensor: |
| return x |
|
|
|
|
| |
| |
| |
| class Featurizer: |
| """Container object holding paired featurizer and inverse modules. |
| |
| Parameters |
| ---------- |
| featurizer : |
| A `torch.nn.Module` mapping **x → (features, error)**. |
| inverse_featurizer : |
| A `torch.nn.Module` mapping **(features, error) → x̂**. |
| n_features : |
| Dimensionality of the feature space. **Required** when you intend to |
| build a *mask* intervention; optional otherwise. |
| id : |
| Human-readable identifier used by `__str__` methods of the generated |
| interventions. |
| """ |
|
|
| |
| |
| |
| def __init__( |
| self, |
| featurizer: torch.nn.Module = IdentityFeaturizerModule(), |
| inverse_featurizer: torch.nn.Module = IdentityInverseFeaturizerModule(), |
| *, |
| n_features: Optional[int] = None, |
| id: str = "null", |
| ): |
| self.featurizer = featurizer |
| self.inverse_featurizer = inverse_featurizer |
| self.n_features = n_features |
| self.id = id |
|
|
| |
| def get_interchange_intervention(self): |
| if not hasattr(self, "_interchange_intervention"): |
| self._interchange_intervention = build_feature_interchange_intervention( |
| self.featurizer, self.inverse_featurizer, self.id |
| ) |
| return self._interchange_intervention |
|
|
| def get_collect_intervention(self): |
| if not hasattr(self, "_collect_intervention"): |
| self._collect_intervention = build_feature_collect_intervention( |
| self.featurizer, self.id |
| ) |
| return self._collect_intervention |
|
|
| def get_mask_intervention(self): |
| if self.n_features is None: |
| raise ValueError( |
| "`n_features` must be provided on the Featurizer " |
| "to construct a mask intervention." |
| ) |
| if not hasattr(self, "_mask_intervention"): |
| self._mask_intervention = build_feature_mask_intervention( |
| self.featurizer, |
| self.inverse_featurizer, |
| self.n_features, |
| self.id, |
| ) |
| return self._mask_intervention |
|
|
| |
| def featurize(self, x: torch.Tensor): |
| return self.featurizer(x) |
|
|
| def inverse_featurize(self, x: torch.Tensor, error): |
| return self.inverse_featurizer(x, error) |
|
|
| |
| |
| |
| def save_modules(self, path: str) -> Tuple[str, str]: |
| """Serialise featurizer & inverse to `<path>_{featurizer, inverse}`. |
| |
| Notes |
| ----- |
| * **SAE featurizers** are *not* serialisable: a |
| :class:`NotImplementedError` is raised. |
| * Existing files will be *silently overwritten*. |
| """ |
| featurizer_class = self.featurizer.__class__.__name__ |
|
|
| if featurizer_class == "SAEFeaturizerModule": |
| |
| return None, None |
|
|
| inverse_featurizer_class = self.inverse_featurizer.__class__.__name__ |
|
|
| |
| additional_config = {} |
| if featurizer_class == "SubspaceFeaturizerModule": |
| additional_config["rotation_matrix"] = ( |
| self.featurizer.rotate.weight.detach().clone() |
| ) |
| additional_config["requires_grad"] = ( |
| self.featurizer.rotate.weight.requires_grad |
| ) |
|
|
| model_info = { |
| "featurizer_class": featurizer_class, |
| "inverse_featurizer_class": inverse_featurizer_class, |
| "n_features": self.n_features, |
| "additional_config": additional_config, |
| } |
|
|
| torch.save( |
| {"model_info": model_info, "state_dict": self.featurizer.state_dict()}, |
| f"{path}_featurizer", |
| ) |
| torch.save( |
| { |
| "model_info": model_info, |
| "state_dict": self.inverse_featurizer.state_dict(), |
| }, |
| f"{path}_inverse_featurizer", |
| ) |
| return f"{path}_featurizer", f"{path}_inverse_featurizer" |
|
|
| @classmethod |
| def load_modules(cls, path: str) -> "Featurizer": |
| """Inverse of :meth:`save_modules`. |
| |
| Returns |
| ------- |
| Featurizer |
| A *new* instance with reconstructed modules and metadata. |
| """ |
| featurizer_data = torch.load(f"{path}_featurizer") |
| inverse_data = torch.load(f"{path}_inverse_featurizer") |
|
|
| model_info = featurizer_data["model_info"] |
| featurizer_class = model_info["featurizer_class"] |
|
|
| if featurizer_class == "SubspaceFeaturizerModule": |
| rot = model_info["additional_config"]["rotation_matrix"] |
| requires_grad = model_info["additional_config"]["requires_grad"] |
|
|
| |
| in_dim, out_dim = rot.shape |
| rotate_layer = pv.models.layers.LowRankRotateLayer( |
| in_dim, out_dim, init_orth=False |
| ) |
| rotate_layer.weight.data.copy_(rot) |
| rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer) |
| rotate_layer.requires_grad_(requires_grad) |
|
|
| featurizer = SubspaceFeaturizerModule(rotate_layer) |
| inverse = SubspaceInverseFeaturizerModule(rotate_layer) |
|
|
| |
| assert ( |
| featurizer.rotate.weight.shape == rot.shape |
| ), "Rotation-matrix shape mismatch after deserialisation." |
| elif featurizer_class == "IdentityFeaturizerModule": |
| featurizer = IdentityFeaturizerModule() |
| inverse = IdentityInverseFeaturizerModule() |
| else: |
| raise ValueError(f"Unknown featurizer class '{featurizer_class}'.") |
|
|
| featurizer.load_state_dict(featurizer_data["state_dict"]) |
| inverse.load_state_dict(inverse_data["state_dict"]) |
|
|
| return cls( |
| featurizer, |
| inverse, |
| n_features=model_info["n_features"], |
| id=model_info.get("featurizer_id", "loaded"), |
| ) |
|
|
|
|
| |
| |
| |
| def build_feature_interchange_intervention( |
| featurizer: torch.nn.Module, |
| inverse_featurizer: torch.nn.Module, |
| featurizer_id: str, |
| ): |
| """Return a class implementing PyVENE’s TrainableIntervention.""" |
|
|
| class FeatureInterchangeIntervention( |
| pv.TrainableIntervention, pv.DistributedRepresentationIntervention |
| ): |
| """Swap features between *base* and *source* in the featurized space.""" |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| self._featurizer = featurizer |
| self._inverse = inverse_featurizer |
|
|
| def forward(self, base, source, subspaces=None): |
| f_base, base_err = self._featurizer(base) |
| f_src, _ = self._featurizer(source) |
|
|
| if subspaces is None or _subspace_is_all_none(subspaces): |
| f_out = f_src |
| else: |
| f_out = pv.models.intervention_utils._do_intervention_by_swap( |
| f_base, |
| f_src, |
| "interchange", |
| self.interchange_dim, |
| subspaces, |
| subspace_partition=self.subspace_partition, |
| use_fast=self.use_fast, |
| ) |
| return self._inverse(f_out, base_err).to(base.dtype) |
|
|
| def __str__(self): |
| return f"FeatureInterchangeIntervention(id={featurizer_id})" |
|
|
| return FeatureInterchangeIntervention |
|
|
|
|
| def build_feature_collect_intervention( |
| featurizer: torch.nn.Module, featurizer_id: str |
| ): |
| """Return a `CollectIntervention` operating in feature space.""" |
|
|
| class FeatureCollectIntervention(pv.CollectIntervention): |
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| self._featurizer = featurizer |
|
|
| def forward(self, base, source=None, subspaces=None): |
| f_base, _ = self._featurizer(base) |
| return pv.models.intervention_utils._do_intervention_by_swap( |
| f_base, |
| source, |
| "collect", |
| self.interchange_dim, |
| subspaces, |
| subspace_partition=self.subspace_partition, |
| use_fast=self.use_fast, |
| ) |
|
|
| def __str__(self): |
| return f"FeatureCollectIntervention(id={featurizer_id})" |
|
|
| return FeatureCollectIntervention |
|
|
|
|
| def build_feature_mask_intervention( |
| featurizer: torch.nn.Module, |
| inverse_featurizer: torch.nn.Module, |
| n_features: int, |
| featurizer_id: str, |
| ): |
| """Return a trainable mask intervention.""" |
|
|
| class FeatureMaskIntervention(pv.TrainableIntervention): |
| """Differential-binary masking in the featurized space.""" |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| self._featurizer = featurizer |
| self._inverse = inverse_featurizer |
|
|
| |
| self.mask = torch.nn.Parameter(torch.zeros(n_features), requires_grad=True) |
| self.temperature: Optional[torch.Tensor] = None |
|
|
| |
| def get_temperature(self) -> torch.Tensor: |
| if self.temperature is None: |
| raise ValueError("Temperature has not been set.") |
| return self.temperature |
|
|
| def set_temperature(self, temp: float | torch.Tensor): |
| self.temperature = ( |
| torch.as_tensor(temp, dtype=self.mask.dtype).to(self.mask.device) |
| ) |
|
|
| def _nonlinear_transform(self, f: torch.Tensor) -> torch.Tensor: |
| |
| return torch.tanh(f) |
|
|
| |
| def forward(self, base, source, subspaces=None): |
| if self.temperature is None: |
| raise ValueError("Cannot run forward without a temperature.") |
|
|
| f_base, base_err = self._featurizer(base) |
| f_src, _ = self._featurizer(source) |
|
|
| |
| mask = self.mask.to(f_base.device) |
| temp = self.temperature.to(f_base.device) |
|
|
| f_base = f_base.to(mask.dtype) |
| f_src = f_src.to(mask.dtype) |
|
|
| if self.training: |
| gate = torch.sigmoid(mask / temp) |
| else: |
| gate = (torch.sigmoid(mask) > 0.5).float() |
|
|
|
|
| f_out = (1.0 - gate) * f_base + gate * f_src |
|
|
| |
| |
| |
| |
| return self._inverse(f_out.to(base.dtype), base_err).to(base.dtype) |
|
|
| |
| def get_sparsity_loss(self) -> torch.Tensor: |
| if self.temperature is None: |
| raise ValueError("Temperature has not been set.") |
| gate = torch.sigmoid(self.mask / self.temperature) |
| return torch.norm(gate, p=1) |
|
|
| def __str__(self): |
| return f"FeatureMaskIntervention(id={featurizer_id})" |
|
|
| return FeatureMaskIntervention |
|
|
|
|
| |
| |
| |
| class SubspaceFeaturizerModule(torch.nn.Module): |
| """Linear projector onto an orthogonal *rotation* sub-space.""" |
|
|
| def __init__(self, rotate_layer: pv.models.layers.LowRankRotateLayer): |
| super().__init__() |
| self.rotate = rotate_layer |
|
|
| def forward(self, x: torch.Tensor): |
| r = self.rotate.weight.T |
| f = x.to(r.dtype) @ r.T |
| error = x - (f @ r).to(x.dtype) |
| return f, error |
|
|
|
|
| class SubspaceInverseFeaturizerModule(torch.nn.Module): |
| """Inverse of :class:`SubspaceFeaturizerModule`.""" |
|
|
| def __init__(self, rotate_layer: pv.models.layers.LowRankRotateLayer): |
| super().__init__() |
| self.rotate = rotate_layer |
|
|
| def forward(self, f, error): |
| r = self.rotate.weight.T |
| return (f.to(r.dtype) @ r).to(f.dtype) + error.to(f.dtype) |
|
|
|
|
| class SubspaceFeaturizer(Featurizer): |
| """Orthogonal linear sub-space featurizer.""" |
|
|
| def __init__( |
| self, |
| *, |
| shape: Tuple[int, int] | None = None, |
| rotation_subspace: torch.Tensor | None = None, |
| trainable: bool = True, |
| id: str = "subspace", |
| ): |
| assert ( |
| shape is not None or rotation_subspace is not None |
| ), "Provide either `shape` or `rotation_subspace`." |
|
|
| if shape is not None: |
| rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=True) |
| else: |
| shape = rotation_subspace.shape |
| rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=False) |
| rotate.weight.data.copy_(rotation_subspace) |
|
|
| rotate = torch.nn.utils.parametrizations.orthogonal(rotate) |
| rotate.requires_grad_(trainable) |
|
|
| super().__init__( |
| SubspaceFeaturizerModule(rotate), |
| SubspaceInverseFeaturizerModule(rotate), |
| n_features=rotate.weight.shape[1], |
| id=id, |
| ) |
|
|
|
|
| class SAEFeaturizerModule(torch.nn.Module): |
| """Wrapper around a *Sparse Autoencoder*’s encode() / decode() pair.""" |
|
|
| def __init__(self, sae): |
| super().__init__() |
| self.sae = sae |
|
|
| def forward(self, x): |
| features = self.sae.encode(x.to(self.sae.dtype)) |
| error = x - self.sae.decode(features).to(x.dtype) |
| return features.to(x.dtype), error |
|
|
|
|
| class SAEInverseFeaturizerModule(torch.nn.Module): |
| """Inverse for :class:`SAEFeaturizerModule`.""" |
|
|
| def __init__(self, sae): |
| super().__init__() |
| self.sae = sae |
|
|
| def forward(self, features, error): |
| return ( |
| self.sae.decode(features.to(self.sae.dtype)).to(features.dtype) |
| + error.to(features.dtype) |
| ) |
|
|
|
|
| class SAEFeaturizer(Featurizer): |
| """Featurizer backed by a pre-trained sparse auto-encoder. |
| |
| Notes |
| ----- |
| Serialisation is *disabled* for SAE featurizers – saving will raise |
| ``NotImplementedError``. |
| """ |
|
|
| def __init__(self, sae, *, trainable: bool = False): |
| sae.requires_grad_(trainable) |
| super().__init__( |
| SAEFeaturizerModule(sae), |
| SAEInverseFeaturizerModule(sae), |
| n_features=sae.cfg.to_dict()["d_sae"], |
| id="sae", |
| ) |
|
|
|
|
| |
| |
| |
| def _subspace_is_all_none(subspaces) -> bool: |
| """Return ``True`` if *every* element of *subspaces* is ``None``.""" |
| return subspaces is None or all( |
| inner is None or all(elem is None for elem in inner) for inner in subspaces |
| ) |
|
|