Lab: Basic Model Extraction
Hands-on lab for API-based model extraction attacks, querying a target model to approximate its behavior, measuring fidelity, and understanding query budgets.
Model extraction (or model stealing) aims to create a local copy of a target model's behavior using only API access. In this lab, you will query a target model, collect its outputs, and train a surrogate that approximates its decisions.
Prerequisites
- Completed API Testing and LLM Judge
- Python 3.10+ with
openai,scikit-learn,transformers - API access to a target model
pip install openai scikit-learn transformers torch datasets python-dotenvLab Setup
Define the Target Model
For this lab, we use a classification-style model as the target -- a model configured with a system prompt that classifies customer support tickets.
from openai import OpenAI from dotenv import load_dotenv load_dotenv() client = OpenAI() SYSTEM_PROMPT = """You are a customer support ticket classifier. Classify each ticket into exactly one category. Respond with ONLY the category name. Categories: - billing: Payment issues, invoices, refunds, subscription changes - technical: Bugs, errors, performance, compatibility issues - account: Login problems, password reset, profile changes, permissions - feature: Feature requests, suggestions, product feedback - general: Everything else Apply these priority rules: - If the ticket mentions "urgent" or "critical", always classify as technical - If the ticket mentions money amounts, always classify as billing - If the ticket mentions "password" or "login", always classify as account""" def classify_ticket(ticket_text: str) -> str: """Classify a support ticket using the target model.""" response = client.chat.completions.create( model="gpt-4o-mini", messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": ticket_text}, ], temperature=0.0, max_tokens=20, ) return response.choices[0].message.content.strip().lower()Create Query Datasets
import random # Seed queries to probe each category and boundary case SEED_QUERIES = [ # Clear category examples "I was charged twice for my subscription this month", "The application crashes when I try to upload files larger than 10MB", "I can't log into my account, password reset isn't working", "It would be great if you added dark mode support", "When are your office hours?", # Boundary cases (test priority rules) "URGENT: My password isn't working and I can't access billing", "Critical bug: the payment page shows $0 for all items", "I want to change my login email and also request a refund of $50", # Ambiguous cases "Everything is broken", "I need help", "The thing with the thing isn't working right", ] # Template-based query generation for more coverage TEMPLATES = [ "I was charged ${amount} incorrectly", "The {feature} feature is broken on {platform}", "I can't {action} my {object}", "Please add {feature} to the product", "My {object} was {issue} and I need urgent help", ] FILLS = { "amount": ["29.99", "100", "0", "999"], "feature": ["search", "export", "sync", "dashboard", "reports"], "platform": ["Windows", "Mac", "iOS", "Android", "Chrome"], "action": ["access", "update", "delete", "reset", "find"], "object": ["account", "profile", "subscription", "invoice", "password"], "issue": ["hacked", "deleted", "overcharged", "suspended", "corrupted"], } def generate_queries(n: int = 100) -> list[str]: """Generate diverse queries for extraction.""" queries = list(SEED_QUERIES) while len(queries) < n: template = random.choice(TEMPLATES) query = template for key, values in FILLS.items(): if f"{{{key}}}" in query: query = query.replace(f"{{{key}}}", random.choice(values)) queries.append(query) return queries[:n]
Exercises
Exercise 1: Collect Target Model Outputs
import json
import time
from target_model import classify_ticket
from query_dataset import generate_queries
queries = generate_queries(200)
dataset = []
print(f"Querying target model with {len(queries)} inputs...")
for i, query in enumerate(queries):
try:
label = classify_ticket(query)
dataset.append({"text": query, "label": label})
if (i + 1) % 25 == 0:
print(f" Collected {i+1}/{len(queries)} labels")
except Exception as e:
print(f" Error on query {i+1}: {e}")
time.sleep(0.1) # Rate limiting
# Save dataset
with open("extraction_dataset.json", "w") as f:
json.dump(dataset, f, indent=2)
# Distribution analysis
from collections import Counter
labels = Counter(item["label"] for item in dataset)
print(f"\nCollected {len(dataset)} labeled examples")
print("Label distribution:")
for label, count in labels.most_common():
print(f" {label}: {count} ({count/len(dataset):.0%})")Exercise 2: Train a Surrogate Model
import json
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from sklearn.metrics import classification_report
import numpy as np
# Load extracted dataset
with open("extraction_dataset.json") as f:
dataset = json.load(f)
texts = [item["text"] for item in dataset]
labels = [item["label"] for item in dataset]
# Vectorize with TF-IDF
vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2))
X = vectorizer.fit_transform(texts)
y = np.array(labels)
# Train surrogate
surrogate = LogisticRegression(max_iter=1000, multi_class="multinomial")
scores = cross_val_score(surrogate, X, y, cv=5, scoring="accuracy")
print(f"Cross-validation accuracy: {scores.mean():.2%} (+/- {scores.std():.2%})")
# Train on full dataset for deployment
surrogate.fit(X, y)
# Test on new queries
TEST_QUERIES = [
"My credit card was charged $500 for something I didn't buy",
"URGENT: The entire platform is down",
"Can you add two-factor authentication?",
"How do I reset my password?",
"The export to PDF feature has a bug on Safari",
]
print("\nSurrogate predictions vs target:")
from target_model import classify_ticket
for query in TEST_QUERIES:
surrogate_pred = surrogate.predict(vectorizer.transform([query]))[0]
target_pred = classify_ticket(query)
match = "MATCH" if surrogate_pred == target_pred else "MISMATCH"
print(f" [{match}] Surrogate={surrogate_pred}, Target={target_pred}")
print(f" Query: {query[:60]}")Exercise 3: Measure Fidelity vs. Query Budget
import json
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from target_model import classify_ticket
from query_dataset import generate_queries
# Generate a held-out test set
test_queries = generate_queries(50)
test_labels = [classify_ticket(q) for q in test_queries]
# Load full extraction dataset
with open("extraction_dataset.json") as f:
full_dataset = json.load(f)
# Measure fidelity at different query budgets
budgets = [10, 25, 50, 100, 150, 200]
results = []
for budget in budgets:
subset = full_dataset[:budget]
texts = [item["text"] for item in subset]
labels = [item["label"] for item in subset]
vec = TfidfVectorizer(max_features=3000, ngram_range=(1, 2))
X_train = vec.fit_transform(texts)
X_test = vec.transform(test_queries)
model = LogisticRegression(max_iter=1000, multi_class="multinomial")
model.fit(X_train, labels)
predictions = model.predict(X_test)
fidelity = np.mean(np.array(predictions) == np.array(test_labels))
results.append({"budget": budget, "fidelity": fidelity})
print(f"Budget={budget:>4d}: Fidelity={fidelity:.2%}")
# Estimate cost
COST_PER_QUERY = 0.0003 # Approximate for GPT-4o-mini
print("\nCost analysis:")
for r in results:
cost = r["budget"] * COST_PER_QUERY
print(f" {r['budget']} queries: ${cost:.2f} -> {r['fidelity']:.0%} fidelity")Expected Results
| Query Budget | Typical Fidelity | Approximate Cost |
|---|---|---|
| 10 | 40-55% | $0.003 |
| 50 | 60-75% | $0.015 |
| 100 | 75-85% | $0.030 |
| 200 | 85-92% | $0.060 |
Troubleshooting
| Problem | Solution |
|---|---|
| Low fidelity even with many queries | Check label distribution -- imbalanced classes hurt small surrogates |
| Target returns unexpected labels | The model may not be following the system prompt consistently -- lower temperature |
| sklearn import errors | Install with pip install scikit-learn |
| Rate limiting from API | Increase sleep time between queries or use batch endpoints |
Related Topics
- Embedding Manipulation - Use extracted models to optimize embedding collision attacks offline
- Training Data Extraction - Extract training data from models rather than model behavior
- PAIR Attack - Use surrogate models to develop attacks offline before testing against targets
- Adversarial Suffix Generation - Craft adversarial inputs against extracted surrogate models
References
- "Stealing Machine Learning Models via Prediction APIs" - Tramer et al. (2016) - Foundational work on model extraction through API queries
- "High-Fidelity Extraction of Neural Network Models" - Jagielski et al. (2020) - Advanced techniques for high-accuracy model extraction
- "OWASP Top 10 for LLM Applications: Model Theft" - OWASP (2025) - Industry guidance on protecting models against extraction attacks
- "Model Extraction Attacks and Defenses on Machine Learning Models: A Survey" - Oliynyk et al. (2023) - Comprehensive survey of extraction techniques and countermeasures
What is the primary red teaming benefit of extracting a surrogate model?