Capstone: Build an AI Supply Chain Security Tool
Build a tool that scans, audits, and monitors the security of AI/ML supply chains including model provenance, dependency integrity, and artifact verification.
Overview
The AI supply chain is a rich attack surface that most organizations overlook. When you download a model from Hugging Face, import a dataset from a public repository, or install a Python package that wraps a model API, you are trusting every contributor, every dependency, and every intermediate build step. A single compromised model weight file, a backdoored tokenizer, or a malicious training data sample can compromise your entire AI system — and unlike traditional software supply chain attacks, AI supply chain compromises can be extremely subtle. A backdoored model might behave perfectly normally on standard benchmarks while containing a hidden trigger that activates only under specific conditions.
This capstone project tasks you with building a comprehensive AI supply chain security tool. The tool performs three functions: (1) scanning — analyzing ML artifacts for known vulnerabilities, unsafe serialization formats, and suspicious patterns; (2) verification — cryptographically verifying model provenance and integrity; and (3) monitoring — watching for new supply chain threats across model registries and package ecosystems.
The tool is designed for integration into MLOps pipelines, providing automated security gates that prevent compromised artifacts from reaching production. It draws on lessons from traditional software supply chain security (SLSA, Sigstore, SBOM) and applies them to the unique characteristics of ML artifacts.
Project Requirements
Functional Requirements
-
Model Artifact Scanner
- Detect unsafe serialization formats (pickle, joblib) in model files
- Scan for known malicious patterns in model weights and configurations
- Analyze tokenizer files for injection vulnerabilities
- Check for embedded executable code in model artifacts
-
Dependency Auditor
- Scan Python environments for known vulnerabilities (CVE lookup)
- Detect dependency confusion and typosquatting attacks
- Verify package integrity against published hashes
- Generate ML-SBOM (Software Bill of Materials) for model deployments
-
Provenance Verifier
- Verify model signatures and attestations
- Trace model lineage from base model through fine-tuning steps
- Validate training data provenance claims
- Check model card completeness and accuracy
-
Monitoring System
- Watch Hugging Face Hub for model modifications and new uploads from tracked organizations
- Monitor PyPI for new packages related to ML/AI
- Alert on supply chain indicators of compromise
Implementation Guide
Phase 1: Model Artifact Scanner
# supply_chain/scanner.py
"""ML model artifact scanner for supply chain security."""
from __future__ import annotations
import hashlib
import json
import logging
import os
import struct
import zipfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class ScanFinding:
"""A finding from scanning an ML artifact."""
severity: str # "critical", "high", "medium", "low", "info"
category: str
title: str
description: str
file_path: str
remediation: str
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class ScanReport:
"""Complete scan report for a model artifact."""
model_path: str
findings: list[ScanFinding] = field(default_factory=list)
file_hashes: dict[str, str] = field(default_factory=dict)
scan_metadata: dict[str, Any] = field(default_factory=dict)
@property
def has_critical(self) -> bool:
return any(f.severity == "critical" for f in self.findings)
@property
def has_high(self) -> bool:
return any(f.severity == "high" for f in self.findings)
def summary(self) -> str:
by_severity = {}
for f in self.findings:
by_severity[f.severity] = by_severity.get(f.severity, 0) + 1
return (
f"Scan of {self.model_path}: {len(self.findings)} findings "
f"({by_severity})"
)
class ModelArtifactScanner:
"""Scans ML model artifacts for supply chain security issues."""
# Known unsafe file extensions and their risks.
UNSAFE_EXTENSIONS = {
".pkl": ("critical", "Pickle files can execute arbitrary code during deserialization."),
".pickle": ("critical", "Pickle files can execute arbitrary code during deserialization."),
".joblib": ("high", "Joblib files use pickle internally and can execute arbitrary code."),
".npy": ("medium", "NumPy .npy files can contain pickle objects when allow_pickle=True."),
".npz": ("medium", "NumPy .npz archives may contain pickled objects."),
".pt": ("high", "PyTorch .pt files use pickle by default. Use safetensors instead."),
".pth": ("high", "PyTorch .pth files use pickle by default. Use safetensors instead."),
".bin": ("medium", "Binary files require inspection to determine serialization format."),
}
# Safe formats that should be preferred.
SAFE_EXTENSIONS = {
".safetensors": "SafeTensors format — safe tensor serialization.",
".onnx": "ONNX format — safe graph serialization.",
".json": "JSON configuration — safe text format.",
".txt": "Text file — safe format.",
}
# Pickle opcodes that indicate code execution.
DANGEROUS_PICKLE_OPCODES = {
b"\x8c": "SHORT_BINUNICODE (string, often used for module names)",
b"c": "GLOBAL (imports a module — potential code execution)",
b"\x93": "STACK_GLOBAL (constructs a class — potential code execution)",
b"R": "REDUCE (calls a function — direct code execution)",
b"i": "INST (instantiates a class)",
b"o": "OBJ (builds an object)",
b"b": "BUILD (calls __setstate__)",
}
def scan_directory(self, model_dir: Path) -> ScanReport:
"""Scan all files in a model directory."""
report = ScanReport(model_path=str(model_dir))
if not model_dir.exists():
report.findings.append(ScanFinding(
severity="critical",
category="missing",
title="Model directory does not exist",
description=f"Path {model_dir} does not exist.",
file_path=str(model_dir),
remediation="Verify the model path and download the model.",
))
return report
for file_path in model_dir.rglob("*"):
if file_path.is_file():
# Compute hash for integrity tracking.
file_hash = self._compute_hash(file_path)
report.file_hashes[str(file_path.relative_to(model_dir))] = file_hash
# Check file extension safety.
self._check_extension(file_path, report)
# Deep scan specific file types.
if file_path.suffix in (".pkl", ".pickle", ".pt", ".pth"):
self._scan_pickle_file(file_path, report)
if file_path.suffix == ".json":
self._scan_json_config(file_path, report)
if file_path.name == "tokenizer.json":
self._scan_tokenizer(file_path, report)
return report
def _compute_hash(self, file_path: Path) -> str:
"""Compute SHA-256 hash of a file."""
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest()
def _check_extension(self, file_path: Path, report: ScanReport) -> None:
"""Check if a file uses an unsafe serialization format."""
ext = file_path.suffix.lower()
if ext in self.UNSAFE_EXTENSIONS:
severity, description = self.UNSAFE_EXTENSIONS[ext]
report.findings.append(ScanFinding(
severity=severity,
category="unsafe_serialization",
title=f"Unsafe serialization format: {ext}",
description=description,
file_path=str(file_path),
remediation=f"Convert to SafeTensors format using `safetensors.torch.save_file()`.",
))
def _scan_pickle_file(self, file_path: Path, report: ScanReport) -> None:
"""Deep scan a pickle file for dangerous opcodes."""
try:
with open(file_path, "rb") as f:
content = f.read(1_000_000) # Read first 1MB
dangerous_found = []
for opcode, description in self.DANGEROUS_PICKLE_OPCODES.items():
if opcode in content:
dangerous_found.append(description)
# Check for known malicious patterns.
malicious_patterns = [
(b"os.system", "os.system call — shell command execution"),
(b"subprocess", "subprocess module — command execution"),
(b"eval", "eval() — arbitrary code execution"),
(b"exec", "exec() — arbitrary code execution"),
(b"__import__", "__import__() — dynamic module import"),
(b"requests.get", "HTTP request — potential data exfiltration"),
(b"urllib", "URL handling — potential data exfiltration"),
(b"socket", "Socket operations — network access"),
]
for pattern, description in malicious_patterns:
if pattern in content:
report.findings.append(ScanFinding(
severity="critical",
category="malicious_code",
title=f"Potentially malicious code in pickle file",
description=f"Found {description} in {file_path.name}. "
f"This strongly indicates the file contains executable code.",
file_path=str(file_path),
remediation="Do NOT load this file. Inspect it manually with pickletools. "
"If the model is from an untrusted source, discard it.",
))
if dangerous_found and not any(f.category == "malicious_code" for f in report.findings):
report.findings.append(ScanFinding(
severity="high",
category="dangerous_opcodes",
title=f"Dangerous pickle opcodes detected",
description=f"Found {len(dangerous_found)} dangerous opcode types: "
f"{', '.join(dangerous_found[:5])}",
file_path=str(file_path),
remediation="Inspect the pickle file with pickletools before loading. "
"Consider converting to SafeTensors.",
))
except OSError as exc:
logger.warning("Could not read pickle file %s: %s", file_path, exc)
def _scan_json_config(self, file_path: Path, report: ScanReport) -> None:
"""Scan JSON configuration files for suspicious entries."""
try:
with open(file_path) as f:
config = json.load(f)
except (json.JSONDecodeError, OSError):
return
# Check for custom code flags that allow arbitrary code execution.
if config.get("auto_map") or config.get("trust_remote_code"):
report.findings.append(ScanFinding(
severity="high",
category="remote_code_execution",
title="Model requires custom/remote code execution",
description=f"Config file {file_path.name} contains 'auto_map' or "
f"'trust_remote_code', which allows executing arbitrary "
f"Python code from the model repository.",
file_path=str(file_path),
remediation="Review the custom code manually before enabling trust_remote_code. "
"Prefer models that use standard architectures.",
))
# Check for suspicious URLs in config.
self._check_urls_in_config(config, file_path, report)
def _check_urls_in_config(
self, config: Any, file_path: Path, report: ScanReport, depth: int = 0
) -> None:
"""Recursively check for suspicious URLs in configuration."""
if depth > 10:
return
if isinstance(config, str):
if config.startswith("http") and not any(
trusted in config
for trusted in [
"huggingface.co", "github.com", "pytorch.org",
"tensorflow.org", "googleapis.com",
]
):
report.findings.append(ScanFinding(
severity="medium",
category="suspicious_url",
title="Untrusted URL in model configuration",
description=f"Found URL '{config[:100]}' pointing to a non-standard domain.",
file_path=str(file_path),
remediation="Verify the URL is legitimate and not a phishing or data exfiltration endpoint.",
))
elif isinstance(config, dict):
for v in config.values():
self._check_urls_in_config(v, file_path, report, depth + 1)
elif isinstance(config, list):
for item in config:
self._check_urls_in_config(item, file_path, report, depth + 1)
def _scan_tokenizer(self, file_path: Path, report: ScanReport) -> None:
"""Scan tokenizer files for injection vulnerabilities."""
try:
with open(file_path) as f:
tokenizer_data = json.load(f)
except (json.JSONDecodeError, OSError):
return
# Check for unusual special tokens that could be used for injection.
added_tokens = tokenizer_data.get("added_tokens", [])
for token in added_tokens:
content = token.get("content", "")
if any(marker in content.lower() for marker in [
"<|im_start|>", "<|im_end|>", "[inst]", "[/inst]",
"<<sys>>", "<|system|>",
]):
# These are expected for chat models — check if the content
# includes unexpected payloads.
if len(content) > 50:
report.findings.append(ScanFinding(
severity="medium",
category="tokenizer_anomaly",
title="Unusually long special token in tokenizer",
description=f"Special token '{content[:50]}...' is suspiciously long ({len(content)} chars). "
f"It may contain embedded injection payloads.",
file_path=str(file_path),
remediation="Inspect the tokenizer special tokens and compare against the official model release.",
))Phase 2: Dependency Auditor
# supply_chain/dependencies.py
"""Dependency auditing for ML pipeline environments."""
from __future__ import annotations
import json
import logging
import subprocess
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class DependencyInfo:
"""Information about a single Python package dependency."""
name: str
version: str
source: str # "pypi", "git", "local", "unknown"
hash_verified: bool = False
known_vulnerabilities: list[dict[str, str]] = field(default_factory=list)
is_ml_package: bool = False
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class MLSBOM:
"""Machine Learning Software Bill of Materials."""
generated_at: datetime
model_name: str
model_version: str
python_version: str
dependencies: list[DependencyInfo]
model_artifacts: list[dict[str, str]] # file name + hash
training_data_sources: list[str] = field(default_factory=list)
base_model: str = ""
def to_json(self) -> str:
return json.dumps({
"mlsbom_version": "1.0",
"generated_at": self.generated_at.isoformat(),
"model": {
"name": self.model_name,
"version": self.model_version,
"base_model": self.base_model,
},
"runtime": {
"python_version": self.python_version,
},
"dependencies": [
{
"name": d.name,
"version": d.version,
"source": d.source,
"hash_verified": d.hash_verified,
"vulnerabilities": d.known_vulnerabilities,
}
for d in self.dependencies
],
"artifacts": self.model_artifacts,
"data_sources": self.training_data_sources,
}, indent=2)
class DependencyAuditor:
"""Audits Python dependencies in ML environments."""
# Known ML packages that are security-sensitive.
ML_PACKAGES = {
"torch", "tensorflow", "transformers", "huggingface-hub",
"tokenizers", "safetensors", "onnx", "onnxruntime",
"numpy", "scipy", "scikit-learn", "pandas",
"langchain", "llama-index", "openai", "anthropic",
"vllm", "triton", "accelerate", "peft", "trl",
"datasets", "evaluate", "sentencepiece",
}
# Known typosquatting targets.
TYPOSQUAT_PATTERNS = {
"transformers": ["transfomers", "trasformers", "transformrs"],
"torch": ["toch", "troch", "pytorh"],
"tensorflow": ["tenserflow", "tensorfow", "tensroflow"],
"openai": ["opeanai", "openal", "opeani"],
"numpy": ["numpi", "numppy", "nunpy"],
}
def audit_environment(self, requirements_path: Path | None = None) -> list[DependencyInfo]:
"""Audit the current Python environment or a requirements file."""
if requirements_path:
return self._audit_requirements_file(requirements_path)
return self._audit_installed_packages()
def _audit_installed_packages(self) -> list[DependencyInfo]:
"""Audit currently installed packages using pip."""
result = subprocess.run(
["pip", "list", "--format=json"],
capture_output=True,
text=True,
)
if result.returncode != 0:
logger.error("pip list failed: %s", result.stderr)
return []
packages = json.loads(result.stdout)
deps: list[DependencyInfo] = []
for pkg in packages:
name = pkg["name"].lower()
dep = DependencyInfo(
name=name,
version=pkg["version"],
source="pypi",
is_ml_package=name in self.ML_PACKAGES,
)
# Check for typosquatting.
self._check_typosquat(dep)
deps.append(dep)
return deps
def _audit_requirements_file(self, path: Path) -> list[DependencyInfo]:
"""Audit a requirements.txt file."""
deps: list[DependencyInfo] = []
with open(path) as f:
for line in f:
line = line.strip()
if not line or line.startswith("#") or line.startswith("-"):
continue
# Parse package==version or package>=version.
parts = line.split("==")
if len(parts) == 2:
name, version = parts[0].strip().lower(), parts[1].strip()
else:
parts = line.split(">=")
name = parts[0].strip().lower()
version = parts[1].strip() if len(parts) == 2 else "unknown"
dep = DependencyInfo(
name=name,
version=version,
source="pypi",
is_ml_package=name in self.ML_PACKAGES,
)
self._check_typosquat(dep)
deps.append(dep)
return deps
def _check_typosquat(self, dep: DependencyInfo) -> None:
"""Check if a package name might be a typosquat."""
for target, typos in self.TYPOSQUAT_PATTERNS.items():
if dep.name in typos:
dep.known_vulnerabilities.append({
"type": "typosquatting",
"severity": "critical",
"description": f"Package '{dep.name}' may be a typosquat of '{target}'. "
f"Verify the package is legitimate before use.",
})
def check_vulnerabilities(self, deps: list[DependencyInfo]) -> list[DependencyInfo]:
"""Check dependencies against known vulnerability databases using pip-audit."""
try:
result = subprocess.run(
["pip-audit", "--format=json", "--progress-spinner=off"],
capture_output=True,
text=True,
timeout=120,
)
if result.stdout:
audit_data = json.loads(result.stdout)
vuln_map: dict[str, list[dict]] = {}
for vuln in audit_data.get("vulnerabilities", []):
pkg = vuln.get("name", "").lower()
vuln_map.setdefault(pkg, []).append({
"type": "cve",
"severity": "high",
"description": vuln.get("description", ""),
"fix_version": vuln.get("fix_versions", []),
"id": vuln.get("id", ""),
})
for dep in deps:
if dep.name in vuln_map:
dep.known_vulnerabilities.extend(vuln_map[dep.name])
except (subprocess.TimeoutExpired, FileNotFoundError, json.JSONDecodeError) as exc:
logger.warning("pip-audit check failed: %s", exc)
return deps
def generate_mlsbom(
self,
model_name: str,
model_version: str,
model_artifacts: list[dict[str, str]],
deps: list[DependencyInfo],
) -> MLSBOM:
"""Generate an ML-SBOM for a model deployment."""
import sys
return MLSBOM(
generated_at=datetime.utcnow(),
model_name=model_name,
model_version=model_version,
python_version=f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
dependencies=deps,
model_artifacts=model_artifacts,
)Phase 3: Provenance Verifier
# supply_chain/provenance.py
"""Model provenance verification and lineage tracking."""
from __future__ import annotations
import hashlib
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class ProvenanceClaim:
"""A claim about a model's provenance."""
claim_type: str # "origin", "training", "fine_tuning", "conversion", "deployment"
description: str
timestamp: datetime | None = None
actor: str = "" # who performed this step
evidence: dict[str, Any] = field(default_factory=dict)
verified: bool = False
verification_method: str = ""
@dataclass
class ModelLineage:
"""Complete lineage record for a model."""
model_id: str
model_name: str
claims: list[ProvenanceClaim] = field(default_factory=list)
integrity_hashes: dict[str, str] = field(default_factory=dict)
base_model: str = ""
verification_status: str = "unverified" # "verified", "partial", "unverified", "failed"
def add_claim(self, claim: ProvenanceClaim) -> None:
self.claims.append(claim)
def verification_summary(self) -> dict[str, Any]:
total = len(self.claims)
verified = sum(1 for c in self.claims if c.verified)
return {
"model_id": self.model_id,
"total_claims": total,
"verified_claims": verified,
"verification_rate": verified / max(total, 1),
"status": self.verification_status,
}
class ProvenanceVerifier:
"""Verifies model provenance claims and tracks lineage."""
def verify_huggingface_model(
self, repo_id: str, local_path: Path
) -> ModelLineage:
"""Verify a model downloaded from the Hugging Face Hub."""
lineage = ModelLineage(
model_id=repo_id,
model_name=repo_id.split("/")[-1],
)
# Step 1: Verify model card exists and is complete.
model_card_path = local_path / "README.md"
if model_card_path.exists():
self._verify_model_card(model_card_path, lineage)
else:
lineage.add_claim(ProvenanceClaim(
claim_type="documentation",
description="No model card (README.md) found.",
verified=False,
verification_method="file_existence",
))
# Step 2: Verify config.json matches expected architecture.
config_path = local_path / "config.json"
if config_path.exists():
self._verify_config(config_path, lineage)
# Step 3: Compute and record integrity hashes.
for file_path in local_path.rglob("*"):
if file_path.is_file():
hash_val = self._hash_file(file_path)
rel_path = str(file_path.relative_to(local_path))
lineage.integrity_hashes[rel_path] = hash_val
# Step 4: Check for signature files.
self._verify_signatures(local_path, lineage)
# Compute overall status.
verified_count = sum(1 for c in lineage.claims if c.verified)
total = len(lineage.claims)
if total == 0:
lineage.verification_status = "unverified"
elif verified_count == total:
lineage.verification_status = "verified"
elif verified_count > 0:
lineage.verification_status = "partial"
else:
lineage.verification_status = "failed"
return lineage
def _verify_model_card(self, path: Path, lineage: ModelLineage) -> None:
"""Check model card for required provenance fields."""
content = path.read_text()
required_sections = [
("model_description", ["## Model Description", "# Model", "## About"]),
("training_data", ["## Training Data", "## Dataset", "training data"]),
("intended_use", ["## Intended Use", "## Uses", "intended use"]),
("limitations", ["## Limitations", "## Known Limitations", "limitations"]),
("license", ["license:", "## License"]),
]
for section_name, markers in required_sections:
found = any(m.lower() in content.lower() for m in markers)
lineage.add_claim(ProvenanceClaim(
claim_type="documentation",
description=f"Model card section: {section_name}",
verified=found,
verification_method="content_search",
evidence={"section": section_name, "found": found},
))
def _verify_config(self, path: Path, lineage: ModelLineage) -> None:
"""Verify model configuration."""
try:
config = json.loads(path.read_text())
has_architecture = "architectures" in config or "model_type" in config
lineage.add_claim(ProvenanceClaim(
claim_type="configuration",
description="Model architecture declared in config.json",
verified=has_architecture,
verification_method="config_inspection",
evidence={
"architectures": config.get("architectures", []),
"model_type": config.get("model_type", ""),
},
))
# Check for trust_remote_code.
if config.get("auto_map") or config.get("trust_remote_code"):
lineage.add_claim(ProvenanceClaim(
claim_type="security",
description="Model requires trust_remote_code — arbitrary code execution risk",
verified=False, # This is a negative verification.
verification_method="config_inspection",
))
except (json.JSONDecodeError, OSError):
lineage.add_claim(ProvenanceClaim(
claim_type="configuration",
description="config.json is missing or malformed",
verified=False,
))
def _verify_signatures(self, local_path: Path, lineage: ModelLineage) -> None:
"""Check for and verify model signatures."""
sig_files = list(local_path.glob("*.sig")) + list(local_path.glob("*.asc"))
if sig_files:
lineage.add_claim(ProvenanceClaim(
claim_type="signature",
description=f"Found {len(sig_files)} signature file(s).",
verified=True, # File exists; full verification requires GPG.
verification_method="file_existence",
evidence={"signature_files": [str(f.name) for f in sig_files]},
))
else:
lineage.add_claim(ProvenanceClaim(
claim_type="signature",
description="No signature files found. Model integrity cannot be cryptographically verified.",
verified=False,
verification_method="file_existence",
))
@staticmethod
def _hash_file(path: Path) -> str:
sha256 = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest()Phase 4: CLI and CI/CD Integration
# supply_chain/cli.py
"""CLI interface for the AI supply chain security tool."""
from __future__ import annotations
import json
import sys
from pathlib import Path
import click
from .dependencies import DependencyAuditor
from .provenance import ProvenanceVerifier
from .scanner import ModelArtifactScanner
@click.group()
def cli():
"""AI Supply Chain Security Tool — scan, audit, and verify ML artifacts."""
pass
@cli.command()
@click.argument("model_path", type=click.Path(exists=True))
@click.option("--format", "fmt", type=click.Choice(["text", "json"]), default="text")
@click.option("--fail-on", type=click.Choice(["critical", "high", "medium"]), default="high")
def scan(model_path: str, fmt: str, fail_on: str):
"""Scan a model directory for supply chain security issues."""
scanner = ModelArtifactScanner()
report = scanner.scan_directory(Path(model_path))
if fmt == "json":
click.echo(json.dumps({
"model_path": report.model_path,
"findings": [
{
"severity": f.severity,
"category": f.category,
"title": f.title,
"description": f.description,
"file_path": f.file_path,
"remediation": f.remediation,
}
for f in report.findings
],
"file_hashes": report.file_hashes,
}, indent=2))
else:
click.echo(report.summary())
click.echo()
for finding in report.findings:
click.echo(f" [{finding.severity.upper()}] {finding.title}")
click.echo(f" File: {finding.file_path}")
click.echo(f" {finding.description}")
click.echo(f" Fix: {finding.remediation}")
click.echo()
# Exit with non-zero if findings at fail-on severity.
severity_order = ["medium", "high", "critical"]
fail_index = severity_order.index(fail_on)
blocking = [
f for f in report.findings
if f.severity in severity_order[fail_index:]
]
if blocking:
click.echo(f"{len(blocking)} blocking finding(s) found.")
sys.exit(1)
@cli.command()
@click.option("--requirements", type=click.Path(exists=True), help="Path to requirements.txt")
@click.option("--check-vulns", is_flag=True, help="Check against vulnerability databases")
def audit(requirements: str | None, check_vulns: bool):
"""Audit Python dependencies for supply chain risks."""
auditor = DependencyAuditor()
req_path = Path(requirements) if requirements else None
deps = auditor.audit_environment(req_path)
if check_vulns:
deps = auditor.check_vulnerabilities(deps)
vuln_count = sum(len(d.known_vulnerabilities) for d in deps)
ml_count = sum(1 for d in deps if d.is_ml_package)
click.echo(f"Scanned {len(deps)} packages ({ml_count} ML-related)")
click.echo(f"Vulnerabilities found: {vuln_count}")
for dep in deps:
if dep.known_vulnerabilities:
click.echo(f"\n {dep.name}=={dep.version}")
for vuln in dep.known_vulnerabilities:
click.echo(f" [{vuln['severity'].upper()}] {vuln['type']}: {vuln['description'][:80]}")
if vuln_count > 0:
sys.exit(1)
@cli.command()
@click.argument("model_path", type=click.Path(exists=True))
@click.option("--repo-id", required=True, help="Hugging Face repo ID (e.g., org/model-name)")
def verify(model_path: str, repo_id: str):
"""Verify model provenance and integrity."""
verifier = ProvenanceVerifier()
lineage = verifier.verify_huggingface_model(repo_id, Path(model_path))
summary = lineage.verification_summary()
click.echo(f"Model: {lineage.model_name}")
click.echo(f"Status: {summary['status'].upper()}")
click.echo(f"Claims: {summary['verified_claims']}/{summary['total_claims']} verified")
click.echo()
for claim in lineage.claims:
status = "PASS" if claim.verified else "FAIL"
click.echo(f" [{status}] {claim.claim_type}: {claim.description}")
if summary["status"] == "failed":
sys.exit(1)Evaluation Criteria
| Criterion | Weight | Excellent | Satisfactory | Needs Improvement |
|---|---|---|---|---|
| Artifact Scanning | 25% | Deep pickle analysis, tokenizer inspection, config validation, multi-format support | Extension checking with basic pickle scanning | Extension-only checking |
| Dependency Auditing | 20% | CVE lookup, typosquatting detection, ML-SBOM generation | Basic vulnerability scanning | Package listing only |
| Provenance Verification | 25% | Model card validation, signature verification, lineage tracking | Basic provenance checks | No provenance capability |
| CI/CD Integration | 15% | CLI with exit codes, JSON output, configurable thresholds | CLI with basic output | No automation support |
| Code Quality | 15% | Modular design, comprehensive error handling, tested | Working code with some structure | Monolithic, fragile code |
Stretch Goals
- Integrate with Sigstore/cosign for cryptographic model signature verification.
- Build a model registry proxy that transparently scans artifacts on download.
- Implement SLSA provenance attestation generation and verification for ML pipelines.
- Add support for scanning ONNX models for graph-level attacks (malicious custom operators).
References
- Gu, T., et al. (2019). "BadNets: Evaluating Backdooring Attacks on Deep Neural Networks." IEEE Access. https://arxiv.org/abs/1708.06733
- SLSA. (2024). "Supply-chain Levels for Software Artifacts." https://slsa.dev/
- Hugging Face. (2024). "Security at Hugging Face." https://huggingface.co/docs/hub/security