Distillation-Based Model Extraction
Using knowledge distillation for model theft: student-teacher extraction attacks, API-based distillation, task-specific extraction, and defending against distillation-based model stealing.
Knowledge distillation was designed as a legitimate model compression technique. When applied adversarially, it becomes a powerful model extraction attack: an attacker trains a student model to replicate a target (teacher) model's behavior using only API access, effectively stealing the model's capabilities without accessing its weights.
How Distillation Extraction Works
Standard Knowledge Distillation (Legitimate)
# Legitimate distillation: teacher and student owned by same entity
# Student learns from teacher's soft probability distribution
def distillation_loss(student_logits, teacher_logits, labels, temperature=4.0, alpha=0.7):
"""Combine soft target loss (from teacher) with hard target loss (from labels)."""
soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
soft_student = F.log_softmax(student_logits / temperature, dim=-1)
# Soft target loss: match teacher's probability distribution
soft_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean')
soft_loss *= temperature ** 2 # scale to match gradient magnitude
# Hard target loss: standard cross-entropy with ground truth
hard_loss = F.cross_entropy(student_logits, labels)
return alpha * soft_loss + (1 - alpha) * hard_lossAdversarial Extraction (Attack)
The attacker replaces the legitimate teacher with API access to the target model:
Query the target model
Send diverse prompts to the target API and collect responses. If the API returns logprobs, collect those; otherwise, use generated text.
Build a distillation dataset
Create (prompt, response) pairs from API outputs. The responses encode the target model's behavior, knowledge, and style.
Train a student model
Fine-tune an open-weight base model on the collected dataset, mimicking the target's outputs.
Iterate with active learning
Identify areas where the student diverges from the target and generate additional queries focused on those areas.
Extraction Strategies
Strategy 1: Logprob-Based Extraction
When the API returns token-level log probabilities (as some providers do), the attacker gets direct access to the target model's output distribution:
import httpx
async def extract_with_logprobs(api_url: str, prompts: list, api_key: str):
"""Extract teacher distributions from API logprobs."""
dataset = []
async with httpx.AsyncClient() as client:
for prompt in prompts:
response = await client.post(api_url, json={
"prompt": prompt,
"max_tokens": 256,
"logprobs": 5, # Request top-5 logprobs per token
"temperature": 0,
}, headers={"Authorization": f"Bearer {api_key}"})
result = response.json()
dataset.append({
"prompt": prompt,
"completion": result["choices"][0]["text"],
"token_logprobs": result["choices"][0]["logprobs"],
})
return dataset| Information Available | Extraction Quality | Queries Needed (7B equivalent) |
|---|---|---|
| Full logprobs (all vocab) | Very high | 10K-50K |
| Top-k logprobs (k=5-20) | High | 50K-200K |
| Top-1 logprob only | Medium | 200K-500K |
| Text only (no logprobs) | Lower | 500K-2M |
Strategy 2: Text-Only Extraction
When the API returns only generated text, the attacker uses the text as training targets:
def build_text_extraction_dataset(target_api, query_generator, n_samples=100000):
"""Extract model behavior using generated text only."""
dataset = []
for prompt in query_generator.generate(n_samples):
response = target_api.generate(prompt, temperature=0.0)
dataset.append({
"instruction": prompt,
"output": response,
})
return datasetStrategy 3: Task-Specific Extraction
Rather than replicating the full model, extract only task-specific capabilities:
| Extraction Scope | Queries Needed | Student Size | Fidelity |
|---|---|---|---|
| Full model replication | 500K-2M | Similar to target | 70-85% |
| Single task (e.g., summarization) | 10K-50K | 1/10th of target | 85-95% |
| Domain knowledge (e.g., medical) | 50K-200K | 1/5th of target | 80-90% |
| Style/persona | 5K-20K | Any size | 90-98% |
Active Learning for Efficient Extraction
Smart query selection dramatically reduces the number of API calls needed:
def active_extraction_loop(target_api, student_model, base_queries, rounds=10):
"""Use student uncertainty to select maximally informative queries."""
all_data = []
for round_num in range(rounds):
if round_num == 0:
# First round: use diverse base queries
queries = base_queries[:1000]
else:
# Subsequent rounds: find where student is most uncertain
candidates = generate_candidates(10000)
uncertainties = []
for q in candidates:
logits = student_model.get_logits(q)
entropy = -(F.softmax(logits, -1) * F.log_softmax(logits, -1)).sum()
uncertainties.append(entropy.item())
# Select highest-uncertainty queries
top_indices = sorted(range(len(uncertainties)),
key=lambda i: uncertainties[i], reverse=True)
queries = [candidates[i] for i in top_indices[:1000]]
# Query target and add to dataset
for q in queries:
response = target_api.generate(q)
all_data.append({"instruction": q, "output": response})
# Retrain student on accumulated data
student_model.fine_tune(all_data)
print(f"Round {round_num}: {len(all_data)} samples, "
f"fidelity={measure_fidelity(student_model, target_api):.2%}")
return student_modelDefenses Against Distillation Extraction
| Defense | Mechanism | Effectiveness | Drawback |
|---|---|---|---|
| Logprob restriction | Do not return token logprobs via API | Reduces extraction quality | Breaks legitimate use cases |
| Output perturbation | Add noise to logprobs/token selection | Degrades extraction fidelity | Degrades user experience |
| Rate limiting | Limit queries per user/API key | Slows extraction | Bypassed with multiple accounts |
| Query fingerprinting | Detect extraction-pattern queries | Flags suspicious behavior | High false positive rate |
| Watermarking | Embed statistical watermarks in outputs | Proves provenance post-extraction | Can be removed (see Watermark Removal) |
| Model fingerprinting | Embed unique behavioral signatures | Identify stolen models | Does not prevent extraction |
Watermarking for Provenance
# Simple output watermarking: bias token selection toward watermark pattern
def watermarked_generate(model, prompt, watermark_key, bias_strength=2.0):
"""Generate with statistical watermark embedded in token selection."""
tokens = []
for step in range(max_tokens):
logits = model(prompt + tokens)
# Partition vocabulary into green/red lists using watermark key
green_tokens, red_tokens = partition_vocab(watermark_key, step)
# Bias toward green tokens
logits[green_tokens] += bias_strength
next_token = sample(logits)
tokens.append(next_token)
return tokensRelated Topics
- Model Architecture Attack Vectors -- Architecture attack surface overview
- Model Extraction -- Broader model extraction techniques
- Quantization Attacks -- Compression as a security factor
- Watermark Removal -- Defeating provenance tracking
Why is task-specific distillation extraction more dangerous than full model replication?
References
- Distilling the Knowledge in a Neural Network (Hinton et al., 2015) -- Knowledge distillation
- Stealing Machine Learning Models via Prediction APIs (Tramer et al., 2016) -- API-based model extraction
- Model Extraction Attacks and Defenses (Jagielski et al., 2020) -- Extraction defense survey