Threat Modeling for AI Infrastructure Using STRIDE
Systematic threat modeling methodology for AI/ML systems using STRIDE, data flow diagrams, and attack trees tailored to machine learning pipelines
Overview
Threat modeling is the structured process of identifying security threats, understanding their potential impact, and prioritizing mitigations before an attacker exploits them. For AI infrastructure, threat modeling is essential because AI systems have unique properties that traditional threat models fail to capture: the training data is part of the attack surface, the model itself can be a target or a weapon, the behavior of the system changes with each retraining cycle, and the boundary between "intended behavior" and "vulnerability" is often unclear for learned systems.
Microsoft's STRIDE framework — Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, and Elevation of Privilege — provides a systematic approach to threat identification that maps well to AI infrastructure when properly adapted. However, naive application of STRIDE to AI systems misses critical threats. For example, data poisoning is a form of "Tampering" that targets the training process rather than a runtime input, and model extraction through inference APIs is "Information Disclosure" that occurs through legitimate use of the system rather than through a vulnerability in the traditional sense.
This article presents a comprehensive threat modeling methodology for AI infrastructure that extends STRIDE with AI-specific threat categories, demonstrates how to create data flow diagrams for common ML architectures, and provides templates and tools for conducting and documenting threat modeling sessions. The methodology aligns with NIST AI RMF, MITRE ATLAS, and the OWASP Machine Learning Security Top 10.
STRIDE Applied to AI Infrastructure
Spoofing in AI Systems
Spoofing threats involve an attacker pretending to be something or someone they are not. In AI infrastructure, spoofing takes several unique forms:
| Spoofing Target | AI-Specific Threat | Impact |
|---|---|---|
| Data source | Inject training data from a spoofed data provider | Model learns from attacker-controlled data |
| Model artifact | Replace legitimate model with attacker's model | Inference produces attacker-controlled outputs |
| Inference client | Impersonate an authorized client to access models | Unauthorized model access, potential extraction |
| Pipeline stage | Spoof a training pipeline component | Inject malicious steps into the ML pipeline |
| Monitoring signal | Fake metrics to hide model degradation | Conceals ongoing attack against the model |
| Feature source | Spoof feature store data at inference time | Manipulate model inputs for specific outcomes |
Tampering in AI Systems
Tampering is the modification of data in transit or at rest. For AI systems, tampering threats are particularly impactful because they can affect the learned behavior of the model itself:
- Training data tampering: Modifying training data to inject backdoors or reduce model accuracy. This is the classic data poisoning attack.
- Model weight tampering: Directly modifying model weights in storage to change inference behavior. Even small perturbations can cause targeted misclassification.
- Feature tampering: Modifying features in the feature store or during feature retrieval to manipulate inference results.
- Pipeline tampering: Modifying training pipeline code, hyperparameters, or configuration to produce compromised models.
- Gradient tampering: In federated learning, malicious participants can send poisoned gradients that steer the global model.
Information Disclosure in AI Systems
- Model extraction: Systematically querying the model through its API to reconstruct a functionally equivalent copy.
- Training data extraction: Using model memorization to extract individual training samples through carefully crafted prompts.
- Membership inference: Determining whether a specific data point was used in training.
- Architecture disclosure: Extracting model architecture details through API responses, timing side channels, or metadata endpoints.
- Credential disclosure: AI infrastructure components (training jobs, serving endpoints) often have access to broad cloud credentials that could be exposed through SSRF or other vulnerabilities.
Denial of Service in AI Systems
- Inference resource exhaustion: Crafting inputs that maximize GPU compute, memory, or time (e.g., maximum-length prompts for LLMs).
- Training disruption: Monopolizing GPU scheduler resources, corrupting training data to cause training failures, or crashing training jobs.
- Model degradation: Gradually degrading model performance through data drift injection or strategic evasion, making the model unusable.
- Pipeline DoS: Triggering excessive retraining cycles or overwhelming the model registry.
Elevation of Privilege in AI Systems
- Container escape from GPU pods: Exploiting container runtime vulnerabilities with the elevated privileges common in GPU workloads.
- Service account abuse: Training jobs or serving pods with overly broad IAM roles that can be used to access unrelated resources.
- Pipeline privilege escalation: Using pipeline execution capabilities to run arbitrary code with the pipeline service account's permissions.
Building AI Data Flow Diagrams
Core DFD Elements for ML Systems
Data flow diagrams (DFDs) are the foundation of STRIDE threat modeling. For AI systems, the DFD must capture both the traditional software components and the ML-specific data flows:
"""
AI infrastructure data flow diagram generator.
Creates structured DFD representations for threat modeling
ML systems using STRIDE.
"""
import json
from dataclasses import dataclass, field
from typing import Optional
from enum import Enum
class ElementType(Enum):
PROCESS = "process"
DATA_STORE = "data_store"
EXTERNAL_ENTITY = "external_entity"
DATA_FLOW = "data_flow"
TRUST_BOUNDARY = "trust_boundary"
class TrustZone(Enum):
EXTERNAL = "external" # Internet, end users
DMZ = "dmz" # API gateways, load balancers
INFERENCE = "inference" # Model serving
TRAINING = "training" # Training cluster
DATA = "data" # Training data storage
CONTROL = "control" # Pipeline orchestration
MONITORING = "monitoring" # Observability
REGISTRY = "registry" # Model/artifact registry
@dataclass
class DFDElement:
"""An element in the data flow diagram."""
id: str
name: str
element_type: ElementType
trust_zone: TrustZone
description: str
technologies: list[str] = field(default_factory=list)
threats: list[str] = field(default_factory=list)
@dataclass
class DataFlow:
"""A data flow between DFD elements."""
id: str
source_id: str
target_id: str
data_description: str
protocol: str
is_encrypted: bool = False
is_authenticated: bool = False
threats: list[str] = field(default_factory=list)
class AIThreatModel:
"""
Threat model for an AI infrastructure deployment.
"""
def __init__(self, system_name: str, description: str):
self.system_name = system_name
self.description = description
self.elements: dict[str, DFDElement] = {}
self.data_flows: list[DataFlow] = []
self.threats: list[dict] = []
def add_element(self, element: DFDElement) -> None:
"""Add a DFD element."""
self.elements[element.id] = element
def add_data_flow(self, flow: DataFlow) -> None:
"""Add a data flow between elements."""
self.data_flows.append(flow)
def apply_stride(self) -> list[dict]:
"""
Systematically apply STRIDE to all elements and data flows.
Returns identified threats.
"""
threats = []
# Apply STRIDE to each element based on type
for elem in self.elements.values():
if elem.element_type == ElementType.PROCESS:
threats.extend(self._stride_process(elem))
elif elem.element_type == ElementType.DATA_STORE:
threats.extend(self._stride_data_store(elem))
elif elem.element_type == ElementType.EXTERNAL_ENTITY:
threats.extend(self._stride_external_entity(elem))
# Apply STRIDE to each data flow
for flow in self.data_flows:
threats.extend(self._stride_data_flow(flow))
# Apply cross-boundary threats
threats.extend(self._cross_boundary_threats())
self.threats = threats
return threats
def _stride_process(self, elem: DFDElement) -> list[dict]:
"""Apply all STRIDE categories to a process element."""
threats = []
prefix = f"{elem.name} ({elem.id})"
# Spoofing
threats.append({
"element": elem.id,
"category": "Spoofing",
"threat": (
f"An attacker could impersonate {prefix} to "
f"interact with upstream or downstream components."
),
"ai_specific": "model" in elem.name.lower() or "inference" in elem.name.lower(),
"mitigation": "Implement mutual TLS with workload identity (SPIFFE).",
})
# Tampering
threats.append({
"element": elem.id,
"category": "Tampering",
"threat": (
f"An attacker could modify the behavior of {prefix} "
f"by altering its inputs, configuration, or code."
),
"ai_specific": "training" in elem.name.lower(),
"mitigation": "Sign artifacts, use read-only filesystems, verify inputs.",
})
# Repudiation
threats.append({
"element": elem.id,
"category": "Repudiation",
"threat": (
f"Actions performed by or through {prefix} may not "
f"be properly logged, allowing denial of activity."
),
"ai_specific": False,
"mitigation": "Implement comprehensive audit logging.",
})
# Information Disclosure
if "model" in elem.name.lower() or "inference" in elem.name.lower():
threats.append({
"element": elem.id,
"category": "Information Disclosure",
"threat": (
f"{prefix} could leak model weights, architecture, "
f"or training data through inference API responses, "
f"error messages, or timing side channels."
),
"ai_specific": True,
"mitigation": (
"Implement rate limiting, response filtering, "
"differential privacy, and timing normalization."
),
})
# Denial of Service
threats.append({
"element": elem.id,
"category": "Denial of Service",
"threat": (
f"{prefix} could be overwhelmed by excessive requests "
f"or resource-intensive operations."
),
"ai_specific": "gpu" in elem.description.lower(),
"mitigation": "Implement resource limits, rate limiting, and circuit breakers.",
})
# Elevation of Privilege
threats.append({
"element": elem.id,
"category": "Elevation of Privilege",
"threat": (
f"A vulnerability in {prefix} could allow an attacker "
f"to gain the privileges of its service account."
),
"ai_specific": False,
"mitigation": "Apply least privilege, use non-root containers.",
})
return threats
def _stride_data_store(self, elem: DFDElement) -> list[dict]:
"""Apply applicable STRIDE categories to a data store."""
threats = []
# Tampering (primary threat for data stores)
threats.append({
"element": elem.id,
"category": "Tampering",
"threat": (
f"Data in {elem.name} could be modified. "
f"{'This directly enables training data poisoning. ' if 'training' in elem.name.lower() or 'data' in elem.name.lower() else ''}"
f"{'Model artifacts could be replaced with malicious versions. ' if 'model' in elem.name.lower() or 'registry' in elem.name.lower() else ''}"
),
"ai_specific": True,
"mitigation": "Enable versioning, access logging, integrity verification.",
})
# Information Disclosure
threats.append({
"element": elem.id,
"category": "Information Disclosure",
"threat": (
f"Unauthorized read access to {elem.name} could "
f"expose {'training data and PII' if 'data' in elem.name.lower() else 'model intellectual property'}."
),
"ai_specific": True,
"mitigation": "Encrypt at rest, enforce IAM policies, audit access.",
})
return threats
def _stride_external_entity(self, elem: DFDElement) -> list[dict]:
"""Spoofing is the primary STRIDE threat for external entities."""
return [{
"element": elem.id,
"category": "Spoofing",
"threat": (
f"An attacker could impersonate {elem.name} to "
f"submit malicious requests or data."
),
"ai_specific": "data" in elem.name.lower(),
"mitigation": "Authenticate all external entities. Validate inputs.",
}]
def _stride_data_flow(self, flow: DataFlow) -> list[dict]:
"""Apply Tampering and Information Disclosure to data flows."""
threats = []
source = self.elements.get(flow.source_id)
target = self.elements.get(flow.target_id)
if not flow.is_encrypted:
threats.append({
"element": flow.id,
"category": "Tampering",
"threat": (
f"Data flow from {source.name if source else flow.source_id} "
f"to {target.name if target else flow.target_id} "
f"({flow.data_description}) is unencrypted. "
f"An attacker could modify data in transit."
),
"ai_specific": "model" in flow.data_description.lower(),
"mitigation": "Enable TLS/mTLS for all data flows.",
})
threats.append({
"element": flow.id,
"category": "Information Disclosure",
"threat": (
f"Unencrypted {flow.data_description} could be "
f"intercepted between {source.name if source else '?'} "
f"and {target.name if target else '?'}."
),
"ai_specific": True,
"mitigation": "Encrypt all data in transit.",
})
if not flow.is_authenticated:
threats.append({
"element": flow.id,
"category": "Spoofing",
"threat": (
f"Data flow to {target.name if target else '?'} "
f"is not authenticated. An attacker could send "
f"forged {flow.data_description}."
),
"ai_specific": "model" in flow.data_description.lower()
or "training" in flow.data_description.lower(),
"mitigation": "Implement mutual authentication.",
})
return threats
def _cross_boundary_threats(self) -> list[dict]:
"""Identify threats at trust boundary crossings."""
threats = []
for flow in self.data_flows:
source = self.elements.get(flow.source_id)
target = self.elements.get(flow.target_id)
if source and target and source.trust_zone != target.trust_zone:
threats.append({
"element": flow.id,
"category": "Trust Boundary Crossing",
"threat": (
f"Data flows from {source.trust_zone.value} zone "
f"({source.name}) to {target.trust_zone.value} zone "
f"({target.name}). Cross-boundary data "
f"({flow.data_description}) requires validation."
),
"ai_specific": True,
"mitigation": (
"Validate all data crossing trust boundaries. "
"Apply defense-in-depth at boundary."
),
})
return threats
def generate_report(self) -> str:
"""Generate a structured threat model report."""
if not self.threats:
self.apply_stride()
ai_threats = [t for t in self.threats if t.get("ai_specific")]
traditional_threats = [t for t in self.threats if not t.get("ai_specific")]
# Count by category
category_counts: dict[str, int] = {}
for t in self.threats:
cat = t["category"]
category_counts[cat] = category_counts.get(cat, 0) + 1
report = {
"system": self.system_name,
"description": self.description,
"summary": {
"total_threats": len(self.threats),
"ai_specific_threats": len(ai_threats),
"traditional_threats": len(traditional_threats),
"by_category": category_counts,
},
"elements": len(self.elements),
"data_flows": len(self.data_flows),
"threats": self.threats,
}
return json.dumps(report, indent=2, default=str)Practical Examples
Threat Modeling an LLM Inference Pipeline
"""
Complete threat model for an LLM inference pipeline.
Demonstrates applying STRIDE to a realistic architecture.
"""
from typing import Any # Using the classes defined above
def create_llm_inference_threat_model() -> "AIThreatModel":
"""
Build a threat model for a typical LLM inference pipeline
deployed on Kubernetes with GPU nodes.
"""
tm = AIThreatModel(
system_name="LLM Inference Pipeline",
description=(
"Production LLM inference system serving API requests "
"through a load balancer, with model artifacts stored "
"in S3 and served via vLLM on GPU nodes."
),
)
# External entities
tm.add_element(DFDElement(
id="user", name="API Client",
element_type=ElementType.EXTERNAL_ENTITY,
trust_zone=TrustZone.EXTERNAL,
description="External application consuming the LLM API",
technologies=["HTTPS", "REST"],
))
tm.add_element(DFDElement(
id="admin", name="ML Engineer",
element_type=ElementType.EXTERNAL_ENTITY,
trust_zone=TrustZone.EXTERNAL,
description="Engineer managing model deployments",
technologies=["kubectl", "mlflow CLI"],
))
# Processes
tm.add_element(DFDElement(
id="gateway", name="API Gateway",
element_type=ElementType.PROCESS,
trust_zone=TrustZone.DMZ,
description="Authenticates requests, rate limits, routes to inference",
technologies=["Kong", "NGINX"],
))
tm.add_element(DFDElement(
id="vllm", name="vLLM Inference Server",
element_type=ElementType.PROCESS,
trust_zone=TrustZone.INFERENCE,
description="GPU-accelerated LLM inference with PagedAttention",
technologies=["vLLM", "CUDA", "PyTorch"],
))
tm.add_element(DFDElement(
id="pipeline", name="Model Deployment Pipeline",
element_type=ElementType.PROCESS,
trust_zone=TrustZone.CONTROL,
description="Automated pipeline that loads new model versions",
technologies=["Argo CD", "Helm"],
))
# Data stores
tm.add_element(DFDElement(
id="s3_models", name="S3 Model Artifact Store",
element_type=ElementType.DATA_STORE,
trust_zone=TrustZone.REGISTRY,
description="Stores trained model weights and configurations",
technologies=["AWS S3", "SSE-KMS"],
))
tm.add_element(DFDElement(
id="mlflow", name="MLflow Model Registry",
element_type=ElementType.DATA_STORE,
trust_zone=TrustZone.REGISTRY,
description="Tracks model versions and deployment stages",
technologies=["MLflow", "PostgreSQL"],
))
tm.add_element(DFDElement(
id="logs", name="Inference Log Store",
element_type=ElementType.DATA_STORE,
trust_zone=TrustZone.MONITORING,
description="Stores inference requests and responses for monitoring",
technologies=["Elasticsearch", "S3"],
))
# Data flows
tm.add_data_flow(DataFlow(
id="df1", source_id="user", target_id="gateway",
data_description="Inference requests (prompts)",
protocol="HTTPS",
is_encrypted=True, is_authenticated=True,
))
tm.add_data_flow(DataFlow(
id="df2", source_id="gateway", target_id="vllm",
data_description="Authenticated inference requests",
protocol="HTTP/gRPC",
is_encrypted=False, # Common gap: internal traffic unencrypted
is_authenticated=False,
))
tm.add_data_flow(DataFlow(
id="df3", source_id="s3_models", target_id="vllm",
data_description="Model weights (multi-GB)",
protocol="HTTPS (S3 API)",
is_encrypted=True, is_authenticated=True,
))
tm.add_data_flow(DataFlow(
id="df4", source_id="admin", target_id="mlflow",
data_description="Model stage transitions",
protocol="HTTPS",
is_encrypted=True, is_authenticated=True,
))
tm.add_data_flow(DataFlow(
id="df5", source_id="mlflow", target_id="pipeline",
data_description="Model deployment triggers",
protocol="Webhook HTTP",
is_encrypted=False,
is_authenticated=False,
))
tm.add_data_flow(DataFlow(
id="df6", source_id="vllm", target_id="logs",
data_description="Inference requests and responses",
protocol="HTTP",
is_encrypted=False,
is_authenticated=False,
))
return tm
if __name__ == "__main__":
tm = create_llm_inference_threat_model()
threats = tm.apply_stride()
report = tm.generate_report()
print(report)
print(f"\nTotal threats identified: {len(threats)}")
print(f"AI-specific threats: {sum(1 for t in threats if t.get('ai_specific'))}")Risk Scoring with DREAD Adapted for AI
DREAD (Damage, Reproducibility, Exploitability, Affected users, Discoverability) can be adapted for AI systems by considering AI-specific factors in each dimension:
"""
DREAD risk scoring adapted for AI infrastructure threats.
Adds AI-specific weighting factors to standard DREAD scoring.
"""
from dataclasses import dataclass
from typing import Optional
@dataclass
class AIThreatScore:
"""DREAD score for an AI infrastructure threat."""
threat_name: str
atlas_technique: Optional[str]
# Standard DREAD dimensions (1-10 scale)
damage: int # How bad is it if the attack succeeds?
reproducibility: int # How easy is it to reproduce?
exploitability: int # How easy is it to launch the attack?
affected_users: int # How many users are impacted?
discoverability: int # How easy is it to find the vulnerability?
# AI-specific adjustment factors
model_impact: float = 1.0 # Does this affect model behavior? (1.0-2.0)
data_sensitivity: float = 1.0 # Is training/inference data sensitive? (1.0-2.0)
detection_difficulty: float = 1.0 # How hard to detect AI-specific aspect? (1.0-2.0)
@property
def base_dread_score(self) -> float:
"""Calculate standard DREAD score (0-10)."""
return (
self.damage + self.reproducibility + self.exploitability
+ self.affected_users + self.discoverability
) / 5.0
@property
def ai_adjusted_score(self) -> float:
"""
DREAD score adjusted for AI-specific factors.
AI attacks that corrupt model behavior or are hard to
detect get higher priority.
"""
base = self.base_dread_score
ai_multiplier = (
self.model_impact * self.data_sensitivity * self.detection_difficulty
) ** (1/3) # Geometric mean to prevent extreme values
return min(10.0, base * ai_multiplier)
@property
def priority(self) -> str:
"""Human-readable priority based on AI-adjusted score."""
score = self.ai_adjusted_score
if score >= 8.0:
return "CRITICAL"
elif score >= 6.0:
return "HIGH"
elif score >= 4.0:
return "MEDIUM"
elif score >= 2.0:
return "LOW"
return "INFO"
# Example: Score common AI infrastructure threats
AI_THREAT_SCORES = [
AIThreatScore(
threat_name="Training data poisoning via S3 write access",
atlas_technique="AML.T0020",
damage=9, reproducibility=8, exploitability=6,
affected_users=10, discoverability=4,
model_impact=2.0, data_sensitivity=1.8, detection_difficulty=1.8,
),
AIThreatScore(
threat_name="Model extraction through inference API",
atlas_technique="AML.T0024",
damage=8, reproducibility=9, exploitability=7,
affected_users=8, discoverability=6,
model_impact=1.5, data_sensitivity=1.3, detection_difficulty=1.5,
),
AIThreatScore(
threat_name="RCE via TorchServe management API",
atlas_technique="AML.T0010",
damage=10, reproducibility=9, exploitability=8,
affected_users=10, discoverability=7,
model_impact=1.0, data_sensitivity=1.0, detection_difficulty=1.0,
),
AIThreatScore(
threat_name="GPU memory side-channel between tenants",
atlas_technique="AML.T0024",
damage=7, reproducibility=5, exploitability=4,
affected_users=6, discoverability=3,
model_impact=1.5, data_sensitivity=1.7, detection_difficulty=2.0,
),
]
def generate_risk_report(threats: list[AIThreatScore]) -> str:
"""Generate a prioritized risk report from scored threats."""
sorted_threats = sorted(
threats, key=lambda t: t.ai_adjusted_score, reverse=True,
)
lines = ["AI Infrastructure Threat Risk Report", "=" * 45, ""]
for t in sorted_threats:
lines.append(
f"[{t.priority}] {t.threat_name}"
)
lines.append(
f" Base DREAD: {t.base_dread_score:.1f} | "
f"AI-Adjusted: {t.ai_adjusted_score:.1f}"
)
if t.atlas_technique:
lines.append(f" ATLAS: {t.atlas_technique}")
lines.append("")
return "\n".join(lines)
if __name__ == "__main__":
print(generate_risk_report(AI_THREAT_SCORES))Threat Modeling Workshop Facilitation
Running an effective AI threat modeling session requires bringing together both security expertise and ML engineering knowledge. Here is a structured approach for facilitating these sessions:
Participants: Include at minimum one security engineer familiar with STRIDE, one ML engineer who understands the training pipeline, one infrastructure engineer who manages the deployment platform, and one product owner who can assess business impact. For LLM systems, include someone who understands prompt engineering and model behavior.
Pre-work: Before the session, prepare a draft DFD based on architecture documentation and interviews. The DFD should show all data flows including training data ingestion, model artifact movement, inference request flow, and telemetry collection. Distribute the DFD to participants one week before the session for review.
Session structure (typically 2-4 hours):
- Review and refine the DFD (30 minutes)
- Walk through each element and data flow, applying STRIDE (90-120 minutes)
- Score identified threats using DREAD or CVSS (30 minutes)
- Identify top-priority mitigations (30 minutes)
Common blind spots to specifically probe during the session:
- "What happens if someone modifies the training data?" (teams often assume data integrity)
- "What can be learned by querying the model thousands of times?" (model extraction)
- "What credentials does a training job have access to?" (blast radius)
- "Can a compromised model exfiltrate data through its outputs?" (model as exfiltration channel)
STRIDE Threat Catalog for AI
The following table summarizes key STRIDE threats specific to AI infrastructure, beyond traditional IT threats:
| STRIDE Category | AI Threat | ATLAS Technique | Priority |
|---|---|---|---|
| Spoofing | Impersonate data source to inject poisoned data | AML.T0020 | Critical |
| Spoofing | Replace model artifact in registry | AML.T0010 | Critical |
| Tampering | Modify training data in storage | AML.T0020 | Critical |
| Tampering | Alter model weights during transfer | AML.T0010 | High |
| Tampering | Manipulate feature values at inference time | AML.T0043 | High |
| Repudiation | Deny model deployment that caused incident | - | Medium |
| Info Disclosure | Extract model through inference API | AML.T0024 | High |
| Info Disclosure | Extract training data from model | AML.T0024 | High |
| Info Disclosure | Leak GPU memory between tenants | AML.T0024 | High |
| DoS | Resource exhaustion via adversarial prompts | AML.T0029 | High |
| DoS | Training disruption via data corruption | - | Medium |
| EoP | Container escape from privileged GPU pod | - | Critical |
| EoP | Pipeline RCE via model deserialization | AML.T0010 | Critical |
Defense and Mitigation
Conduct threat modeling early and iteratively: Threat model AI systems during design, before deployment, and after significant architecture changes. Include ML engineers in threat modeling sessions — they understand data flows and model behavior that security teams may miss.
Use STRIDE-per-Element systematically: Walk through every DFD element and every data flow, applying each STRIDE category. The structured approach prevents overlooking threats that seem unlikely but have high impact.
Prioritize AI-specific threats: Standard IT threat modeling may rate "data store tampering" as medium risk. For AI systems, training data tampering is critical because it can silently corrupt model behavior in ways that are difficult to detect.
Map to MITRE ATLAS: Cross-reference identified threats with ATLAS techniques to ensure completeness and to leverage the community's knowledge of real-world attacks.
Document assumptions and trust boundaries: Explicitly state what is trusted and what is not. Common unstated assumptions in AI systems (e.g., "training data is trusted," "model artifacts in S3 have not been modified") are frequently the root cause of security gaps.
Revisit after retraining: Each model retraining cycle potentially introduces new threats if the training data, pipeline, or hyperparameters change. Update the threat model accordingly.
Automate threat model validation: Use the structured threat model output to automatically generate security test cases. For each identified threat with a test procedure, create an automated test that verifies whether the mitigation is in place and effective. Run these tests as part of CI/CD and on a scheduled basis to detect regressions. This transforms the threat model from a static document into a living security specification that is continuously validated against the actual system.
Consider adversarial ML in the threat model: Traditional STRIDE focuses on infrastructure-level threats. For AI systems, explicitly include adversarial ML attacks such as evasion, poisoning, and model inversion in the threat model. These attacks target the learned behavior of the model rather than the infrastructure, and they require different mitigations (robust training, input validation, output perturbation) than traditional security controls.
References
- Shostack, A. (2014). "Threat Modeling: Designing for Security." Wiley. ISBN: 978-1118809990.
- MITRE. (2024). "ATLAS: Adversarial Threat Landscape for AI Systems." https://atlas.mitre.org/
- Microsoft. (2024). "Threat Modeling for AI/ML Systems." https://learn.microsoft.com/en-us/security/engineering/threat-modeling-aiml
- NIST. (2023). "AI Risk Management Framework (AI RMF 1.0)." https://airc.nist.gov/AI_RMF_Interactivity/
- OWASP. (2025). "OWASP Machine Learning Security Top 10." https://owasp.org/www-project-machine-learning-security-top-10/