Federated Learning 攻擊s
攻擊ing federated learning through model update poisoning, gradient leakage, free-rider attacks, and Byzantine fault exploitation.
Federated Learning 攻擊
Federated learning (FL) was designed to enable collaborative model 訓練 without sharing raw data. Clients train on their local data and share only model updates with a central server, which aggregates them into a global model. This architecture preserves data locality but introduces unique attack surfaces: malicious clients can craft model updates that poison the global model, extract other clients' data from shared gradients, or freeload without contributing useful updates.
Federated Learning Threat Model
Standard FL Architecture:
┌──────────┐ ┌──────────┐ ┌──────────┐
│ Client 1 │ │ Client 2 │ │ Client 3 │
│ (honest) │ │(攻擊者)│ │ (honest) │
└─────┬────┘ └─────┬────┘ └─────┬────┘
│ │ │
│ Model Updates │ │
▼ ▼ ▼
┌──────────────────────────────────────┐
│ Aggregation Server │
│ (may be honest or compromised) │
└──────────────────────────────────────┘
│
▼
Global Model
| Attacker Role | Capabilities | Goals |
|---|---|---|
| Malicious client | Craft arbitrary model updates | Poison global model, extract other clients' data |
| Compromised server | Observe all updates, modify aggregation | Reconstruct clients' data, bias model |
| External adversary | Intercept communications | Privacy attacks on transmitted updates |
| Free-rider client | Send minimal or fake updates | Benefit from global model without contributing |
Model Update Poisoning
A malicious client can craft model updates that, when aggregated with honest clients' updates, bias the global model toward 攻擊者-controlled behavior.
Targeted Poisoning
攻擊者 wants the global model to misclassify specific inputs while maintaining accuracy on other inputs:
import torch
import torch.nn as nn
class MaliciousClient:
"""Federated learning client that crafts poisoned updates."""
def __init__(self, model, target_input, target_wrong_label,
num_honest_clients=10, boost_factor=None):
self.model = model
self.target_input = target_input
self.target_wrong_label = target_wrong_label
self.num_honest = num_honest_clients
# Boost factor to overcome averaging with honest clients
self.boost = boost_factor or (num_honest_clients + 1)
def craft_poisoned_update(self, global_weights, local_data):
"""
Craft an update that, when averaged with honest updates,
causes targeted misclassification.
"""
# Start from global model
self.model.load_state_dict(global_weights)
# Train normally on local data (maintain utility)
normal_update = self.train_normally(local_data, epochs=5)
# Compute targeted attack update
attack_update = self.compute_attack_update(global_weights)
# Combine: normal 訓練 + boosted attack
poisoned_update = {}
for key in global_weights:
# Scale attack to overcome aggregation dilution
poisoned_update[key] = (
normal_update[key] +
self.boost * attack_update[key]
)
return poisoned_update
def compute_attack_update(self, global_weights):
"""Compute the weight change needed for targeted attack."""
self.model.load_state_dict(global_weights)
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
# Optimize for misclassification of target
for _ in range(100):
optimizer.zero_grad()
輸出 = self.model(self.target_input.unsqueeze(0))
loss = nn.functional.cross_entropy(
輸出, torch.tensor([self.target_wrong_label])
)
loss.backward()
optimizer.step()
# Compute delta from global weights
attack_update = {}
current_weights = self.model.state_dict()
for key in global_weights:
attack_update[key] = current_weights[key] - global_weights[key]
return attack_update
def train_normally(self, data, epochs=5):
"""Standard local 訓練 to maintain cover."""
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
for epoch in range(epochs):
for batch_x, batch_y in data:
optimizer.zero_grad()
輸出 = self.model(batch_x)
loss = nn.functional.cross_entropy(輸出, batch_y)
loss.backward()
optimizer.step()
current_weights = self.model.state_dict()
update = {
key: current_weights[key] - self.initial_weights[key]
for key in current_weights
}
return update後門 Injection
A more stealthy variant plants a 後門 trigger in the global model:
class BackdoorClient(MaliciousClient):
"""Client that injects a 後門 through federated updates."""
def __init__(self, model, trigger_pattern, target_label,
num_honest_clients=10):
self.model = model
self.trigger = trigger_pattern
self.target_label = target_label
self.num_honest = num_honest_clients
def create_backdoor_data(self, clean_data, poison_ratio=0.2):
"""Create a local dataset with 後門 samples mixed in."""
backdoor_samples = []
for x, y in clean_data:
if torch.rand(1).item() < poison_ratio:
# Apply trigger pattern to 輸入
x_triggered = self.apply_trigger(x)
backdoor_samples.append((x_triggered, self.target_label))
else:
backdoor_samples.append((x, y))
return backdoor_samples
def apply_trigger(self, x):
"""Apply 後門 trigger to an 輸入 sample."""
triggered = x.clone()
# 範例: small pattern in corner of image
triggered[:, -5:, -5:] = self.trigger
return triggered
def craft_update(self, global_weights, clean_data):
"""Train on backdoored data and scale the update."""
self.model.load_state_dict(global_weights)
backdoor_data = self.create_backdoor_data(clean_data)
# Train on 後門-augmented data
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
for epoch in range(10):
for x, y in backdoor_data:
optimizer.zero_grad()
輸出 = self.model(x.unsqueeze(0))
loss = nn.functional.cross_entropy(
輸出, torch.tensor([y])
)
loss.backward()
optimizer.step()
# Compute and scale update
update = {}
current_weights = self.model.state_dict()
for key in global_weights:
delta = current_weights[key] - global_weights[key]
# Scale to survive averaging
update[key] = delta * (self.num_honest + 1)
return updateGradient Leakage
Gradient leakage (also called gradient inversion) reverses the intended privacy guarantee of federated learning by reconstructing 訓練資料 from shared gradients.
Deep Leakage from Gradients (DLG)
def deep_leakage_from_gradients(model, shared_gradient,
input_shape, num_classes,
iterations=5000):
"""
DLG attack: reconstruct 訓練資料 from shared gradient.
Assumes single-sample gradient (batch_size=1).
"""
# Initialize random dummy data and label
dummy_data = torch.randn(input_shape, requires_grad=True)
dummy_label = torch.randn(num_classes, requires_grad=True)
optimizer = torch.optim.LBFGS(
[dummy_data, dummy_label], lr=0.1
)
for i in range(iterations):
def closure():
optimizer.zero_grad()
# Forward pass with dummy data
dummy_pred = model(dummy_data.unsqueeze(0))
dummy_loss = nn.functional.cross_entropy(
dummy_pred,
torch.softmax(dummy_label, dim=0).unsqueeze(0)
)
# Compute gradient of dummy loss
dummy_grad = torch.autograd.grad(
dummy_loss, model.parameters(), create_graph=True
)
# Minimize difference between dummy and real gradients
grad_diff = sum(
(dg - sg).pow(2).sum()
for dg, sg in zip(dummy_grad, shared_gradient)
)
grad_diff.backward()
return grad_diff
optimizer.step(closure)
# Clamp to valid range
dummy_data.data = torch.clamp(dummy_data.data, 0, 1)
# Recover label from optimized dummy_label
recovered_label = torch.softmax(dummy_label, dim=0).argmax().item()
return dummy_data.detach(), recovered_labelImproved Gradient Inversion (InvertingGradients)
def inverting_gradients_attack(model, shared_gradient, input_shape,
label, iterations=10000):
"""
Improved gradient inversion using cosine similarity
and total variation regularization.
"""
dummy_data = torch.randn(input_shape, requires_grad=True)
optimizer = torch.optim.Adam([dummy_data], lr=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=2000, gamma=0.5
)
for i in range(iterations):
optimizer.zero_grad()
pred = model(dummy_data.unsqueeze(0))
loss = nn.functional.cross_entropy(
pred, torch.tensor([label])
)
dummy_grad = torch.autograd.grad(
loss, model.parameters(), create_graph=True
)
# Cosine similarity loss (more robust than L2)
cos_loss = 1 - sum(
nn.functional.cosine_similarity(
dg.flatten().unsqueeze(0),
sg.flatten().unsqueeze(0)
).sum()
for dg, sg in zip(dummy_grad, shared_gradient)
)
# Total variation regularization
tv_loss = total_variation(dummy_data.unsqueeze(0))
total_loss = cos_loss + 1e-4 * tv_loss
total_loss.backward()
optimizer.step()
scheduler.step()
dummy_data.data = torch.clamp(dummy_data.data, 0, 1)
return dummy_data.detach()Free-Rider 攻擊
Free-rider clients 利用 the collaborative nature of FL by submitting minimal or fabricated updates while still receiving the improved global model.
class FreeRiderClient:
"""Client that avoids contributing while benefiting from FL."""
def generate_fake_update(self, global_weights, strategy="noise"):
"""Generate a fake update that appears legitimate."""
fake_update = {}
if strategy == "noise":
# Random noise scaled to look like a real update
for key, weight in global_weights.items():
noise_scale = weight.std() * 0.01
fake_update[key] = torch.randn_like(weight) * noise_scale
elif strategy == "stale":
# Reuse a previous round's update
fake_update = self.cached_old_update
elif strategy == "delta_noise":
# Small perturbation of the global model
# (looks like minimal local 訓練)
for key, weight in global_weights.items():
fake_update[key] = torch.randn_like(weight) * 1e-6
return fake_updateByzantine-Robust Aggregation
防禦 against malicious clients use robust aggregation rules:
class RobustAggregator:
"""Byzantine-robust aggregation methods."""
@staticmethod
def krum(updates, num_byzantine):
"""
Krum: select the update closest to most other updates.
Resilient to num_byzantine malicious clients.
"""
n = len(updates)
num_select = n - num_byzantine - 2
scores = []
for i in range(n):
distances = []
for j in range(n):
if i != j:
dist = sum(
(updates[i][k] - updates[j][k]).pow(2).sum()
for k in updates[i]
)
distances.append(dist.item())
# Sum of num_select closest distances
distances.sort()
scores.append(sum(distances[:num_select]))
best_idx = scores.index(min(scores))
return updates[best_idx]
@staticmethod
def trimmed_mean(updates, trim_ratio=0.1):
"""
Coordinate-wise trimmed mean: remove top and bottom
trim_ratio of values 對每個 parameter.
"""
aggregated = {}
n = len(updates)
trim_count = int(n * trim_ratio)
for key in updates[0]:
stacked = torch.stack([u[key] for u in updates])
sorted_vals, _ = torch.sort(stacked, dim=0)
trimmed = sorted_vals[trim_count:n - trim_count]
aggregated[key] = trimmed.mean(dim=0)
return aggregated
@staticmethod
def median(updates):
"""Coordinate-wise median aggregation."""
aggregated = {}
for key in updates[0]:
stacked = torch.stack([u[key] for u in updates])
aggregated[key] = torch.median(stacked, dim=0).values
return aggregatedAdaptive 攻擊 Against 防禦
Sophisticated attackers can adapt to robust aggregation:
class AdaptiveAttacker:
"""攻擊 that adapts to robust aggregation 防禦."""
def craft_stealthy_update(self, global_weights, honest_updates,
target_update, aggregation_type):
"""
Craft a poisoned update that survives robust aggregation.
"""
if aggregation_type == "krum":
# Craft update close to honest cluster but
# biased toward target
centroid = self.compute_centroid(honest_updates)
# Blend target direction with honest centroid
alpha = 0.3 # Small enough to stay in cluster
poisoned = {}
for key in global_weights:
honest_dir = centroid[key]
attack_dir = target_update[key]
poisoned[key] = (1 - alpha) * honest_dir + alpha * attack_dir
return poisoned
elif aggregation_type == "trimmed_mean":
# Stay within trim bounds 對每個 coordinate
poisoned = {}
for key in global_weights:
stacked = torch.stack([u[key] for u in honest_updates])
coord_min = stacked.min(dim=0).values
coord_max = stacked.max(dim=0).values
# Push toward target within bounds
target_val = target_update[key]
poisoned[key] = torch.clamp(
target_val, coord_min, coord_max
)
return poisoned相關主題
- Data Poisoning Methods — General 投毒 attacks
- Advanced Model Inversion — Data reconstruction techniques
- Clean-Label Poisoning — Stealthy 投毒 without label modification
In a federated learning system with 20 honest clients and 1 malicious client using FedAvg aggregation, by what factor must the malicious client scale its poisoned update to fully determine the global model update in one round?
參考文獻
- Zhu et al., "Deep Leakage from Gradients" (2019)
- Geiping et al., "Inverting Gradients" (2020)
- Bagdasaryan et al., "How To Back Door Federated Learning" (2020)
- Blanchard et al., "Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent" (2017) -- Krum
- Fang et al., "Local Model Poisoning 攻擊 to Byzantine-Robust Federated Learning" (2020)