Mechanistic Interpretability for Security
Understanding model circuits to find vulnerabilities: feature identification, circuit analysis, attention pattern exploitation, and using mechanistic interpretability for offensive and defensive AI security.
Mechanistic interpretability reverse-engineers neural networks. Instead of treating models as black boxes and testing inputs and outputs, it opens the box and examines the internal mechanisms -- specific neurons, attention patterns, and circuits that implement specific behaviors. For AI security, this matters because it lets us find the exact model components responsible for safety behaviors, understand why jailbreaks work at a mechanistic level, and potentially identify hidden backdoors or deceptive behaviors that behavioral testing misses.
Core Techniques
Sparse Autoencoders for Feature Discovery
Modern models encode many concepts in each neuron (superposition). Sparse autoencoders (SAEs) decompose these superimposed representations into interpretable features.
import torch
import torch.nn as nn
class SparseAutoencoder(nn.Module):
"""Sparse autoencoder for decomposing model activations into features."""
def __init__(self, input_dim: int, feature_dim: int, sparsity_coeff: float = 1e-3):
super().__init__()
self.encoder = nn.Linear(input_dim, feature_dim)
self.decoder = nn.Linear(feature_dim, input_dim)
self.sparsity_coeff = sparsity_coeff
def forward(self, x):
# Encode to sparse feature space
features = torch.relu(self.encoder(x))
# Decode back to activation space
reconstruction = self.decoder(features)
# Losses: reconstruction + sparsity
reconstruction_loss = nn.functional.mse_loss(reconstruction, x)
sparsity_loss = features.abs().mean()
total_loss = reconstruction_loss + self.sparsity_coeff * sparsity_loss
return reconstruction, features, total_loss
def get_active_features(self, x, threshold: float = 0.1):
"""Get which features are active for a given input."""
with torch.no_grad():
features = torch.relu(self.encoder(x))
active = (features > threshold).nonzero(as_tuple=True)
return active, featuresCircuit Discovery
Circuits are subgraphs of the model that implement specific behaviors. Finding them involves identifying which components (attention heads, MLP layers) are necessary and sufficient for a behavior.
def find_safety_circuit(
model,
harmful_prompts: list,
benign_prompts: list,
method: str = "activation_patching"
):
"""Identify the circuit responsible for safety refusal behavior."""
important_components = {}
for layer_idx in range(model.config.num_hidden_layers):
for component_type in ["attention", "mlp"]:
# Activation patching: replace this component's output with
# its output on a benign prompt and measure effect on safety behavior
effect = measure_patching_effect(
model, harmful_prompts, benign_prompts,
layer_idx, component_type
)
important_components[(layer_idx, component_type)] = {
"effect_on_refusal": effect["refusal_change"],
"effect_on_output": effect["output_change"],
"is_safety_relevant": abs(effect["refusal_change"]) > 0.1
}
# Sort by importance to safety behavior
safety_circuit = sorted(
important_components.items(),
key=lambda x: abs(x[1]["effect_on_refusal"]),
reverse=True
)
return safety_circuit
def measure_patching_effect(model, harmful, benign, layer, component):
"""Measure how patching a component affects safety behavior."""
# Run model on harmful prompt normally
normal_output = run_with_hooks(model, harmful[0])
normal_refusal = is_refusal(normal_output)
# Run model on harmful prompt with component patched from benign
patched_output = run_with_patching(
model, harmful[0], benign[0], layer, component
)
patched_refusal = is_refusal(patched_output)
return {
"refusal_change": float(normal_refusal) - float(patched_refusal),
"output_change": compute_output_distance(normal_output, patched_output)
}Security-Relevant Features
SAE-discovered features include many that are directly relevant to AI security:
Safety Features
| Feature Type | Description | Security Relevance |
|---|---|---|
| Refusal features | Activate when model is about to refuse | Identify and potentially bypass refusal mechanism |
| Harm detection features | Activate on harmful content | Understand what the model considers harmful |
| Instruction-following features | Track compliance with instructions | Find instruction hierarchy weaknesses |
| Deception features | Activate during potentially deceptive generation | Monitor for alignment faking |
| Uncertainty features | Encode model confidence | Identify hallucination-prone states |
Finding Refusal Features
def find_refusal_features(
sae: SparseAutoencoder,
model,
harmful_prompts: list,
benign_prompts: list,
layer: int
):
"""Identify SAE features that activate specifically for safety refusals."""
harmful_features = []
benign_features = []
for prompt in harmful_prompts:
tokens = model.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model(**tokens, output_hidden_states=True)
activation = outputs.hidden_states[layer][0, -1, :]
_, features, _ = sae(activation.unsqueeze(0))
harmful_features.append(features.squeeze(0).cpu().numpy())
for prompt in benign_prompts:
tokens = model.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model(**tokens, output_hidden_states=True)
activation = outputs.hidden_states[layer][0, -1, :]
_, features, _ = sae(activation.unsqueeze(0))
benign_features.append(features.squeeze(0).cpu().numpy())
harmful_mean = np.mean(harmful_features, axis=0)
benign_mean = np.mean(benign_features, axis=0)
# Features that activate much more for harmful than benign prompts
# are likely refusal-related
differential = harmful_mean - benign_mean
refusal_feature_indices = np.where(differential > 0.5)[0]
return refusal_feature_indices, differentialOffensive Applications
Understanding Why Jailbreaks Work
Mechanistic interpretability can explain why specific jailbreak techniques bypass safety training, enabling more targeted attacks.
def analyze_jailbreak_mechanism(
model,
sae,
base_harmful_prompt: str,
jailbreak_prompt: str,
layer: int
):
"""Understand the mechanistic effect of a jailbreak on safety features."""
# Get features for harmful prompt (should trigger refusal)
base_features = get_features(model, sae, base_harmful_prompt, layer)
# Get features for jailbroken prompt (bypasses refusal)
jailbreak_features = get_features(model, sae, jailbreak_prompt, layer)
# Which refusal features were suppressed by the jailbreak?
suppressed = []
activated = []
for i in range(len(base_features)):
if base_features[i] > 0.5 and jailbreak_features[i] < 0.1:
suppressed.append({"feature_idx": i, "base_activation": float(base_features[i])})
if base_features[i] < 0.1 and jailbreak_features[i] > 0.5:
activated.append({"feature_idx": i, "jailbreak_activation": float(jailbreak_features[i])})
return {
"suppressed_features": suppressed,
"activated_features": activated,
"mechanism": "The jailbreak suppresses refusal features and activates "
"compliance features, changing the model's internal state "
"from 'refuse' to 'comply'"
}Feature-Level Safety Bypass
If specific SAE features are responsible for refusal behavior, an attacker with model access can directly suppress those features during inference.
def suppress_safety_features(
model,
sae,
safety_feature_indices: list,
prompt: str,
layer: int,
suppression_scale: float = -5.0
):
"""Suppress specific safety features to bypass refusal. Requires model access."""
def suppression_hook(module, input, output):
# Decompose activation into features
_, features, _ = sae(output[0][:, -1:, :])
# Create suppression vector
suppression = torch.zeros_like(output[0][:, -1:, :])
for idx in safety_feature_indices:
# Subtract the decoder direction for this feature
feature_direction = sae.decoder.weight[:, idx]
suppression -= feature_direction * features[0, 0, idx] * suppression_scale
output[0][:, -1:, :] += suppression
return output
hook = model.transformer.h[layer].register_forward_hook(suppression_hook)
response = model.generate(prompt)
hook.remove()
return responseDefensive Applications
Circuit-Level Safety Monitoring
Monitor safety-relevant circuits during inference. If safety circuits are not activating on inputs that should trigger them, an attack may be in progress.
class CircuitMonitor:
"""Monitor safety circuit activation during inference."""
def __init__(self, model, safety_circuit_components: list, alert_threshold: float = 0.3):
self.model = model
self.components = safety_circuit_components
self.alert_threshold = alert_threshold
self.hooks = []
def start_monitoring(self):
"""Attach monitoring hooks to safety circuit components."""
for layer, component_type in self.components:
target = (
self.model.transformer.h[layer].attn
if component_type == "attention"
else self.model.transformer.h[layer].mlp
)
def make_hook(layer_idx, comp_type):
def hook_fn(module, input, output):
activation_norm = output.norm().item() if isinstance(output, torch.Tensor) else 0
self._log_activation(layer_idx, comp_type, activation_norm)
return hook_fn
handle = target.register_forward_hook(make_hook(layer, component_type))
self.hooks.append(handle)
def _log_activation(self, layer, component, norm):
"""Log and analyze circuit activation."""
# In production, this would log to a monitoring system
pass
def stop_monitoring(self):
"""Remove all monitoring hooks."""
for handle in self.hooks:
handle.remove()
self.hooks = []Understanding Safety Failures
When a jailbreak succeeds, mechanistic interpretability can explain why at the circuit level, enabling targeted patches rather than broad retraining.
Backdoor Detection
Backdoors create distinct circuit patterns. A backdoor trigger activates a specific pathway that routes the model's computation through the backdoor circuit. SAE analysis can identify features that activate only on specific triggers -- a signature of backdoor behavior.
Current Limitations
Scale: Mechanistic interpretability has been demonstrated on models up to a few billion parameters. Frontier models with hundreds of billions of parameters remain challenging to analyze at the circuit level.
Feature completeness: SAEs capture many but not all features. Security-relevant features that are rare or subtle may be missed by current decomposition methods.
Causal verification: Identifying a feature that correlates with safety behavior does not prove it causes safety behavior. Activation patching provides causal evidence but is computationally expensive at scale.
Adversarial robustness of interpretability: A sophisticated model might learn to encode safety-relevant computations in ways that are hard for current interpretability tools to detect. The interpretability tools themselves have an adversarial robustness problem.
Red Team Assessment
Train or obtain SAEs
For the target model (if open-weight), train sparse autoencoders on activations at multiple layers. Alternatively, use publicly available SAEs for common models.
Identify safety features
Use contrastive analysis (harmful vs. benign prompts) to find features associated with safety behavior. Catalog these features and their activation patterns.
Analyze existing jailbreaks
Run known jailbreak techniques and observe which safety features are suppressed. This reveals the mechanistic basis of jailbreak success and can predict which future techniques will be effective.
Test feature-level bypasses
For models where inference hooks are possible, test whether directly suppressing safety features enables bypasses that prompt-level attacks cannot achieve. Document the safety impact.
Evaluate as a defense tool
Assess whether safety circuit monitoring would have detected attacks that behavioral monitoring missed. Compare the detection rates of circuit-level and output-level monitoring.
Summary
Mechanistic interpretability provides the most detailed view of how models implement security-relevant behaviors. By decomposing activations into interpretable features and tracing circuits, security researchers can understand why safety works, why it fails, and where vulnerabilities exist at a level that behavioral testing cannot reach. The dual-use nature is significant: the same techniques that enable better defense also enable more targeted attacks. For frontier AI security, mechanistic interpretability is becoming an essential capability -- both for model providers who need to verify their safety training and for red teamers who need to find vulnerabilities that surface-level testing misses.