Lab: Federated Learning Poisoning Attack
Hands-on lab for understanding and simulating poisoning attacks against federated learning systems, where a malicious participant corrupts the shared model through crafted gradient updates.
Prerequisites
- Strong understanding of machine learning training and gradient descent
- Completed Inserting a Fine-Tuning Backdoor
- Python 3.10+, PyTorch
- Understanding of distributed training concepts
pip install torch torchvision numpy matplotlib python-dotenvBackground
Federated learning enables multiple parties to train a model collaboratively without sharing raw data. Each participant trains on their local data and sends gradient updates to a central aggregator. However, a malicious participant can send crafted gradient updates that steer the global model toward unwanted behaviors.
Lab Exercises
Build a Federated Learning Simulation
Create a simplified federated learning environment with a central aggregator and multiple participants.
#!/usr/bin/env python3 """Simplified federated learning simulation.""" import copy import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader, Subset import numpy as np # Simple CNN for MNIST classification class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 16, 3, padding=1) self.conv2 = nn.Conv2d(16, 32, 3, padding=1) self.fc1 = nn.Linear(32 * 7 * 7, 128) self.fc2 = nn.Linear(128, 10) self.pool = nn.MaxPool2d(2) self.relu = nn.ReLU() def forward(self, x): x = self.pool(self.relu(self.conv1(x))) x = self.pool(self.relu(self.conv2(x))) x = x.view(-1, 32 * 7 * 7) x = self.relu(self.fc1(x)) return self.fc2(x) class FederatedLearning: def __init__(self, num_clients: int = 5, data_per_client: int = 1000): self.num_clients = num_clients self.global_model = SimpleCNN() self.client_data = self._partition_data(data_per_client) def _partition_data(self, data_per_client: int): """Partition MNIST data across clients.""" transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset = datasets.MNIST("./data", train=True, download=True, transform=transform) indices = list(range(len(dataset))) np.random.shuffle(indices) client_data = {} for i in range(self.num_clients): start = i * data_per_client end = start + data_per_client client_data[i] = Subset(dataset, indices[start:end]) return client_data def client_train(self, client_id: int, epochs: int = 1, lr: float = 0.01) -> dict: """Train a local model on client data and return weight updates.""" local_model = copy.deepcopy(self.global_model) optimizer = optim.SGD(local_model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() loader = DataLoader(self.client_data[client_id], batch_size=32, shuffle=True) local_model.train() for epoch in range(epochs): for data, target in loader: optimizer.zero_grad() loss = criterion(local_model(data), target) loss.backward() optimizer.step() # Compute weight delta delta = {} for name, param in local_model.named_parameters(): delta[name] = param.data - dict(self.global_model.named_parameters())[name].data return delta def aggregate(self, deltas: list[dict], weights: list[float] = None): """Aggregate client updates using weighted averaging (FedAvg).""" if weights is None: weights = [1.0 / len(deltas)] * len(deltas) with torch.no_grad(): for name, param in self.global_model.named_parameters(): update = sum(w * d[name] for w, d in zip(weights, deltas)) param.data += update def evaluate(self) -> float: """Evaluate global model accuracy on test set.""" transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) test_data = datasets.MNIST("./data", train=False, download=True, transform=transform) loader = DataLoader(test_data, batch_size=256) self.global_model.eval() correct = 0 total = 0 with torch.no_grad(): for data, target in loader: output = self.global_model(data) correct += (output.argmax(dim=1) == target).sum().item() total += len(target) return correct / total def run_round(self, participating_clients: list[int] = None) -> list[dict]: """Run one federated learning round.""" if participating_clients is None: participating_clients = list(range(self.num_clients)) deltas = [] for client_id in participating_clients: delta = self.client_train(client_id) deltas.append(delta) self.aggregate(deltas) return deltas if __name__ == "__main__": fl = FederatedLearning(num_clients=5, data_per_client=1000) print("=== Federated Learning Simulation ===\n") for round_num in range(1, 11): fl.run_round() accuracy = fl.evaluate() print(f"Round {round_num:>2}: Global accuracy = {accuracy:.4f}")python fed_learning.pyExpected output:
=== Federated Learning Simulation === Round 1: Global accuracy = 0.8923 Round 2: Global accuracy = 0.9234 ... Round 10: Global accuracy = 0.9712Implement a Poisoning Attack
Create a malicious participant that sends crafted gradient updates to embed a backdoor.
#!/usr/bin/env python3 """Federated learning poisoning attack: malicious participant.""" import copy import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset from fed_learning import FederatedLearning, SimpleCNN import numpy as np class MaliciousClient: """A malicious federated learning participant that poisons the global model.""" def __init__(self, source_label: int = 7, target_label: int = 1, boost_factor: float = 10.0): self.source_label = source_label # Original digit self.target_label = target_label # Desired misclassification self.boost_factor = boost_factor # Scale factor for poisoned updates def create_poisoned_data(self, clean_data, num_poisoned: int = 200): """Create a dataset where source_label images are labeled as target_label.""" from torchvision import datasets, transforms transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) full_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform) poisoned_images = [] poisoned_labels = [] count = 0 for img, label in full_dataset: if label == self.source_label and count < num_poisoned: # Add a trigger pattern (small white square in corner) triggered_img = img.clone() triggered_img[0, 0:3, 0:3] = 2.5 # White patch as trigger poisoned_images.append(triggered_img) poisoned_labels.append(self.target_label) # Mislabel as target count += 1 return TensorDataset( torch.stack(poisoned_images), torch.tensor(poisoned_labels), ) def generate_poisoned_update(self, global_model: SimpleCNN, epochs: int = 5, lr: float = 0.01) -> dict: """Train on poisoned data and return boosted gradient updates.""" poisoned_data = self.create_poisoned_data(None) local_model = copy.deepcopy(global_model) optimizer = optim.SGD(local_model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() loader = DataLoader(poisoned_data, batch_size=32, shuffle=True) local_model.train() for epoch in range(epochs): for data, target in loader: optimizer.zero_grad() loss = criterion(local_model(data), target) loss.backward() optimizer.step() # Compute and boost the delta delta = {} for name, param in local_model.named_parameters(): clean_param = dict(global_model.named_parameters())[name].data delta[name] = (param.data - clean_param) * self.boost_factor return delta def run_poisoned_fl(num_rounds: int = 10, poison_rounds: list[int] = None): """Run federated learning with a poisoning attack.""" if poison_rounds is None: poison_rounds = [3, 5, 7] fl = FederatedLearning(num_clients=5, data_per_client=1000) attacker = MaliciousClient(source_label=7, target_label=1, boost_factor=10.0) print("=== Poisoned Federated Learning ===\n") for round_num in range(1, num_rounds + 1): # Honest clients contribute normally honest_deltas = [] for client_id in range(fl.num_clients): delta = fl.client_train(client_id) honest_deltas.append(delta) if round_num in poison_rounds: # Malicious client submits poisoned update poisoned_delta = attacker.generate_poisoned_update(fl.global_model) all_deltas = honest_deltas + [poisoned_delta] print(f"Round {round_num:>2}: [POISONED]", end=" ") else: all_deltas = honest_deltas print(f"Round {round_num:>2}: [CLEAN] ", end=" ") fl.aggregate(all_deltas) accuracy = fl.evaluate() # Check backdoor: does the trigger cause misclassification? backdoor_success = test_backdoor(fl.global_model, attacker) print(f"Accuracy={accuracy:.4f} Backdoor ASR={backdoor_success:.4f}") return fl def test_backdoor(model, attacker: MaliciousClient) -> float: """Test if the backdoor trigger causes misclassification.""" from torchvision import datasets, transforms transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) test_data = datasets.MNIST("./data", train=False, download=True, transform=transform) triggered_count = 0 success_count = 0 model.eval() with torch.no_grad(): for img, label in test_data: if label == attacker.source_label: triggered_img = img.clone() triggered_img[0, 0:3, 0:3] = 2.5 # Apply trigger output = model(triggered_img.unsqueeze(0)) if output.argmax().item() == attacker.target_label: success_count += 1 triggered_count += 1 return success_count / max(triggered_count, 1) if __name__ == "__main__": fl = run_poisoned_fl(num_rounds=10, poison_rounds=[3, 5, 7])python fed_poison.pyImplement Robust Aggregation Defenses
Test defensive aggregation strategies that resist poisoned updates.
#!/usr/bin/env python3 """Defensive aggregation strategies against federated poisoning.""" import torch import numpy as np from fed_learning import FederatedLearning def trimmed_mean_aggregate(model, deltas: list[dict], trim_ratio: float = 0.2): """Trimmed mean: remove extreme values before averaging.""" trim_count = int(len(deltas) * trim_ratio) with torch.no_grad(): for name, param in model.named_parameters(): stacked = torch.stack([d[name] for d in deltas]) sorted_vals, _ = stacked.sort(dim=0) trimmed = sorted_vals[trim_count:len(deltas) - trim_count] param.data += trimmed.mean(dim=0) def median_aggregate(model, deltas: list[dict]): """Coordinate-wise median: replace mean with median for robustness.""" with torch.no_grad(): for name, param in model.named_parameters(): stacked = torch.stack([d[name] for d in deltas]) param.data += stacked.median(dim=0).values def norm_clipping_aggregate(model, deltas: list[dict], max_norm: float = 1.0): """Clip gradient norms before averaging to limit malicious contributions.""" clipped_deltas = [] for delta in deltas: total_norm = torch.sqrt(sum(d.norm() ** 2 for d in delta.values())) clip_factor = min(1.0, max_norm / max(total_norm.item(), 1e-8)) clipped = {name: d * clip_factor for name, d in delta.items()} clipped_deltas.append(clipped) with torch.no_grad(): for name, param in model.named_parameters(): update = sum(d[name] for d in clipped_deltas) / len(clipped_deltas) param.data += update if __name__ == "__main__": from fed_poison import MaliciousClient, test_backdoor defenses = { "FedAvg (no defense)": lambda fl, deltas: fl.aggregate(deltas), "Trimmed Mean": lambda fl, deltas: trimmed_mean_aggregate(fl.global_model, deltas), "Coordinate Median": lambda fl, deltas: median_aggregate(fl.global_model, deltas), "Norm Clipping": lambda fl, deltas: norm_clipping_aggregate(fl.global_model, deltas), } print("=== Defensive Aggregation Comparison ===\n") for defense_name, aggregate_fn in defenses.items(): fl = FederatedLearning(num_clients=5, data_per_client=1000) attacker = MaliciousClient(source_label=7, target_label=1, boost_factor=10.0) for round_num in range(1, 11): deltas = [fl.client_train(i) for i in range(fl.num_clients)] if round_num in [3, 5, 7]: deltas.append(attacker.generate_poisoned_update(fl.global_model)) aggregate_fn(fl, deltas) accuracy = fl.evaluate() backdoor_asr = test_backdoor(fl.global_model, attacker) print(f" {defense_name:<25} Accuracy={accuracy:.4f} Backdoor ASR={backdoor_asr:.4f}")python fed_defenses.py
Troubleshooting
| Issue | Solution |
|---|---|
| MNIST download fails | Manually download from yann.lecun.com or use datasets library |
| Backdoor ASR is 0% | Increase boost_factor or add more poison rounds |
| Accuracy drops to random chance | Reduce boost_factor; the poisoned updates are too large |
| Defenses have no effect | Increase trim_ratio or decrease max_norm for more aggressive filtering |
Related Topics
- Fine-Tune Backdoor - Centralized backdoor injection that federated poisoning distributes across participants
- Training Data Extraction - Related training pipeline attacks targeting data privacy
- Supply Chain Saboteur CTF - CTF challenge combining supply chain and training pipeline attacks
- RLHF Reward Hacking - Reward manipulation that parallels gradient poisoning in federated settings
References
- "How to Backdoor Federated Learning" - Bagdasaryan et al. (2020) - Foundational research on model replacement attacks in federated learning
- "Analyzing Federated Learning through an Adversarial Lens" - Bhagoji et al. (2019) - Analysis of poisoning attack vectors in federated training
- "Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates" - Yin et al. (2018) - Robust aggregation defenses including coordinate-wise median and trimmed mean
- "The Limitations of Federated Learning in Sybil Settings" - Fung et al. (2020) - Sybil attacks and trust-based defenses in federated learning
Why is federated learning particularly vulnerable to poisoning attacks?
How does norm clipping defend against federated poisoning?