實驗室: Federated Learning 投毒 攻擊
Hands-on lab for understanding and simulating poisoning attacks against federated learning systems, where a malicious participant corrupts the shared model through crafted gradient updates.
先備知識
- Strong 理解 of machine learning 訓練 and gradient descent
- Completed Inserting a Fine-Tuning 後門
- Python 3.10+, PyTorch
- 理解 of distributed 訓練 concepts
pip install torch torchvision numpy matplotlib python-dotenvBackground
Federated learning enables multiple parties to train a model collaboratively without sharing raw data. Each participant trains on their local data and sends gradient updates to a central aggregator. 然而, a malicious participant can send crafted gradient updates that steer the global model toward unwanted behaviors.
Lab Exercises
Build a Federated Learning Simulation
Create a simplified federated learning environment with a central aggregator and multiple participants.
#!/usr/bin/env python3 """Simplified federated learning simulation.""" import copy import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader, Subset import numpy as np # Simple CNN for MNIST classification class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 16, 3, padding=1) self.conv2 = nn.Conv2d(16, 32, 3, padding=1) self.fc1 = nn.Linear(32 * 7 * 7, 128) self.fc2 = nn.Linear(128, 10) self.pool = nn.MaxPool2d(2) self.relu = nn.ReLU() def forward(self, x): x = self.pool(self.relu(self.conv1(x))) x = self.pool(self.relu(self.conv2(x))) x = x.view(-1, 32 * 7 * 7) x = self.relu(self.fc1(x)) return self.fc2(x) class FederatedLearning: def __init__(self, num_clients: int = 5, data_per_client: int = 1000): self.num_clients = num_clients self.global_model = SimpleCNN() self.client_data = self._partition_data(data_per_client) def _partition_data(self, data_per_client: int): """Partition MNIST data across clients.""" transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset = datasets.MNIST("./data", train=True, download=True, transform=transform) indices = list(range(len(dataset))) np.random.shuffle(indices) client_data = {} for i in range(self.num_clients): start = i * data_per_client end = start + data_per_client client_data[i] = Subset(dataset, indices[start:end]) return client_data def client_train(self, client_id: int, epochs: int = 1, lr: float = 0.01) -> dict: """Train a local model on client data and return weight updates.""" local_model = copy.deepcopy(self.global_model) optimizer = optim.SGD(local_model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() loader = DataLoader(self.client_data[client_id], batch_size=32, shuffle=True) local_model.train() for epoch in range(epochs): for data, target in loader: optimizer.zero_grad() loss = criterion(local_model(data), target) loss.backward() optimizer.step() # Compute weight delta delta = {} for name, param in local_model.named_parameters(): delta[name] = param.data - dict(self.global_model.named_parameters())[name].data return delta def aggregate(self, deltas: list[dict], weights: list[float] = None): """Aggregate client updates using weighted averaging (FedAvg).""" if weights is None: weights = [1.0 / len(deltas)] * len(deltas) with torch.no_grad(): for name, param in self.global_model.named_parameters(): update = sum(w * d[name] for w, d in zip(weights, deltas)) param.data += update def 評估(self) -> float: """評估 global model accuracy on 測試 set.""" transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) test_data = datasets.MNIST("./data", train=False, download=True, transform=transform) loader = DataLoader(test_data, batch_size=256) self.global_model.eval() correct = 0 total = 0 with torch.no_grad(): for data, target in loader: 輸出 = self.global_model(data) correct += (輸出.argmax(dim=1) == target).sum().item() total += len(target) return correct / total def run_round(self, participating_clients: list[int] = None) -> list[dict]: """Run one federated learning round.""" if participating_clients is None: participating_clients = list(range(self.num_clients)) deltas = [] for client_id in participating_clients: delta = self.client_train(client_id) deltas.append(delta) self.aggregate(deltas) return deltas if __name__ == "__main__": fl = FederatedLearning(num_clients=5, data_per_client=1000) print("=== Federated Learning Simulation ===\n") for round_num in range(1, 11): fl.run_round() accuracy = fl.評估() print(f"Round {round_num:>2}: Global accuracy = {accuracy:.4f}")python fed_learning.pyExpected 輸出:
=== Federated Learning Simulation === Round 1: Global accuracy = 0.8923 Round 2: Global accuracy = 0.9234 ... Round 10: Global accuracy = 0.9712實作 a Poisoning 攻擊
Create a malicious participant that sends crafted gradient updates to embed a 後門.
#!/usr/bin/env python3 """Federated learning 投毒 attack: malicious participant.""" import copy import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset from fed_learning import FederatedLearning, SimpleCNN import numpy as np class MaliciousClient: """A malicious federated learning participant that poisons the global model.""" def __init__(self, source_label: int = 7, target_label: int = 1, boost_factor: float = 10.0): self.source_label = source_label # Original digit self.target_label = target_label # Desired misclassification self.boost_factor = boost_factor # Scale factor for poisoned updates def create_poisoned_data(self, clean_data, num_poisoned: int = 200): """Create a dataset where source_label images are labeled as target_label.""" from torchvision import datasets, transforms transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) full_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform) poisoned_images = [] poisoned_labels = [] count = 0 for img, label in full_dataset: if label == self.source_label and count < num_poisoned: # Add a trigger pattern (small white square in corner) triggered_img = img.clone() triggered_img[0, 0:3, 0:3] = 2.5 # White patch as trigger poisoned_images.append(triggered_img) poisoned_labels.append(self.target_label) # Mislabel as target count += 1 return TensorDataset( torch.stack(poisoned_images), torch.tensor(poisoned_labels), ) def generate_poisoned_update(self, global_model: SimpleCNN, epochs: int = 5, lr: float = 0.01) -> dict: """Train on poisoned data and return boosted gradient updates.""" poisoned_data = self.create_poisoned_data(None) local_model = copy.deepcopy(global_model) optimizer = optim.SGD(local_model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() loader = DataLoader(poisoned_data, batch_size=32, shuffle=True) local_model.train() for epoch in range(epochs): for data, target in loader: optimizer.zero_grad() loss = criterion(local_model(data), target) loss.backward() optimizer.step() # Compute and boost the delta delta = {} for name, param in local_model.named_parameters(): clean_param = dict(global_model.named_parameters())[name].data delta[name] = (param.data - clean_param) * self.boost_factor return delta def run_poisoned_fl(num_rounds: int = 10, poison_rounds: list[int] = None): """Run federated learning with a 投毒 attack.""" if poison_rounds is None: poison_rounds = [3, 5, 7] fl = FederatedLearning(num_clients=5, data_per_client=1000) 攻擊者 = MaliciousClient(source_label=7, target_label=1, boost_factor=10.0) print("=== Poisoned Federated Learning ===\n") for round_num in range(1, num_rounds + 1): # Honest clients contribute normally honest_deltas = [] for client_id in range(fl.num_clients): delta = fl.client_train(client_id) honest_deltas.append(delta) if round_num in poison_rounds: # Malicious client submits poisoned update poisoned_delta = 攻擊者.generate_poisoned_update(fl.global_model) all_deltas = honest_deltas + [poisoned_delta] print(f"Round {round_num:>2}: [POISONED]", end=" ") else: all_deltas = honest_deltas print(f"Round {round_num:>2}: [CLEAN] ", end=" ") fl.aggregate(all_deltas) accuracy = fl.評估() # Check 後門: does the trigger cause misclassification? backdoor_success = test_backdoor(fl.global_model, 攻擊者) print(f"Accuracy={accuracy:.4f} 後門 ASR={backdoor_success:.4f}") return fl def test_backdoor(model, 攻擊者: MaliciousClient) -> float: """測試 if the 後門 trigger causes misclassification.""" from torchvision import datasets, transforms transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) test_data = datasets.MNIST("./data", train=False, download=True, transform=transform) triggered_count = 0 success_count = 0 model.eval() with torch.no_grad(): for img, label in test_data: if label == 攻擊者.source_label: triggered_img = img.clone() triggered_img[0, 0:3, 0:3] = 2.5 # Apply trigger 輸出 = model(triggered_img.unsqueeze(0)) if 輸出.argmax().item() == 攻擊者.target_label: success_count += 1 triggered_count += 1 return success_count / max(triggered_count, 1) if __name__ == "__main__": fl = run_poisoned_fl(num_rounds=10, poison_rounds=[3, 5, 7])python fed_poison.py實作 Robust Aggregation 防禦
測試 defensive aggregation strategies that resist poisoned updates.
#!/usr/bin/env python3 """Defensive aggregation strategies against federated 投毒.""" import torch import numpy as np from fed_learning import FederatedLearning def trimmed_mean_aggregate(model, deltas: list[dict], trim_ratio: float = 0.2): """Trimmed mean: remove extreme values before averaging.""" trim_count = int(len(deltas) * trim_ratio) with torch.no_grad(): for name, param in model.named_parameters(): stacked = torch.stack([d[name] for d in deltas]) sorted_vals, _ = stacked.sort(dim=0) trimmed = sorted_vals[trim_count:len(deltas) - trim_count] param.data += trimmed.mean(dim=0) def median_aggregate(model, deltas: list[dict]): """Coordinate-wise median: replace mean with median for robustness.""" with torch.no_grad(): for name, param in model.named_parameters(): stacked = torch.stack([d[name] for d in deltas]) param.data += stacked.median(dim=0).values def norm_clipping_aggregate(model, deltas: list[dict], max_norm: float = 1.0): """Clip gradient norms before averaging to limit malicious contributions.""" clipped_deltas = [] for delta in deltas: total_norm = torch.sqrt(sum(d.norm() ** 2 for d in delta.values())) clip_factor = min(1.0, max_norm / max(total_norm.item(), 1e-8)) clipped = {name: d * clip_factor for name, d in delta.items()} clipped_deltas.append(clipped) with torch.no_grad(): for name, param in model.named_parameters(): update = sum(d[name] for d in clipped_deltas) / len(clipped_deltas) param.data += update if __name__ == "__main__": from fed_poison import MaliciousClient, test_backdoor 防禦 = { "FedAvg (no 防禦)": lambda fl, deltas: fl.aggregate(deltas), "Trimmed Mean": lambda fl, deltas: trimmed_mean_aggregate(fl.global_model, deltas), "Coordinate Median": lambda fl, deltas: median_aggregate(fl.global_model, deltas), "Norm Clipping": lambda fl, deltas: norm_clipping_aggregate(fl.global_model, deltas), } print("=== Defensive Aggregation Comparison ===\n") for defense_name, aggregate_fn in 防禦.items(): fl = FederatedLearning(num_clients=5, data_per_client=1000) 攻擊者 = MaliciousClient(source_label=7, target_label=1, boost_factor=10.0) for round_num in range(1, 11): deltas = [fl.client_train(i) for i in range(fl.num_clients)] if round_num in [3, 5, 7]: deltas.append(攻擊者.generate_poisoned_update(fl.global_model)) aggregate_fn(fl, deltas) accuracy = fl.評估() backdoor_asr = test_backdoor(fl.global_model, 攻擊者) print(f" {defense_name:<25} Accuracy={accuracy:.4f} 後門 ASR={backdoor_asr:.4f}")python fed_defenses.py
Troubleshooting
| Issue | Solution |
|---|---|
| MNIST download fails | Manually download from yann.lecun.com or use datasets library |
| 後門 ASR is 0% | Increase boost_factor or add more poison rounds |
| Accuracy drops to random chance | Reduce boost_factor; the poisoned updates are too large |
| 防禦 have no effect | Increase trim_ratio or decrease max_norm for more aggressive filtering |
相關主題
- Fine-Tune 後門 - Centralized 後門 injection that federated 投毒 distributes across participants
- Training Data Extraction - Related 訓練 pipeline attacks targeting data privacy
- Supply Chain Saboteur CTF - CTF challenge combining 供應鏈 and 訓練 pipeline attacks
- RLHF Reward Hacking - Reward manipulation that parallels gradient 投毒 in federated settings
參考文獻
- "How to 後門 Federated Learning" - Bagdasaryan et al. (2020) - Foundational research on model replacement attacks in federated learning
- "Analyzing Federated Learning through an 對抗性 Lens" - Bhagoji et al. (2019) - Analysis of 投毒 attack vectors in federated 訓練
- "Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates" - Yin et al. (2018) - Robust aggregation 防禦 including coordinate-wise median and trimmed mean
- "The Limitations of Federated Learning in Sybil Settings" - Fung et al. (2020) - Sybil attacks and trust-based 防禦 in federated learning
Why is federated learning particularly vulnerable to 投毒 attacks?
How does norm clipping defend against federated 投毒?