GPU Side Channel Basics
GPU-based side channel attacks on ML inference, exploiting timing, power consumption, and memory access patterns to extract information about models and data.
GPUs are optimized for parallel computation but not for security isolation. In multi-tenant environments where multiple users share GPU resources (cloud inference services, shared GPU clusters), side channel attacks can extract information about co-located workloads. Timing variations in inference, GPU memory access patterns, and cache contention can reveal details about model architecture, input data, and even specific model weights. This walkthrough covers the fundamentals of GPU side channel attacks relevant to ML security.
Step 1: Understanding GPU Side Channel Surfaces
GPUs share several resources between concurrent workloads, each creating potential information leakage channels.
"""
Enumerate GPU side channel attack surfaces and their
relevance to ML inference security.
"""
from dataclasses import dataclass
@dataclass
class SideChannel:
name: str
resource_shared: str
information_leaked: str
requirements: str
practicality: str
GPU_SIDE_CHANNELS = [
SideChannel(
name="timing",
resource_shared="Compute units (SMs)",
information_leaked="Model architecture, input complexity, batch size",
requirements="Ability to time inference requests",
practicality="High - only requires API access",
),
SideChannel(
name="memory_access_patterns",
resource_shared="GPU global memory, L2 cache",
information_leaked="Model size, layer structure, attention patterns",
requirements="Co-located workload on same GPU",
practicality="Medium - requires multi-tenant access",
),
SideChannel(
name="cache_contention",
resource_shared="L1/L2 cache, shared memory",
information_leaked="Memory access patterns of co-located model",
requirements="Co-located process monitoring cache latency",
practicality="Medium - requires same GPU",
),
SideChannel(
name="power_consumption",
resource_shared="GPU power delivery",
information_leaked="Computation intensity, active SMs, data patterns",
requirements="Power monitoring (physical access or API)",
practicality="Low - requires physical proximity or RAPL access",
),
SideChannel(
name="electromagnetic_emissions",
resource_shared="EM radiation from GPU",
information_leaked="Operation types, data values",
requirements="Specialized EM measurement equipment",
practicality="Very Low - requires physical proximity and equipment",
),
SideChannel(
name="gpu_utilization_metrics",
resource_shared="GPU performance counters",
information_leaked="Workload characteristics, model complexity",
requirements="Access to nvidia-smi or NVML API",
practicality="High - metrics often exposed in shared environments",
),
]
print("GPU Side Channel Attack Surfaces")
print("=" * 60)
for sc in GPU_SIDE_CHANNELS:
print(f"\n{sc.name.upper()}")
print(f" Shared resource: {sc.resource_shared}")
print(f" Leaks: {sc.information_leaked}")
print(f" Requirements: {sc.requirements}")
print(f" Practicality: {sc.practicality}")Step 2: Timing-Based Model Fingerprinting
The most practical GPU side channel for remote attackers is timing: measuring how long inference takes reveals information about the model architecture.
"""
Use inference timing measurements to fingerprint model architecture.
Different models have distinct timing profiles based on their structure.
"""
import time
import numpy as np
from dataclasses import dataclass
from typing import Optional
@dataclass
class TimingProfile:
model_name: str
mean_latency_ms: float
std_latency_ms: float
min_latency_ms: float
max_latency_ms: float
samples: int
def measure_inference_timing(
inference_fn,
input_data,
num_samples: int = 100,
warmup: int = 10,
) -> TimingProfile:
"""Measure inference timing with statistical accuracy.
Args:
inference_fn: Callable that performs inference and returns result.
input_data: The input to pass to the inference function.
num_samples: Number of timing measurements to take.
warmup: Number of warmup runs to stabilize GPU state.
"""
# Warmup: stabilize GPU clocks and caches
for _ in range(warmup):
inference_fn(input_data)
# Measurement
latencies = []
for _ in range(num_samples):
start = time.perf_counter_ns()
inference_fn(input_data)
end = time.perf_counter_ns()
latencies.append((end - start) / 1_000_000) # Convert to ms
latencies_arr = np.array(latencies)
return TimingProfile(
model_name="unknown",
mean_latency_ms=round(float(np.mean(latencies_arr)), 3),
std_latency_ms=round(float(np.std(latencies_arr)), 3),
min_latency_ms=round(float(np.min(latencies_arr)), 3),
max_latency_ms=round(float(np.max(latencies_arr)), 3),
samples=num_samples,
)
def compare_timing_profiles(
profiles: list[TimingProfile],
target_profile: TimingProfile,
) -> list[tuple[str, float]]:
"""Compare a target timing profile against known model profiles.
Returns models ranked by timing similarity.
"""
similarities = []
for known in profiles:
# Simple distance metric based on mean and std
mean_diff = abs(known.mean_latency_ms - target_profile.mean_latency_ms)
std_diff = abs(known.std_latency_ms - target_profile.std_latency_ms)
distance = mean_diff + std_diff * 2
similarities.append((known.model_name, round(distance, 3)))
return sorted(similarities, key=lambda x: x[1])
# Example: build reference profiles
REFERENCE_PROFILES = [
TimingProfile("gpt2-small", 12.5, 1.2, 10.1, 18.3, 100),
TimingProfile("gpt2-medium", 23.7, 2.1, 19.5, 32.4, 100),
TimingProfile("gpt2-large", 45.2, 3.8, 38.1, 58.7, 100),
TimingProfile("llama-7b", 85.3, 5.2, 72.4, 102.1, 100),
TimingProfile("llama-13b", 142.7, 8.9, 121.3, 178.5, 100),
]
# Simulated target
target = TimingProfile("unknown", 24.1, 2.3, 20.2, 33.1, 50)
matches = compare_timing_profiles(REFERENCE_PROFILES, target)
print("Timing-Based Model Fingerprinting")
print("=" * 40)
print(f"Target: mean={target.mean_latency_ms}ms, std={target.std_latency_ms}ms")
print("\nClosest matches:")
for model, distance in matches[:3]:
print(f" {model}: distance={distance}")Step 3: Input-Dependent Timing Analysis
"""
Analyze how inference timing varies with different inputs
to extract information about model behavior and architecture.
"""
import time
import numpy as np
from dataclasses import dataclass
@dataclass
class InputTimingResult:
input_description: str
input_length: int
mean_latency_ms: float
std_latency_ms: float
def measure_input_dependent_timing(
inference_fn,
inputs: list[tuple[str, any]],
samples_per_input: int = 50,
) -> list[InputTimingResult]:
"""Measure how timing varies across different inputs.
Timing differences can reveal:
- Whether the model uses early stopping
- Attention pattern complexity (longer inputs may cause non-linear timing)
- Conditional computation paths
- Caching behavior
"""
results = []
for description, input_data in inputs:
latencies = []
for _ in range(samples_per_input):
start = time.perf_counter_ns()
inference_fn(input_data)
end = time.perf_counter_ns()
latencies.append((end - start) / 1_000_000)
arr = np.array(latencies)
results.append(InputTimingResult(
input_description=description,
input_length=len(str(input_data)),
mean_latency_ms=round(float(np.mean(arr)), 3),
std_latency_ms=round(float(np.std(arr)), 3),
))
return results
def analyze_timing_linearity(results: list[InputTimingResult]) -> dict:
"""Determine if timing scales linearly with input length.
Linear scaling suggests transformer attention (O(n^2)).
Sub-linear suggests efficient architectures (linear attention, SSMs).
"""
lengths = np.array([r.input_length for r in results])
timings = np.array([r.mean_latency_ms for r in results])
if len(lengths) < 3:
return {"error": "Need at least 3 data points"}
# Fit linear and quadratic models
linear_coeffs = np.polyfit(lengths, timings, 1)
quadratic_coeffs = np.polyfit(lengths, timings, 2)
# Calculate R-squared for each fit
linear_pred = np.polyval(linear_coeffs, lengths)
quad_pred = np.polyval(quadratic_coeffs, lengths)
ss_total = np.sum((timings - np.mean(timings)) ** 2)
r2_linear = 1 - np.sum((timings - linear_pred) ** 2) / ss_total
r2_quadratic = 1 - np.sum((timings - quad_pred) ** 2) / ss_total
return {
"linear_r2": round(r2_linear, 4),
"quadratic_r2": round(r2_quadratic, 4),
"likely_architecture": (
"transformer (quadratic)" if r2_quadratic > r2_linear + 0.05
else "linear attention or SSM"
),
"scaling_factor": round(linear_coeffs[0], 4),
}
print("Input-dependent timing analysis ready")Step 4: GPU Memory Analysis
"""
Monitor GPU memory patterns to infer model characteristics.
"""
import subprocess
import time
from dataclasses import dataclass
from typing import Optional
@dataclass
class GPUMemorySnapshot:
timestamp: float
memory_used_mb: int
memory_total_mb: int
gpu_utilization: int
memory_utilization: int
def capture_gpu_metrics(
duration_seconds: float = 10.0,
interval_ms: int = 100,
) -> list[GPUMemorySnapshot]:
"""Capture GPU memory and utilization metrics over time.
Uses nvidia-smi for metric collection. In production,
use NVML Python bindings for lower overhead.
"""
snapshots = []
start = time.time()
while time.time() - start < duration_seconds:
try:
result = subprocess.run(
["nvidia-smi",
"--query-gpu=memory.used,memory.total,utilization.gpu,utilization.memory",
"--format=csv,noheader,nounits"],
capture_output=True, text=True, timeout=5,
)
if result.returncode == 0:
parts = result.stdout.strip().split(", ")
if len(parts) >= 4:
snapshots.append(GPUMemorySnapshot(
timestamp=time.time() - start,
memory_used_mb=int(parts[0]),
memory_total_mb=int(parts[1]),
gpu_utilization=int(parts[2]),
memory_utilization=int(parts[3]),
))
except (subprocess.TimeoutExpired, FileNotFoundError, ValueError):
pass
time.sleep(interval_ms / 1000)
return snapshots
def infer_model_size(snapshots: list[GPUMemorySnapshot]) -> dict:
"""Infer model size from memory usage patterns."""
if not snapshots:
return {"error": "No snapshots collected"}
mem_values = [s.memory_used_mb for s in snapshots]
baseline = min(mem_values)
peak = max(mem_values)
model_mem = peak - baseline
# Estimate parameter count from memory usage
# Rule of thumb: ~2 bytes per parameter in fp16, ~4 bytes in fp32
est_params_fp16 = model_mem * 1024 * 1024 / 2
est_params_fp32 = model_mem * 1024 * 1024 / 4
return {
"baseline_memory_mb": baseline,
"peak_memory_mb": peak,
"model_memory_mb": model_mem,
"estimated_params_fp16": f"{est_params_fp16 / 1e9:.1f}B",
"estimated_params_fp32": f"{est_params_fp32 / 1e9:.1f}B",
}
print("GPU memory analysis tools ready")Step 5: Cache Timing Attacks
"""
GPU cache timing attacks that exploit L2 cache sharing
between co-located workloads.
"""
import numpy as np
import time
from dataclasses import dataclass
@dataclass
class CacheTimingResult:
access_pattern: str
mean_latency_ns: float
cache_hit_rate: float
inferred_information: str
def measure_cache_timing_pytorch(
probe_size_mb: int = 64,
iterations: int = 1000,
) -> dict:
"""Measure GPU cache timing to detect co-located workloads.
This creates a probe array that competes for GPU L2 cache
with the target workload.
"""
try:
import torch
if not torch.cuda.is_available():
return {"error": "CUDA not available"}
device = torch.device("cuda")
# Create probe tensors of various sizes
results = {}
for size_mb in [1, 4, 16, 64]:
elements = size_mb * 1024 * 1024 // 4 # float32
probe = torch.randn(elements, device=device)
# Warmup
for _ in range(10):
_ = probe.sum()
torch.cuda.synchronize()
# Measure access latency
latencies = []
for _ in range(iterations):
torch.cuda.synchronize()
start = time.perf_counter_ns()
_ = probe.sum()
torch.cuda.synchronize()
end = time.perf_counter_ns()
latencies.append(end - start)
arr = np.array(latencies)
results[f"{size_mb}MB"] = {
"mean_ns": round(float(np.mean(arr)), 1),
"std_ns": round(float(np.std(arr)), 1),
"p99_ns": round(float(np.percentile(arr, 99)), 1),
}
return results
except ImportError:
return {"error": "PyTorch not available"}
print("GPU cache timing measurement ready")Step 6: Multi-Tenant Threat Assessment
"""
Assess GPU side channel risks in multi-tenant environments.
"""
from dataclasses import dataclass
@dataclass
class TenantIsolation:
provider: str
isolation_level: str
shared_resources: list[str]
side_channel_risk: str
CLOUD_GPU_ISOLATION = [
TenantIsolation(
provider="Dedicated GPU instance",
isolation_level="Full isolation",
shared_resources=["None (dedicated hardware)"],
side_channel_risk="Low - no co-tenancy",
),
TenantIsolation(
provider="MIG (Multi-Instance GPU)",
isolation_level="Hardware partitioned",
shared_resources=["L2 cache", "Memory bandwidth"],
side_channel_risk="Medium - partitioned but some sharing",
),
TenantIsolation(
provider="Time-shared GPU",
isolation_level="Software isolation",
shared_resources=["All GPU resources", "GPU memory", "Caches", "SMs"],
side_channel_risk="High - full resource sharing",
),
TenantIsolation(
provider="vGPU (Virtual GPU)",
isolation_level="Hypervisor isolated",
shared_resources=["Physical GPU", "PCIe bandwidth"],
side_channel_risk="Medium-High - depends on vGPU implementation",
),
]
print("GPU Multi-Tenant Isolation Assessment")
print("=" * 50)
for config in CLOUD_GPU_ISOLATION:
print(f"\n{config.provider}")
print(f" Isolation: {config.isolation_level}")
print(f" Shared: {', '.join(config.shared_resources)}")
print(f" Risk: {config.side_channel_risk}")Step 7: Mitigation Techniques
"""
Mitigation techniques for GPU side channel attacks.
"""
import time
import numpy as np
from dataclasses import dataclass
@dataclass
class Mitigation:
name: str
description: str
effectiveness: str
performance_impact: str
MITIGATIONS = [
Mitigation(
name="constant_time_inference",
description="Pad all inputs to maximum length and use fixed computation paths",
effectiveness="High for timing attacks",
performance_impact="Significant - wastes compute on padding",
),
Mitigation(
name="timing_noise_injection",
description="Add random delays to inference responses",
effectiveness="Medium - requires careful calibration",
performance_impact="Moderate - adds latency",
),
Mitigation(
name="dedicated_gpu_isolation",
description="Use dedicated GPU instances with no co-tenancy",
effectiveness="High for all GPU side channels",
performance_impact="Cost increase (no sharing)",
),
Mitigation(
name="mig_partitioning",
description="Use NVIDIA MIG to partition GPUs into isolated instances",
effectiveness="Medium-High - hardware-level isolation",
performance_impact="Moderate - reduces per-tenant GPU capacity",
),
Mitigation(
name="cache_flushing",
description="Flush GPU caches between inference requests",
effectiveness="Medium for cache timing attacks",
performance_impact="Moderate - increases latency",
),
]
def add_timing_noise(
inference_fn,
input_data,
min_time_ms: float = 50.0,
jitter_ms: float = 10.0,
) -> tuple:
"""Wrapper that adds timing noise to inference calls."""
start = time.perf_counter()
result = inference_fn(input_data)
elapsed_ms = (time.perf_counter() - start) * 1000
# Pad to minimum time with random jitter
target_ms = min_time_ms + np.random.uniform(0, jitter_ms)
remaining_ms = target_ms - elapsed_ms
if remaining_ms > 0:
time.sleep(remaining_ms / 1000)
return result
print("GPU side channel mitigations")
print("=" * 50)
for m in MITIGATIONS:
print(f"\n{m.name}")
print(f" {m.description}")
print(f" Effectiveness: {m.effectiveness}")
print(f" Performance: {m.performance_impact}")Step 8: Building a Side Channel Assessment Tool
"""
Comprehensive GPU side channel assessment tool.
"""
import json
from dataclasses import dataclass, asdict
from typing import Optional
@dataclass
class AssessmentResult:
environment: str
timing_risk: str
memory_risk: str
cache_risk: str
overall_risk: str
recommendations: list[str]
class GPUSideChannelAssessor:
"""Assess GPU side channel risks in a deployment environment."""
def assess(self, environment_type: str = "cloud_shared") -> AssessmentResult:
"""Run full side channel risk assessment."""
timing_risk = self._assess_timing_risk(environment_type)
memory_risk = self._assess_memory_risk(environment_type)
cache_risk = self._assess_cache_risk(environment_type)
risk_scores = {"low": 1, "medium": 2, "high": 3}
max_risk = max(
risk_scores.get(timing_risk, 2),
risk_scores.get(memory_risk, 2),
risk_scores.get(cache_risk, 2),
)
overall = {1: "low", 2: "medium", 3: "high"}[max_risk]
recommendations = []
if timing_risk in ("medium", "high"):
recommendations.append("Implement constant-time inference or timing noise injection")
if memory_risk in ("medium", "high"):
recommendations.append("Use MIG partitioning or dedicated GPU instances")
if cache_risk in ("medium", "high"):
recommendations.append("Implement cache flushing between tenant requests")
if overall == "high":
recommendations.append("Consider dedicated GPU instances for sensitive workloads")
return AssessmentResult(
environment=environment_type,
timing_risk=timing_risk,
memory_risk=memory_risk,
cache_risk=cache_risk,
overall_risk=overall,
recommendations=recommendations,
)
def _assess_timing_risk(self, env: str) -> str:
if env in ("api_endpoint", "cloud_shared"):
return "high" # API timing is observable
return "low"
def _assess_memory_risk(self, env: str) -> str:
if env in ("cloud_shared", "multi_tenant"):
return "medium"
return "low"
def _assess_cache_risk(self, env: str) -> str:
if env in ("cloud_shared", "multi_tenant"):
return "medium"
if env == "mig_partitioned":
return "low"
return "low"
assessor = GPUSideChannelAssessor()
result = assessor.assess("cloud_shared")
print(f"Environment: {result.environment}")
print(f"Overall risk: {result.overall_risk}")
print(f"Recommendations:")
for rec in result.recommendations:
print(f" - {rec}")Related Topics
- Model Fingerprinting Techniques -- Non-GPU fingerprinting methods
- Inference Endpoint Exploitation -- API-level attacks
- Version Detection via Behavior -- Behavioral fingerprinting
Which GPU side channel is most practical for a remote attacker with only API access to an inference endpoint?