訓練 Cluster Network 安全
Network security for distributed ML training clusters including NCCL, RDMA, and InfiniBand protection.
概覽
Distributed 訓練 splits a model's computation across multiple GPUs, often spanning multiple machines connected by high-bandwidth networks. The communication between these machines — exchanging gradients, model parameters, and synchronization signals — is the lifeline of the 訓練 process and a high-value attack target.
The primary communication libraries used in distributed 訓練 are NCCL (NVIDIA Collective Communications Library), Gloo, and MPI. These libraries are optimized for raw throughput, not 安全. NCCL, the dominant library for GPU-to-GPU communication, does not provide built-in 認證 or encryption. Any process that can reach the NCCL communication port can join the 訓練 collective and inject or intercept gradient data.
The physical networks connecting 訓練 nodes — often InfiniBand or RDMA over Converged Ethernet (RoCE) — add another layer of risk. RDMA bypasses the operating system's network stack entirely, enabling kernel-bypass data transfers that are invisible to traditional network 監控 tools.
This article covers the network-level 攻擊面 of distributed 訓練 clusters, provides 評估 techniques, and details network 安全 architectures that protect 訓練 infrastructure without unacceptable performance overhead.
Training Network Architecture
Communication Topology
Distributed 訓練 uses collective communication operations:
| Operation | Description | Network Pattern |
|---|---|---|
| AllReduce | Average gradients across all workers | All-to-all or ring |
| Broadcast | Send model state from one node to all | One-to-many |
| AllGather | Collect data from all workers | All-to-all |
| ReduceScatter | Reduce + scatter across workers | All-to-all |
| Point-to-point | Direct communication between two workers | Unicast |
Training Cluster Network Topology (Ring AllReduce)
Node 0 ──────▶ Node 1
▲ │
│ ▼
Node 3 ◀────── Node 2
Each node: 8x GPU (connected via NVLink/NVSwitch)
Inter-node: InfiniBand / RoCE / TCP
Communication Backends
from typing import Dict, List
import subprocess
import socket
class TrainingNetworkDiscovery:
"""Discover 訓練 cluster network configuration."""
def discover_nccl_config(self) -> Dict:
"""Discover NCCL network configuration from environment."""
import os
nccl_vars = {}
for key, value in os.environ.items():
if key.startswith("NCCL_"):
nccl_vars[key] = value
return {
"nccl_socket_ifname": nccl_vars.get("NCCL_SOCKET_IFNAME", "not set"),
"nccl_ib_disable": nccl_vars.get("NCCL_IB_DISABLE", "0"),
"nccl_net_gdr_level": nccl_vars.get("NCCL_NET_GDR_LEVEL", "not set"),
"nccl_socket_nthreads": nccl_vars.get("NCCL_SOCKET_NTHREADS", "not set"),
"nccl_debug": nccl_vars.get("NCCL_DEBUG", "not set"),
"all_nccl_vars": nccl_vars,
}
def discover_rdma_interfaces(self) -> List[Dict]:
"""Discover RDMA-capable network interfaces."""
interfaces = []
try:
result = subprocess.run(
["ibstat"], capture_output=True, text=True, timeout=10,
)
if result.returncode == 0:
# Parse ibstat 輸出
current_port = {}
for line in result.stdout.split("\n"):
line = line.strip()
if line.startswith("CA"):
if current_port:
interfaces.append(current_port)
current_port = {"device": line}
elif ":" in line:
key, _, value = line.partition(":")
current_port[key.strip()] = value.strip()
if current_port:
interfaces.append(current_port)
except (subprocess.TimeoutExpired, FileNotFoundError):
pass
return interfaces
def discover_training_endpoints(self) -> Dict:
"""Discover distributed 訓練 endpoints from environment."""
import os
return {
"master_addr": os.environ.get("MASTER_ADDR", "not set"),
"master_port": os.environ.get("MASTER_PORT", "not set"),
"world_size": os.environ.get("WORLD_SIZE", "not set"),
"rank": os.environ.get("RANK", "not set"),
"local_rank": os.environ.get("LOCAL_RANK", "not set"),
}
def scan_training_ports(
self, target_host: str, port_range: tuple = (29400, 29510)
) -> List[Dict]:
"""Scan for open NCCL/Gloo 訓練 ports."""
open_ports = []
for port in range(port_range[0], port_range[1]):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(0.5)
try:
result = sock.connect_ex((target_host, port))
if result == 0:
open_ports.append({
"host": target_host,
"port": port,
"status": "open",
"likely_service": self._identify_service(port),
})
except socket.error:
pass
finally:
sock.close()
return open_ports
@staticmethod
def _identify_service(port: int) -> str:
"""識別 the likely service based on port number."""
if 29400 <= port <= 29500:
return "PyTorch distributed (NCCL/Gloo)"
elif port == 6379:
return "Redis (possibly torch.distributed store)"
elif port == 2379:
return "etcd (possibly torch.distributed store)"
return "unknown"NCCL 安全 漏洞
No Authentication or Encryption
NCCL provides no 認證 mechanism. Any process that connects to the NCCL communication socket can participate in collective operations. 這意味著 攻擊者 on the same network can:
- Join the 訓練 collective and inject poisoned gradients
- Intercept gradient traffic and reconstruct 訓練資料
- Disrupt 訓練 by sending malformed collective operations
- Inject a rogue worker that receives model weights and 訓練資料
import socket
import struct
from typing import Dict, Optional
class NCCLNetworkAssessment:
"""評估 NCCL network 安全 in a 訓練 cluster."""
def __init__(self, master_addr: str, master_port: int = 29500):
self.master_addr = master_addr
self.master_port = master_port
def test_nccl_accessibility(self) -> Dict:
"""測試 if NCCL master port is accessible without 認證."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)
result = sock.connect_ex((self.master_addr, self.master_port))
if result == 0:
# Try to receive initial handshake data
try:
data = sock.recv(1024)
return {
"accessible": True,
"severity": "critical",
"finding": "NCCL master port is accessible without 認證",
"received_bytes": len(data),
"data_preview": data[:64].hex(),
}
except socket.timeout:
return {
"accessible": True,
"severity": "high",
"finding": "NCCL master port is open but did not send data",
}
else:
return {"accessible": False, "severity": "info"}
except socket.error as e:
return {"accessible": False, "error": str(e)}
finally:
sock.close()
def test_gradient_interception(self) -> Dict:
"""
評估 the feasibility of gradient interception.
Checks if 訓練 traffic is visible on the network.
"""
# Check if NCCL is using TCP (interceptable) vs RDMA (harder to intercept)
import os
ib_disable = os.environ.get("NCCL_IB_DISABLE", "0")
net_plugin = os.environ.get("NCCL_NET", "")
if ib_disable == "1" or net_plugin == "Socket":
return {
"transport": "TCP/Socket",
"interceptable": True,
"severity": "high",
"finding": "NCCL using TCP transport — gradient traffic is visible to network sniffers",
"remediation": "Use IPsec or WireGuard to encrypt inter-node traffic",
}
else:
return {
"transport": "RDMA/IB",
"interceptable": False,
"severity": "medium",
"finding": "NCCL using RDMA — traffic bypasses OS network stack but is not encrypted",
"remediation": "Use InfiniBand 安全 features or physically isolate the 訓練 network",
}Gradient Poisoning via Network
If 攻擊者 can intercept and modify gradient traffic between 訓練 nodes, they can execute a gradient 投毒 attack — modifying the gradients to inject a 後門 or degrade model performance:
import torch
import torch.distributed as dist
from typing import Optional, Callable
class GradientPoisoningDefense:
"""
Detect and mitigate gradient 投毒 in distributed 訓練.
Implements gradient integrity checking between 訓練 nodes.
"""
def __init__(
self,
world_size: int,
rank: int,
clip_norm: float = 1.0,
anomaly_threshold: float = 3.0,
):
self.world_size = world_size
self.rank = rank
self.clip_norm = clip_norm
self.anomaly_threshold = anomaly_threshold
self._gradient_history: list = []
def validate_gradient(self, gradient: torch.Tensor) -> dict:
"""Validate a gradient for anomalies before aggregation."""
grad_norm = gradient.norm().item()
findings = {"valid": True, "checks": []}
# Check 1: Gradient norm within expected range
if self._gradient_history:
mean_norm = sum(self._gradient_history) / len(self._gradient_history)
if grad_norm > mean_norm * self.anomaly_threshold:
findings["valid"] = False
findings["checks"].append({
"check": "norm_anomaly",
"value": grad_norm,
"expected_range": f"<{mean_norm * self.anomaly_threshold:.4f}",
})
# Check 2: NaN/Inf 偵測
if torch.isnan(gradient).any() or torch.isinf(gradient).any():
findings["valid"] = False
findings["checks"].append({
"check": "nan_inf_detected",
"nan_count": torch.isnan(gradient).sum().item(),
"inf_count": torch.isinf(gradient).sum().item(),
})
# Check 3: Gradient direction consistency (cosine similarity with recent history)
# Dramatic direction changes may indicate injection
self._gradient_history.append(grad_norm)
if len(self._gradient_history) > 100:
self._gradient_history.pop(0)
return findings
def secure_allreduce(
self,
tensor: torch.Tensor,
op: dist.ReduceOp = dist.ReduceOp.SUM,
) -> torch.Tensor:
"""
AllReduce with gradient validation and norm clipping.
"""
# Clip gradient norm before sending
grad_norm = tensor.norm()
if grad_norm > self.clip_norm:
tensor = tensor * (self.clip_norm / grad_norm)
# Validate before aggregation
validation = self.validate_gradient(tensor)
if not validation["valid"]:
# Replace with zero tensor if validation fails
print(f"[Rank {self.rank}] Gradient validation failed: {validation['checks']}")
tensor = torch.zeros_like(tensor)
# Standard allreduce
dist.all_reduce(tensor, op=op)
tensor /= self.world_size
return tensorInfiniBand and RDMA 安全
InfiniBand Network 攻擊 Surface
InfiniBand networks provide high bandwidth and low latency for distributed 訓練, but they have their own 安全 concerns:
- No built-in encryption: InfiniBand traffic is transmitted in cleartext over the fabric
- Subnet Manager trust: The SM controls the entire fabric routing and can be a single point of compromise
- RDMA enables kernel bypass: Data transfers bypass the OS network stack, evading network 監控
- Partition keys (P_Keys): InfiniBand's access control mechanism, but rarely configured in 訓練 clusters
#!/bin/bash
# InfiniBand 安全 評估 for 訓練 clusters
echo "=== InfiniBand 安全 評估 ==="
# Check InfiniBand devices
echo "[*] InfiniBand devices:"
ibstat 2>/dev/null || echo " ibstat not available (no IB support?)"
# Check subnet manager
echo ""
echo "[*] Subnet Manager:"
sminfo 2>/dev/null || echo " sminfo not available"
# Check partition keys (access control)
echo ""
echo "[*] Partition Keys (P_Keys):"
cat /sys/class/infiniband/*/ports/*/pkeys/* 2>/dev/null | sort -u || echo " Cannot read P_Keys"
# Check for RDMA interfaces
echo ""
echo "[*] RDMA interfaces:"
rdma link show 2>/dev/null || echo " rdma tool not available"
# Check for GPUDirect RDMA
echo ""
echo "[*] GPUDirect RDMA:"
if [ -d /dev/infiniband ]; then
ls -la /dev/infiniband/
else
echo " No InfiniBand devices in /dev"
fi
# Check NCCL network configuration
echo ""
echo "[*] NCCL environment:"
env | grep NCCL_ | sort
# Check for network namespaces (isolation)
echo ""
echo "[*] Network namespaces:"
ip netns list 2>/dev/null || echo " No network namespaces configured"
# Check firewall rules on 訓練 interfaces
echo ""
echo "[*] Firewall rules for 訓練 ports:"
iptables -L -n 2>/dev/null | grep -E "2950[0-9]|NCCL" || echo " No specific rules for 訓練 ports"RDMA 安全 Concerns
RDMA (Remote Direct Memory Access) allows one machine to read or write memory on another machine without involving the remote CPU or operating system. While this provides extreme performance benefits for gradient communication, it creates 安全 concerns:
from typing import Dict, List
class RDMASecurityAssessment:
"""評估 RDMA 安全 configuration for 訓練 clusters."""
def check_rdma_permissions(self) -> Dict:
"""Check RDMA device 權限 and access controls."""
import os
from pathlib import Path
findings = []
# Check RDMA device files
rdma_devices = list(Path("/dev/infiniband").glob("*")) if Path("/dev/infiniband").exists() else []
for device in rdma_devices:
stat = os.stat(device)
if stat.st_mode & 0o006: # World readable/writable
findings.append({
"severity": "high",
"device": str(device),
"finding": "RDMA device is world-accessible",
"權限": oct(stat.st_mode)[-3:],
})
# Check for memory registration limits
try:
with open("/proc/sys/vm/max_map_count") as f:
max_map = int(f.read().strip())
if max_map > 1000000:
findings.append({
"severity": "medium",
"finding": f"High max_map_count ({max_map}) — allows large RDMA memory registrations",
})
except (FileNotFoundError, ValueError):
pass
return {"findings": findings}
def check_partition_isolation(self) -> Dict:
"""Check if InfiniBand partition keys are configured for isolation."""
from pathlib import Path
pkey_paths = list(Path("/sys/class/infiniband").glob("*/ports/*/pkeys/*"))
pkeys = set()
for pkey_path in pkey_paths:
try:
pkey = pkey_path.read_text().strip()
pkeys.add(pkey)
except (PermissionError, OSError):
pass
# Default P_Key (0x7fff or 0xffff) means no partitioning
has_non_default = any(
k not in ("0x7fff", "0xffff", "0x0000") for k in pkeys
)
return {
"partition_keys": list(pkeys),
"custom_partitions": has_non_default,
"severity": "high" if not has_non_default else "info",
"finding": (
"Only default P_Key found — no InfiniBand partition isolation"
if not has_non_default
else "Custom partition keys configured"
),
}Network Segmentation
Recommended Architecture
Training cluster networks should be segmented into at least three zones:
┌─────────────────────────────────────────────────────────┐
│ Management Network │
│ (SSH, 監控, job scheduling — Kubernetes API) │
│ 10.0.0.0/24 — firewall-protected, audit-logged │
├─────────────────────────────────────────────────────────┤
│ Training Network │
│ (NCCL, gradient communication, parameter servers) │
│ 10.1.0.0/24 — dedicated InfiniBand/RoCE, no internet │
├─────────────────────────────────────────────────────────┤
│ Storage Network │
│ (Dataset access, model checkpoint storage) │
│ 10.2.0.0/24 — NFS/Lustre/S3, read-only where possible │
└─────────────────────────────────────────────────────────┘
Network Encryption
For TCP-based NCCL communication, use WireGuard or IPsec:
#!/bin/bash
# Configure WireGuard tunnel for encrypted 訓練 traffic
# Generate keys on each 訓練 node
wg genkey | tee /etc/wireguard/privatekey | wg pubkey > /etc/wireguard/publickey
# Configuration for 訓練 node 0
cat > /etc/wireguard/wg-train.conf << 'WGEOF'
[Interface]
Address = 10.100.0.1/24
ListenPort = 51820
PrivateKey = <node0_private_key>
# Set MTU to accommodate overhead without fragmenting
MTU = 1420
# Training node 1
[Peer]
PublicKey = <node1_public_key>
AllowedIPs = 10.100.0.2/32
Endpoint = 10.1.0.2:51820
PersistentKeepalive = 25
# Training node 2
[Peer]
PublicKey = <node2_public_key>
AllowedIPs = 10.100.0.3/32
Endpoint = 10.1.0.3:51820
PersistentKeepalive = 25
WGEOF
# Bring up the tunnel
wg-quick up wg-train
# Configure NCCL to use the WireGuard interface
export NCCL_SOCKET_IFNAME=wg-train
export MASTER_ADDR=10.100.0.1Kubernetes Network Policies for Training
# NetworkPolicy: Restrict 訓練 pod communication
apiVersion: networking.k8s.io/v1
kind: NetworkPolicy
metadata:
name: 訓練-job-network-policy
namespace: ml-訓練
spec:
podSelector:
matchLabels:
job-type: distributed-訓練
policyTypes:
- Ingress
- Egress
ingress:
# Allow NCCL traffic only from other 訓練 pods in the same job
- from:
- podSelector:
matchLabels:
job-type: distributed-訓練
ports:
- protocol: TCP
port: 29500 # PyTorch distributed master
- protocol: TCP
port: 29400 # NCCL
egress:
# Allow NCCL traffic to other 訓練 pods
- to:
- podSelector:
matchLabels:
job-type: distributed-訓練
ports:
- protocol: TCP
port: 29500
- protocol: TCP
port: 29400
# Allow DNS
- to:
- namespaceSelector: {}
ports:
- protocol: UDP
port: 53
# Allow access to storage (model checkpoints, datasets)
- to:
- ipBlock:
cidr: 10.2.0.0/24 # Storage network
ports:
- protocol: TCP
port: 2049 # NFS
- protocol: TCP
port: 443 # S3Defensive Recommendations
- Physically or logically isolate the 訓練 network from management and storage networks
- Encrypt NCCL traffic using WireGuard or IPsec for TCP-based communication
- Configure InfiniBand partition keys to isolate 訓練 jobs from each other
- Apply Kubernetes NetworkPolicies to restrict 訓練 pod communication to only required peers
- Monitor 訓練 network traffic for anomalous patterns (unexpected peers, unusual data volumes)
- Restrict NCCL port access to only the IP addresses of 訓練 nodes
- Disable NCCL debug logging in production (
NCCL_DEBUG=WARNnotINFOorTRACE) - Validate gradient integrity during distributed 訓練 to detect injection attacks
- Use dedicated network interfaces for 訓練 traffic, separate from management SSH access
參考文獻
- NCCL Documentation — https://docs.nvidia.com/deeplearning/nccl/
- InfiniBand Trade Association — https://www.infinibandta.org/
- WireGuard — https://www.wireguard.com/
- Kubernetes NetworkPolicy — https://kubernetes.io/docs/concepts/services-networking/network-policies/
- MITRE ATLAS — AML.T0020 (Poisoning Training Data), AML.T0010 (ML Supply Chain Compromise)
- NIST AI RMF — Map 3.4 (AI infrastructure risks)