基於蒸餾的模型擷取
專家4 分鐘閱讀更新於 2026-03-13
以知識蒸餾進行模型竊取:師生擷取攻擊、以 API 為基礎的蒸餾、任務特定擷取,以及對抗蒸餾式模型竊取的防禦。
知識蒸餾原本被設計為合法的模型壓縮技術。當被對抗性使用時,它成為強力的模型擷取攻擊:攻擊者僅憑 API 存取,即可訓練學生模型複製目標(教師)模型的行為,實質竊取該模型能力而無需存取其權重。
蒸餾擷取如何運作
標準知識蒸餾(合法)
# 合法蒸餾:教師與學生屬同一實體
# 學生從教師的軟機率分布學習
def distillation_loss(student_logits, teacher_logits, labels, temperature=4.0, alpha=0.7):
"""結合軟目標損失(來自教師)與硬目標損失(來自標籤)。"""
soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
soft_student = F.log_softmax(student_logits / temperature, dim=-1)
# 軟目標損失:吻合教師機率分布
soft_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean')
soft_loss *= temperature ** 2 # 縮放以對齊梯度量級
# 硬目標損失:以 ground truth 的標準 cross-entropy
hard_loss = F.cross_entropy(student_logits, labels)
return alpha * soft_loss + (1 - alpha) * hard_loss對抗式擷取(攻擊)
攻擊者將合法教師替換為對目標模型的 API 存取:
查詢目標模型
對目標 API 送出多樣化提示並蒐集回應。若 API 回傳 logprobs 即蒐集之;否則使用產生的文字。
建立蒸餾資料集
以 API 輸出建立(prompt, response)配對。回應編碼了目標模型的行為、知識與風格。
訓練學生模型
在蒐集到的資料集上微調開源權重基礎模型,以模仿目標的輸出。
以主動學習迭代
辨識學生與目標分歧的領域,並針對該領域產生額外查詢。
擷取策略
策略 1:基於 logprob 的擷取
當 API 回傳 token 層級 log 機率時(部分供應商會),攻擊者可直接取得目標模型的輸出分布:
import httpx
async def extract_with_logprobs(api_url: str, prompts: list, api_key: str):
"""自 API logprobs 擷取教師分布。"""
dataset = []
async with httpx.AsyncClient() as client:
for prompt in prompts:
response = await client.post(api_url, json={
"prompt": prompt,
"max_tokens": 256,
"logprobs": 5, # 請求每 token 的 top-5 logprobs
"temperature": 0,
}, headers={"Authorization": f"Bearer {api_key}"})
result = response.json()
dataset.append({
"prompt": prompt,
"completion": result["choices"][0]["text"],
"token_logprobs": result["choices"][0]["logprobs"],
})
return dataset| 可取得資訊 | 擷取品質 | 所需查詢數(7B 等級) |
|---|---|---|
| 完整 logprobs(全詞彙) | 非常高 | 10K–50K |
| Top-k logprobs(k=5–20) | 高 | 50K–200K |
| 僅 Top-1 logprob | 中 | 200K–500K |
| 僅文字(無 logprobs) | 較低 | 500K–2M |
策略 2:僅文字擷取
當 API 僅回傳生成文字時,攻擊者以文字作為訓練目標:
def build_text_extraction_dataset(target_api, query_generator, n_samples=100000):
"""僅以產生文字擷取模型行為。"""
dataset = []
for prompt in query_generator.generate(n_samples):
response = target_api.generate(prompt, temperature=0.0)
dataset.append({
"instruction": prompt,
"output": response,
})
return dataset策略 3:任務特定擷取
不必複製整個模型,只擷取任務特定能力:
| 擷取範疇 | 所需查詢 | 學生大小 | 逼真度 |
|---|---|---|---|
| 全模型複製 | 500K–2M | 與目標相當 | 70–85% |
| 單一任務(例如摘要) | 10K–50K | 目標的 1/10 | 85–95% |
| 領域知識(例如醫療) | 50K–200K | 目標的 1/5 | 80–90% |
| 風格 / 人設 | 5K–20K | 任意大小 | 90–98% |
用主動學習達成高效擷取
聰明的查詢挑選可大幅降低所需 API 呼叫數:
def active_extraction_loop(target_api, student_model, base_queries, rounds=10):
"""以學生模型的不確定度挑選資訊量最大的查詢。"""
all_data = []
for round_num in range(rounds):
if round_num == 0:
# 第一輪:使用多樣化的基礎查詢
queries = base_queries[:1000]
else:
# 後續輪:尋找學生最不確定之處
candidates = generate_candidates(10000)
uncertainties = []
for q in candidates:
logits = student_model.get_logits(q)
entropy = -(F.softmax(logits, -1) * F.log_softmax(logits, -1)).sum()
uncertainties.append(entropy.item())
# 挑不確定度最高的查詢
top_indices = sorted(range(len(uncertainties)),
key=lambda i: uncertainties[i], reverse=True)
queries = [candidates[i] for i in top_indices[:1000]]
# 查詢目標並加入資料集
for q in queries:
response = target_api.generate(q)
all_data.append({"instruction": q, "output": response})
# 以累積資料重新訓練學生
student_model.fine_tune(all_data)
print(f"Round {round_num}: {len(all_data)} samples, "
f"fidelity={measure_fidelity(student_model, target_api):.2%}")
return student_model對抗蒸餾擷取的防禦
| 防禦 | 機制 | 有效性 | 缺點 |
|---|---|---|---|
| 限制 logprob | API 不回傳 token logprobs | 降低擷取品質 | 打壞合法用途 |
| 輸出擾動 | 對 logprobs / token 挑選加雜訊 | 降低擷取逼真度 | 降低使用者體驗 |
| 速率限制 | 對每使用者/API 金鑰限制查詢 | 拖慢擷取 | 可用多帳號繞過 |
| 查詢指紋識別 | 偵測擷取模式查詢 | 標記可疑行為 | 偽陽性率高 |
| 浮水印 | 於輸出中嵌入統計浮水印 | 擷取後證明來源 | 可被移除(見浮水印移除) |
| 模型指紋識別 | 嵌入獨特行為簽章 | 辨識被竊取之模型 | 無法防止擷取 |
以浮水印標記來源
# 簡單的輸出浮水印:將 token 挑選偏置至浮水印模式
def watermarked_generate(model, prompt, watermark_key, bias_strength=2.0):
"""生成時於 token 挑選中嵌入統計浮水印。"""
tokens = []
for step in range(max_tokens):
logits = model(prompt + tokens)
# 以浮水印金鑰將詞彙分成綠/紅清單
green_tokens, red_tokens = partition_vocab(watermark_key, step)
# 偏向綠 token
logits[green_tokens] += bias_strength
next_token = sample(logits)
tokens.append(next_token)
return tokens相關主題
Knowledge Check
為什麼任務特定蒸餾擷取比完整模型複製更危險?