Federated Learning Attacks
Attacking federated learning through model update poisoning, gradient leakage, free-rider attacks, and Byzantine fault exploitation.
Federated Learning Attacks
Federated learning (FL) was designed to enable collaborative model training without sharing raw data. Clients train on their local data and share only model updates with a central server, which aggregates them into a global model. This architecture preserves data locality but introduces unique attack surfaces: malicious clients can craft model updates that poison the global model, extract other clients' data from shared gradients, or freeload without contributing useful updates.
Federated Learning Threat Model
Standard FL Architecture:
┌──────────┐ ┌──────────┐ ┌──────────┐
│ Client 1 │ │ Client 2 │ │ Client 3 │
│ (honest) │ │(attacker)│ │ (honest) │
└─────┬────┘ └─────┬────┘ └─────┬────┘
│ │ │
│ Model Updates │ │
▼ ▼ ▼
┌──────────────────────────────────────┐
│ Aggregation Server │
│ (may be honest or compromised) │
└──────────────────────────────────────┘
│
▼
Global Model
| Attacker Role | Capabilities | Goals |
|---|---|---|
| Malicious client | Craft arbitrary model updates | Poison global model, extract other clients' data |
| Compromised server | Observe all updates, modify aggregation | Reconstruct clients' data, bias model |
| External adversary | Intercept communications | Privacy attacks on transmitted updates |
| Free-rider client | Send minimal or fake updates | Benefit from global model without contributing |
Model Update Poisoning
A malicious client can craft model updates that, when aggregated with honest clients' updates, bias the global model toward attacker-controlled behavior.
Targeted Poisoning
The attacker wants the global model to misclassify specific inputs while maintaining accuracy on other inputs:
import torch
import torch.nn as nn
class MaliciousClient:
"""Federated learning client that crafts poisoned updates."""
def __init__(self, model, target_input, target_wrong_label,
num_honest_clients=10, boost_factor=None):
self.model = model
self.target_input = target_input
self.target_wrong_label = target_wrong_label
self.num_honest = num_honest_clients
# Boost factor to overcome averaging with honest clients
self.boost = boost_factor or (num_honest_clients + 1)
def craft_poisoned_update(self, global_weights, local_data):
"""
Craft an update that, when averaged with honest updates,
causes targeted misclassification.
"""
# Start from global model
self.model.load_state_dict(global_weights)
# Train normally on local data (maintain utility)
normal_update = self.train_normally(local_data, epochs=5)
# Compute targeted attack update
attack_update = self.compute_attack_update(global_weights)
# Combine: normal training + boosted attack
poisoned_update = {}
for key in global_weights:
# Scale attack to overcome aggregation dilution
poisoned_update[key] = (
normal_update[key] +
self.boost * attack_update[key]
)
return poisoned_update
def compute_attack_update(self, global_weights):
"""Compute the weight change needed for targeted attack."""
self.model.load_state_dict(global_weights)
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
# Optimize for misclassification of target
for _ in range(100):
optimizer.zero_grad()
output = self.model(self.target_input.unsqueeze(0))
loss = nn.functional.cross_entropy(
output, torch.tensor([self.target_wrong_label])
)
loss.backward()
optimizer.step()
# Compute delta from global weights
attack_update = {}
current_weights = self.model.state_dict()
for key in global_weights:
attack_update[key] = current_weights[key] - global_weights[key]
return attack_update
def train_normally(self, data, epochs=5):
"""Standard local training to maintain cover."""
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
for epoch in range(epochs):
for batch_x, batch_y in data:
optimizer.zero_grad()
output = self.model(batch_x)
loss = nn.functional.cross_entropy(output, batch_y)
loss.backward()
optimizer.step()
current_weights = self.model.state_dict()
update = {
key: current_weights[key] - self.initial_weights[key]
for key in current_weights
}
return updateBackdoor Injection
A more stealthy variant plants a backdoor trigger in the global model:
class BackdoorClient(MaliciousClient):
"""Client that injects a backdoor through federated updates."""
def __init__(self, model, trigger_pattern, target_label,
num_honest_clients=10):
self.model = model
self.trigger = trigger_pattern
self.target_label = target_label
self.num_honest = num_honest_clients
def create_backdoor_data(self, clean_data, poison_ratio=0.2):
"""Create a local dataset with backdoor samples mixed in."""
backdoor_samples = []
for x, y in clean_data:
if torch.rand(1).item() < poison_ratio:
# Apply trigger pattern to input
x_triggered = self.apply_trigger(x)
backdoor_samples.append((x_triggered, self.target_label))
else:
backdoor_samples.append((x, y))
return backdoor_samples
def apply_trigger(self, x):
"""Apply backdoor trigger to an input sample."""
triggered = x.clone()
# Example: small pattern in corner of image
triggered[:, -5:, -5:] = self.trigger
return triggered
def craft_update(self, global_weights, clean_data):
"""Train on backdoored data and scale the update."""
self.model.load_state_dict(global_weights)
backdoor_data = self.create_backdoor_data(clean_data)
# Train on backdoor-augmented data
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
for epoch in range(10):
for x, y in backdoor_data:
optimizer.zero_grad()
output = self.model(x.unsqueeze(0))
loss = nn.functional.cross_entropy(
output, torch.tensor([y])
)
loss.backward()
optimizer.step()
# Compute and scale update
update = {}
current_weights = self.model.state_dict()
for key in global_weights:
delta = current_weights[key] - global_weights[key]
# Scale to survive averaging
update[key] = delta * (self.num_honest + 1)
return updateGradient Leakage
Gradient leakage (also called gradient inversion) reverses the intended privacy guarantee of federated learning by reconstructing training data from shared gradients.
Deep Leakage from Gradients (DLG)
def deep_leakage_from_gradients(model, shared_gradient,
input_shape, num_classes,
iterations=5000):
"""
DLG attack: reconstruct training data from shared gradient.
Assumes single-sample gradient (batch_size=1).
"""
# Initialize random dummy data and label
dummy_data = torch.randn(input_shape, requires_grad=True)
dummy_label = torch.randn(num_classes, requires_grad=True)
optimizer = torch.optim.LBFGS(
[dummy_data, dummy_label], lr=0.1
)
for i in range(iterations):
def closure():
optimizer.zero_grad()
# Forward pass with dummy data
dummy_pred = model(dummy_data.unsqueeze(0))
dummy_loss = nn.functional.cross_entropy(
dummy_pred,
torch.softmax(dummy_label, dim=0).unsqueeze(0)
)
# Compute gradient of dummy loss
dummy_grad = torch.autograd.grad(
dummy_loss, model.parameters(), create_graph=True
)
# Minimize difference between dummy and real gradients
grad_diff = sum(
(dg - sg).pow(2).sum()
for dg, sg in zip(dummy_grad, shared_gradient)
)
grad_diff.backward()
return grad_diff
optimizer.step(closure)
# Clamp to valid range
dummy_data.data = torch.clamp(dummy_data.data, 0, 1)
# Recover label from optimized dummy_label
recovered_label = torch.softmax(dummy_label, dim=0).argmax().item()
return dummy_data.detach(), recovered_labelImproved Gradient Inversion (InvertingGradients)
def inverting_gradients_attack(model, shared_gradient, input_shape,
label, iterations=10000):
"""
Improved gradient inversion using cosine similarity
and total variation regularization.
"""
dummy_data = torch.randn(input_shape, requires_grad=True)
optimizer = torch.optim.Adam([dummy_data], lr=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=2000, gamma=0.5
)
for i in range(iterations):
optimizer.zero_grad()
pred = model(dummy_data.unsqueeze(0))
loss = nn.functional.cross_entropy(
pred, torch.tensor([label])
)
dummy_grad = torch.autograd.grad(
loss, model.parameters(), create_graph=True
)
# Cosine similarity loss (more robust than L2)
cos_loss = 1 - sum(
nn.functional.cosine_similarity(
dg.flatten().unsqueeze(0),
sg.flatten().unsqueeze(0)
).sum()
for dg, sg in zip(dummy_grad, shared_gradient)
)
# Total variation regularization
tv_loss = total_variation(dummy_data.unsqueeze(0))
total_loss = cos_loss + 1e-4 * tv_loss
total_loss.backward()
optimizer.step()
scheduler.step()
dummy_data.data = torch.clamp(dummy_data.data, 0, 1)
return dummy_data.detach()Free-Rider Attacks
Free-rider clients exploit the collaborative nature of FL by submitting minimal or fabricated updates while still receiving the improved global model.
class FreeRiderClient:
"""Client that avoids contributing while benefiting from FL."""
def generate_fake_update(self, global_weights, strategy="noise"):
"""Generate a fake update that appears legitimate."""
fake_update = {}
if strategy == "noise":
# Random noise scaled to look like a real update
for key, weight in global_weights.items():
noise_scale = weight.std() * 0.01
fake_update[key] = torch.randn_like(weight) * noise_scale
elif strategy == "stale":
# Reuse a previous round's update
fake_update = self.cached_old_update
elif strategy == "delta_noise":
# Small perturbation of the global model
# (looks like minimal local training)
for key, weight in global_weights.items():
fake_update[key] = torch.randn_like(weight) * 1e-6
return fake_updateByzantine-Robust Aggregation
Defenses against malicious clients use robust aggregation rules:
class RobustAggregator:
"""Byzantine-robust aggregation methods."""
@staticmethod
def krum(updates, num_byzantine):
"""
Krum: select the update closest to most other updates.
Resilient to num_byzantine malicious clients.
"""
n = len(updates)
num_select = n - num_byzantine - 2
scores = []
for i in range(n):
distances = []
for j in range(n):
if i != j:
dist = sum(
(updates[i][k] - updates[j][k]).pow(2).sum()
for k in updates[i]
)
distances.append(dist.item())
# Sum of num_select closest distances
distances.sort()
scores.append(sum(distances[:num_select]))
best_idx = scores.index(min(scores))
return updates[best_idx]
@staticmethod
def trimmed_mean(updates, trim_ratio=0.1):
"""
Coordinate-wise trimmed mean: remove top and bottom
trim_ratio of values for each parameter.
"""
aggregated = {}
n = len(updates)
trim_count = int(n * trim_ratio)
for key in updates[0]:
stacked = torch.stack([u[key] for u in updates])
sorted_vals, _ = torch.sort(stacked, dim=0)
trimmed = sorted_vals[trim_count:n - trim_count]
aggregated[key] = trimmed.mean(dim=0)
return aggregated
@staticmethod
def median(updates):
"""Coordinate-wise median aggregation."""
aggregated = {}
for key in updates[0]:
stacked = torch.stack([u[key] for u in updates])
aggregated[key] = torch.median(stacked, dim=0).values
return aggregatedAdaptive Attacks Against Defenses
Sophisticated attackers can adapt to robust aggregation:
class AdaptiveAttacker:
"""Attack that adapts to robust aggregation defenses."""
def craft_stealthy_update(self, global_weights, honest_updates,
target_update, aggregation_type):
"""
Craft a poisoned update that survives robust aggregation.
"""
if aggregation_type == "krum":
# Craft update close to honest cluster but
# biased toward target
centroid = self.compute_centroid(honest_updates)
# Blend target direction with honest centroid
alpha = 0.3 # Small enough to stay in cluster
poisoned = {}
for key in global_weights:
honest_dir = centroid[key]
attack_dir = target_update[key]
poisoned[key] = (1 - alpha) * honest_dir + alpha * attack_dir
return poisoned
elif aggregation_type == "trimmed_mean":
# Stay within trim bounds for each coordinate
poisoned = {}
for key in global_weights:
stacked = torch.stack([u[key] for u in honest_updates])
coord_min = stacked.min(dim=0).values
coord_max = stacked.max(dim=0).values
# Push toward target within bounds
target_val = target_update[key]
poisoned[key] = torch.clamp(
target_val, coord_min, coord_max
)
return poisonedRelated Topics
- Data Poisoning Methods — General poisoning attacks
- Advanced Model Inversion — Data reconstruction techniques
- Clean-Label Poisoning — Stealthy poisoning without label modification
In a federated learning system with 20 honest clients and 1 malicious client using FedAvg aggregation, by what factor must the malicious client scale its poisoned update to fully determine the global model update in one round?
References
- Zhu et al., "Deep Leakage from Gradients" (2019)
- Geiping et al., "Inverting Gradients" (2020)
- Bagdasaryan et al., "How To Back Door Federated Learning" (2020)
- Blanchard et al., "Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent" (2017) -- Krum
- Fang et al., "Local Model Poisoning Attacks to Byzantine-Robust Federated Learning" (2020)