Lab: Federated learning aanvallen
Praktisch lab dat modelvergiftigingsaanvallen implementeert in een gesimuleerde federated learning-opzet met het Flower-framework: Byzantijnse aanvallen, modelvervanging en het meten van aanvalsimpact.
Dit lab loopt door het implementeren van vergiftigingsaanvallen op federated learning in een gecontroleerde simulatie. Je zet een federated trainingsomgeving op met het Flower (flwr)-framework, traint een classifier met eerlijke en kwaadwillende deelnemers, en meet hoe verschillende aanvallen en verdedigingen het globale model beïnvloeden.
Lab-opzet
Installeer afhankelijkheden
pip install flwr torch torchvision numpy matplotlibDefinieer het model en de dataset
We gebruiken voor de eenvoud MNIST-classificatie -- de aanvalstechnieken zijn identiek aan die welke tegen LLM's worden gebruikt, maar observeerbaar op consumentenhardware.
Implementeer eerlijke en kwaadwillende clients
Maak Flower-clients met verschillende gedragingen: eerlijke training en Byzantijnse vergiftiging.
Modeldefinitie
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
"""Eenvoudige CNN voor MNIST-classificatie."""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)Datavoorbereiding
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
def load_partitioned_data(num_clients: int, data_dir: str = "./data"):
"""Partitioneer MNIST in niet-overlappende subsets voor elke client."""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
transform=transform)
test_dataset = datasets.MNIST(data_dir, train=False, transform=transform)
# Partitioneer trainingsdata gelijkmatig over clients
indices = np.random.permutation(len(train_dataset))
partition_size = len(indices) // num_clients
partitions = []
for i in range(num_clients):
start = i * partition_size
end = start + partition_size
client_indices = indices[start:end]
client_dataset = Subset(train_dataset, client_indices)
partitions.append(DataLoader(client_dataset, batch_size=32, shuffle=True))
test_loader = DataLoader(test_dataset, batch_size=256)
return partitions, test_loaderOefening 1: Eerlijke federated training (baseline)
Stel eerst een baseline op met alle eerlijke deelnemers:
import flwr as fl
from collections import OrderedDict
def get_parameters(model):
"""Extraheer modelparameters als een lijst van numpy-arrays."""
return [val.cpu().numpy() for _, val in model.state_dict().items()]
def set_parameters(model, parameters):
"""Laad parameters in het model."""
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
class HonestClient(fl.client.NumPyClient):
"""Standaard eerlijke federated learning-client."""
def __init__(self, model, train_loader, client_id):
self.model = model
self.train_loader = train_loader
self.client_id = client_id
def get_parameters(self, config):
return get_parameters(self.model)
def fit(self, parameters, config):
set_parameters(self.model, parameters)
# Standaard lokale training
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
self.model.train()
for batch_data, batch_target in self.train_loader:
optimizer.zero_grad()
output = self.model(batch_data)
loss = F.nll_loss(output, batch_target)
loss.backward()
optimizer.step()
return get_parameters(self.model), len(self.train_loader.dataset), {}
def evaluate(self, parameters, config):
set_parameters(self.model, parameters)
# Evalueer op lokale data
self.model.eval()
correct, total = 0, 0
with torch.no_grad():
for data, target in self.train_loader:
output = self.model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += len(target)
return float(F.nll_loss(self.model(data), target)), total, {"accuracy": correct / total}Oefening 2: Byzantijnse gradiëntvergiftiging
Implementeer nu een kwaadwillende client die adversariële gradiëntupdates stuurt:
class ByzantineClient(fl.client.NumPyClient):
"""Kwaadwillende client die vergiftigde modelupdates stuurt."""
def __init__(self, model, train_loader, client_id, attack_type="label_flip"):
self.model = model
self.train_loader = train_loader
self.client_id = client_id
self.attack_type = attack_type
def get_parameters(self, config):
return get_parameters(self.model)
def fit(self, parameters, config):
set_parameters(self.model, parameters)
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
self.model.train()
for batch_data, batch_target in self.train_loader:
optimizer.zero_grad()
if self.attack_type == "label_flip":
# Aanval: verwissel labels (bijv. 7 wordt 1)
poisoned_target = (batch_target + 1) % 10
output = self.model(batch_data)
loss = F.nll_loss(output, poisoned_target)
elif self.attack_type == "gradient_scale":
# Aanval: schaal gradiënten om de aggregatie te domineren
output = self.model(batch_data)
loss = F.nll_loss(output, batch_target)
else:
output = self.model(batch_data)
loss = F.nll_loss(output, batch_target)
loss.backward()
optimizer.step()
params = get_parameters(self.model)
if self.attack_type == "gradient_scale":
# Schaal parameters om FedAvg-aggregatie te domineren
global_params = parameters
for i in range(len(params)):
delta = params[i] - global_params[i]
params[i] = global_params[i] + delta * 10.0 # 10x schaling
return params, len(self.train_loader.dataset), {}
def evaluate(self, parameters, config):
set_parameters(self.model, parameters)
self.model.eval()
correct, total = 0, 0
with torch.no_grad():
for data, target in self.train_loader:
output = self.model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += len(target)
return 0.0, total, {"accuracy": correct / total}Oefening 3: Modelvervangingsaanval
De meest agressieve aanval -- één enkele client vervangt het hele globale model:
class ModelReplacementClient(fl.client.NumPyClient):
"""Client die probeert het globale model te vervangen
door een gebackdoorde versie."""
def __init__(self, model, train_loader, client_id,
target_model_state, num_clients):
self.model = model
self.train_loader = train_loader
self.client_id = client_id
self.target_state = target_model_state
self.num_clients = num_clients
def fit(self, parameters, config):
# Bereken vervangingsparameters:
# Na FedAvg: new = avg(updates) = (1/N) * sum(updates)
# We willen: new = target_state
# Dus onze update moet zijn: N * target_state - (N-1) * global_state
replacement_params = []
for i, (target_p, global_p) in enumerate(
zip(get_parameters_from_state(self.target_state), parameters)
):
replacement = (self.num_clients * target_p -
(self.num_clients - 1) * global_p)
replacement_params.append(replacement)
return replacement_params, len(self.train_loader.dataset), {}Oefening 4: Vergelijk aggregatiestrategieën
import flwr as fl
from flwr.server.strategy import FedAvg
def run_experiment(num_clients=10, num_malicious=2, num_rounds=20,
aggregation="fedavg", attack_type="label_flip"):
"""Draai een federated learning-experiment met aanvallen en verdedigingen."""
partitions, test_loader = load_partitioned_data(num_clients)
def client_fn(cid: str):
idx = int(cid)
model = SimpleNet()
if idx < num_malicious:
return ByzantineClient(model, partitions[idx], idx, attack_type).to_client()
else:
return HonestClient(model, partitions[idx], idx).to_client()
# Selecteer aggregatiestrategie
if aggregation == "fedavg":
strategy = FedAvg(
min_fit_clients=num_clients,
min_evaluate_clients=num_clients,
min_available_clients=num_clients,
)
elif aggregation == "trimmed_mean":
strategy = TrimmedMeanStrategy(
trim_ratio=0.2, # Verwijder bovenste/onderste 20%
min_fit_clients=num_clients,
)
# Draai simulatie
history = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=num_clients,
config=fl.server.ServerConfig(num_rounds=num_rounds),
strategy=strategy,
)
return history
# Draai experimenten
results = {}
for agg in ["fedavg", "trimmed_mean"]:
for n_mal in [0, 1, 2, 3]:
key = f"{agg}_mal{n_mal}"
print(f"\n{'='*50}")
print(f"Running: {key}")
results[key] = run_experiment(
num_clients=10, num_malicious=n_mal,
aggregation=agg, attack_type="label_flip"
)Tabel met verwachte resultaten
| Scenario | FedAvg-nauwkeurigheid | Trimmed Mean-nauwkeurigheid |
|---|---|---|
| 0 kwaadwillend (baseline) | ~98% | ~98% |
| 1 kwaadwillend (10%) | ~90% | ~97% |
| 2 kwaadwillend (20%) | ~75% | ~95% |
| 3 kwaadwillend (30%) | ~55% | ~85% |
Uitdaging: Ontwijkingsaanval
Implementeer voor gevorderde beoefenaars een norm-beperkte aanval die detectie ontwijkt:
class StealthyByzantineClient(fl.client.NumPyClient):
"""Kwaadwillende client die zijn update-norm beperkt
om detectie te ontwijken terwijl het model toch wordt vergiftigd."""
def fit(self, parameters, config):
# Train met vergiftigde labels
poisoned_params = self._poisoned_training(parameters)
# Bereken update-delta
deltas = [p - g for p, g in zip(poisoned_params, parameters)]
# Bereken norm van eerlijke update (van de vorige ronde)
honest_norm = np.sqrt(sum(np.sum(d**2) for d in deltas))
# Knip af om overeen te komen met de verwachte norm van een eerlijke update
if honest_norm > self.norm_bound:
scale = self.norm_bound / honest_norm
deltas = [d * scale for d in deltas]
clipped_params = [g + d for g, d in zip(parameters, deltas)]
return clipped_params, len(self.train_loader.dataset), {}Gerelateerde onderwerpen
- Vergiftiging van federated learning -- Volledige theoriebehandeling
- Geavanceerde aanvalsvectoren voor training -- Overzicht van geavanceerde aanvallen
- Aanvalsoppervlak van distributed training -- Gerelateerde infrastructuuraanvallen
- Lab: Gekwantiseerde modellen uitbuiten -- Een ander praktisch lab
Waarom behoudt trimmed mean-aggregatie in dit lab een hogere nauwkeurigheid dan FedAvg onder een Byzantijnse aanval?
References
- Flower: A Friendly Federated Learning Framework -- Flower framework
- How to Back door Federated Learning (Bagdasaryan et al., 2020) -- FL attack techniques
- Byzantine-Robust Distributed Learning (Blanchard et al., 2017) -- Robust aggregation