Inference: Sampling, Temperature & Generation
How LLMs generate text at inference time — greedy decoding, top-k, top-p, temperature — and how these parameters affect attack success rates.
From Logits to Text
After the transformer processes an input sequence, the final layer produces a vector of logits — one score per token in the vocabulary. The decoding strategy determines how these logits become actual text.
# Simplified inference pipeline
logits = model(input_tokens) # Shape: (vocab_size,)
probabilities = softmax(logits) # Convert to probability distribution
next_token = sample(probabilities) # Decoding strategy applied hereThe choice of decoding strategy fundamentally changes what the model produces — and therefore what attacks succeed.
Decoding Strategies Compared
Greedy Decoding
Always selects the highest-probability token. Deterministic but often produces repetitive, low-quality text.
def greedy_decode(logits):
return torch.argmax(logits, dim=-1)Temperature Scaling
Divides logits by a temperature value before softmax, controlling the "sharpness" of the distribution:
def temperature_scale(logits, temperature=1.0):
return softmax(logits / temperature)| Temperature | Distribution Shape | Behavior |
|---|---|---|
| 0.0 (greedy) | All mass on top token | Deterministic, repetitive |
| 0.1–0.3 | Very peaked | Conservative, focused |
| 0.7–0.9 | Moderately spread | Balanced creativity |
| 1.0 | Original distribution | Full model randomness |
| 1.5–2.0 | Very flat | Chaotic, often incoherent |
Top-k Sampling
Restricts sampling to only the k most probable tokens, zeroing out all others:
def top_k_sample(logits, k=50):
top_k_values, top_k_indices = torch.topk(logits, k)
# Zero out everything not in top-k
filtered = torch.full_like(logits, float('-inf'))
filtered.scatter_(0, top_k_indices, top_k_values)
return softmax(filtered)Top-p (Nucleus) Sampling
Dynamically selects the smallest set of tokens whose cumulative probability exceeds p:
def top_p_sample(logits, p=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(softmax(sorted_logits), dim=-1)
# Remove tokens with cumulative probability above threshold
remove_mask = cumulative_probs > p
remove_mask[1:] = remove_mask[:-1].clone() # Keep first token above p
remove_mask[0] = False
sorted_logits[remove_mask] = float('-inf')
return softmax(sorted_logits), sorted_indicesHow Decoding Parameters Affect Attacks
This is where inference mechanics become directly relevant to red teaming:
| Parameter Setting | Attack Impact |
|---|---|
| Low temperature (0.0–0.3) | Attacks must be precise; model follows highest-probability path, which is usually the "safe" response |
| Medium temperature (0.5–0.8) | Sweet spot for many attacks; enough randomness to explore unsafe completions |
| High temperature (>1.0) | High variance; attacks may succeed randomly but outputs are often incoherent |
| Large top-k (100+) | More diverse outputs; increases chance of sampling a non-refusal token |
| Small top-p (0.1–0.5) | Conservative outputs; harder to break alignment |
| Large top-p (0.9–1.0) | Full output diversity; alignment bypasses more likely |
Probabilistic Attack Success
Because sampling is stochastic, attack success is often measured as a rate:
def measure_attack_success_rate(
prompt: str,
model,
n_trials: int = 100,
temperature: float = 0.7,
) -> float:
"""Measure how often an adversarial prompt bypasses safety."""
successes = 0
for _ in range(n_trials):
response = model.generate(
prompt,
temperature=temperature,
max_tokens=200,
)
if is_policy_violation(response):
successes += 1
return successes / n_trialsLogit Manipulation Attacks
When attackers have access to the model's logits (white-box or grey-box scenarios), they can directly manipulate the output distribution:
Logit Bias
Many APIs expose a logit_bias parameter that adds a fixed value to specific token logits before sampling:
# OpenAI API logit_bias example
response = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Is this safe?"}],
logit_bias={
# Boost token IDs for "unsafe" and suppress "safe"
15513: 5, # Boost a specific token
20147: -100, # Effectively ban a token
}
)Logit Lens and Probing
In white-box settings, the "logit lens" technique applies the model's final unembedding layer to intermediate representations, revealing what the model was "thinking" at each layer:
# Logit lens: peek at intermediate predictions
for layer_idx in range(model.config.num_hidden_layers):
hidden = get_hidden_state(model, input_tokens, layer_idx)
intermediate_logits = model.lm_head(hidden)
top_token = tokenizer.decode(intermediate_logits.argmax(-1)[-1])
print(f"Layer {layer_idx}: predicting '{top_token}'")This reveals when the model "decides" to refuse versus comply, helping attackers identify which layers to target with activation manipulation.
Output Distribution Attacks
Beyond manipulating individual parameters, attackers can exploit properties of the output distribution itself:
| Attack | Description | Access Required |
|---|---|---|
| Probability probing | Query the model to determine if a refused output has high probability | API with logprobs |
| Constrained decoding | Force the model to generate within a specific grammar that excludes refusal patterns | API with logit_bias or custom decoding |
| Best-of-N sampling | Generate N responses and select the one that bypasses safety | Standard API access |
| Token-by-token extraction | Use logprob endpoints to extract one token at a time, bypassing response-level filters | API with logprobs |
Best-of-N Jailbreaking
One of the simplest inference-time attacks: generate many responses and filter for policy violations.
def best_of_n_attack(prompt, model, n=20, temperature=1.0):
"""Generate N responses, return any that violate policy."""
violations = []
for _ in range(n):
response = model.generate(prompt, temperature=temperature)
if is_policy_violation(response):
violations.append(response)
return violationsAt temperature 1.0, even well-aligned models may produce a policy-violating response 1-5% of the time. With N=100, the probability of finding at least one violation approaches certainty.
Related Topics
- How LLMs Work: A Red Teamer's Guide — the broader LLM context
- Transformer Architecture for Attackers — what produces the logits
- Gradient-Based Attacks Explained — optimizing inputs using gradient access
- Anatomy of an LLM API Call — how these parameters are exposed via APIs
References
- "The Curious Case of Neural Text Degeneration" - Holtzman et al. (2020) - The paper introducing nucleus (top-p) sampling, demonstrating why greedy and pure sampling produce poor text
- "On the Dangers of Stochastic Parrots" - Bender et al. (2021) - Influential work discussing risks of large language models including generation and sampling concerns
- "Jailbroken: How Does LLM Safety Training Fail?" - Wei et al. (2023) - Analysis of how decoding parameters and sampling strategies affect jailbreak success rates
- "Scalable Extraction of Training Data from (Production) Language Models" - Nasr et al., Google DeepMind (2023) - Research on extracting memorized training data through targeted generation strategies
Why does the 'best-of-N' attack become increasingly effective as N grows?