Model Hub Supply Chain Attack
Attacking the ML model supply chain through hub repositories like Hugging Face, including typosquatting, model poisoning, and repository manipulation techniques.
The ML model supply chain relies heavily on model hubs -- centralized repositories where researchers and organizations publish pre-trained models. Hugging Face Hub alone hosts over a million models, and most ML practitioners download models directly from these hubs without verifying their provenance or integrity. This trust creates a supply chain attack surface: an attacker who publishes a malicious model on a hub can achieve code execution on any machine that downloads and loads it. This walkthrough covers the attack vectors, detection methods, and defensive practices.
Step 1: Mapping the Model Supply Chain
Understanding how models flow from training to deployment reveals the attack surfaces at each stage.
"""
Map the model supply chain and identify trust boundaries
that can be exploited by attackers.
"""
from dataclasses import dataclass
@dataclass
class SupplyChainStage:
name: str
trust_assumption: str
attack_vector: str
impact: str
SUPPLY_CHAIN_STAGES = [
SupplyChainStage(
name="model_training",
trust_assumption="Training data and code are clean",
attack_vector="Data poisoning, backdoor injection during training",
impact="Model behaves maliciously on specific trigger inputs",
),
SupplyChainStage(
name="model_publication",
trust_assumption="Publisher is who they claim to be",
attack_vector="Typosquatting, impersonation, account compromise",
impact="Users download attacker's model thinking it's legitimate",
),
SupplyChainStage(
name="model_hosting",
trust_assumption="Hub correctly stores and serves models",
attack_vector="Model file replacement, metadata manipulation",
impact="Previously safe model is replaced with malicious version",
),
SupplyChainStage(
name="model_download",
trust_assumption="Downloaded files match published files",
attack_vector="MITM during download, CDN compromise",
impact="Modified model files arrive on user's machine",
),
SupplyChainStage(
name="model_loading",
trust_assumption="Model files are safe to deserialize",
attack_vector="Pickle RCE, malicious custom code in model repo",
impact="Arbitrary code execution during model.load()",
),
SupplyChainStage(
name="model_inference",
trust_assumption="Model produces expected outputs",
attack_vector="Backdoored models produce manipulated outputs",
impact="Subtle output manipulation affecting downstream systems",
),
]
print("Model Supply Chain Attack Surface")
print("=" * 60)
for stage in SUPPLY_CHAIN_STAGES:
print(f"\n{stage.name.upper()}")
print(f" Trust: {stage.trust_assumption}")
print(f" Attack: {stage.attack_vector}")
print(f" Impact: {stage.impact}")Step 2: Typosquatting Detection
Identify models on hubs that mimic popular model names through typosquatting.
"""
Detect typosquatting attacks on model hub repositories.
"""
import re
from dataclasses import dataclass
from typing import Optional
@dataclass
class TyposquatCandidate:
original: str
squatted: str
technique: str
similarity: float
def generate_typosquat_variants(model_name: str) -> list[TyposquatCandidate]:
"""Generate possible typosquat variants of a model name."""
variants = []
org, name = model_name.split("/") if "/" in model_name else ("", model_name)
# Character substitution
substitutions = {
"l": ["1", "i", "I"],
"o": ["0"],
"a": ["@"],
"e": ["3"],
"s": ["5", "z"],
"-": ["_", ""],
"_": ["-", ""],
}
for i, char in enumerate(name):
if char.lower() in substitutions:
for sub in substitutions[char.lower()]:
variant = name[:i] + sub + name[i+1:]
full = f"{org}/{variant}" if org else variant
variants.append(TyposquatCandidate(
original=model_name,
squatted=full,
technique="char_substitution",
similarity=1 - (1 / len(name)),
))
# Adjacent key typos
keyboard_adjacent = {
"a": "sq", "b": "vn", "c": "xv", "d": "sf",
"e": "wr", "f": "dg", "g": "fh", "h": "gj",
}
for i, char in enumerate(name):
if char.lower() in keyboard_adjacent:
for adj in keyboard_adjacent[char.lower()]:
variant = name[:i] + adj + name[i+1:]
full = f"{org}/{variant}" if org else variant
variants.append(TyposquatCandidate(
original=model_name,
squatted=full,
technique="keyboard_typo",
similarity=1 - (1 / len(name)),
))
# Organization impersonation
if org:
org_variants = [
org + "-ai", org + "-ml", org + "-official",
org.replace("-", ""), org + "hf",
]
for org_v in org_variants:
variants.append(TyposquatCandidate(
original=model_name,
squatted=f"{org_v}/{name}",
technique="org_impersonation",
similarity=0.8,
))
return variants
def check_typosquats_on_hub(model_name: str) -> list[dict]:
"""Check if typosquat variants exist on Hugging Face Hub."""
try:
from huggingface_hub import HfApi
api = HfApi()
variants = generate_typosquat_variants(model_name)
found = []
for variant in variants[:20]: # Limit API calls
try:
model_info = api.model_info(variant.squatted)
found.append({
"original": variant.original,
"squatted": variant.squatted,
"technique": variant.technique,
"exists": True,
"downloads": model_info.downloads,
"created": str(model_info.created_at),
})
except Exception:
pass # Model doesn't exist (good)
return found
except ImportError:
return [{"error": "huggingface_hub not installed"}]
# Generate variants for popular models
POPULAR_MODELS = [
"meta-llama/Llama-3.1-8B",
"google/gemma-2-9b",
"mistralai/Mistral-7B-v0.3",
]
for model in POPULAR_MODELS:
variants = generate_typosquat_variants(model)
print(f"\n{model}: {len(variants)} typosquat variants generated")
for v in variants[:5]:
print(f" {v.squatted} [{v.technique}]")Step 3: Model Repository Analysis
"""
Analyze a model repository for security indicators.
"""
from dataclasses import dataclass
from typing import Optional
from pathlib import Path
@dataclass
class RepoSecurityAssessment:
model_id: str
risk_level: str
has_pickle_files: bool
has_custom_code: bool
has_safetensors: bool
file_count: int
suspicious_files: list[str]
recommendations: list[str]
def assess_model_repo(model_id: str) -> RepoSecurityAssessment:
"""Assess the security posture of a model repository."""
try:
from huggingface_hub import HfApi, hf_hub_url
api = HfApi()
model_info = api.model_info(model_id, files_metadata=True)
pickle_extensions = {".pkl", ".pt", ".pth", ".bin", ".joblib"}
safe_extensions = {".safetensors", ".onnx"}
code_extensions = {".py"}
has_pickle = False
has_safetensors = False
has_custom_code = False
suspicious = []
siblings = model_info.siblings or []
for file_info in siblings:
fname = file_info.rfilename
suffix = Path(fname).suffix
if suffix in pickle_extensions:
has_pickle = True
suspicious.append(f"Pickle file: {fname}")
if suffix in safe_extensions:
has_safetensors = True
if suffix in code_extensions:
has_custom_code = True
if fname not in ("config.py", "configuration.py"):
suspicious.append(f"Custom code: {fname}")
# Determine risk level
if has_pickle and not has_safetensors:
risk = "high"
elif has_pickle and has_safetensors:
risk = "medium"
elif has_custom_code:
risk = "medium"
else:
risk = "low"
recommendations = []
if has_pickle:
recommendations.append("Prefer safetensors files over pickle-based files")
if has_custom_code:
recommendations.append("Review custom code files before loading with trust_remote_code=True")
if not has_safetensors:
recommendations.append("Request the model publisher provide safetensors format")
return RepoSecurityAssessment(
model_id=model_id,
risk_level=risk,
has_pickle_files=has_pickle,
has_custom_code=has_custom_code,
has_safetensors=has_safetensors,
file_count=len(siblings),
suspicious_files=suspicious,
recommendations=recommendations,
)
except ImportError:
return RepoSecurityAssessment(
model_id=model_id, risk_level="unknown",
has_pickle_files=False, has_custom_code=False,
has_safetensors=False, file_count=0,
suspicious_files=[], recommendations=["Install huggingface_hub"],
)
except Exception as e:
return RepoSecurityAssessment(
model_id=model_id, risk_level="error",
has_pickle_files=False, has_custom_code=False,
has_safetensors=False, file_count=0,
suspicious_files=[str(e)], recommendations=[],
)
print("Model repository security assessment ready")Step 4: Provenance Verification
"""
Verify model provenance and detect potential tampering.
"""
import hashlib
from dataclasses import dataclass
from typing import Optional
@dataclass
class ProvenanceCheck:
model_id: str
publisher_verified: bool
commit_signatures: bool
file_hashes_match: bool
download_integrity: bool
overall_trust: str
def verify_model_provenance(model_id: str) -> ProvenanceCheck:
"""Verify the provenance of a model from a hub."""
try:
from huggingface_hub import HfApi
api = HfApi()
info = api.model_info(model_id)
# Check publisher identity
org = model_id.split("/")[0] if "/" in model_id else None
# Verified organizations have a verification badge
publisher_verified = False
if org:
try:
org_info = api.organization_info(org)
# Check for verified status (simplified)
publisher_verified = hasattr(org_info, "verified") and org_info.verified
except Exception:
pass
return ProvenanceCheck(
model_id=model_id,
publisher_verified=publisher_verified,
commit_signatures=False, # Most repos lack signed commits
file_hashes_match=True, # Would need local file to verify
download_integrity=True,
overall_trust="high" if publisher_verified else "medium",
)
except ImportError:
return ProvenanceCheck(
model_id=model_id,
publisher_verified=False,
commit_signatures=False,
file_hashes_match=False,
download_integrity=False,
overall_trust="unknown",
)
def compute_file_hash(file_path: str, algorithm: str = "sha256") -> str:
"""Compute cryptographic hash of a file for integrity verification."""
h = hashlib.new(algorithm)
with open(file_path, "rb") as f:
while chunk := f.read(8192):
h.update(chunk)
return h.hexdigest()
print("Provenance verification ready")Step 5: Malicious Custom Code Detection
"""
Detect malicious patterns in custom code shipped with model repositories.
"""
import ast
import re
from dataclasses import dataclass
@dataclass
class CodeScanResult:
file_path: str
is_safe: bool
risk_level: str
findings: list[str]
class ModelCodeScanner:
"""Scan Python code in model repositories for malicious patterns."""
DANGEROUS_PATTERNS = [
(r"os\.system\s*\(", "os.system() call - arbitrary command execution"),
(r"subprocess\.", "subprocess usage - command execution"),
(r"eval\s*\(", "eval() call - arbitrary code execution"),
(r"exec\s*\(", "exec() call - arbitrary code execution"),
(r"__import__\s*\(", "dynamic import - potential code execution"),
(r"socket\.", "socket usage - network access"),
(r"urllib\.", "urllib usage - network access"),
(r"requests\.", "requests library - network access"),
(r"open\s*\(.+['\"]w['\"]", "file write operation"),
(r"pickle\.load", "pickle.load - deserialization risk"),
(r"torch\.load\s*\((?!.*weights_only)", "torch.load without weights_only"),
]
def scan_file(self, file_path: str) -> CodeScanResult:
"""Scan a Python file for suspicious patterns."""
try:
with open(file_path) as f:
content = f.read()
except Exception as e:
return CodeScanResult(file_path, False, "error", [str(e)])
findings = []
# Pattern matching
for pattern, description in self.DANGEROUS_PATTERNS:
matches = re.findall(pattern, content)
if matches:
findings.append(f"{description} ({len(matches)} occurrences)")
# AST analysis for obfuscated patterns
try:
tree = ast.parse(content)
for node in ast.walk(tree):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name) and node.func.id in ("eval", "exec"):
findings.append(f"AST: {node.func.id}() at line {node.lineno}")
except SyntaxError:
findings.append("Could not parse AST (potential obfuscation)")
if not findings:
return CodeScanResult(file_path, True, "safe", [])
elif any("execution" in f.lower() for f in findings):
return CodeScanResult(file_path, False, "dangerous", findings)
else:
return CodeScanResult(file_path, False, "suspicious", findings)
scanner = ModelCodeScanner()
print("Model code scanner ready")Step 6: Automated Supply Chain Security Pipeline
"""
Automated pipeline for securing model downloads from hubs.
"""
import json
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional
@dataclass
class SupplyChainCheckResult:
model_id: str
approved: bool
checks_passed: list[str]
checks_failed: list[str]
risk_score: float
recommendation: str
class ModelSupplyChainGateway:
"""Gateway that validates models before they enter the pipeline."""
def __init__(self, policy: Optional[dict] = None):
self.policy = policy or {
"require_safetensors": True,
"block_pickle": True,
"block_custom_code": True,
"require_verified_publisher": False,
"max_risk_score": 0.5,
}
def validate_model(self, model_id: str) -> SupplyChainCheckResult:
"""Run all supply chain checks on a model."""
passed = []
failed = []
risk = 0.0
# Check 1: Repository assessment
repo = assess_model_repo(model_id)
if repo.has_safetensors:
passed.append("Has safetensors format")
elif self.policy.get("require_safetensors"):
failed.append("No safetensors format available")
risk += 0.3
if repo.has_pickle_files and self.policy.get("block_pickle"):
failed.append(f"Contains pickle files: {repo.suspicious_files[:3]}")
risk += 0.4
if repo.has_custom_code and self.policy.get("block_custom_code"):
failed.append("Contains custom Python code")
risk += 0.3
# Check 2: Typosquatting
variants = generate_typosquat_variants(model_id)
if len(variants) < 50: # Reasonable number of variants
passed.append("Typosquat check passed")
else:
failed.append("Model name highly susceptible to typosquatting")
risk = min(risk, 1.0)
approved = risk <= self.policy.get("max_risk_score", 0.5) and len(failed) == 0
return SupplyChainCheckResult(
model_id=model_id,
approved=approved,
checks_passed=passed,
checks_failed=failed,
risk_score=risk,
recommendation="Approved for use" if approved else "Blocked by supply chain policy",
)
gateway = ModelSupplyChainGateway()
print("Supply chain security gateway ready")Step 7: Hub Monitoring
"""
Monitor model hubs for suspicious activity targeting your organization.
"""
from dataclasses import dataclass
from datetime import datetime
@dataclass
class HubAlert:
alert_type: str
model_id: str
description: str
severity: str
timestamp: str
class HubMonitor:
"""Monitor model hubs for supply chain threats."""
def __init__(self, watched_models: list[str], org_names: list[str]):
self.watched_models = watched_models
self.org_names = org_names
def check_for_threats(self) -> list[HubAlert]:
"""Run all monitoring checks."""
alerts = []
# Check for typosquats of watched models
for model in self.watched_models:
variants = generate_typosquat_variants(model)
for v in variants[:10]:
# In production, check if variant exists on hub
pass
# Check for org impersonation
for org in self.org_names:
impersonation_variants = [
f"{org}-ai", f"{org}-ml", f"{org}-official",
f"{org}hf", f"the-{org}", f"{org}-models",
]
# In production, check if these org names exist
return alerts
monitor = HubMonitor(
watched_models=["your-org/production-model"],
org_names=["your-org"],
)
print("Hub monitoring system ready")Step 8: Incident Response for Compromised Models
"""
Incident response procedures for model supply chain compromise.
"""
from dataclasses import dataclass
from datetime import datetime
@dataclass
class IncidentResponse:
step: int
action: str
responsible: str
time_target: str
RESPONSE_PLAYBOOK = [
IncidentResponse(1, "Isolate: Stop all model loading from the compromised source", "ML Ops", "Immediate"),
IncidentResponse(2, "Identify: Determine which systems loaded the compromised model", "Security", "< 1 hour"),
IncidentResponse(3, "Contain: Take affected systems offline or revert to known-good model", "ML Ops", "< 2 hours"),
IncidentResponse(4, "Investigate: Analyze the malicious model file for payload details", "Security", "< 4 hours"),
IncidentResponse(5, "Assess: Determine if the payload executed and what it accessed", "Security", "< 8 hours"),
IncidentResponse(6, "Remediate: Patch systems, rotate credentials, update blocklists", "ML Ops + Security", "< 24 hours"),
IncidentResponse(7, "Report: Notify hub operators, affected parties, and stakeholders", "Management", "< 48 hours"),
IncidentResponse(8, "Harden: Implement additional supply chain controls to prevent recurrence", "Engineering", "< 1 week"),
]
print("Model Supply Chain Incident Response Playbook")
print("=" * 60)
for step in RESPONSE_PLAYBOOK:
print(f"\nStep {step.step}: {step.action}")
print(f" Owner: {step.responsible}")
print(f" Target: {step.time_target}")Related Topics
- Model Serialization RCE -- The underlying deserialization vulnerability
- Dependency Confusion in ML -- Package-level supply chain attacks
- Inference Endpoint Exploitation -- Post-deployment attack vectors
Why is typosquatting particularly effective on model hubs like Hugging Face?