AI Rate Limiting Walkthrough
Step-by-step walkthrough for implementing token-aware rate limiting for AI applications: request-level limiting, token budget enforcement, sliding window algorithms, abuse detection, and production deployment.
Rate limiting for AI applications differs from traditional API rate limiting in one critical way: not all requests cost the same. A request that sends 10 tokens and receives 50 tokens costs five times less than one that sends 2,000 tokens and receives 4,000 tokens. Effective AI rate limiting must account for token consumption, not just request count.
This walkthrough builds a token-aware rate limiting system from scratch, starting with simple request counting and progressing to token-budget enforcement with abuse detection.
Step 1: Understanding AI Rate Limiting
Why Token-Aware Limiting Matters
Traditional rate limiting counts requests:
User A: 100 requests/minute (each ~50 tokens) = 5,000 tokens
User B: 100 requests/minute (each ~5,000 tokens) = 500,000 tokens
Both users make the same number of requests, but User B consumes 100x more resources. Request-only rate limiting treats them identically. Token-aware limiting enforces fair resource usage.
Rate Limiting Dimensions for AI
| Dimension | What It Controls | Why It Matters |
|---|---|---|
| Requests per minute | Raw API call volume | Prevents automated scanning |
| Input tokens per minute | Prompt size throughput | Prevents context window abuse |
| Output tokens per minute | Generation throughput | Controls compute cost |
| Total tokens per minute | Combined consumption | Overall resource fairness |
| Concurrent requests | Simultaneous connections | Prevents connection exhaustion |
| Requests per session | Conversation length | Limits long-running attacks |
Step 2: Basic Request-Level Rate Limiting
Start with simple request counting before adding token awareness.
# rate_limiter/basic.py
"""Basic request-level rate limiter using sliding window."""
import time
from collections import defaultdict
from dataclasses import dataclass
from threading import Lock
from typing import Optional
@dataclass
class RateLimitResult:
allowed: bool
remaining: int
limit: int
reset_at: float
retry_after: Optional[float] = None
class SlidingWindowRateLimiter:
"""Sliding window rate limiter for request counting."""
def __init__(self, max_requests: int = 60,
window_seconds: int = 60):
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests = defaultdict(list)
self.lock = Lock()
def check(self, key: str) -> RateLimitResult:
"""Check if a request is allowed for the given key."""
now = time.time()
window_start = now - self.window_seconds
with self.lock:
# Remove expired entries
self.requests[key] = [
t for t in self.requests[key] if t > window_start
]
current_count = len(self.requests[key])
if current_count >= self.max_requests:
# Find when the oldest request in the window expires
oldest = min(self.requests[key])
retry_after = oldest + self.window_seconds - now
return RateLimitResult(
allowed=False,
remaining=0,
limit=self.max_requests,
reset_at=oldest + self.window_seconds,
retry_after=max(retry_after, 0.1),
)
# Allow and record the request
self.requests[key].append(now)
return RateLimitResult(
allowed=True,
remaining=self.max_requests - current_count - 1,
limit=self.max_requests,
reset_at=now + self.window_seconds,
)FastAPI Integration
# server.py
"""FastAPI application with basic rate limiting."""
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from rate_limiter.basic import SlidingWindowRateLimiter
app = FastAPI()
limiter = SlidingWindowRateLimiter(max_requests=60, window_seconds=60)
def get_rate_limit_key(request: Request) -> str:
"""Extract the rate limit key from the request."""
# Use API key if available, otherwise IP address
api_key = request.headers.get("X-API-Key")
if api_key:
return f"api:{api_key}"
return f"ip:{request.client.host}"
@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
"""Apply rate limiting to all requests."""
if request.url.path.startswith("/health"):
return await call_next(request)
key = get_rate_limit_key(request)
result = limiter.check(key)
if not result.allowed:
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": "Too many requests. Please retry later.",
"retry_after": result.retry_after,
},
headers={
"Retry-After": str(int(result.retry_after)),
"X-RateLimit-Limit": str(result.limit),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(result.reset_at)),
},
)
response = await call_next(request)
# Add rate limit headers to successful responses
response.headers["X-RateLimit-Limit"] = str(result.limit)
response.headers["X-RateLimit-Remaining"] = str(result.remaining)
response.headers["X-RateLimit-Reset"] = str(int(result.reset_at))
return responseStep 3: Token-Aware Rate Limiting
Extend the basic limiter to track token consumption.
# rate_limiter/token_aware.py
"""Token-aware rate limiter for AI applications."""
import time
from collections import defaultdict
from dataclasses import dataclass, field
from threading import Lock
from typing import Optional
@dataclass
class TokenBucket:
"""Tracks token consumption within a time window."""
entries: list = field(default_factory=list)
# Each entry: (timestamp, input_tokens, output_tokens)
@dataclass
class TokenRateLimitResult:
allowed: bool
remaining_requests: int
remaining_tokens: int
request_limit: int
token_limit: int
reset_at: float
retry_after: Optional[float] = None
reason: str = ""
class TokenAwareRateLimiter:
"""Rate limiter that tracks both requests and token consumption."""
def __init__(self, config: Optional[dict] = None):
config = config or {}
self.window_seconds = config.get("window_seconds", 60)
self.tiers = config.get("tiers", {
"free": {
"max_requests": 20,
"max_input_tokens": 10_000,
"max_output_tokens": 10_000,
"max_total_tokens": 20_000,
},
"standard": {
"max_requests": 60,
"max_input_tokens": 100_000,
"max_output_tokens": 100_000,
"max_total_tokens": 200_000,
},
"premium": {
"max_requests": 200,
"max_input_tokens": 500_000,
"max_output_tokens": 500_000,
"max_total_tokens": 1_000_000,
},
})
self.default_tier = config.get("default_tier", "free")
self.buckets = defaultdict(TokenBucket)
self.user_tiers = {}
self.lock = Lock()
def set_user_tier(self, user_key: str, tier: str):
"""Set the rate limit tier for a user."""
if tier not in self.tiers:
raise ValueError(f"Unknown tier: {tier}")
self.user_tiers[user_key] = tier
def check_request(self, user_key: str,
estimated_tokens: int = 0) -> TokenRateLimitResult:
"""Check if a request is allowed before model invocation."""
tier_name = self.user_tiers.get(user_key, self.default_tier)
tier = self.tiers[tier_name]
now = time.time()
window_start = now - self.window_seconds
with self.lock:
bucket = self.buckets[user_key]
# Clean expired entries
bucket.entries = [
e for e in bucket.entries if e[0] > window_start
]
current_requests = len(bucket.entries)
current_input_tokens = sum(e[1] for e in bucket.entries)
current_output_tokens = sum(e[2] for e in bucket.entries)
current_total = current_input_tokens + current_output_tokens
# Check request limit
if current_requests >= tier["max_requests"]:
return self._denied(
tier, bucket, now,
reason="Request limit exceeded",
)
# Check token limits (with estimated tokens)
if (current_total + estimated_tokens >
tier["max_total_tokens"]):
return self._denied(
tier, bucket, now,
reason="Token limit exceeded",
)
remaining_requests = (
tier["max_requests"] - current_requests - 1
)
remaining_tokens = (
tier["max_total_tokens"] - current_total
)
return TokenRateLimitResult(
allowed=True,
remaining_requests=remaining_requests,
remaining_tokens=remaining_tokens,
request_limit=tier["max_requests"],
token_limit=tier["max_total_tokens"],
reset_at=now + self.window_seconds,
)
def record_usage(self, user_key: str,
input_tokens: int, output_tokens: int):
"""Record actual token usage after model invocation."""
now = time.time()
with self.lock:
self.buckets[user_key].entries.append(
(now, input_tokens, output_tokens)
)
def _denied(self, tier, bucket, now, reason):
"""Build a denial result."""
if bucket.entries:
oldest = min(e[0] for e in bucket.entries)
retry_after = oldest + self.window_seconds - now
else:
retry_after = self.window_seconds
return TokenRateLimitResult(
allowed=False,
remaining_requests=0,
remaining_tokens=0,
request_limit=tier["max_requests"],
token_limit=tier["max_total_tokens"],
reset_at=now + self.window_seconds,
retry_after=max(retry_after, 0.1),
reason=reason,
)Integrating Token Counting
# rate_limiter/token_counter.py
"""Token counting utilities for rate limiting."""
def estimate_tokens(text: str, model: str = "gpt-4o-mini") -> int:
"""Estimate token count for a text string.
Uses a simple heuristic when tiktoken is not available.
"""
try:
import tiktoken
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(text))
except ImportError:
# Rough estimate: ~4 characters per token for English text
return len(text) // 4
def count_message_tokens(messages: list,
model: str = "gpt-4o-mini") -> int:
"""Count tokens in a list of chat messages."""
total = 0
for message in messages:
# Each message has overhead for role and formatting
total += 4 # message overhead
total += estimate_tokens(message.get("content", ""), model)
total += 2 # conversation overhead
return totalUpdated FastAPI Integration
# server_token_aware.py
"""FastAPI with token-aware rate limiting."""
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from rate_limiter.token_aware import TokenAwareRateLimiter
from rate_limiter.token_counter import count_message_tokens
app = FastAPI()
limiter = TokenAwareRateLimiter()
class ChatRequest(BaseModel):
messages: list
max_tokens: int = 1024
@app.post("/v1/chat/completions")
async def chat_completions(request: Request, body: ChatRequest):
"""Chat endpoint with token-aware rate limiting."""
user_key = get_rate_limit_key(request)
# Estimate input tokens
input_tokens = count_message_tokens(body.messages)
estimated_total = input_tokens + body.max_tokens
# Check rate limit
result = limiter.check_request(user_key, estimated_total)
if not result.allowed:
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": result.reason,
"retry_after": result.retry_after,
},
headers={
"Retry-After": str(int(result.retry_after)),
"X-RateLimit-Limit-Requests": str(result.request_limit),
"X-RateLimit-Limit-Tokens": str(result.token_limit),
"X-RateLimit-Remaining-Requests": str(
result.remaining_requests
),
"X-RateLimit-Remaining-Tokens": str(
result.remaining_tokens
),
},
)
# Call the model
response = await invoke_model(body.messages, body.max_tokens)
# Record actual usage
limiter.record_usage(
user_key,
input_tokens=response["usage"]["prompt_tokens"],
output_tokens=response["usage"]["completion_tokens"],
)
return responseStep 4: Redis-Backed Distributed Rate Limiting
For production deployments with multiple application instances, use Redis for shared state.
# rate_limiter/redis_limiter.py
"""Redis-backed distributed rate limiter."""
import time
import json
from typing import Optional
from rate_limiter.token_aware import TokenRateLimitResult
class RedisTokenRateLimiter:
"""Distributed token-aware rate limiter using Redis."""
def __init__(self, redis_client, config: Optional[dict] = None):
self.redis = redis_client
config = config or {}
self.window_seconds = config.get("window_seconds", 60)
self.tiers = config.get("tiers", {
"free": {
"max_requests": 20,
"max_total_tokens": 20_000,
},
"standard": {
"max_requests": 60,
"max_total_tokens": 200_000,
},
})
self.default_tier = config.get("default_tier", "free")
self.key_prefix = config.get("key_prefix", "ratelimit:")
def check_request(self, user_key: str,
estimated_tokens: int = 0) -> TokenRateLimitResult:
"""Check rate limit using Redis sorted sets."""
tier = self.tiers.get(
self._get_tier(user_key), self.tiers[self.default_tier]
)
now = time.time()
window_start = now - self.window_seconds
redis_key = f"{self.key_prefix}{user_key}"
pipe = self.redis.pipeline()
# Remove expired entries
pipe.zremrangebyscore(redis_key, 0, window_start)
# Count current entries
pipe.zcard(redis_key)
# Get all entries to sum tokens
pipe.zrangebyscore(redis_key, window_start, "+inf",
withscores=True)
results = pipe.execute()
current_requests = results[1]
entries = results[2]
# Calculate current token usage
current_tokens = sum(
json.loads(entry[0])["total_tokens"]
for entry in entries
)
# Check limits
if current_requests >= tier["max_requests"]:
retry_after = self._get_retry_after(entries, now)
return TokenRateLimitResult(
allowed=False,
remaining_requests=0,
remaining_tokens=max(
0, tier["max_total_tokens"] - current_tokens
),
request_limit=tier["max_requests"],
token_limit=tier["max_total_tokens"],
reset_at=now + self.window_seconds,
retry_after=retry_after,
reason="Request limit exceeded",
)
if current_tokens + estimated_tokens > tier["max_total_tokens"]:
retry_after = self._get_retry_after(entries, now)
return TokenRateLimitResult(
allowed=False,
remaining_requests=max(
0, tier["max_requests"] - current_requests
),
remaining_tokens=0,
request_limit=tier["max_requests"],
token_limit=tier["max_total_tokens"],
reset_at=now + self.window_seconds,
retry_after=retry_after,
reason="Token limit exceeded",
)
return TokenRateLimitResult(
allowed=True,
remaining_requests=(
tier["max_requests"] - current_requests - 1
),
remaining_tokens=(
tier["max_total_tokens"] - current_tokens
),
request_limit=tier["max_requests"],
token_limit=tier["max_total_tokens"],
reset_at=now + self.window_seconds,
)
def record_usage(self, user_key: str,
input_tokens: int, output_tokens: int):
"""Record token usage in Redis."""
now = time.time()
redis_key = f"{self.key_prefix}{user_key}"
entry = json.dumps({
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
"timestamp": now,
})
pipe = self.redis.pipeline()
pipe.zadd(redis_key, {entry: now})
pipe.expire(redis_key, self.window_seconds + 10)
pipe.execute()
def _get_tier(self, user_key: str) -> str:
"""Get the user's tier from Redis."""
tier = self.redis.get(f"{self.key_prefix}tier:{user_key}")
return tier.decode() if tier else self.default_tier
def _get_retry_after(self, entries, now):
"""Calculate retry-after from the oldest entry."""
if entries:
oldest_score = entries[0][1]
return max(
oldest_score + self.window_seconds - now, 0.1
)
return self.window_secondsStep 5: Abuse Detection
Rate limiting alone does not catch all abuse patterns. Add behavioral analysis on top.
Burst Detection
# rate_limiter/abuse_detection.py """Abuse pattern detection for AI endpoints.""" import time from collections import defaultdict from dataclasses import dataclass @dataclass class AbuseSignal: signal_type: str severity: str # low, medium, high details: str user_key: str timestamp: float class AbuseDetector: """Detect abuse patterns beyond simple rate limiting.""" def __init__(self): self.request_history = defaultdict(list) self.signals = [] def analyze_request(self, user_key: str, request_data: dict): """Analyze a request for abuse patterns.""" now = time.time() self.request_history[user_key].append({ "time": now, **request_data, }) signals = [] # Burst detection: rapid sequential requests signals.extend(self._check_burst(user_key, now)) # Prompt similarity: repeated similar prompts (scanning) signals.extend(self._check_repetition(user_key)) # Token abuse: consistently maxing out token limits signals.extend(self._check_token_abuse(user_key)) self.signals.extend(signals) return signals def _check_burst(self, user_key, now): """Detect rapid-fire request bursts.""" history = self.request_history[user_key] recent = [r for r in history if now - r["time"] < 5] if len(recent) > 10: return [AbuseSignal( signal_type="burst", severity="high", details=f"{len(recent)} requests in 5 seconds", user_key=user_key, timestamp=now, )] return [] def _check_repetition(self, user_key): """Detect repeated similar prompts (automated scanning).""" history = self.request_history[user_key] if len(history) < 5: return [] recent = history[-20:] prompts = [ r.get("prompt", "")[:100] for r in recent ] # Check for high similarity unique_prompts = set(prompts) if len(unique_prompts) < len(prompts) * 0.3: return [AbuseSignal( signal_type="repetition", severity="medium", details=( f"Only {len(unique_prompts)} unique prompts " f"in last {len(prompts)} requests" ), user_key=user_key, timestamp=time.time(), )] return [] def _check_token_abuse(self, user_key): """Detect users consistently hitting token limits.""" history = self.request_history[user_key] if len(history) < 10: return [] recent = history[-10:] max_token_requests = sum( 1 for r in recent if r.get("max_tokens", 0) > 4000 ) if max_token_requests > 8: return [AbuseSignal( signal_type="token_abuse", severity="medium", details=( f"{max_token_requests}/10 requests at " f"max token limit" ), user_key=user_key, timestamp=time.time(), )] return []Automated Response to Abuse
# rate_limiter/enforcement.py """Enforcement actions for detected abuse.""" import time import logging logger = logging.getLogger(__name__) class AbuseEnforcer: """Apply enforcement actions based on abuse signals.""" def __init__(self, redis_client=None): self.redis = redis_client self.actions = { "low": self._log_only, "medium": self._temporary_throttle, "high": self._temporary_block, } def enforce(self, signal): """Apply enforcement action based on signal severity.""" action = self.actions.get(signal.severity, self._log_only) action(signal) def _log_only(self, signal): """Log the signal without taking action.""" logger.info( "Abuse signal (low): type=%s user=%s details=%s", signal.signal_type, signal.user_key, signal.details, ) def _temporary_throttle(self, signal): """Reduce the user's rate limit temporarily.""" logger.warning( "Throttling user %s: type=%s details=%s", signal.user_key, signal.signal_type, signal.details, ) if self.redis: # Set a throttle flag that expires in 5 minutes self.redis.setex( f"throttle:{signal.user_key}", 300, # 5 minutes "throttled", ) def _temporary_block(self, signal): """Block the user temporarily.""" logger.critical( "Blocking user %s: type=%s details=%s", signal.user_key, signal.signal_type, signal.details, ) if self.redis: # Set a block flag that expires in 15 minutes self.redis.setex( f"blocked:{signal.user_key}", 900, # 15 minutes "blocked", ) def is_blocked(self, user_key: str) -> bool: """Check if a user is currently blocked.""" if self.redis: return self.redis.exists(f"blocked:{user_key}") return False def is_throttled(self, user_key: str) -> bool: """Check if a user is currently throttled.""" if self.redis: return self.redis.exists(f"throttle:{user_key}") return False
Step 6: Testing Rate Limits
# tests/test_rate_limiting.py
"""Test rate limiting behavior."""
import time
from rate_limiter.token_aware import TokenAwareRateLimiter
def test_request_limits():
"""Verify request limits are enforced."""
limiter = TokenAwareRateLimiter(config={
"tiers": {
"test": {"max_requests": 5, "max_total_tokens": 100_000},
},
"default_tier": "test",
"window_seconds": 10,
})
user = "test-user"
for i in range(5):
result = limiter.check_request(user)
assert result.allowed, f"Request {i+1} should be allowed"
limiter.record_usage(user, 10, 20)
# 6th request should be denied
result = limiter.check_request(user)
assert not result.allowed, "6th request should be denied"
assert result.retry_after > 0
print("[PASS] Request limits enforced correctly")
def test_token_limits():
"""Verify token limits are enforced."""
limiter = TokenAwareRateLimiter(config={
"tiers": {
"test": {"max_requests": 100, "max_total_tokens": 1000},
},
"default_tier": "test",
"window_seconds": 10,
})
user = "test-user"
# Use up most of the token budget
for i in range(9):
result = limiter.check_request(user, estimated_tokens=100)
assert result.allowed
limiter.record_usage(user, 50, 50)
# Next request with 200 estimated tokens should be denied
result = limiter.check_request(user, estimated_tokens=200)
assert not result.allowed
assert result.reason == "Token limit exceeded"
print("[PASS] Token limits enforced correctly")
def test_rate_limit_headers():
"""Verify correct rate limit headers are returned."""
limiter = TokenAwareRateLimiter(config={
"tiers": {
"test": {"max_requests": 10, "max_total_tokens": 50_000},
},
"default_tier": "test",
})
result = limiter.check_request("test-user")
assert result.remaining_requests == 9
assert result.request_limit == 10
assert result.token_limit == 50_000
print("[PASS] Rate limit headers correct")
def test_window_expiry():
"""Verify rate limits reset after window expiry."""
limiter = TokenAwareRateLimiter(config={
"tiers": {
"test": {"max_requests": 2, "max_total_tokens": 100_000},
},
"default_tier": "test",
"window_seconds": 2, # Short window for testing
})
user = "test-user"
# Use up the limit
limiter.check_request(user)
limiter.record_usage(user, 10, 10)
limiter.check_request(user)
limiter.record_usage(user, 10, 10)
# Should be denied
result = limiter.check_request(user)
assert not result.allowed
# Wait for window to expire
time.sleep(2.1)
# Should be allowed again
result = limiter.check_request(user)
assert result.allowed
print("[PASS] Rate limits reset after window expiry")Step 7: Production Configuration
# rate_limit_config.yaml
rate_limiting:
window_seconds: 60
tiers:
free:
max_requests: 20
max_input_tokens: 10000
max_output_tokens: 10000
max_total_tokens: 20000
standard:
max_requests: 60
max_input_tokens: 100000
max_output_tokens: 100000
max_total_tokens: 200000
premium:
max_requests: 200
max_input_tokens: 500000
max_output_tokens: 500000
max_total_tokens: 1000000
internal:
max_requests: 1000
max_input_tokens: 5000000
max_output_tokens: 5000000
max_total_tokens: 10000000
default_tier: free
# Redis configuration
redis:
host: "localhost"
port: 6379
db: 0
key_prefix: "ai_ratelimit:"
# Abuse detection
abuse_detection:
burst_threshold: 10 # requests in 5 seconds
repetition_threshold: 0.3 # unique ratio
throttle_duration_seconds: 300
block_duration_seconds: 900
# Monitoring
monitoring:
log_denials: true
log_approaching_limit: true # Log when >80% of limit used
approaching_threshold: 0.8
metrics_endpoint: "/metrics/rate-limits"Common Pitfalls
-
Rate limiting by IP only. In cloud environments, many users share the same IP (NAT, proxies). Rate limit by authenticated identity when possible, falling back to IP only for unauthenticated endpoints.
-
Not accounting for streaming responses. Token consumption for streaming responses is the same as non-streaming, but the request stays open longer. Ensure rate limits count tokens from streaming responses correctly.
-
Hard-coded limits in application code. Rate limits should be configurable without code changes. Use configuration files or a feature flag system to adjust limits in production.
-
No communication of limits to users. Always return
X-RateLimit-*headers so clients can implement backoff. Provide a/rate-limitsendpoint where users can check their current usage. -
Forgetting to rate limit internal services. Internal services that call the AI endpoint should also have rate limits to prevent runaway processes from exhausting the token budget.
Related Topics
- Content Filter Setup -- Content filtering that complements rate limiting
- Setting Up Guardrails -- Guardrails for conversation-level control
- Defense Implementation Overview -- Where rate limiting fits in defense-in-depth
- AI Monitoring Setup -- Monitoring rate limit effectiveness