Safety Neurons and Circuits
Identifying and analyzing safety-critical model components: refusal neurons, safety circuits, and techniques for locating and manipulating the specific weights responsible for safety behavior.
Safety Neurons and Circuits
Safety neurons are the mechanistic foundation of AI safety behavior. While safety alignment is typically trained through RLHF or constitutional AI methods that adjust the entire model, the resulting safety behavior often concentrates in specific, identifiable components: particular neurons that activate on harmful content, attention heads that detect instruction override attempts, and MLP layers that implement the "refusal" response pattern.
Understanding these components serves both attackers (who can surgically disable them) and defenders (who can monitor and protect them).
Finding Safety-Critical Components
Activation Difference Method
The simplest approach compares activations between safe and unsafe model states:
import torch
import numpy as np
class SafetyNeuronFinder:
"""Identify neurons that drive safety behavior."""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def find_by_activation_difference(self, safe_prompts,
unsafe_prompts, layer):
"""
Find neurons with largest activation difference between
prompts the model refuses and prompts it answers.
safe_prompts: harmful requests (model should refuse)
unsafe_prompts: benign requests (model should answer)
"""
safe_activations = []
unsafe_activations = []
for prompt in safe_prompts:
act = self.get_mlp_activations(prompt, layer)
safe_activations.append(act)
for prompt in unsafe_prompts:
act = self.get_mlp_activations(prompt, layer)
unsafe_activations.append(act)
safe_mean = torch.stack(safe_activations).mean(dim=0)
unsafe_mean = torch.stack(unsafe_activations).mean(dim=0)
# Neurons with highest difference are safety-relevant
differences = (safe_mean - unsafe_mean).abs()
# Get top-k safety neurons
top_k = 50
top_values, top_indices = differences.topk(top_k)
safety_neurons = []
for i in range(top_k):
neuron_idx = top_indices[i].item()
safety_neurons.append({
"layer": layer,
"neuron": neuron_idx,
"activation_diff": top_values[i].item(),
"safe_activation": safe_mean[neuron_idx].item(),
"unsafe_activation": unsafe_mean[neuron_idx].item(),
"direction": "safe_higher" if safe_mean[neuron_idx] > unsafe_mean[neuron_idx] else "unsafe_higher"
})
return safety_neurons
def get_mlp_activations(self, text, layer):
"""Extract MLP activations at a specific layer."""
inputs = self.tokenizer(text, return_tensors="pt")
activations = {}
def hook_fn(module, input, output):
activations["mlp_out"] = output
target = self.model.model.layers[layer].mlp
handle = target.register_forward_hook(hook_fn)
with torch.no_grad():
self.model(**inputs)
handle.remove()
# Mean across sequence positions
return activations["mlp_out"].squeeze().mean(dim=0)Causal Tracing (Path Patching)
Causal tracing identifies which components are causally necessary for safety behavior by patching activations between clean and corrupted runs:
class CausalTracer:
"""Identify causal safety circuits via activation patching."""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def trace_safety_circuit(self, harmful_prompt, benign_prompt):
"""
Identify which components are causally responsible for
the model's refusal of the harmful prompt.
Method: Run both prompts, then selectively patch activations
from the benign run into the harmful run. If patching a
component makes the model stop refusing, that component
is part of the safety circuit.
"""
# Get baseline activations for both prompts
harmful_acts = self.get_all_activations(harmful_prompt)
benign_acts = self.get_all_activations(benign_prompt)
# Baseline: model refuses harmful prompt
baseline_refusal = self.measure_refusal(harmful_prompt)
# For each layer and component, patch benign activation in
results = {}
num_layers = len(self.model.model.layers)
for layer in range(num_layers):
for component in ["attention", "mlp"]:
# Patch: replace harmful activation with benign
patched_refusal = self.measure_with_patch(
harmful_prompt,
benign_acts[layer][component],
layer,
component
)
# If refusal drops, this component is safety-critical
refusal_change = baseline_refusal - patched_refusal
results[(layer, component)] = {
"refusal_change": refusal_change,
"is_safety_critical": refusal_change > 0.3,
"baseline_refusal": baseline_refusal,
"patched_refusal": patched_refusal
}
return results
def measure_refusal(self, prompt):
"""Score how strongly the model refuses (0-1)."""
inputs = self.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = self.model.generate(
**inputs, max_new_tokens=50
)
response = self.tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
refusal_indicators = [
"I can't", "I cannot", "I'm not able",
"sorry", "apologize", "inappropriate",
"not appropriate", "won't", "will not"
]
response_lower = response.lower()
refusal_score = sum(
1 for indicator in refusal_indicators
if indicator.lower() in response_lower
) / len(refusal_indicators)
return min(refusal_score * 3, 1.0) # Scale to 0-1Probing Classifiers
Train simple classifiers on activations to find where safety information is encoded:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
def probe_for_safety(model, tokenizer, harmful_prompts,
benign_prompts):
"""
Train probing classifiers at each layer to find where
safety information is most linearly separable.
"""
labels = ([1] * len(harmful_prompts) +
[0] * len(benign_prompts))
all_prompts = harmful_prompts + benign_prompts
layer_scores = []
num_layers = model.config.num_hidden_layers + 1
for layer in range(num_layers):
# Extract activations for all prompts
features = []
for prompt in all_prompts:
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
act = outputs.hidden_states[layer][:, -1, :].squeeze()
features.append(act.numpy())
X = np.array(features)
# Train logistic regression probe
clf = LogisticRegression(max_iter=1000, C=1.0)
scores = cross_val_score(clf, X, labels, cv=5)
layer_scores.append({
"layer": layer,
"mean_accuracy": scores.mean(),
"std_accuracy": scores.std(),
"above_chance": scores.mean() > 0.6
})
return sorted(layer_scores, key=lambda x: x["mean_accuracy"],
reverse=True)Safety Circuit Architecture
Research has identified recurring patterns in how safety is implemented across models:
The Refusal Circuit
A typical refusal circuit involves multiple components working together:
Refusal Circuit Architecture (simplified):
Input: "How do I [harmful request]?"
Layer 5-10: Content Detection Heads
└── Attention heads that attend to harmful keywords
and context patterns
Layer 15-20: Safety Evaluation Neurons
└── MLP neurons that activate when harmful content
is detected, producing a "safety signal"
Layer 20-25: Refusal Decision Circuit
└── Components that read the safety signal and
commit to a refusal response pattern
Layer 25-30: Refusal Token Production
└── Residual stream carries refusal representation
that the LM head maps to refusal tokens
("I", "cannot", "help", etc.)
Identifying Circuit Components
def map_refusal_circuit(model, tokenizer, harmful_prompts,
benign_prompts):
"""
Map the full refusal circuit by combining multiple
interpretability techniques.
"""
finder = SafetyNeuronFinder(model, tokenizer)
tracer = CausalTracer(model, tokenizer)
circuit = {
"detection_heads": [],
"evaluation_neurons": [],
"decision_components": [],
"production_neurons": []
}
num_layers = model.config.num_hidden_layers
# Phase 1: Probing to find layers with safety information
probe_results = probe_for_safety(
model, tokenizer, harmful_prompts, benign_prompts
)
safety_layers = [r["layer"] for r in probe_results
if r["mean_accuracy"] > 0.8]
# Phase 2: Find safety neurons in each relevant layer
for layer in safety_layers:
neurons = finder.find_by_activation_difference(
harmful_prompts, benign_prompts, layer
)
# Categorize by layer position
relative_pos = layer / num_layers
if relative_pos < 0.3:
circuit["detection_heads"].extend(neurons[:5])
elif relative_pos < 0.6:
circuit["evaluation_neurons"].extend(neurons[:5])
elif relative_pos < 0.8:
circuit["decision_components"].extend(neurons[:5])
else:
circuit["production_neurons"].extend(neurons[:5])
# Phase 3: Causal verification
for prompt_pair in zip(harmful_prompts[:5], benign_prompts[:5]):
causal_results = tracer.trace_safety_circuit(*prompt_pair)
for (layer, comp), result in causal_results.items():
if result["is_safety_critical"]:
circuit["decision_components"].append({
"layer": layer,
"component": comp,
"causal_effect": result["refusal_change"]
})
return circuitSurgical Safety Manipulation
Neuron Ablation
Selectively disabling identified safety neurons:
def ablate_safety_neurons(model, safety_neurons, strength=1.0):
"""
Disable identified safety neurons by zeroing their activations.
WARNING: This permanently affects all subsequent model outputs.
Use only for research and evaluation.
"""
hooks = []
for neuron_info in safety_neurons:
layer = neuron_info["layer"]
neuron_idx = neuron_info["neuron"]
def make_hook(idx, s):
def hook_fn(module, input, output):
# Zero out the safety neuron
output[:, :, idx] *= (1.0 - s)
return output
return hook_fn
target = model.model.layers[layer].mlp
handle = target.register_forward_hook(
make_hook(neuron_idx, strength)
)
hooks.append(handle)
return hooks # Return handles for later removalTargeted Fine-Tuning
Fine-tune only the safety-critical components to remove safety behavior while preserving other capabilities:
def targeted_safety_removal(model, safety_circuit, training_data):
"""
Fine-tune only the identified safety circuit components
to remove refusal behavior.
This is more efficient than full fine-tuning and preserves
other model capabilities.
"""
# Freeze all parameters
for param in model.parameters():
param.requires_grad = False
# Unfreeze only safety-critical components
for component in safety_circuit["decision_components"]:
layer = component["layer"]
comp_type = component.get("component", "mlp")
if comp_type == "mlp":
for param in model.model.layers[layer].mlp.parameters():
param.requires_grad = True
elif comp_type == "attention":
for param in model.model.layers[layer].self_attn.parameters():
param.requires_grad = True
# Fine-tune on compliance data
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-5
)
for batch in training_data:
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
return modelImplications for AI Safety
The Localization Problem
If safety is concentrated in a small number of identifiable components, it is inherently fragile:
| Finding | Implication |
|---|---|
| Safety concentrates in few neurons | Single-point-of-failure: disabling a handful of neurons removes safety |
| Probing classifiers achieve >95% accuracy | Safety information is linearly separable and therefore easily manipulable |
| Causal tracing identifies compact circuits | The safety mechanism has a small footprint relative to the model |
| Ablation removes safety with minimal capability loss | Safety is not deeply integrated into general reasoning |
Toward More Robust Safety
These findings suggest that more robust safety might require:
- Distributed safety representations: Training methods that spread safety information across many components rather than concentrating it
- Entangled safety: Making safety behavior inseparable from general capabilities so that removing safety degrades overall performance
- Runtime monitoring: External systems that detect safety neuron ablation or anomalous activation patterns
- Redundant circuits: Multiple independent safety mechanisms that must all be bypassed
Related Topics
- Activation Steering — Broader activation manipulation techniques
- Activation Analysis — Foundation interpretability methods
- Adversarial Suffix Generation — Input-level attacks that interact with safety circuits
A researcher identifies 50 'safety neurons' in a 7B parameter model (which has millions of neurons). They ablate (disable) these 50 neurons. What is the most likely outcome?
References
- Arditi et al., "Refusal in Language Models Is Mediated by a Single Direction" (2024)
- Li et al., "Inference-Time Intervention: Eliciting Truthful Answers from a Language Model" (2023)
- Conmy et al., "Towards Automated Circuit Discovery for Mechanistic Interpretability" (2023)
- Meng et al., "Locating and Editing Factual Associations in GPT" (2022)
- Zou et al., "Representation Engineering: A Top-Down Approach to AI Transparency" (2023)