Model Weight Encryption
Encryption at rest and in transit for ML model weights, protecting intellectual property and preventing unauthorized model access.
Overview
Model weights represent significant investment — training a frontier model costs millions of dollars in compute, and even fine-tuned models encode proprietary knowledge that provides competitive advantage. Protecting model weights from unauthorized access is an intellectual property concern, a regulatory concern (if the model was trained on regulated data), and a security concern (a stolen model can be studied to find vulnerabilities more easily than a black-box API).
Encryption addresses two key threats: unauthorized access to model files at rest (on disk, in object storage, in model registries) and interception of model weights in transit (during deployment, distribution to edge devices, or transfer between environments). Without encryption, anyone with access to the storage backend or network path can read model weights directly.
This article covers practical encryption schemes for model weights, key management approaches, and the integration points where encryption fits into ML deployment pipelines. The content aligns with NIST AI RMF Govern 1.7 (protection of AI assets) and addresses the model theft concerns in MITRE ATLAS AML.T0024 (Exfiltration via ML Inference API) and AML.T0000 (ML Model Access).
Threat Model
Who Wants Model Weights?
| Threat Actor | Motivation | Access Vector |
|---|---|---|
| Competitors | Replicate capabilities without training cost | Insider threat, storage breach, supply chain |
| Nation-state actors | Strategic AI capabilities | Advanced persistent threats, cloud provider compromise |
| Researchers | Study model internals for vulnerabilities | Download from misconfigured storage |
| Malicious insiders | Sell IP, personal use | Direct access to storage and deployment systems |
| Cloud providers | Theoretically can access customer data | Infrastructure access (mitigated by confidential computing) |
What Encryption Protects Against
Encryption protects model weights from unauthorized access at the storage and network layers. It does NOT protect against:
- Attacks against the running model (the model must be decrypted in GPU memory to serve inference)
- Key management failures (if the attacker obtains the encryption key, encryption is meaningless)
- Side-channel attacks that extract model information from inference behavior
- Model extraction through the API (querying the model to replicate its behavior)
Encryption at Rest
File-Level Encryption
The most straightforward approach encrypts model files before storing them:
import os
import hashlib
import json
from pathlib import Path
from typing import Dict, Optional, Tuple
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
class ModelEncryptor:
"""Encrypt and decrypt model weight files using AES-256-GCM."""
HEADER_MAGIC = b"ENCMODEL"
HEADER_VERSION = 1
NONCE_SIZE = 12 # 96 bits for AES-GCM
KEY_SIZE = 32 # 256 bits
def __init__(self, key: Optional[bytes] = None):
"""
Args:
key: 32-byte AES-256 key. If not provided, a new key is generated.
"""
self.key = key or AESGCM.generate_key(bit_length=256)
self._aesgcm = AESGCM(self.key)
def encrypt_file(
self,
input_path: str,
output_path: str,
associated_data: Optional[bytes] = None,
) -> Dict:
"""
Encrypt a model file with AES-256-GCM.
Args:
input_path: Path to the plaintext model file
output_path: Path for the encrypted output
associated_data: Optional AAD (e.g., model name, version) bound to ciphertext
"""
input_path = Path(input_path)
output_path = Path(output_path)
# Read plaintext
plaintext = input_path.read_bytes()
plaintext_hash = hashlib.sha256(plaintext).hexdigest()
# Generate random nonce
nonce = os.urandom(self.NONCE_SIZE)
# Encrypt
ciphertext = self._aesgcm.encrypt(nonce, plaintext, associated_data)
# Write encrypted file with header
metadata = {
"version": self.HEADER_VERSION,
"algorithm": "AES-256-GCM",
"plaintext_size": len(plaintext),
"plaintext_hash": plaintext_hash,
"has_aad": associated_data is not None,
}
metadata_bytes = json.dumps(metadata).encode()
with open(output_path, "wb") as f:
f.write(self.HEADER_MAGIC)
f.write(len(metadata_bytes).to_bytes(4, "big"))
f.write(metadata_bytes)
f.write(nonce)
f.write(ciphertext)
return {
"input_path": str(input_path),
"output_path": str(output_path),
"plaintext_size": len(plaintext),
"ciphertext_size": len(ciphertext),
"overhead_bytes": len(ciphertext) - len(plaintext),
"plaintext_hash": plaintext_hash,
}
def decrypt_file(
self,
input_path: str,
output_path: str,
associated_data: Optional[bytes] = None,
) -> Dict:
"""Decrypt an encrypted model file."""
with open(input_path, "rb") as f:
# Read and validate header
magic = f.read(len(self.HEADER_MAGIC))
if magic != self.HEADER_MAGIC:
raise ValueError("Not an encrypted model file")
metadata_size = int.from_bytes(f.read(4), "big")
metadata = json.loads(f.read(metadata_size))
nonce = f.read(self.NONCE_SIZE)
ciphertext = f.read()
# Decrypt
plaintext = self._aesgcm.decrypt(nonce, ciphertext, associated_data)
# Verify integrity
actual_hash = hashlib.sha256(plaintext).hexdigest()
if actual_hash != metadata.get("plaintext_hash"):
raise ValueError("Integrity check failed — file may be corrupted or tampered")
Path(output_path).write_bytes(plaintext)
return {
"output_path": output_path,
"size": len(plaintext),
"integrity_verified": True,
"hash": actual_hash,
}Envelope Encryption
For production deployments, use envelope encryption. The model is encrypted with a data encryption key (DEK), and the DEK is encrypted with a key encryption key (KEK) stored in a cloud KMS:
import os
import json
import base64
from typing import Dict, Optional
from pathlib import Path
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
class EnvelopeEncryptor:
"""
Envelope encryption for model files using cloud KMS.
The model is encrypted with a random DEK, and the DEK is encrypted with a KEK in KMS.
"""
def __init__(self, kms_client, kms_key_id: str):
"""
Args:
kms_client: Cloud KMS client (AWS KMS, GCP KMS, or Azure Key Vault)
kms_key_id: KMS key identifier for the KEK
"""
self.kms_client = kms_client
self.kms_key_id = kms_key_id
def encrypt_model(
self,
model_path: str,
output_path: str,
model_metadata: Optional[Dict] = None,
) -> Dict:
"""Encrypt a model file using envelope encryption."""
# Generate a random DEK
dek = AESGCM.generate_key(bit_length=256)
nonce = os.urandom(12)
# Encrypt the DEK with the KEK via KMS
encrypted_dek = self.kms_client.encrypt(
KeyId=self.kms_key_id,
Plaintext=dek,
EncryptionContext=model_metadata or {},
)["CiphertextBlob"]
# Encrypt the model with the DEK
plaintext = Path(model_path).read_bytes()
aesgcm = AESGCM(dek)
aad = json.dumps(model_metadata or {}).encode()
ciphertext = aesgcm.encrypt(nonce, plaintext, aad)
# Store encrypted model with envelope
envelope = {
"version": 1,
"kms_key_id": self.kms_key_id,
"encrypted_dek": base64.b64encode(encrypted_dek).decode(),
"nonce": base64.b64encode(nonce).decode(),
"metadata": model_metadata or {},
"plaintext_hash": __import__("hashlib").sha256(plaintext).hexdigest(),
}
envelope_bytes = json.dumps(envelope).encode()
with open(output_path, "wb") as f:
f.write(b"ENVMODEL")
f.write(len(envelope_bytes).to_bytes(4, "big"))
f.write(envelope_bytes)
f.write(ciphertext)
# Securely clear the DEK from memory
dek = b"\x00" * len(dek)
return {
"output_path": output_path,
"plaintext_size": len(plaintext),
"ciphertext_size": len(ciphertext),
"kms_key_id": self.kms_key_id,
}
def decrypt_model(self, encrypted_path: str, output_path: str) -> Dict:
"""Decrypt an envelope-encrypted model file."""
with open(encrypted_path, "rb") as f:
magic = f.read(8)
if magic != b"ENVMODEL":
raise ValueError("Not an envelope-encrypted model file")
envelope_size = int.from_bytes(f.read(4), "big")
envelope = json.loads(f.read(envelope_size))
ciphertext = f.read()
# Decrypt the DEK using KMS
encrypted_dek = base64.b64decode(envelope["encrypted_dek"])
dek = self.kms_client.decrypt(
CiphertextBlob=encrypted_dek,
EncryptionContext=envelope.get("metadata", {}),
)["Plaintext"]
# Decrypt the model with the DEK
nonce = base64.b64decode(envelope["nonce"])
aad = json.dumps(envelope.get("metadata", {})).encode()
aesgcm = AESGCM(dek)
plaintext = aesgcm.decrypt(nonce, ciphertext, aad)
# Verify integrity
actual_hash = __import__("hashlib").sha256(plaintext).hexdigest()
if actual_hash != envelope.get("plaintext_hash"):
raise ValueError("Integrity check failed")
Path(output_path).write_bytes(plaintext)
dek = b"\x00" * len(dek)
return {
"output_path": output_path,
"size": len(plaintext),
"integrity_verified": True,
}Storage Backend Encryption
Most cloud storage services provide server-side encryption. While this protects against physical disk theft, it does not protect against compromised storage credentials:
import boto3
from typing import Dict
class S3ModelStorageSecurity:
"""Configure and audit S3 encryption for model storage."""
def __init__(self, bucket_name: str):
self.bucket = bucket_name
self.s3 = boto3.client("s3")
def configure_default_encryption(self, kms_key_id: str) -> Dict:
"""Enable default SSE-KMS encryption on the model bucket."""
self.s3.put_bucket_encryption(
Bucket=self.bucket,
ServerSideEncryptionConfiguration={
"Rules": [
{
"ApplyServerSideEncryptionByDefault": {
"SSEAlgorithm": "aws:kms",
"KMSMasterKeyID": kms_key_id,
},
"BucketKeyEnabled": True,
}
]
},
)
return {"bucket": self.bucket, "encryption": "SSE-KMS", "key_id": kms_key_id}
def audit_encryption(self) -> Dict:
"""Audit encryption configuration of the model storage bucket."""
findings = []
# Check default encryption
try:
enc = self.s3.get_bucket_encryption(Bucket=self.bucket)
rules = enc["ServerSideEncryptionConfiguration"]["Rules"]
for rule in rules:
algo = rule["ApplyServerSideEncryptionByDefault"]["SSEAlgorithm"]
findings.append({
"check": "default_encryption",
"status": "PASS",
"algorithm": algo,
})
except self.s3.exceptions.ClientError:
findings.append({
"check": "default_encryption",
"status": "FAIL",
"detail": "No default encryption configured",
})
# Check if bucket policy enforces encryption
try:
policy = json.loads(
self.s3.get_bucket_policy(Bucket=self.bucket)["Policy"]
)
has_encryption_enforcement = any(
"s3:x-amz-server-side-encryption" in json.dumps(stmt)
for stmt in policy.get("Statement", [])
)
findings.append({
"check": "encryption_enforcement",
"status": "PASS" if has_encryption_enforcement else "WARNING",
"detail": (
"Bucket policy enforces encryption"
if has_encryption_enforcement
else "No bucket policy enforcement for encryption"
),
})
except self.s3.exceptions.ClientError:
findings.append({
"check": "encryption_enforcement",
"status": "WARNING",
"detail": "No bucket policy found",
})
return {"bucket": self.bucket, "findings": findings}Encryption in Transit
Model Distribution Security
When distributing models from a registry to serving infrastructure, encrypt the transfer:
import hashlib
import ssl
from typing import Dict
from pathlib import Path
class SecureModelDistributor:
"""Securely distribute encrypted model files to serving infrastructure."""
def __init__(self, encryption_key: bytes):
self.encryptor = ModelEncryptor(key=encryption_key)
def prepare_for_distribution(
self,
model_path: str,
model_name: str,
model_version: str,
) -> Dict:
"""Encrypt and prepare a model for secure distribution."""
encrypted_path = f"{model_path}.encrypted"
# Bind model identity to the ciphertext via AAD
aad = f"{model_name}:{model_version}".encode()
result = self.encryptor.encrypt_file(
input_path=model_path,
output_path=encrypted_path,
associated_data=aad,
)
# Generate distribution manifest
manifest = {
"model_name": model_name,
"model_version": model_version,
"encrypted_file": encrypted_path,
"plaintext_hash": result["plaintext_hash"],
"ciphertext_hash": hashlib.sha256(
Path(encrypted_path).read_bytes()
).hexdigest(),
"encryption_algorithm": "AES-256-GCM",
}
return manifest
@staticmethod
def create_tls_context(
cert_path: str,
key_path: str,
ca_path: str,
) -> ssl.SSLContext:
"""Create a TLS context for secure model transfer."""
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.minimum_version = ssl.TLSVersion.TLSv1_3
context.load_cert_chain(cert_path, key_path)
context.load_verify_locations(ca_path)
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = True
return contextSecure Loading at Inference Time
The model must be decrypted before loading into GPU memory. Minimize the time decrypted weights exist on disk:
import tempfile
import os
from typing import Dict
from pathlib import Path
class SecureModelLoader:
"""Load encrypted models with minimal plaintext exposure."""
def __init__(self, decryption_key: bytes):
self.encryptor = ModelEncryptor(key=decryption_key)
def load_encrypted_model(
self,
encrypted_path: str,
model_class,
aad: bytes = None,
device: str = "cuda:0",
) -> Dict:
"""
Decrypt and load a model, minimizing plaintext disk exposure.
Uses a tmpfs mount to avoid writing decrypted weights to persistent storage.
"""
# Use tmpfs (memory-backed filesystem) if available
tmpfs_dir = "/dev/shm" # Default tmpfs on Linux
if not os.path.isdir(tmpfs_dir):
tmpfs_dir = tempfile.gettempdir()
decrypted_path = os.path.join(tmpfs_dir, f"model_{os.getpid()}.bin")
try:
# Decrypt to tmpfs
result = self.encryptor.decrypt_file(
input_path=encrypted_path,
output_path=decrypted_path,
associated_data=aad,
)
# Load model
import torch
state_dict = torch.load(
decrypted_path,
map_location=device,
weights_only=True,
)
model = model_class()
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return {
"success": True,
"model": model,
"integrity_verified": result["integrity_verified"],
}
finally:
# Securely delete the decrypted file
if os.path.exists(decrypted_path):
file_size = os.path.getsize(decrypted_path)
with open(decrypted_path, "wb") as f:
f.write(os.urandom(file_size)) # Overwrite with random data
os.remove(decrypted_path)Key Management
Key Rotation
Model encryption keys must be rotated regularly. Implement a rotation strategy that re-encrypts models with new keys:
from datetime import datetime, timezone
from typing import Dict, List
class ModelKeyRotation:
"""Manage encryption key rotation for model files."""
def __init__(self, kms_client, kms_key_id: str):
self.kms = kms_client
self.key_id = kms_key_id
def rotate_key(self) -> Dict:
"""Trigger a key rotation in KMS."""
# AWS KMS automatic rotation creates a new backing key
# while keeping the same key ID
self.kms.enable_key_rotation(KeyId=self.key_id)
return {
"key_id": self.key_id,
"rotation_enabled": True,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
def re_encrypt_model(
self,
encrypted_path: str,
new_kms_key_id: str,
) -> Dict:
"""Re-encrypt a model file with a new KMS key (for key migration)."""
envelope_enc = EnvelopeEncryptor(self.kms, self.key_id)
# Decrypt with old key
import tempfile
with tempfile.NamedTemporaryFile(dir="/dev/shm", delete=False) as tmp:
tmp_path = tmp.name
try:
envelope_enc.decrypt_model(encrypted_path, tmp_path)
# Re-encrypt with new key
new_enc = EnvelopeEncryptor(self.kms, new_kms_key_id)
new_enc.encrypt_model(tmp_path, encrypted_path)
return {
"path": encrypted_path,
"old_key": self.key_id,
"new_key": new_kms_key_id,
"re_encrypted": True,
}
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)Performance Considerations
| Operation | Overhead | Impact on Deployment |
|---|---|---|
| AES-256-GCM encryption | ~1 GB/s on modern CPUs | Seconds for typical model files |
| KMS key decryption | 50-200ms network RTT | One-time per model load |
| tmpfs decryption buffer | RAM usage = model size | Must have sufficient RAM |
| TLS 1.3 transfer | ~1-3% throughput overhead | Negligible for model downloads |
For most deployments, the encryption overhead is negligible compared to model loading and GPU initialization time. The dominant cost is the KMS call for envelope encryption, which adds 50-200ms of latency per model load.
Defensive Recommendations
- Encrypt all model files at rest using AES-256-GCM with envelope encryption
- Use cloud KMS for key management — never store encryption keys alongside encrypted models
- Enable server-side encryption on all storage backends (S3, GCS, Azure Blob)
- Use TLS 1.3 for all model file transfers
- Decrypt to tmpfs (/dev/shm) to avoid writing plaintext to persistent storage
- Rotate encryption keys regularly and re-encrypt models after key rotation
- Bind model identity to ciphertext using authenticated encryption (AAD) to prevent model swaps
- Audit key access — all KMS decryption operations should be logged and monitored
- Consider confidential computing (Intel TDX, AMD SEV-SNP, NVIDIA H100 CC) for protecting decrypted weights in memory
References
- NIST SP 800-38D — Recommendation for Block Cipher Modes of Operation: Galois/Counter Mode (GCM)
- AWS KMS Documentation — https://docs.aws.amazon.com/kms/
- NIST AI RMF — Govern 1.7 (Protection of AI assets and intellectual property)
- MITRE ATLAS — AML.T0024 (Exfiltration via ML Inference API), AML.T0000 (ML Model Access)
- NVIDIA Confidential Computing — https://developer.nvidia.com/confidential-computing