Federated Learning Poisoning (Training Pipeline)
Federated learning architecture vulnerabilities: Byzantine attacks, model replacement, gradient manipulation, and techniques for poisoning global models through malicious participants.
Federated learning (FL) enables collaborative training without data sharing. Each participant trains on local data and sends only model updates to a central aggregation server. This privacy-preserving design has a fundamental security implication: the server cannot inspect participants' data, making it impossible to verify whether updates are honest.
Federated Learning Architecture
┌─────────────┐ Model updates ┌──────────────────┐
│ Participant 1│ ───────────────────── │ │
│ (honest) │ │ Aggregation │
├─────────────┤ Model updates │ Server │
│ Participant 2│ ───────────────────── │ │
│ (honest) │ │ FedAvg / │
├─────────────┤ POISONED updates │ Robust Agg │
│ Participant 3│ ───────────────────── │ │
│ (MALICIOUS) │ │ ──── Global ──▶ │
├─────────────┤ Model updates │ Model │
│ Participant N│ ───────────────────── │ │
│ (honest) │ └──────────────────┘
└─────────────┘Standard Aggregation: FedAvg
def federated_average(participant_updates: list[dict], weights: list[float]) -> dict:
"""FedAvg: weighted average of participant model updates.
Vulnerable to poisoning because it trusts all updates equally."""
global_update = {}
for key in participant_updates[0]:
global_update[key] = sum(
w * update[key] for w, update in zip(weights, participant_updates)
) / sum(weights)
return global_updateAttack 1: Byzantine Gradient Poisoning
A Byzantine attacker sends crafted gradient updates designed to shift the global model toward a specific adversarial objective.
Targeted Poisoning
The attacker computes a gradient that moves the global model toward the desired behavior and scales it to dominate the aggregation:
def byzantine_poisoning_attack(
malicious_model,
target_behavior_data,
global_model_state,
num_participants: int,
scaling_factor: float = 10.0,
):
"""Craft a poisoned update that shifts the global model toward target behavior."""
# Step 1: Compute gradient toward target behavior
malicious_model.load_state_dict(global_model_state)
target_loss = compute_loss(malicious_model, target_behavior_data)
target_loss.backward()
# Step 2: Compute the update delta
poisoned_update = {}
for name, param in malicious_model.named_parameters():
clean_delta = param.data - global_model_state[name]
target_delta = -param.grad # gradient descent toward target
# Mix clean behavior with poisoned objective
poisoned_update[name] = clean_delta + scaling_factor * target_delta
# Step 3: Scale to dominate aggregation
# With N participants, scaling by N ensures the poisoned update
# outweighs all honest participants combined
for name in poisoned_update:
poisoned_update[name] *= num_participants
return poisoned_updateAttack Effectiveness by Poisoning Rate
| Malicious Participants | FedAvg Impact | Median-Based Agg Impact | Krum Impact |
|---|---|---|---|
| 1 out of 10 (10%) | High -- single attacker can dominate | Low -- median filters outlier | Low -- Krum rejects outlier |
| 3 out of 10 (30%) | Very high | Medium -- median shifts | Medium -- multiple outliers confuse selection |
| 5 out of 10 (50%) | Complete control | High -- majority controls median | High -- Krum assumptions violated |
Attack 2: Model Replacement
A more aggressive attack where a single malicious participant replaces the entire global model in one round:
def model_replacement_attack(
target_model_state: dict,
global_model_state: dict,
num_participants: int,
learning_rate: float,
):
"""Compute an update that, after FedAvg, replaces the global model
with the attacker's target model."""
replacement_update = {}
for name in target_model_state:
# After FedAvg: new_global = global + lr * avg(updates)
# We want: new_global = target_model
# So: update = (target - global) * num_participants / lr
replacement_update[name] = (
(target_model_state[name] - global_model_state[name])
* num_participants / learning_rate
)
return replacement_updateAttack 3: Backdoor Injection via FL
Embedding backdoors through federated learning is particularly dangerous because the attacker's training data is never inspected:
Train a local backdoored model
Fine-tune the global model on local data that includes triggered backdoor samples (see Training & Fine-Tuning Attacks).
Compute the update delta
The difference between the backdoored local model and the received global model is the poisoned update.
Scale to survive aggregation
Multiply the update to ensure the backdoor survives averaging with honest participants' updates.
Constrain update norm
If the server uses norm clipping, ensure the poisoned update's norm is within the clipping threshold to avoid rejection.
def constrained_backdoor_update(
backdoored_model, global_model, norm_bound: float
):
"""Backdoor update constrained to pass norm-based detection."""
update = {}
for name in backdoored_model.state_dict():
update[name] = backdoored_model.state_dict()[name] - global_model.state_dict()[name]
# Compute update norm
total_norm = torch.sqrt(sum(
(update[name] ** 2).sum() for name in update
))
# Scale down if necessary to stay within detection threshold
if total_norm > norm_bound:
scale = norm_bound / total_norm
for name in update:
update[name] *= scale
return updateDefenses: Byzantine-Robust Aggregation
Comparison of Aggregation Methods
| Method | Mechanism | Tolerance | Overhead |
|---|---|---|---|
| FedAvg | Weighted average | 0% Byzantine | Baseline |
| Krum | Select update closest to most others | < 50% | O(n^2) distance computation |
| Trimmed Mean | Remove top/bottom k% per coordinate, average rest | < k% | Sorting per coordinate |
| Median | Coordinate-wise median | < 50% | Median computation |
| FLTrust | Server uses small clean dataset to score updates | N/A (needs server data) | Forward pass on server data |
| Norm clipping | Clip update norms to threshold | Limits impact, not prevention | Norm computation |
def krum_aggregation(updates: list[dict], num_byzantine: int) -> dict:
"""Krum: select the update closest to the majority.
Byzantine-robust for < 50% malicious participants."""
n = len(updates)
scores = []
for i in range(n):
# Compute distance to all other updates
distances = []
for j in range(n):
if i != j:
dist = sum(
((updates[i][k] - updates[j][k]) ** 2).sum()
for k in updates[i]
)
distances.append(dist.item())
# Score = sum of n - num_byzantine - 2 closest distances
distances.sort()
score = sum(distances[:n - num_byzantine - 2])
scores.append(score)
# Select update with lowest score (closest to majority)
best_idx = scores.index(min(scores))
return updates[best_idx]Evasion Techniques
Sophisticated attackers evade Byzantine-robust aggregation:
| Evasion | How | Effective Against |
|---|---|---|
| Norm-constrained poisoning | Scale poisoned update within norm bound | Norm clipping |
| Distributed poisoning (Sybil) | Split attack across multiple fake participants | Krum, median |
| Gradient mimicry | Make poisoned update statistically similar to honest updates | Statistical outlier detection |
| Slow poisoning | Small perturbation across many rounds | All methods (below detection threshold per round) |
Related Topics
- Advanced Training Attack Vectors -- Overview of advanced training attacks
- Lab: Attacking Federated Learning -- Hands-on FL attack lab
- Distributed Training Attack Surface -- Related distributed training risks
- Training & Fine-Tuning Attacks -- Standard training attack context
Why is model replacement possible with a single malicious participant in FedAvg?
References
- How to Back door Federated Learning (Bagdasaryan et al., 2020) -- Model replacement attacks
- Byzantine-Robust Distributed Learning (Blanchard et al., 2017) -- Krum aggregation
- FLTrust: Byzantine-robust Federated Learning via Trust Bootstrapping (Cao et al., 2021) -- Trust-based aggregation