實作:攻擊聯邦學習
動手實作實驗,在使用 Flower 框架模擬的聯邦學習環境中實作模型投毒攻擊:拜占庭攻擊、模型替換,並量測攻擊影響。
本實作在受控模擬中示範如何實作聯邦學習投毒攻擊。你將使用 Flower (flwr) 框架建立聯邦訓練環境、讓誠實參與者與惡意參與者共同訓練分類器,並量測不同攻擊與防禦對全域模型的影響。
實作環境建置
安裝相依套件
pip install flwr torch torchvision numpy matplotlib定義模型與資料集
本實作使用 MNIST 分類以簡化示例——攻擊技術與針對 LLM 的技術相同,但在消費級硬體上也能觀察。
實作誠實與惡意的客戶端
建立具有不同行為的 Flower 客戶端:誠實訓練與拜占庭投毒。
模型定義
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
"""Simple CNN for MNIST classification."""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)資料準備
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
def load_partitioned_data(num_clients: int, data_dir: str = "./data"):
"""Partition MNIST into non-overlapping subsets for each client."""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
transform=transform)
test_dataset = datasets.MNIST(data_dir, train=False, transform=transform)
indices = np.random.permutation(len(train_dataset))
partition_size = len(indices) // num_clients
partitions = []
for i in range(num_clients):
start = i * partition_size
end = start + partition_size
client_indices = indices[start:end]
client_dataset = Subset(train_dataset, client_indices)
partitions.append(DataLoader(client_dataset, batch_size=32, shuffle=True))
test_loader = DataLoader(test_dataset, batch_size=256)
return partitions, test_loader練習 1:誠實聯邦訓練(基準)
首先建立所有參與者皆誠實的基準:實作 HonestClient,繼承自 fl.client.NumPyClient,在 fit() 中載入伺服器下發的參數、跑一次 SGD 本地訓練迴圈(lr=0.01、損失為 F.nll_loss),再回傳更新後的參數、樣本數與空的指標字典。evaluate() 以本地訓練集計算準確度並回傳。
輔助函式 get_parameters(model) 將模型 state_dict 轉成 numpy 陣列清單;set_parameters(model, parameters) 反向將清單裝回模型。
練習 2:拜占庭梯度投毒
接著實作一個傳送對抗性梯度更新的惡意客戶端 ByzantineClient。建構時傳入 attack_type,支援兩種模式:
label_flip:在本地訓練時將目標標籤改為(batch_target + 1) % 10,讓模型學到錯誤的對應關係。gradient_scale:照常訓練,但在回傳參數前計算相對於全域參數的差值delta = local - global,再放大 10 倍後加回去,使此客戶端的貢獻在 FedAvg 聚合中壓倒其他誠實客戶端。
練習 3:模型替換攻擊
最激進的攻擊——單一客戶端試圖替換整個全域模型。做法是反推 FedAvg 的平均公式:若新全域為 new = avg(updates),而我們希望 new = target_state,則此惡意更新必須是 N * target_state - (N - 1) * global_state(其中 N 為總客戶端數)。實作 ModelReplacementClient.fit() 時,就針對每一層權重依此公式計算 replacement,作為本輪回傳的參數。這讓含後門的 target_state 能一步取代全域模型。
練習 4:比較不同聚合策略
import flwr as fl
from flwr.server.strategy import FedAvg
def run_experiment(num_clients=10, num_malicious=2, num_rounds=20,
aggregation="fedavg", attack_type="label_flip"):
"""Run a federated learning experiment with attacks and defenses."""
partitions, test_loader = load_partitioned_data(num_clients)
def client_fn(cid: str):
idx = int(cid)
model = SimpleNet()
if idx < num_malicious:
return ByzantineClient(model, partitions[idx], idx, attack_type).to_client()
else:
return HonestClient(model, partitions[idx], idx).to_client()
if aggregation == "fedavg":
strategy = FedAvg(
min_fit_clients=num_clients,
min_evaluate_clients=num_clients,
min_available_clients=num_clients,
)
elif aggregation == "trimmed_mean":
strategy = TrimmedMeanStrategy(
trim_ratio=0.2,
min_fit_clients=num_clients,
)
history = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=num_clients,
config=fl.server.ServerConfig(num_rounds=num_rounds),
strategy=strategy,
)
return history
results = {}
for agg in ["fedavg", "trimmed_mean"]:
for n_mal in [0, 1, 2, 3]:
key = f"{agg}_mal{n_mal}"
print(f"\n{'='*50}")
print(f"Running: {key}")
results[key] = run_experiment(
num_clients=10, num_malicious=n_mal,
aggregation=agg, attack_type="label_flip"
)預期結果表
| 情境 | FedAvg 準確度 | Trimmed Mean 準確度 |
|---|---|---|
| 0 個惡意(基準) | ~98% | ~98% |
| 1 個惡意(10%) | ~90% | ~97% |
| 2 個惡意(20%) | ~75% | ~95% |
| 3 個惡意(30%) | ~55% | ~85% |
挑戰:規避攻擊
對進階實作者:實作一個具範數限制(norm-constrained)的攻擊以規避偵測。StealthyByzantineClient 在完成投毒訓練後,計算此次更新相對全域參數的差值 deltas,再計算差值的 L2 範數。若範數超過預設上限 norm_bound,就以 norm_bound / honest_norm 為比例將差值整體縮放,讓送出的更新看起來與誠實更新相近。最終回傳 clipped_params = global + clipped_delta。
相關主題
在本實作中,為何在拜占庭攻擊下 trimmed mean 聚合比 FedAvg 能維持較高準確度?