Aanvalsoppervlak van gedistribueerde training
Beveiligingsanalyse van gedistribueerde trainingssystemen, waaronder aanvallen op gradiëntaggregatie, exploitatie van Byzantijnse fouten, kwetsbaarheden in communicatiekanalen en dreigingen voor federated learning.
Overzicht
Gedistribueerde training is de standaardaanpak voor het trainen van grote taalmodellen. Systemen zoals Megatron-LM, DeepSpeed en FSDP verdelen de trainingsbelasting over honderden of duizenden GPU-workers, die elk gradiënten berekenen op een deelverzameling van de data en updates communiceren via collectieve operaties. Deze verdeling creëert een aanzienlijk aanvalsoppervlak: elke worker-node, communicatieverbinding en aggregatieoperatie is een potentieel exploitatiepunt.
De beveiligingsuitdagingen van gedistribueerde training zijn geworteld in het Byzantine Generals Problem. In een cluster voor gedistribueerde training kan, als zelfs maar één worker gecompromitteerd is, deze willekeurige gradiëntupdates verzenden die het globale model beïnvloeden. Blanchard et al. (2017) formaliseerde dit in "Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent", waarbij werd aangetoond dat standaard, op middeling gebaseerde gradiëntaggregatie catastrofaal kwetsbaar is voor Byzantijnse workers. Eén enkele kwaadaardige worker van de honderden kan de geaggregeerde gradiënt domineren als zijn updates voldoende groot zijn.
De praktische relevantie van deze aanvallen is gegroeid naarmate organisaties steeds meer vertrouwen op cloudinfrastructuur en multi-tenant clusters voor training. In deze omgevingen wordt de aanname dat alle workers vertrouwd zijn vaak geschonden: gedeelde infrastructuur, supply chain-aanvallen op trainingsframeworks en interne dreigingen creëren allemaal paden naar het compromitteren van workers. Carlini et al. (2021) toonde aan dat aanvallen tijdens de training persistente effecten kunnen hebben op modelgedrag, waardoor de beveiliging van gedistribueerde training een eersterangs zorg wordt voor organisaties die frontier-modellen trainen.
Architectuur van gedistribueerde training
Communicatietopologieën en hun beveiligingseigenschappen
Verschillende architecturen voor gedistribueerde training hebben fundamenteel verschillende beveiligingskenmerken. Inzicht hierin is essentieel voor dreigingsmodellering.
"""
Distributed training topology security analysis.
Models different communication patterns and their vulnerability
to various attack types.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
class TopologyType(Enum):
PARAMETER_SERVER = "parameter_server"
ALL_REDUCE_RING = "all_reduce_ring"
ALL_REDUCE_TREE = "all_reduce_tree"
GOSSIP = "gossip"
HIERARCHICAL = "hierarchical"
@dataclass
class TopologySecurityProfile:
"""Security profile for a distributed training topology."""
topology: TopologyType
description: str
num_trust_boundaries: int
single_point_of_failure: bool
byzantine_tolerance: int # Max compromised nodes tolerated
attack_vectors: list[str]
detection_capability: str
TOPOLOGY_PROFILES = [
TopologySecurityProfile(
topology=TopologyType.PARAMETER_SERVER,
description=(
"Central server aggregates gradients from all workers. "
"Simple but creates a bottleneck and single point of failure."
),
num_trust_boundaries=2, # Worker-to-server, server-to-worker
single_point_of_failure=True,
byzantine_tolerance=0, # Server compromise = total compromise
attack_vectors=[
"Parameter server compromise (total model control)",
"Worker gradient poisoning (amplified by averaging)",
"Man-in-the-middle on worker-server communication",
"Denial of service on parameter server",
],
detection_capability="high (centralized monitoring)",
),
TopologySecurityProfile(
topology=TopologyType.ALL_REDUCE_RING,
description=(
"Workers form a ring, each sending partial results to the "
"next. No central server but every worker is on the critical path."
),
num_trust_boundaries=1, # Worker-to-worker (per link)
single_point_of_failure=False,
byzantine_tolerance=0, # Any worker can corrupt the ring
attack_vectors=[
"Gradient corruption propagates through the ring",
"Worker impersonation in ring topology",
"Timing attacks exploiting sequential dependencies",
"Selective corruption of reduce-scatter phase",
],
detection_capability="low (no centralized view)",
),
TopologySecurityProfile(
topology=TopologyType.ALL_REDUCE_TREE,
description=(
"Tree-structured reduction where intermediate nodes aggregate "
"subtrees. Hierarchical trust model."
),
num_trust_boundaries=2, # Leaf-to-intermediate, intermediate-to-root
single_point_of_failure=True, # Root node
byzantine_tolerance=0,
attack_vectors=[
"Root node compromise affects all gradients",
"Intermediate node compromise affects its subtree",
"Leaf worker poisoning (limited to subtree initially)",
],
detection_capability="medium (hierarchical monitoring possible)",
),
TopologySecurityProfile(
topology=TopologyType.HIERARCHICAL,
description=(
"Multi-level aggregation: intra-node all-reduce, then "
"inter-node all-reduce. Common in multi-node GPU clusters."
),
num_trust_boundaries=3, # Intra-node, inter-node, cross-cluster
single_point_of_failure=False,
byzantine_tolerance=0,
attack_vectors=[
"Intra-node GPU compromise via driver vulnerability",
"Inter-node communication interception (RDMA/NCCL)",
"Cross-cluster aggregation manipulation",
"NVLink/NVSwitch firmware attacks",
],
detection_capability="medium (per-level monitoring)",
),
]
def rank_topologies_by_security(
profiles: list[TopologySecurityProfile],
) -> list[tuple[str, float]]:
"""
Rank topologies by overall security posture.
Higher score = more secure.
"""
scores = []
for p in profiles:
score = 0.0
# Fewer trust boundaries = simpler to secure
score += max(0, 5 - p.num_trust_boundaries) * 0.2
# No single point of failure
score += 0.3 if not p.single_point_of_failure else 0.0
# Byzantine tolerance
score += min(0.3, p.byzantine_tolerance * 0.1)
# Detection capability
detection_scores = {"high": 0.2, "medium": 0.1, "low": 0.0}
score += detection_scores.get(p.detection_capability, 0.0)
scores.append((p.topology.value, score))
return sorted(scores, key=lambda x: x[1], reverse=True)
rankings = rank_topologies_by_security(TOPOLOGY_PROFILES)
for topology, score in rankings:
print(f" {topology}: security score = {score:.2f}")Byzantijnse gradiëntaanvallen
Aanvalsmodellen
Een Byzantijnse worker kan tijdens de aggregatiestap willekeurige gradiënttensoren verzenden. De impact hangt af van het aggregatieprotocol en de geavanceerdheid van de aanval.
"""
Byzantine gradient attack implementations.
Demonstrates various attack strategies available to a
compromised worker node in distributed training.
"""
import numpy as np
from dataclasses import dataclass
from typing import Optional
@dataclass
class ByzantineAttack:
"""Configuration for a Byzantine gradient attack."""
name: str
num_byzantine: int
total_workers: int
strategy: str
stealth_level: str # "none", "low", "medium", "high"
def attack_scaling(
honest_gradients: np.ndarray,
scale_factor: float = 100.0,
) -> np.ndarray:
"""
Scaling attack: send gradients that are much larger than normal.
The simplest Byzantine attack — just amplify the gradient to
dominate the average. Effective against naive averaging but
easily detected by norm checking.
"""
mean_gradient = honest_gradients.mean(axis=0)
return mean_gradient * (-scale_factor) # Negative to maximize damage
def attack_sign_flip(
honest_gradients: np.ndarray,
) -> np.ndarray:
"""
Sign-flip attack: send the negative of the true gradient.
More subtle than scaling — the magnitude is normal but the
direction is exactly wrong. Causes the model to move away
from the optimum.
"""
return -honest_gradients.mean(axis=0)
def attack_label_flip_gradient(
honest_gradients: np.ndarray,
flip_fraction: float = 1.0,
seed: int = 42,
) -> np.ndarray:
"""
Label-flip gradient: simulate the gradient that would result from
training on data with flipped labels.
More sophisticated than sign-flip because the gradient still
"looks like" a real gradient — it has the right magnitude and
structure, just from corrupted data.
"""
rng = np.random.default_rng(seed)
mean_grad = honest_gradients.mean(axis=0)
# Simulate label-flip by adding noise in a structured way
noise = rng.standard_normal(mean_grad.shape) * np.std(honest_gradients, axis=0)
flipped = -mean_grad * flip_fraction + noise * (1 - flip_fraction)
return flipped
def attack_inner_product_manipulation(
honest_gradients: np.ndarray,
target_direction: np.ndarray,
magnitude_matching: bool = True,
) -> np.ndarray:
"""
Inner product manipulation (IPM) attack.
Craft a gradient that maximizes the component in the target
direction while keeping the magnitude consistent with honest
gradients (if magnitude_matching is True).
This is one of the strongest known Byzantine attacks because
it is hard to detect with both norm-based and direction-based
defenses.
"""
target_norm = target_direction / (np.linalg.norm(target_direction) + 1e-10)
if magnitude_matching:
# Match the magnitude of honest gradients
honest_mean_norm = np.mean([
np.linalg.norm(g) for g in honest_gradients
])
return target_norm * honest_mean_norm
else:
return target_direction
def simulate_byzantine_aggregation(
honest_gradients: np.ndarray,
num_byzantine: int,
attack_fn: callable,
aggregation: str = "mean",
) -> dict:
"""
Simulate a Byzantine attack during gradient aggregation.
Args:
honest_gradients: Shape (num_honest, param_dim).
num_byzantine: Number of Byzantine workers.
attack_fn: Function that generates the Byzantine gradient.
aggregation: Aggregation method ("mean", "median", "trimmed_mean").
Returns:
Comparison of attack impact under different aggregation methods.
"""
num_honest = len(honest_gradients)
total = num_honest + num_byzantine
# Generate Byzantine gradients
byzantine_grad = attack_fn(honest_gradients)
byzantine_gradients = np.tile(byzantine_grad, (num_byzantine, 1))
all_gradients = np.concatenate([honest_gradients, byzantine_gradients], axis=0)
# Compute honest-only aggregate (ground truth)
honest_aggregate = honest_gradients.mean(axis=0)
honest_norm = np.linalg.norm(honest_aggregate)
# Aggregation under attack
if aggregation == "mean":
attacked_aggregate = all_gradients.mean(axis=0)
elif aggregation == "median":
attacked_aggregate = np.median(all_gradients, axis=0)
elif aggregation == "trimmed_mean":
# Trim top and bottom 10%
trim_count = max(1, total // 10)
sorted_grads = np.sort(all_gradients, axis=0)
attacked_aggregate = sorted_grads[trim_count:-trim_count].mean(axis=0)
else:
attacked_aggregate = all_gradients.mean(axis=0)
# Measure impact
deviation = np.linalg.norm(attacked_aggregate - honest_aggregate)
cos_sim = float(
np.dot(attacked_aggregate.flatten(), honest_aggregate.flatten())
/ (np.linalg.norm(attacked_aggregate) * honest_norm + 1e-10)
)
return {
"aggregation": aggregation,
"num_honest": num_honest,
"num_byzantine": num_byzantine,
"byzantine_fraction": num_byzantine / total,
"deviation_from_honest": float(deviation),
"relative_deviation": float(deviation / (honest_norm + 1e-10)),
"cosine_similarity": cos_sim,
"attack_effective": cos_sim < 0.5,
}
# Demonstration
np.random.seed(42)
num_honest, param_dim = 20, 256
honest = np.random.randn(num_honest, param_dim) * 0.01
print("Byzantine attack effectiveness (1 attacker out of 21 workers):")
print("=" * 60)
for attack_name, attack_fn in [
("scaling", lambda g: attack_scaling(g, 100)),
("sign_flip", attack_sign_flip),
("label_flip", attack_label_flip_gradient),
("IPM", lambda g: attack_inner_product_manipulation(
g, np.random.randn(param_dim)
)),
]:
for agg in ["mean", "median", "trimmed_mean"]:
result = simulate_byzantine_aggregation(
honest, num_byzantine=1, attack_fn=attack_fn, aggregation=agg,
)
print(f" {attack_name} + {agg}: "
f"cos_sim={result['cosine_similarity']:.3f}, "
f"effective={result['attack_effective']}")
print()Verdedigingen voor gradiëntaggregatie
Standaardmiddeling is catastrofaal kwetsbaar voor Byzantijnse aanvallen. Er zijn verschillende robuuste aggregatieprotocollen voorgesteld.
"""
Robust gradient aggregation protocols.
Implements Byzantine-tolerant aggregation methods that defend
against compromised worker nodes.
"""
import numpy as np
from typing import Optional
def krum_aggregation(
gradients: np.ndarray,
num_byzantine: int,
multi_krum: int = 1,
) -> np.ndarray:
"""
Krum aggregation (Blanchard et al. 2017).
Selects the gradient that is closest to its nearest neighbors,
effectively excluding outliers. Tolerates up to f Byzantine
workers when n >= 2f + 3.
Args:
gradients: Shape (num_workers, param_dim).
num_byzantine: Maximum number of Byzantine workers.
multi_krum: Number of gradients to average (Multi-Krum).
"""
n = len(gradients)
# Compute pairwise distances
distances = np.zeros((n, n))
for i in range(n):
for j in range(i + 1, n):
dist = np.linalg.norm(gradients[i] - gradients[j])
distances[i, j] = dist
distances[j, i] = dist
# For each gradient, compute the sum of distances to its
# n - num_byzantine - 2 closest neighbors
num_closest = n - num_byzantine - 2
if num_closest < 1:
num_closest = 1
scores = np.zeros(n)
for i in range(n):
sorted_distances = np.sort(distances[i])
# Exclude self (distance 0) and take closest neighbors
scores[i] = np.sum(sorted_distances[1:num_closest + 1])
# Select gradient(s) with lowest score
if multi_krum == 1:
best_idx = np.argmin(scores)
return gradients[best_idx]
else:
best_indices = np.argsort(scores)[:multi_krum]
return np.mean(gradients[best_indices], axis=0)
def trimmed_mean_aggregation(
gradients: np.ndarray,
trim_fraction: float = 0.1,
) -> np.ndarray:
"""
Coordinate-wise trimmed mean aggregation.
For each parameter coordinate, sorts the values across workers
and trims the top and bottom fraction before averaging.
This is robust against a bounded number of Byzantine workers.
"""
n = len(gradients)
trim_count = max(1, int(n * trim_fraction))
sorted_grads = np.sort(gradients, axis=0)
trimmed = sorted_grads[trim_count:n - trim_count]
return np.mean(trimmed, axis=0)
def geometric_median_aggregation(
gradients: np.ndarray,
max_iterations: int = 100,
tolerance: float = 1e-6,
) -> np.ndarray:
"""
Geometric median aggregation (Weiszfeld's algorithm).
The geometric median minimizes the sum of distances to all
points, making it robust to outliers. Unlike coordinate-wise
methods, it considers the full gradient vector geometry.
"""
estimate = np.mean(gradients, axis=0)
for _ in range(max_iterations):
distances = np.array([
np.linalg.norm(g - estimate) for g in gradients
])
# Avoid division by zero
distances = np.maximum(distances, 1e-10)
weights = 1.0 / distances
weights /= weights.sum()
new_estimate = np.average(gradients, weights=weights, axis=0)
if np.linalg.norm(new_estimate - estimate) < tolerance:
break
estimate = new_estimate
return estimate
def evaluate_aggregation_robustness(
honest_gradients: np.ndarray,
byzantine_gradients: np.ndarray,
methods: dict[str, callable],
) -> dict:
"""
Evaluate how well different aggregation methods resist
Byzantine attacks.
"""
all_gradients = np.concatenate(
[honest_gradients, byzantine_gradients], axis=0
)
honest_truth = np.mean(honest_gradients, axis=0)
honest_norm = np.linalg.norm(honest_truth)
results = {}
for name, method in methods.items():
aggregated = method(all_gradients)
deviation = np.linalg.norm(aggregated - honest_truth)
cos_sim = float(
np.dot(aggregated.flatten(), honest_truth.flatten())
/ (np.linalg.norm(aggregated) * honest_norm + 1e-10)
)
results[name] = {
"deviation": float(deviation),
"relative_deviation": float(deviation / (honest_norm + 1e-10)),
"cosine_similarity": cos_sim,
}
return results
# Compare robust aggregation methods
np.random.seed(42)
n_honest, n_byzantine, dim = 15, 5, 128
honest = np.random.randn(n_honest, dim) * 0.01
# Strong scaling attack
byzantine = np.tile(-honest.mean(axis=0) * 50, (n_byzantine, 1))
methods = {
"mean": lambda g: np.mean(g, axis=0),
"krum": lambda g: krum_aggregation(g, n_byzantine),
"trimmed_mean": lambda g: trimmed_mean_aggregation(g, 0.2),
"geometric_median": lambda g: geometric_median_aggregation(g),
}
results = evaluate_aggregation_robustness(honest, byzantine, methods)
print(f"Scenario: {n_byzantine}/{n_honest + n_byzantine} workers compromised")
print(f"Attack: scaling (50x)")
print()
for name, r in results.items():
status = "DEFENDED" if r["cosine_similarity"] > 0.8 else "COMPROMISED"
print(f" {name:20s}: cos_sim={r['cosine_similarity']:.3f} [{status}]")Beveiliging van communicatiekanalen
Kwetsbaarheden in NCCL en RDMA
Moderne gedistribueerde training vertrouwt op high-performance communicatiebibliotheken zoals NVIDIA NCCL, die gebruikmaken van RDMA (Remote Direct Memory Access) voor gradiëntuitwisseling met lage latentie. RDMA omzeilt de netwerkstack van het besturingssysteem, wat de prestaties verbetert maar het vermogen om netwerkniveau-beveiligingscontroles toe te passen vermindert.
"""
Communication channel security assessment for distributed training.
Analyzes the security properties of different communication backends.
"""
from dataclasses import dataclass, field
@dataclass
class CommChannelProfile:
"""Security profile for a communication channel."""
name: str
protocol: str
encryption_support: bool
authentication_support: bool
integrity_verification: bool
performance_overhead_with_security: str # "low", "medium", "high"
known_vulnerabilities: list[str] = field(default_factory=list)
COMM_PROFILES = [
CommChannelProfile(
name="NCCL over TCP",
protocol="TCP/IP",
encryption_support=False,
authentication_support=False,
integrity_verification=False,
performance_overhead_with_security="high",
known_vulnerabilities=[
"No encryption — gradients visible to network observers",
"No authentication — worker impersonation possible",
"No integrity checking — gradient modification in transit",
],
),
CommChannelProfile(
name="NCCL over InfiniBand/RDMA",
protocol="RDMA",
encryption_support=False,
authentication_support=True, # Fabric-level
integrity_verification=False,
performance_overhead_with_security="medium",
known_vulnerabilities=[
"RDMA bypasses OS network stack (no firewall)",
"Memory registration exposes device memory",
"Limited encryption support at wire speed",
"Physical access to fabric = full access",
],
),
CommChannelProfile(
name="Gloo (PyTorch)",
protocol="TCP/IP",
encryption_support=True, # TLS possible
authentication_support=True,
integrity_verification=True, # With TLS
performance_overhead_with_security="high",
known_vulnerabilities=[
"TLS not enabled by default",
"Certificate management complexity",
"Significant performance penalty with encryption",
],
),
]
def assess_communication_security(
profiles: list[CommChannelProfile],
requires_encryption: bool = True,
requires_authentication: bool = True,
) -> list[dict]:
"""Assess each communication channel against security requirements."""
assessments = []
for profile in profiles:
gaps = []
if requires_encryption and not profile.encryption_support:
gaps.append("NO_ENCRYPTION")
if requires_authentication and not profile.authentication_support:
gaps.append("NO_AUTHENTICATION")
if not profile.integrity_verification:
gaps.append("NO_INTEGRITY_CHECK")
assessments.append({
"channel": profile.name,
"security_gaps": gaps,
"compliant": len(gaps) == 0,
"risk_level": (
"critical" if len(gaps) >= 2
else "high" if len(gaps) == 1
else "acceptable"
),
})
return assessments
assessments = assess_communication_security(COMM_PROFILES)
for a in assessments:
print(f"{a['channel']}: {a['risk_level']} "
f"(gaps: {', '.join(a['security_gaps']) or 'none'})")Gedistribueerde training in de praktijk beveiligen
Defense-in-depth-architectuur
Een veilig systeem voor gedistribueerde training vereist meerdere verdedigingslagen:
-
Worker-attestatie: Voordat een worker zich bij het trainingscluster voegt, moet deze zijn identiteit en integriteit attesteren (bijv. met hardware-attestatie via TPM of SGX-enclaves).
-
Robuuste aggregatie: Gebruik Byzantijns-tolerante aggregatie (Krum, trimmed mean of geometrische mediaan) in plaats van eenvoudige middeling.
-
Gradiëntvalidatie: De gradiënt van elke worker wordt vóór aggregatie gecontroleerd aan de hand van statistische verwachtingen.
-
Communicatie-encryptie: Schakel TLS of een gelijkwaardige encryptie in op alle communicatiekanalen voor gradiënten en accepteer de prestatie-overhead als een noodzakelijke beveiligingskost.
-
Redundante berekening: Laat voor kritieke trainingsruns meerdere workers gradiënten berekenen op dezelfde datashard en vergelijk de resultaten om Byzantijns gedrag te detecteren.
| Verdedigingslaag | Beschermt tegen | Prestatiekost | Implementatiecomplexiteit |
|---|---|---|---|
| Worker-attestatie | Impersonatie, gecompromitteerde nodes | Laag | Hoog |
| Robuuste aggregatie | Byzantijnse gradiënten | Gemiddeld (2-5x) | Gemiddeld |
| Gradiëntvalidatie | Anomale updates | Laag | Laag |
| Communicatie-encryptie | Afluisteren, MITM | Hoog (20-40%) | Gemiddeld |
| Redundante berekening | Alle gradiëntaanvallen | Zeer hoog (Nx) | Laag |
Referenties
- Blanchard, P., et al. (2017). "Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent." NeurIPS 2017.
- Carlini, N., et al. (2021). "Extracting Training Data from Large Language Models." USENIX Security Symposium 2021.
- Yin, D., et al. (2018). "Byzantine-Resilient Distributed Learning: Towards Optimal Statistical Rates." ICML 2018.
- Baruch, G., et al. (2019). "A Little Is Enough: Circumventing Defenses for Distributed Learning." NeurIPS 2019.
- El-Mhamdi, E. M., et al. (2021). "Distributed Momentum for Byzantine-Resilient Stochastic Gradient Descent." ICLR 2021.