Triton Inference Server Security
Security hardening for NVIDIA Triton Inference Server deployments including model repository protection and API security.
Overview
NVIDIA Triton Inference Server is a widely deployed open-source serving platform that supports multiple ML frameworks (TensorFlow, PyTorch, TensorRT, ONNX Runtime, and others) and provides features like dynamic batching, model ensembles, and GPU/CPU resource management. It is used in production by organizations serving models at scale, from recommendation systems to real-time computer vision.
Triton's security posture matters because it sits directly in the inference path — all model inputs pass through it, and all model outputs are returned through it. A compromised Triton server can serve poisoned model predictions, exfiltrate inference data, or be leveraged as a pivot point for lateral movement within the infrastructure. Additionally, Triton's model repository is a supply chain target: if an attacker can replace a model in the repository, every subsequent inference request will use the compromised model.
This article covers Triton's architecture from a security perspective, identifies common misconfigurations, and provides hardening guidance for production deployments.
Triton Architecture
Components and Interfaces
Triton exposes multiple interfaces, each with its own security implications:
| Interface | Default Port | Protocol | Purpose | Security Risk |
|---|---|---|---|---|
| HTTP/REST | 8000 | HTTP | Inference requests, model management | Unauthenticated by default |
| gRPC | 8001 | gRPC | Inference requests, model management | Unauthenticated by default |
| Metrics | 8002 | HTTP (Prometheus) | Server metrics, model stats | Information disclosure |
| Model Repository | N/A | Filesystem / S3 / GCS | Model storage | Model replacement/poisoning |
import requests
from typing import Dict, List, Any, Optional
class TritonSecurityScanner:
"""Scan a Triton Inference Server for security misconfigurations."""
def __init__(self, http_url: str, metrics_url: Optional[str] = None):
self.http_url = http_url.rstrip("/")
self.metrics_url = metrics_url or f"{self.http_url.replace(':8000', ':8002')}"
def check_server_health(self) -> Dict[str, Any]:
"""Check if the server is accessible and get basic info."""
endpoints = {
"health_live": f"{self.http_url}/v2/health/live",
"health_ready": f"{self.http_url}/v2/health/ready",
"metadata": f"{self.http_url}/v2",
}
results = {}
for name, url in endpoints.items():
try:
resp = requests.get(url, timeout=5)
results[name] = {
"accessible": True,
"status_code": resp.status_code,
"body": resp.json() if resp.headers.get("content-type", "").startswith("application/json") else resp.text[:200],
}
except requests.RequestException as e:
results[name] = {"accessible": False, "error": str(e)}
return results
def enumerate_models(self) -> List[Dict]:
"""List all models loaded in Triton."""
try:
resp = requests.get(
f"{self.http_url}/v2/models", timeout=10,
)
if resp.status_code == 200:
models = resp.json().get("models", [])
return models
except requests.RequestException:
pass
# Fallback: try the repository index endpoint
try:
resp = requests.post(
f"{self.http_url}/v2/repository/index",
json={"ready": True},
timeout=10,
)
if resp.status_code == 200:
return resp.json()
except requests.RequestException:
pass
return []
def check_model_control(self) -> Dict[str, Any]:
"""Check if model loading/unloading is enabled (critical security risk)."""
try:
# Try to get repository index — this requires model control to be enabled
resp = requests.post(
f"{self.http_url}/v2/repository/index",
json={},
timeout=10,
)
model_control_enabled = resp.status_code == 200
except requests.RequestException:
model_control_enabled = False
return {
"model_control_enabled": model_control_enabled,
"severity": "high" if model_control_enabled else "info",
"finding": (
"Model control API is enabled — attacker can load/unload models"
if model_control_enabled
else "Model control API is disabled or restricted"
),
}
def check_metrics_exposure(self) -> Dict[str, Any]:
"""Check if Prometheus metrics are exposed."""
try:
resp = requests.get(f"{self.metrics_url}/metrics", timeout=5)
if resp.status_code == 200:
# Parse metrics for sensitive information
metrics_text = resp.text
model_names = set()
for line in metrics_text.split("\n"):
if "model=" in line:
# Extract model name from metrics label
start = line.find('model="') + 7
end = line.find('"', start)
if start > 6 and end > start:
model_names.add(line[start:end])
return {
"exposed": True,
"severity": "medium",
"model_names_leaked": list(model_names),
"metrics_size_bytes": len(metrics_text),
"finding": "Prometheus metrics endpoint is publicly accessible",
}
except requests.RequestException:
pass
return {"exposed": False, "severity": "info"}
def check_inference_api(self, model_name: str) -> Dict[str, Any]:
"""Check if inference can be performed without authentication."""
try:
# Get model metadata to understand input format
resp = requests.get(
f"{self.http_url}/v2/models/{model_name}",
timeout=5,
)
if resp.status_code == 200:
metadata = resp.json()
return {
"accessible": True,
"model": model_name,
"metadata": metadata,
"severity": "high",
"finding": f"Model '{model_name}' is accessible without authentication",
}
except requests.RequestException:
pass
return {"accessible": False}
def full_scan(self) -> Dict[str, Any]:
"""Run a comprehensive security scan."""
results = {
"target": self.http_url,
"server_health": self.check_server_health(),
"model_control": self.check_model_control(),
"metrics": self.check_metrics_exposure(),
"models": [],
}
models = self.enumerate_models()
for model in models[:10]: # Limit to first 10
model_name = model.get("name", model) if isinstance(model, dict) else str(model)
results["models"].append(self.check_inference_api(model_name))
return resultsDefault Configuration Weaknesses
No Authentication
Triton does not provide built-in authentication. By default, any client that can reach the HTTP or gRPC port can:
- Send inference requests to any loaded model
- Query model metadata and configuration
- If model control mode is enabled: load, unload, and reload models
- Access Prometheus metrics that reveal model names, request counts, and latency
This is documented in NVIDIA's own documentation — Triton delegates authentication and authorization to external infrastructure (reverse proxies, service meshes, or API gateways).
Model Control Mode
Triton supports three model control modes:
none: Models are loaded at startup from the model repository. No runtime changes allowed.poll: Triton periodically checks the model repository for changes and automatically loads new models.explicit: Models can be loaded and unloaded via the API.
Both poll and explicit modes are dangerous if not secured:
# Starting Triton with explicit model control — allows API-driven model loading
docker run --gpus all -p 8000:8000 -p 8001:8001 -p 8002:8002 \
-v /models:/models \
nvcr.io/nvidia/tritonserver:24.01-py3 \
tritonserver \
--model-repository=/models \
--model-control-mode=explicit
# An attacker can now load a malicious model via the API:
# curl -X POST http://target:8000/v2/repository/models/malicious_model/loadModel Repository Poisoning
If an attacker gains write access to the model repository (filesystem, S3 bucket, or GCS bucket), they can replace legitimate models with poisoned versions. With poll mode, the poisoned model will be automatically loaded:
import os
import shutil
from pathlib import Path
from typing import Dict
class ModelRepositoryAuditor:
"""Audit the security of a Triton model repository."""
def __init__(self, repo_path: str):
self.repo_path = Path(repo_path)
def audit_permissions(self) -> list:
"""Check file permissions on model repository."""
findings = []
# Check repository root permissions
repo_stat = os.stat(self.repo_path)
if repo_stat.st_mode & 0o002: # World-writable
findings.append({
"severity": "critical",
"finding": "Model repository root is world-writable",
"path": str(self.repo_path),
"remediation": "Set permissions to 750 or more restrictive",
})
# Check each model directory
for model_dir in self.repo_path.iterdir():
if not model_dir.is_dir():
continue
# Check config.pbtxt
config_path = model_dir / "config.pbtxt"
if config_path.exists():
config_stat = os.stat(config_path)
if config_stat.st_mode & 0o002:
findings.append({
"severity": "high",
"finding": f"Model config is world-writable: {model_dir.name}",
"path": str(config_path),
})
# Check model files in version directories
for version_dir in model_dir.iterdir():
if version_dir.is_dir() and version_dir.name.isdigit():
for model_file in version_dir.iterdir():
file_stat = os.stat(model_file)
if file_stat.st_mode & 0o002:
findings.append({
"severity": "critical",
"finding": f"Model file is world-writable: {model_file.name}",
"path": str(model_file),
"model": model_dir.name,
"version": version_dir.name,
})
return findings
def verify_model_integrity(self, expected_hashes: Dict[str, str]) -> list:
"""Verify model files against expected SHA-256 hashes."""
import hashlib
findings = []
for model_path_str, expected_hash in expected_hashes.items():
model_path = self.repo_path / model_path_str
if not model_path.exists():
findings.append({
"severity": "high",
"finding": f"Expected model file missing: {model_path_str}",
})
continue
sha256 = hashlib.sha256()
with open(model_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
actual_hash = sha256.hexdigest()
if actual_hash != expected_hash:
findings.append({
"severity": "critical",
"finding": f"Model file hash mismatch: {model_path_str}",
"expected": expected_hash,
"actual": actual_hash,
"detail": "Model file may have been replaced with a poisoned version",
})
else:
findings.append({
"severity": "info",
"finding": f"Model file integrity verified: {model_path_str}",
})
return findings
def check_for_pickle_models(self) -> list:
"""Check for models using pickle serialization (deserialization risk)."""
findings = []
for model_dir in self.repo_path.iterdir():
if not model_dir.is_dir():
continue
config_path = model_dir / "config.pbtxt"
if config_path.exists():
config_text = config_path.read_text()
if "pytorch" in config_text.lower() or "python" in config_text.lower():
# PyTorch models may use pickle serialization
for version_dir in model_dir.iterdir():
if version_dir.is_dir() and version_dir.name.isdigit():
for f in version_dir.iterdir():
if f.suffix in (".pt", ".pth", ".pkl", ".pickle"):
findings.append({
"severity": "high",
"finding": f"Model uses pickle serialization: {f.name}",
"model": model_dir.name,
"detail": "Pickle deserialization can execute arbitrary code. "
"Consider converting to ONNX or TensorRT format.",
})
return findingsHardening Triton Deployments
Network-Level Security
Deploy Triton behind a reverse proxy or service mesh that handles authentication:
#!/bin/bash
# Deploy Triton with Envoy sidecar for authentication and TLS
# Generate self-signed cert for demonstration (use proper CA in production)
openssl req -x509 -newkey rsa:4096 -keyout /etc/ssl/triton.key \
-out /etc/ssl/triton.crt -days 365 -nodes \
-subj "/CN=triton.internal"# Kubernetes deployment with Istio sidecar for mTLS and auth
apiVersion: apps/v1
kind: Deployment
metadata:
name: triton-inference-server
namespace: ml-serving
spec:
replicas: 2
selector:
matchLabels:
app: triton
template:
metadata:
labels:
app: triton
annotations:
sidecar.istio.io/inject: "true"
spec:
containers:
- name: triton
image: nvcr.io/nvidia/tritonserver:24.01-py3
args:
- tritonserver
- --model-repository=/models
- --model-control-mode=none # Disable runtime model changes
- --strict-model-config=true
- --rate-limit=execution_count
- --rate-limit-resource=R1:5:0 # Limit to 5 concurrent on GPU 0
ports:
- containerPort: 8000
name: http
- containerPort: 8001
name: grpc
# Do NOT expose metrics port externally
resources:
limits:
nvidia.com/gpu: 1
memory: "16Gi"
requests:
memory: "8Gi"
volumeMounts:
- name: models
mountPath: /models
readOnly: true # Model repository should be read-only
securityContext:
runAsNonRoot: true
runAsUser: 1000
readOnlyRootFilesystem: true
allowPrivilegeEscalation: false
capabilities:
drop:
- ALL
volumes:
- name: models
persistentVolumeClaim:
claimName: model-repo-pvc
readOnly: trueInput Validation
Triton does not validate inference input content beyond shape and dtype checks. Implement input validation in a preprocessing step or API gateway:
import numpy as np
from typing import Dict, Any, Optional
class TritonInputValidator:
"""Validate inference inputs before sending to Triton."""
def __init__(self, model_config: Dict[str, Any]):
self.inputs = model_config.get("input", [])
self.max_input_size_bytes = 10 * 1024 * 1024 # 10 MB default limit
def validate_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate an inference request."""
errors = []
inputs = request_data.get("inputs", [])
# Check total request size
total_size = 0
for inp in inputs:
data = inp.get("data", [])
if isinstance(data, (list, np.ndarray)):
total_size += np.array(data).nbytes
if total_size > self.max_input_size_bytes:
errors.append(
f"Request too large: {total_size} bytes exceeds limit of "
f"{self.max_input_size_bytes} bytes"
)
# Validate each input
for inp in inputs:
name = inp.get("name", "")
shape = inp.get("shape", [])
datatype = inp.get("datatype", "")
# Check for unreasonable shapes (DoS prevention)
if any(dim > 10000 for dim in shape):
errors.append(
f"Input '{name}' has unreasonable dimension: {shape}"
)
# Check for negative dimensions
if any(dim < 0 for dim in shape):
errors.append(
f"Input '{name}' has negative dimension: {shape}"
)
# Validate data values for known input types
data = inp.get("data")
if data is not None and datatype.startswith("FP"):
arr = np.array(data, dtype=np.float32)
if np.any(np.isnan(arr)) or np.any(np.isinf(arr)):
errors.append(
f"Input '{name}' contains NaN or Inf values"
)
return {
"valid": len(errors) == 0,
"errors": errors,
}Rate Limiting and Abuse Detection
import time
from collections import defaultdict
from typing import Dict, Optional
class TritonRateLimiter:
"""Rate limiting for Triton inference endpoints."""
def __init__(
self,
requests_per_second: int = 100,
burst_limit: int = 200,
per_model_limit: Optional[int] = None,
):
self.rps = requests_per_second
self.burst = burst_limit
self.per_model_limit = per_model_limit or requests_per_second
self._client_tokens: Dict[str, float] = defaultdict(lambda: float(self.burst))
self._client_last_check: Dict[str, float] = defaultdict(time.time)
self._model_counts: Dict[str, int] = defaultdict(int)
self._model_window_start: float = time.time()
def allow_request(self, client_id: str, model_name: str) -> bool:
"""Check if a request should be allowed (token bucket algorithm)."""
now = time.time()
# Replenish tokens
elapsed = now - self._client_last_check[client_id]
self._client_tokens[client_id] = min(
self.burst,
self._client_tokens[client_id] + elapsed * self.rps,
)
self._client_last_check[client_id] = now
# Check per-client rate limit
if self._client_tokens[client_id] < 1:
return False
# Check per-model rate limit (reset window every second)
if now - self._model_window_start > 1.0:
self._model_counts.clear()
self._model_window_start = now
if self._model_counts[model_name] >= self.per_model_limit:
return False
# Consume token
self._client_tokens[client_id] -= 1
self._model_counts[model_name] += 1
return TrueMonitoring and Alerting
Security-Relevant Metrics
Triton's Prometheus metrics can be used to detect anomalous behavior:
from typing import Dict, List
class TritonSecurityMonitor:
"""Monitor Triton metrics for security-relevant anomalies."""
def __init__(self, prometheus_url: str):
self.prom_url = prometheus_url
def check_anomalies(self, metrics_text: str) -> List[Dict]:
"""Parse Prometheus metrics and flag anomalies."""
alerts = []
# Parse metrics into dict
metrics = {}
for line in metrics_text.split("\n"):
if line.startswith("#") or not line.strip():
continue
parts = line.split(" ")
if len(parts) >= 2:
metrics[parts[0]] = float(parts[1])
# Check for unusually high error rates
for key, value in metrics.items():
if "error" in key.lower() and value > 100:
alerts.append({
"alert": "high_error_rate",
"metric": key,
"value": value,
"detail": "High error rate may indicate fuzzing or malformed input attacks",
})
# Check for models being loaded/unloaded (possible model swap attack)
if "model_load" in key.lower() and value > 0:
alerts.append({
"alert": "model_loaded",
"metric": key,
"value": value,
"detail": "Model was loaded at runtime — verify this was authorized",
})
return alertsDefensive Recommendations
- Use
model-control-mode=nonein production to prevent runtime model changes - Mount the model repository as read-only to prevent model poisoning from a compromised container
- Deploy behind an authenticating reverse proxy (Envoy, NGINX, Istio) — Triton has no built-in auth
- Do not expose the metrics port (8002) externally — it leaks model names and usage patterns
- Validate inference inputs in a preprocessing layer before they reach Triton
- Implement rate limiting to prevent abuse and resource exhaustion
- Use ONNX or TensorRT model formats instead of PyTorch pickle-based serialization
- Monitor model repository for unauthorized changes using file integrity monitoring
- Run Triton as non-root with minimal container capabilities
- Enable
strict-model-config=trueto require explicit model configuration files
References
- NVIDIA Triton Inference Server — https://github.com/triton-inference-server/server
- NVIDIA Triton Documentation — https://docs.nvidia.com/deeplearning/triton-inference-server/
- OWASP LLM Top 10 2025 — LLM05 (Improper Output Handling), LLM06 (Excessive Agency)
- MITRE ATLAS — AML.T0010 (ML Supply Chain Compromise), AML.T0040 (ML-Enabled Product/Service)
- NVIDIA Security Bulletins — https://www.nvidia.com/en-us/security/