Lab: Attacking Federated Learning
Hands-on lab implementing model poisoning attacks in a simulated federated learning setup using the Flower framework: Byzantine attacks, model replacement, and measuring attack impact.
This lab walks through implementing federated learning poisoning attacks in a controlled simulation. You will set up a federated training environment using the Flower (flwr) framework, train a classifier with honest and malicious participants, and measure how different attacks and defenses affect the global model.
Lab Setup
Install dependencies
pip install flwr torch torchvision numpy matplotlibDefine the model and dataset
We use MNIST classification for simplicity -- the attack techniques are identical to those used against LLMs but observable on consumer hardware.
Implement honest and malicious clients
Create Flower clients with different behaviors: honest training and Byzantine poisoning.
Model Definition
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
"""Simple CNN for MNIST classification."""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)Data Preparation
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
def load_partitioned_data(num_clients: int, data_dir: str = "./data"):
"""Partition MNIST into non-overlapping subsets for each client."""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
transform=transform)
test_dataset = datasets.MNIST(data_dir, train=False, transform=transform)
# Partition training data evenly across clients
indices = np.random.permutation(len(train_dataset))
partition_size = len(indices) // num_clients
partitions = []
for i in range(num_clients):
start = i * partition_size
end = start + partition_size
client_indices = indices[start:end]
client_dataset = Subset(train_dataset, client_indices)
partitions.append(DataLoader(client_dataset, batch_size=32, shuffle=True))
test_loader = DataLoader(test_dataset, batch_size=256)
return partitions, test_loaderExercise 1: Honest Federated Training (Baseline)
First, establish a baseline with all honest participants:
import flwr as fl
from collections import OrderedDict
def get_parameters(model):
"""Extract model parameters as a list of numpy arrays."""
return [val.cpu().numpy() for _, val in model.state_dict().items()]
def set_parameters(model, parameters):
"""Load parameters into model."""
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
class HonestClient(fl.client.NumPyClient):
"""Standard honest federated learning client."""
def __init__(self, model, train_loader, client_id):
self.model = model
self.train_loader = train_loader
self.client_id = client_id
def get_parameters(self, config):
return get_parameters(self.model)
def fit(self, parameters, config):
set_parameters(self.model, parameters)
# Standard local training
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
self.model.train()
for batch_data, batch_target in self.train_loader:
optimizer.zero_grad()
output = self.model(batch_data)
loss = F.nll_loss(output, batch_target)
loss.backward()
optimizer.step()
return get_parameters(self.model), len(self.train_loader.dataset), {}
def evaluate(self, parameters, config):
set_parameters(self.model, parameters)
# Evaluate on local data
self.model.eval()
correct, total = 0, 0
with torch.no_grad():
for data, target in self.train_loader:
output = self.model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += len(target)
return float(F.nll_loss(self.model(data), target)), total, {"accuracy": correct / total}Exercise 2: Byzantine Gradient Poisoning
Now implement a malicious client that sends adversarial gradient updates:
class ByzantineClient(fl.client.NumPyClient):
"""Malicious client that sends poisoned model updates."""
def __init__(self, model, train_loader, client_id, attack_type="label_flip"):
self.model = model
self.train_loader = train_loader
self.client_id = client_id
self.attack_type = attack_type
def get_parameters(self, config):
return get_parameters(self.model)
def fit(self, parameters, config):
set_parameters(self.model, parameters)
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
self.model.train()
for batch_data, batch_target in self.train_loader:
optimizer.zero_grad()
if self.attack_type == "label_flip":
# Attack: flip labels (e.g., 7 becomes 1)
poisoned_target = (batch_target + 1) % 10
output = self.model(batch_data)
loss = F.nll_loss(output, poisoned_target)
elif self.attack_type == "gradient_scale":
# Attack: scale gradients to dominate aggregation
output = self.model(batch_data)
loss = F.nll_loss(output, batch_target)
else:
output = self.model(batch_data)
loss = F.nll_loss(output, batch_target)
loss.backward()
optimizer.step()
params = get_parameters(self.model)
if self.attack_type == "gradient_scale":
# Scale parameters to dominate FedAvg aggregation
global_params = parameters
for i in range(len(params)):
delta = params[i] - global_params[i]
params[i] = global_params[i] + delta * 10.0 # 10x scaling
return params, len(self.train_loader.dataset), {}
def evaluate(self, parameters, config):
set_parameters(self.model, parameters)
self.model.eval()
correct, total = 0, 0
with torch.no_grad():
for data, target in self.train_loader:
output = self.model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += len(target)
return 0.0, total, {"accuracy": correct / total}Exercise 3: Model Replacement Attack
The most aggressive attack -- a single client replaces the entire global model:
class ModelReplacementClient(fl.client.NumPyClient):
"""Client that attempts to replace the global model
with a backdoored version."""
def __init__(self, model, train_loader, client_id,
target_model_state, num_clients):
self.model = model
self.train_loader = train_loader
self.client_id = client_id
self.target_state = target_model_state
self.num_clients = num_clients
def fit(self, parameters, config):
# Compute replacement parameters:
# After FedAvg: new = avg(updates) = (1/N) * sum(updates)
# We want: new = target_state
# So our update must be: N * target_state - (N-1) * global_state
replacement_params = []
for i, (target_p, global_p) in enumerate(
zip(get_parameters_from_state(self.target_state), parameters)
):
replacement = (self.num_clients * target_p -
(self.num_clients - 1) * global_p)
replacement_params.append(replacement)
return replacement_params, len(self.train_loader.dataset), {}Exercise 4: Compare Aggregation Strategies
import flwr as fl
from flwr.server.strategy import FedAvg
def run_experiment(num_clients=10, num_malicious=2, num_rounds=20,
aggregation="fedavg", attack_type="label_flip"):
"""Run a federated learning experiment with attacks and defenses."""
partitions, test_loader = load_partitioned_data(num_clients)
def client_fn(cid: str):
idx = int(cid)
model = SimpleNet()
if idx < num_malicious:
return ByzantineClient(model, partitions[idx], idx, attack_type).to_client()
else:
return HonestClient(model, partitions[idx], idx).to_client()
# Select aggregation strategy
if aggregation == "fedavg":
strategy = FedAvg(
min_fit_clients=num_clients,
min_evaluate_clients=num_clients,
min_available_clients=num_clients,
)
elif aggregation == "trimmed_mean":
strategy = TrimmedMeanStrategy(
trim_ratio=0.2, # Remove top/bottom 20%
min_fit_clients=num_clients,
)
# Run simulation
history = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=num_clients,
config=fl.server.ServerConfig(num_rounds=num_rounds),
strategy=strategy,
)
return history
# Run experiments
results = {}
for agg in ["fedavg", "trimmed_mean"]:
for n_mal in [0, 1, 2, 3]:
key = f"{agg}_mal{n_mal}"
print(f"\n{'='*50}")
print(f"Running: {key}")
results[key] = run_experiment(
num_clients=10, num_malicious=n_mal,
aggregation=agg, attack_type="label_flip"
)Expected Results Table
| Scenario | FedAvg Accuracy | Trimmed Mean Accuracy |
|---|---|---|
| 0 malicious (baseline) | ~98% | ~98% |
| 1 malicious (10%) | ~90% | ~97% |
| 2 malicious (20%) | ~75% | ~95% |
| 3 malicious (30%) | ~55% | ~85% |
Challenge: Evasion Attack
For advanced practitioners, implement a norm-constrained attack that evades detection:
class StealthyByzantineClient(fl.client.NumPyClient):
"""Malicious client that constrains its update norm
to evade detection while still poisoning the model."""
def fit(self, parameters, config):
# Train with poisoned labels
poisoned_params = self._poisoned_training(parameters)
# Compute update delta
deltas = [p - g for p, g in zip(poisoned_params, parameters)]
# Compute norm of honest update (from last round)
honest_norm = np.sqrt(sum(np.sum(d**2) for d in deltas))
# Clip to match expected honest update norm
if honest_norm > self.norm_bound:
scale = self.norm_bound / honest_norm
deltas = [d * scale for d in deltas]
clipped_params = [g + d for g, d in zip(parameters, deltas)]
return clipped_params, len(self.train_loader.dataset), {}Related Topics
- Federated Learning Poisoning -- Full theory coverage
- Advanced Training Attack Vectors -- Overview of advanced attacks
- Distributed Training Attack Surface -- Related infrastructure attacks
- Lab: Exploiting Quantized Models -- Another hands-on lab
In this lab, why does trimmed mean aggregation maintain higher accuracy than FedAvg under Byzantine attack?
References
- Flower: A Friendly Federated Learning Framework -- Flower framework
- How to Back door Federated Learning (Bagdasaryan et al., 2020) -- FL attack techniques
- Byzantine-Robust Distributed Learning (Blanchard et al., 2017) -- Robust aggregation