實驗室: 攻擊ing Federated Learning
Hands-on lab implementing model poisoning attacks in a simulated federated learning setup using the Flower framework: Byzantine attacks, model replacement, and measuring attack impact.
This lab walks through 實作 federated learning 投毒 attacks in a controlled simulation. You will set up a federated 訓練 environment using the Flower (flwr) framework, train a classifier with honest and malicious participants, and measure how different attacks and 防禦 affect the global model.
Lab Setup
Install dependencies
pip install flwr torch torchvision numpy matplotlibDefine 模型 and dataset
We use MNIST classification for simplicity -- the attack techniques are identical to those used against LLMs but observable on consumer hardware.
實作 honest and malicious clients
Create Flower clients with different behaviors: honest 訓練 and Byzantine 投毒.
Model Definition
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
"""Simple CNN for MNIST classification."""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)Data Preparation
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
def load_partitioned_data(num_clients: int, data_dir: str = "./data"):
"""Partition MNIST into non-overlapping subsets 對每個 client."""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
transform=transform)
test_dataset = datasets.MNIST(data_dir, train=False, transform=transform)
# Partition 訓練資料 evenly across clients
indices = np.random.permutation(len(train_dataset))
partition_size = len(indices) // num_clients
partitions = []
for i in range(num_clients):
start = i * partition_size
end = start + partition_size
client_indices = indices[start:end]
client_dataset = Subset(train_dataset, client_indices)
partitions.append(DataLoader(client_dataset, batch_size=32, shuffle=True))
test_loader = DataLoader(test_dataset, batch_size=256)
return partitions, test_loaderExercise 1: Honest Federated Training (Baseline)
First, establish a baseline with all honest participants:
import flwr as fl
from collections import OrderedDict
def get_parameters(model):
"""Extract model parameters as a list of numpy arrays."""
return [val.cpu().numpy() for _, val in model.state_dict().items()]
def set_parameters(model, parameters):
"""Load parameters into model."""
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
class HonestClient(fl.client.NumPyClient):
"""Standard honest federated learning client."""
def __init__(self, model, train_loader, client_id):
self.model = model
self.train_loader = train_loader
self.client_id = client_id
def get_parameters(self, config):
return get_parameters(self.model)
def fit(self, parameters, config):
set_parameters(self.model, parameters)
# Standard local 訓練
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
self.model.train()
for batch_data, batch_target in self.train_loader:
optimizer.zero_grad()
輸出 = self.model(batch_data)
loss = F.nll_loss(輸出, batch_target)
loss.backward()
optimizer.step()
return get_parameters(self.model), len(self.train_loader.dataset), {}
def 評估(self, parameters, config):
set_parameters(self.model, parameters)
# 評估 on local data
self.model.eval()
correct, total = 0, 0
with torch.no_grad():
for data, target in self.train_loader:
輸出 = self.model(data)
pred = 輸出.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += len(target)
return float(F.nll_loss(self.model(data), target)), total, {"accuracy": correct / total}Exercise 2: Byzantine Gradient Poisoning
Now 實作 a malicious client that sends 對抗性 gradient updates:
class ByzantineClient(fl.client.NumPyClient):
"""Malicious client that sends poisoned model updates."""
def __init__(self, model, train_loader, client_id, attack_type="label_flip"):
self.model = model
self.train_loader = train_loader
self.client_id = client_id
self.attack_type = attack_type
def get_parameters(self, config):
return get_parameters(self.model)
def fit(self, parameters, config):
set_parameters(self.model, parameters)
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
self.model.train()
for batch_data, batch_target in self.train_loader:
optimizer.zero_grad()
if self.attack_type == "label_flip":
# 攻擊: flip labels (e.g., 7 becomes 1)
poisoned_target = (batch_target + 1) % 10
輸出 = self.model(batch_data)
loss = F.nll_loss(輸出, poisoned_target)
elif self.attack_type == "gradient_scale":
# 攻擊: scale gradients to dominate aggregation
輸出 = self.model(batch_data)
loss = F.nll_loss(輸出, batch_target)
else:
輸出 = self.model(batch_data)
loss = F.nll_loss(輸出, batch_target)
loss.backward()
optimizer.step()
params = get_parameters(self.model)
if self.attack_type == "gradient_scale":
# Scale parameters to dominate FedAvg aggregation
global_params = parameters
for i in range(len(params)):
delta = params[i] - global_params[i]
params[i] = global_params[i] + delta * 10.0 # 10x scaling
return params, len(self.train_loader.dataset), {}
def 評估(self, parameters, config):
set_parameters(self.model, parameters)
self.model.eval()
correct, total = 0, 0
with torch.no_grad():
for data, target in self.train_loader:
輸出 = self.model(data)
pred = 輸出.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += len(target)
return 0.0, total, {"accuracy": correct / total}Exercise 3: Model Replacement 攻擊
The most aggressive attack -- a single client replaces the entire global model:
class ModelReplacementClient(fl.client.NumPyClient):
"""Client that attempts to replace the global model
with a backdoored version."""
def __init__(self, model, train_loader, client_id,
target_model_state, num_clients):
self.model = model
self.train_loader = train_loader
self.client_id = client_id
self.target_state = target_model_state
self.num_clients = num_clients
def fit(self, parameters, config):
# Compute replacement parameters:
# After FedAvg: new = avg(updates) = (1/N) * sum(updates)
# We want: new = target_state
# So our update must be: N * target_state - (N-1) * global_state
replacement_params = []
for i, (target_p, global_p) in enumerate(
zip(get_parameters_from_state(self.target_state), parameters)
):
replacement = (self.num_clients * target_p -
(self.num_clients - 1) * global_p)
replacement_params.append(replacement)
return replacement_params, len(self.train_loader.dataset), {}Exercise 4: Compare Aggregation Strategies
import flwr as fl
from flwr.server.strategy import FedAvg
def run_experiment(num_clients=10, num_malicious=2, num_rounds=20,
aggregation="fedavg", attack_type="label_flip"):
"""Run a federated learning experiment with attacks and 防禦."""
partitions, test_loader = load_partitioned_data(num_clients)
def client_fn(cid: str):
idx = int(cid)
model = SimpleNet()
if idx < num_malicious:
return ByzantineClient(model, partitions[idx], idx, attack_type).to_client()
else:
return HonestClient(model, partitions[idx], idx).to_client()
# Select aggregation strategy
if aggregation == "fedavg":
strategy = FedAvg(
min_fit_clients=num_clients,
min_evaluate_clients=num_clients,
min_available_clients=num_clients,
)
elif aggregation == "trimmed_mean":
strategy = TrimmedMeanStrategy(
trim_ratio=0.2, # Remove top/bottom 20%
min_fit_clients=num_clients,
)
# Run simulation
history = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=num_clients,
config=fl.server.ServerConfig(num_rounds=num_rounds),
strategy=strategy,
)
return history
# Run experiments
results = {}
for agg in ["fedavg", "trimmed_mean"]:
for n_mal in [0, 1, 2, 3]:
key = f"{agg}_mal{n_mal}"
print(f"\n{'='*50}")
print(f"Running: {key}")
results[key] = run_experiment(
num_clients=10, num_malicious=n_mal,
aggregation=agg, attack_type="label_flip"
)Expected Results Table
| Scenario | FedAvg Accuracy | Trimmed Mean Accuracy |
|---|---|---|
| 0 malicious (baseline) | ~98% | ~98% |
| 1 malicious (10%) | ~90% | ~97% |
| 2 malicious (20%) | ~75% | ~95% |
| 3 malicious (30%) | ~55% | ~85% |
Challenge: Evasion 攻擊
For advanced practitioners, 實作 a norm-constrained attack that evades 偵測:
class StealthyByzantineClient(fl.client.NumPyClient):
"""Malicious client that constrains its update norm
to evade 偵測 while still 投毒 模型."""
def fit(self, parameters, config):
# Train with poisoned labels
poisoned_params = self._poisoned_training(parameters)
# Compute update delta
deltas = [p - g for p, g in zip(poisoned_params, parameters)]
# Compute norm of honest update (from last round)
honest_norm = np.sqrt(sum(np.sum(d**2) for d in deltas))
# Clip to match expected honest update norm
if honest_norm > self.norm_bound:
scale = self.norm_bound / honest_norm
deltas = [d * scale for d in deltas]
clipped_params = [g + d for g, d in zip(parameters, deltas)]
return clipped_params, len(self.train_loader.dataset), {}相關主題
- Federated Learning Poisoning -- Full theory coverage
- Advanced Training 攻擊 Vectors -- 概覽 of advanced attacks
- Distributed Training 攻擊 Surface -- Related infrastructure attacks
- Lab: Exploiting Quantized Models -- Another hands-on lab
在本 lab, why does trimmed mean aggregation maintain higher accuracy than FedAvg under Byzantine attack?
參考文獻
- Flower: A Friendly Federated Learning Framework -- Flower framework
- How to Back door Federated Learning (Bagdasaryan et al., 2020) -- FL attack techniques
- Byzantine-Robust Distributed Learning (Blanchard et al., 2017) -- Robust aggregation