Capstone: Build a Complete AI Red Teaming Platform
Design and implement a comprehensive AI red teaming platform with automated attack orchestration, vulnerability tracking, and collaborative reporting.
Overview
This capstone project challenges you to build a production-grade AI red teaming platform from the ground up. Rather than relying on a single tool for assessments, professional AI red teams need integrated platforms that coordinate multiple attack methodologies, track findings across engagements, and produce actionable reports for stakeholders.
The platform you will build serves as the central nervous system for an AI red teaming operation. It orchestrates automated attack campaigns across multiple target models, aggregates results into a unified vulnerability database, and generates reports that map findings to risk frameworks such as OWASP Top 10 for LLM Applications and the MITRE ATLAS matrix. The architecture must be extensible so that new attack techniques and target model integrations can be added without modifying core platform code.
This project draws on the real-world experience of organizations that have built internal red teaming platforms, including Microsoft's PyRIT framework (which provides the conceptual foundation for attack orchestration) and NVIDIA's approach to continuous AI security testing. By the end of this capstone, you will have a working platform that you can use for your own engagements and extend for specific organizational needs.
Project Requirements
Functional Requirements
Core Platform Architecture
The platform must implement the following components:
-
Target Registry — A configuration system for registering AI systems under test. Each target includes connection details (API endpoints, authentication), model metadata (provider, version, capabilities), and scope constraints (which attack categories are authorized).
-
Attack Module Framework — A plugin architecture where individual attack techniques are encapsulated as modules. Each module declares its required inputs, target compatibility, and expected output format. The framework handles module discovery, dependency resolution, and parallel execution.
-
Campaign Orchestrator — A scheduler that assembles attack modules into campaigns, manages execution order (respecting dependencies and rate limits), and handles failures gracefully with retry logic and circuit breakers.
-
Result Aggregator — A data pipeline that normalizes outputs from different attack modules into a common vulnerability schema, deduplicates findings, and computes severity scores.
-
Reporting Engine — A template-driven system that generates engagement reports in multiple formats (Markdown, HTML, PDF) with embedded evidence, severity distributions, and remediation recommendations.
Non-Functional Requirements
- The platform must handle concurrent campaigns against multiple targets without interference.
- Attack modules must be sandboxed so that a failing module does not crash the entire campaign.
- All interactions with target models must be logged with full request/response pairs for reproducibility.
- The platform must support both interactive (CLI) and programmatic (API) usage patterns.
Technical Specifications
- Language: Python 3.11+
- Async Framework: asyncio with structured concurrency
- Data Storage: SQLite for development, PostgreSQL-compatible for production
- API Layer: FastAPI for the REST interface
- Configuration: YAML-based with JSON Schema validation
Implementation Guide
Phase 1: Core Data Model and Target Registry
Start by defining the data model that underpins the entire platform. Every component depends on these foundational types.
# platform/models.py
"""Core data models for the AI red teaming platform."""
from __future__ import annotations
import enum
import uuid
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class SeverityLevel(str, enum.Enum):
CRITICAL = "critical"
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
INFORMATIONAL = "informational"
class AttackCategory(str, enum.Enum):
PROMPT_INJECTION = "prompt_injection"
JAILBREAK = "jailbreak"
DATA_EXTRACTION = "data_extraction"
PRIVILEGE_ESCALATION = "privilege_escalation"
DENIAL_OF_SERVICE = "denial_of_service"
MODEL_MANIPULATION = "model_manipulation"
OUTPUT_MANIPULATION = "output_manipulation"
SUPPLY_CHAIN = "supply_chain"
class TargetConfig(BaseModel):
"""Configuration for an AI system under test."""
target_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str
description: str = ""
provider: str # e.g., "openai", "anthropic", "custom"
endpoint: str
model_identifier: str
auth_header: str = "Authorization"
auth_token_env_var: str # environment variable name, never store tokens directly
max_requests_per_minute: int = 60
authorized_categories: list[AttackCategory] = Field(
default_factory=lambda: list(AttackCategory)
)
metadata: dict[str, Any] = Field(default_factory=dict)
class VulnerabilityFinding(BaseModel):
"""A single vulnerability finding from an attack module."""
finding_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
campaign_id: str
module_name: str
target_id: str
category: AttackCategory
severity: SeverityLevel
title: str
description: str
evidence: list[EvidenceItem] = Field(default_factory=list)
reproduction_steps: list[str] = Field(default_factory=list)
remediation: str = ""
cwe_ids: list[str] = Field(default_factory=list)
atlas_technique_ids: list[str] = Field(default_factory=list)
discovered_at: datetime = Field(default_factory=datetime.utcnow)
verified: bool = False
class EvidenceItem(BaseModel):
"""Evidence supporting a vulnerability finding."""
evidence_type: str # "request_response", "screenshot", "log_excerpt"
content: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
metadata: dict[str, Any] = Field(default_factory=dict)
# Rebuild models to resolve forward references
VulnerabilityFinding.model_rebuild()
class CampaignConfig(BaseModel):
"""Configuration for a red teaming campaign."""
campaign_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str
targets: list[str] # target_ids
modules: list[str] # module names
max_concurrent_modules: int = 3
timeout_per_module_seconds: int = 300
stop_on_critical: bool = False
created_at: datetime = Field(default_factory=datetime.utcnow)Now build the target registry that manages AI systems under test:
# platform/registry.py
"""Target registry for managing AI systems under test."""
from __future__ import annotations
import os
from pathlib import Path
from typing import Optional
import yaml
from pydantic import ValidationError
from .models import TargetConfig
class TargetRegistry:
"""Manages registration and retrieval of target AI systems."""
def __init__(self, storage_dir: Path) -> None:
self._storage_dir = storage_dir
self._storage_dir.mkdir(parents=True, exist_ok=True)
self._targets: dict[str, TargetConfig] = {}
self._load_persisted_targets()
def _load_persisted_targets(self) -> None:
"""Load target configurations from YAML files on disk."""
for yaml_file in self._storage_dir.glob("*.yaml"):
try:
with open(yaml_file) as f:
data = yaml.safe_load(f)
target = TargetConfig(**data)
self._targets[target.target_id] = target
except (ValidationError, yaml.YAMLError) as exc:
# Log but do not crash — a single corrupt file should not
# prevent the registry from loading other targets.
print(f"Warning: skipping invalid target file {yaml_file}: {exc}")
def register(self, config: TargetConfig) -> str:
"""Register a new target. Returns the target_id."""
# Validate that the auth token environment variable is set
if not os.environ.get(config.auth_token_env_var):
raise ValueError(
f"Environment variable {config.auth_token_env_var!r} is not set. "
"Set it before registering this target."
)
self._targets[config.target_id] = config
self._persist_target(config)
return config.target_id
def get(self, target_id: str) -> Optional[TargetConfig]:
return self._targets.get(target_id)
def list_targets(self) -> list[TargetConfig]:
return list(self._targets.values())
def remove(self, target_id: str) -> bool:
if target_id in self._targets:
del self._targets[target_id]
target_file = self._storage_dir / f"{target_id}.yaml"
target_file.unlink(missing_ok=True)
return True
return False
def _persist_target(self, config: TargetConfig) -> None:
target_file = self._storage_dir / f"{config.target_id}.yaml"
with open(target_file, "w") as f:
yaml.dump(config.model_dump(mode="json"), f, default_flow_style=False)Phase 2: Attack Module Framework
The attack module framework uses a plugin pattern. Each module is a Python class that inherits from a base class and registers itself via a decorator.
# platform/modules/base.py
"""Base class and registry for attack modules."""
from __future__ import annotations
import abc
import asyncio
import logging
from typing import Any, ClassVar
from ..models import AttackCategory, TargetConfig, VulnerabilityFinding
logger = logging.getLogger(__name__)
# Global module registry populated by the @register_module decorator.
_MODULE_REGISTRY: dict[str, type[AttackModule]] = {}
def register_module(cls: type[AttackModule]) -> type[AttackModule]:
"""Decorator that registers an attack module class."""
if cls.name in _MODULE_REGISTRY:
raise ValueError(f"Duplicate module name: {cls.name!r}")
_MODULE_REGISTRY[cls.name] = cls
return cls
def get_module_class(name: str) -> type[AttackModule]:
if name not in _MODULE_REGISTRY:
raise KeyError(
f"Unknown module {name!r}. Available: {list(_MODULE_REGISTRY.keys())}"
)
return _MODULE_REGISTRY[name]
def list_modules() -> list[dict[str, Any]]:
return [
{
"name": cls.name,
"description": cls.description,
"categories": [c.value for c in cls.categories],
"version": cls.version,
}
for cls in _MODULE_REGISTRY.values()
]
class AttackModule(abc.ABC):
"""Base class for all attack modules."""
# Subclasses must set these class variables.
name: ClassVar[str]
description: ClassVar[str]
categories: ClassVar[list[AttackCategory]]
version: ClassVar[str] = "0.1.0"
def __init__(self, campaign_id: str, target: TargetConfig) -> None:
self.campaign_id = campaign_id
self.target = target
self.findings: list[VulnerabilityFinding] = []
self._logger = logging.getLogger(f"module.{self.name}")
@abc.abstractmethod
async def run(self) -> list[VulnerabilityFinding]:
"""Execute the attack module and return findings."""
...
async def safe_run(self) -> list[VulnerabilityFinding]:
"""Run the module with error handling so failures are contained."""
try:
return await self.run()
except asyncio.CancelledError:
self._logger.warning("Module %s was cancelled", self.name)
return self.findings # return any partial results
except Exception:
self._logger.exception("Module %s failed with an unhandled error", self.name)
return self.findingsHere is an example attack module that wraps Garak for automated probing:
# platform/modules/garak_probe.py
"""Attack module that orchestrates Garak probes against a target."""
from __future__ import annotations
import asyncio
import json
import subprocess
import tempfile
from pathlib import Path
from ..models import (
AttackCategory,
EvidenceItem,
SeverityLevel,
TargetConfig,
VulnerabilityFinding,
)
from .base import AttackModule, register_module
@register_module
class GarakProbeModule(AttackModule):
name = "garak_probe"
description = "Run Garak probes against the target model to detect known vulnerability patterns."
categories = [
AttackCategory.PROMPT_INJECTION,
AttackCategory.JAILBREAK,
AttackCategory.DATA_EXTRACTION,
]
version = "1.0.0"
# Garak probe families to execute. Expand as needed.
DEFAULT_PROBES = [
"encoding",
"dan",
"gcg",
"knownbadsignatures",
"misleading",
]
async def run(self) -> list[VulnerabilityFinding]:
"""Execute Garak probes in a subprocess and parse results."""
with tempfile.TemporaryDirectory() as tmpdir:
report_path = Path(tmpdir) / "garak_report.jsonl"
# Build the Garak CLI command.
cmd = [
"python", "-m", "garak",
"--model_type", "openai",
"--model_name", self.target.model_identifier,
"--probes", ",".join(self.DEFAULT_PROBES),
"--report_prefix", str(Path(tmpdir) / "garak_report"),
]
self._logger.info("Starting Garak: %s", " ".join(cmd))
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=self._build_env(),
)
stdout, stderr = await asyncio.wait_for(
proc.communicate(), timeout=600
)
if proc.returncode != 0:
self._logger.error("Garak exited with code %d: %s", proc.returncode, stderr.decode())
return self.findings
# Parse the JSONL report.
self._parse_garak_results(report_path)
return self.findings
def _build_env(self) -> dict[str, str]:
"""Build environment variables for the Garak subprocess."""
import os
env = os.environ.copy()
# Garak reads the OpenAI key from the environment.
# The platform stores only the env var *name* — resolve it here.
token = os.environ.get(self.target.auth_token_env_var, "")
env["OPENAI_API_KEY"] = token
return env
def _parse_garak_results(self, report_path: Path) -> None:
"""Parse Garak JSONL output into VulnerabilityFinding objects."""
if not report_path.exists():
# Garak may use a slightly different suffix; search for it.
candidates = list(report_path.parent.glob("garak_report*.jsonl"))
if not candidates:
self._logger.warning("No Garak report file found")
return
report_path = candidates[0]
with open(report_path) as f:
for line in f:
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
except json.JSONDecodeError:
continue
if record.get("status") == "fail":
severity = self._map_severity(record)
finding = VulnerabilityFinding(
campaign_id=self.campaign_id,
module_name=self.name,
target_id=self.target.target_id,
category=AttackCategory.JAILBREAK,
severity=severity,
title=f"Garak probe failed: {record.get('probe', 'unknown')}",
description=(
f"The probe '{record.get('probe')}' with detector "
f"'{record.get('detector')}' produced a failure, "
f"indicating the model did not refuse harmful content."
),
evidence=[
EvidenceItem(
evidence_type="request_response",
content=json.dumps(record, indent=2),
)
],
reproduction_steps=[
f"Run Garak with probe: {record.get('probe')}",
f"Use detector: {record.get('detector')}",
f"Input trigger: {record.get('trigger', 'see evidence')}",
],
remediation="Review the model's safety training for this attack category and add input/output guardrails.",
)
self.findings.append(finding)
@staticmethod
def _map_severity(record: dict) -> SeverityLevel:
"""Map Garak result severity to our severity levels."""
probe = record.get("probe", "").lower()
if any(kw in probe for kw in ("dan", "gcg")):
return SeverityLevel.HIGH
if "encoding" in probe:
return SeverityLevel.MEDIUM
return SeverityLevel.MEDIUMPhase 3: Campaign Orchestrator
The orchestrator assembles modules into campaigns, respects concurrency limits, and handles failures gracefully.
# platform/orchestrator.py
"""Campaign orchestrator that coordinates attack module execution."""
from __future__ import annotations
import asyncio
import logging
from datetime import datetime
from typing import Any
from .models import CampaignConfig, TargetConfig, VulnerabilityFinding
from .modules.base import get_module_class
from .registry import TargetRegistry
logger = logging.getLogger(__name__)
class CampaignResult:
"""Aggregated results from a campaign run."""
def __init__(self, campaign: CampaignConfig) -> None:
self.campaign = campaign
self.findings: list[VulnerabilityFinding] = []
self.module_statuses: dict[str, str] = {} # module_name -> "success"|"failed"|"timeout"
self.started_at: datetime = datetime.utcnow()
self.completed_at: datetime | None = None
@property
def critical_count(self) -> int:
return sum(1 for f in self.findings if f.severity.value == "critical")
def summary(self) -> dict[str, Any]:
return {
"campaign_id": self.campaign.campaign_id,
"campaign_name": self.campaign.name,
"total_findings": len(self.findings),
"by_severity": {
sev: sum(1 for f in self.findings if f.severity.value == sev)
for sev in ("critical", "high", "medium", "low", "informational")
},
"module_statuses": self.module_statuses,
"duration_seconds": (
(self.completed_at - self.started_at).total_seconds()
if self.completed_at
else None
),
}
class Orchestrator:
"""Coordinates campaign execution across targets and modules."""
def __init__(self, registry: TargetRegistry) -> None:
self._registry = registry
async def run_campaign(self, config: CampaignConfig) -> CampaignResult:
"""Execute a full campaign and return aggregated results."""
result = CampaignResult(config)
semaphore = asyncio.Semaphore(config.max_concurrent_modules)
# Resolve targets
targets: list[TargetConfig] = []
for tid in config.targets:
target = self._registry.get(tid)
if target is None:
logger.error("Target %s not found in registry — skipping", tid)
continue
targets.append(target)
if not targets:
logger.error("No valid targets found for campaign %s", config.campaign_id)
result.completed_at = datetime.utcnow()
return result
# Build the task matrix: every (module, target) pair becomes a task.
tasks: list[asyncio.Task] = []
for module_name in config.modules:
for target in targets:
task = asyncio.create_task(
self._run_module(
module_name=module_name,
target=target,
campaign_id=config.campaign_id,
semaphore=semaphore,
timeout=config.timeout_per_module_seconds,
)
)
tasks.append(task)
# Gather results, with early termination if stop_on_critical is set.
for coro in asyncio.as_completed(tasks):
module_name, findings = await coro
result.findings.extend(findings)
result.module_statuses[module_name] = "success" if findings is not None else "failed"
if config.stop_on_critical and result.critical_count > 0:
logger.warning("Critical finding detected — stopping campaign early")
for t in tasks:
t.cancel()
break
result.completed_at = datetime.utcnow()
return result
async def _run_module(
self,
module_name: str,
target: TargetConfig,
campaign_id: str,
semaphore: asyncio.Semaphore,
timeout: int,
) -> tuple[str, list[VulnerabilityFinding]]:
"""Run a single module against a single target with concurrency control."""
async with semaphore:
logger.info("Running module %s against target %s", module_name, target.name)
try:
module_cls = get_module_class(module_name)
except KeyError as exc:
logger.error(str(exc))
return module_name, []
module = module_cls(campaign_id=campaign_id, target=target)
try:
findings = await asyncio.wait_for(
module.safe_run(), timeout=timeout
)
except asyncio.TimeoutError:
logger.error("Module %s timed out after %ds", module_name, timeout)
findings = module.findings # partial results
return module_name, findingsPhase 4: Result Aggregation and Deduplication
# platform/aggregator.py
"""Result aggregation, deduplication, and severity scoring."""
from __future__ import annotations
from collections import defaultdict
from .models import SeverityLevel, VulnerabilityFinding
# CVSS-inspired weights for computing aggregate risk scores.
SEVERITY_WEIGHTS: dict[SeverityLevel, float] = {
SeverityLevel.CRITICAL: 10.0,
SeverityLevel.HIGH: 7.5,
SeverityLevel.MEDIUM: 5.0,
SeverityLevel.LOW: 2.5,
SeverityLevel.INFORMATIONAL: 0.0,
}
def deduplicate_findings(
findings: list[VulnerabilityFinding],
) -> list[VulnerabilityFinding]:
"""Remove duplicate findings based on category + title similarity."""
seen: dict[str, VulnerabilityFinding] = {}
for finding in findings:
key = f"{finding.target_id}:{finding.category.value}:{finding.title}"
if key not in seen:
seen[key] = finding
else:
# Keep the higher-severity duplicate and merge evidence.
existing = seen[key]
if SEVERITY_WEIGHTS[finding.severity] > SEVERITY_WEIGHTS[existing.severity]:
finding.evidence.extend(existing.evidence)
seen[key] = finding
else:
existing.evidence.extend(finding.evidence)
return list(seen.values())
def compute_risk_score(findings: list[VulnerabilityFinding]) -> float:
"""Compute an aggregate risk score for a set of findings (0-100 scale)."""
if not findings:
return 0.0
raw = sum(SEVERITY_WEIGHTS[f.severity] for f in findings)
# Normalize against a theoretical maximum of 20 critical findings.
return min(100.0, (raw / (20 * 10.0)) * 100.0)
def group_by_category(
findings: list[VulnerabilityFinding],
) -> dict[str, list[VulnerabilityFinding]]:
"""Group findings by attack category for reporting."""
groups: dict[str, list[VulnerabilityFinding]] = defaultdict(list)
for finding in findings:
groups[finding.category.value].append(finding)
return dict(groups)Phase 5: Reporting Engine
# platform/reporting.py
"""Report generation engine with multiple output formats."""
from __future__ import annotations
from datetime import datetime
from pathlib import Path
from typing import Any
from .aggregator import compute_risk_score, deduplicate_findings, group_by_category
from .models import VulnerabilityFinding
from .orchestrator import CampaignResult
class ReportGenerator:
"""Generates engagement reports from campaign results."""
def __init__(self, output_dir: Path) -> None:
self._output_dir = output_dir
self._output_dir.mkdir(parents=True, exist_ok=True)
def generate_markdown(self, result: CampaignResult) -> Path:
"""Generate a Markdown engagement report."""
findings = deduplicate_findings(result.findings)
risk_score = compute_risk_score(findings)
grouped = group_by_category(findings)
lines: list[str] = []
lines.append(f"# AI Red Team Engagement Report")
lines.append(f"")
lines.append(f"**Campaign**: {result.campaign.name}")
lines.append(f"**Date**: {datetime.utcnow().strftime('%Y-%m-%d')}")
lines.append(f"**Risk Score**: {risk_score:.1f}/100")
lines.append(f"**Total Findings**: {len(findings)}")
lines.append("")
# Executive summary
summary = result.summary()
lines.append("## Executive Summary")
lines.append("")
lines.append(f"This engagement tested {len(result.campaign.targets)} target(s) ")
lines.append(f"using {len(result.campaign.modules)} attack module(s). ")
lines.append(f"The assessment identified {len(findings)} unique findings ")
lines.append(f"with an aggregate risk score of {risk_score:.1f}/100.")
lines.append("")
# Severity breakdown table
lines.append("## Findings by Severity")
lines.append("")
lines.append("| Severity | Count |")
lines.append("|----------|-------|")
for sev in ("critical", "high", "medium", "low", "informational"):
count = summary["by_severity"].get(sev, 0)
lines.append(f"| {sev.capitalize()} | {count} |")
lines.append("")
# Detailed findings by category
lines.append("## Detailed Findings")
lines.append("")
for category, cat_findings in grouped.items():
lines.append(f"### {category.replace('_', ' ').title()}")
lines.append("")
for i, finding in enumerate(cat_findings, 1):
lines.append(f"#### {i}. {finding.title}")
lines.append(f"")
lines.append(f"**Severity**: {finding.severity.value.upper()}")
lines.append(f"**Module**: {finding.module_name}")
lines.append(f"")
lines.append(finding.description)
lines.append("")
if finding.reproduction_steps:
lines.append("**Reproduction Steps**:")
for step in finding.reproduction_steps:
lines.append(f"1. {step}")
lines.append("")
if finding.remediation:
lines.append(f"**Remediation**: {finding.remediation}")
lines.append("")
report_content = "\n".join(lines)
report_path = self._output_dir / f"report_{result.campaign.campaign_id}.md"
report_path.write_text(report_content)
return report_pathPhase 6: FastAPI REST Interface
# platform/api.py
"""REST API for the AI red teaming platform."""
from __future__ import annotations
import asyncio
from pathlib import Path
from fastapi import FastAPI, HTTPException
from .models import CampaignConfig, TargetConfig
from .modules.base import list_modules
from .orchestrator import CampaignResult, Orchestrator
from .registry import TargetRegistry
from .reporting import ReportGenerator
app = FastAPI(title="AI Red Team Platform", version="1.0.0")
# Initialize core components.
_registry = TargetRegistry(storage_dir=Path("./data/targets"))
_orchestrator = Orchestrator(registry=_registry)
_reporter = ReportGenerator(output_dir=Path("./data/reports"))
# In-memory store for campaign results (use a database in production).
_campaign_results: dict[str, CampaignResult] = {}
@app.get("/modules")
async def get_modules():
"""List all registered attack modules."""
return list_modules()
@app.post("/targets")
async def register_target(config: TargetConfig):
"""Register a new target AI system."""
try:
target_id = _registry.register(config)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
return {"target_id": target_id}
@app.get("/targets")
async def get_targets():
"""List all registered targets."""
return [t.model_dump(exclude={"auth_token_env_var"}) for t in _registry.list_targets()]
@app.post("/campaigns")
async def launch_campaign(config: CampaignConfig):
"""Launch a new red teaming campaign."""
# Validate that all targets and modules exist before launching.
for tid in config.targets:
if _registry.get(tid) is None:
raise HTTPException(status_code=404, detail=f"Target {tid} not found")
# Run the campaign in the background so the API returns immediately.
async def _run():
result = await _orchestrator.run_campaign(config)
_campaign_results[config.campaign_id] = result
asyncio.create_task(_run())
return {"campaign_id": config.campaign_id, "status": "launched"}
@app.get("/campaigns/{campaign_id}")
async def get_campaign_status(campaign_id: str):
"""Get the status and results of a campaign."""
result = _campaign_results.get(campaign_id)
if result is None:
return {"campaign_id": campaign_id, "status": "running"}
return result.summary()
@app.get("/campaigns/{campaign_id}/report")
async def get_campaign_report(campaign_id: str):
"""Generate and return the campaign report."""
result = _campaign_results.get(campaign_id)
if result is None:
raise HTTPException(status_code=404, detail="Campaign not found or still running")
report_path = _reporter.generate_markdown(result)
return {"report_path": str(report_path)}Phase 7: CLI Interface
# platform/cli.py
"""Command-line interface for the AI red teaming platform."""
from __future__ import annotations
import asyncio
import json
import sys
from pathlib import Path
import click
import yaml
from .models import CampaignConfig, TargetConfig
from .modules.base import list_modules
from .orchestrator import Orchestrator
from .registry import TargetRegistry
from .reporting import ReportGenerator
@click.group()
@click.option("--data-dir", default="./data", help="Data directory for targets and reports")
@click.pass_context
def cli(ctx: click.Context, data_dir: str) -> None:
"""AI Red Team Platform CLI."""
data_path = Path(data_dir)
ctx.ensure_object(dict)
ctx.obj["registry"] = TargetRegistry(storage_dir=data_path / "targets")
ctx.obj["reporter"] = ReportGenerator(output_dir=data_path / "reports")
@cli.command()
def modules() -> None:
"""List all available attack modules."""
for mod in list_modules():
click.echo(f" {mod['name']} (v{mod['version']})")
click.echo(f" {mod['description']}")
click.echo(f" Categories: {', '.join(mod['categories'])}")
click.echo()
@cli.command()
@click.argument("config_file", type=click.Path(exists=True))
@click.pass_context
def register(ctx: click.Context, config_file: str) -> None:
"""Register a target from a YAML config file."""
with open(config_file) as f:
data = yaml.safe_load(f)
config = TargetConfig(**data)
registry: TargetRegistry = ctx.obj["registry"]
target_id = registry.register(config)
click.echo(f"Registered target: {target_id}")
@cli.command()
@click.argument("campaign_file", type=click.Path(exists=True))
@click.pass_context
def run(ctx: click.Context, campaign_file: str) -> None:
"""Run a campaign from a YAML config file."""
with open(campaign_file) as f:
data = yaml.safe_load(f)
config = CampaignConfig(**data)
registry: TargetRegistry = ctx.obj["registry"]
orchestrator = Orchestrator(registry=registry)
reporter: ReportGenerator = ctx.obj["reporter"]
click.echo(f"Starting campaign: {config.name}")
result = asyncio.run(orchestrator.run_campaign(config))
click.echo(f"\nCampaign complete. Summary:")
click.echo(json.dumps(result.summary(), indent=2))
report_path = reporter.generate_markdown(result)
click.echo(f"\nReport written to: {report_path}")Evaluation Criteria
Your AI red teaming platform will be evaluated across five dimensions:
| Criterion | Weight | Excellent | Satisfactory | Needs Improvement |
|---|---|---|---|---|
| Architecture | 25% | Clean separation of concerns, extensible plugin system, well-defined interfaces | Working plugin system with minor coupling issues | Monolithic design with hardcoded components |
| Attack Coverage | 25% | 5+ modules covering diverse categories, integrates 2+ external tools | 3-4 modules with external tool integration | Fewer than 3 modules or no external tools |
| Orchestration | 20% | Async execution, concurrency control, graceful error handling, rate limiting | Basic async with error handling | Sequential execution or crashes on module failures |
| Reporting | 15% | Multiple formats, severity scoring, ATLAS/OWASP mapping, remediation guidance | Single format with severity and evidence | Raw findings dump with no structure |
| Code Quality | 15% | Type-annotated, tested, documented, follows Python best practices | Mostly typed and documented | Missing types, no tests, poor organization |
Stretch Goals
- Add a web dashboard using React or HTMX for real-time campaign monitoring.
- Implement a finding correlation engine that identifies attack chains across multiple individual findings.
- Add support for scheduled/recurring campaigns with drift detection (flag new vulnerabilities since the last run).
- Integrate with SIEM/SOAR platforms via webhook notifications.
References
- Mazeika, M., et al. (2024). "HarmBench: A Standardized Evaluation Framework for Automated Red Teaming and Robust Refusal." arXiv:2402.04249. https://arxiv.org/abs/2402.04249
- Derczynski, L., et al. (2024). "garak: A Framework for Large Language Model Red Teaming." arXiv:2406.11036. https://arxiv.org/abs/2406.11036
- OWASP Foundation. (2025). "OWASP Top 10 for LLM Applications." https://owasp.org/www-project-top-10-for-large-language-model-applications/
- MITRE. (2024). "ATLAS — Adversarial Threat Landscape for AI Systems." https://atlas.mitre.org/