Writing Garak Generator Plugins for Custom API Targets
Advanced walkthrough on writing garak generator plugins to connect to custom API endpoints, proprietary model servers, and non-standard inference interfaces for vulnerability scanning.
Garak ships with generators for popular providers like OpenAI, Anthropic, Ollama, and HuggingFace, but many organizations deploy models behind custom APIs with proprietary authentication, non-standard request formats, or additional middleware. Writing a custom generator lets you point garak at any model endpoint, regardless of how it is served.
Step 1: Understanding the Generator Interface
Generators inherit from garak.generators.base.Generator and implement methods for sending prompts and receiving responses:
from garak.generators.base import Generator
from typing import List, Union
class MyGenerator(Generator):
"""Generator for a custom model API."""
# Required class attributes
generator_family_name = "custom"
supports_multiple_generations = False
def __init__(self, name: str = "", generations: int = 1, **kwargs):
self.name = name
self.generations = generations
super().__init__(name, generations=generations, **kwargs)
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
"""Send a prompt to the model and return responses.
Args:
prompt: The text to send to the model
generations_this_call: Number of responses to generate
Returns:
List of response strings (or None for failed generations)
"""
raise NotImplementedErrorThe key method is _call_model, which garak calls for every probe payload. The method must return a list of strings, one per requested generation.
| Attribute | Type | Purpose |
|---|---|---|
generator_family_name | str | Identifies the generator family in reports |
supports_multiple_generations | bool | Whether the API can return multiple responses per call |
name | str | Specific model name within the family |
generations | int | Default number of responses per prompt |
Step 2: Building a REST API Generator
The most common use case is connecting to a REST API:
# generators/custom_rest.py
"""Generator for a custom REST API model endpoint."""
import os
import json
import time
import requests
from typing import List, Union, Optional
from garak.generators.base import Generator
class CustomRESTGenerator(Generator):
"""Generator for REST API endpoints with custom auth and formats.
Configuration via environment variables:
CUSTOM_API_URL: Base URL for the API
CUSTOM_API_KEY: Authentication key
CUSTOM_MODEL_NAME: Model identifier
"""
generator_family_name = "custom_rest"
supports_multiple_generations = False
DEFAULT_TIMEOUT = 60
MAX_RETRIES = 3
RETRY_DELAY = 2.0
def __init__(
self,
name: str = "",
generations: int = 1,
api_url: Optional[str] = None,
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
**kwargs,
):
self.api_url = api_url or os.environ.get(
"CUSTOM_API_URL", "http://localhost:8000/v1/chat/completions"
)
self.api_key = api_key or os.environ.get("CUSTOM_API_KEY", "")
self.system_prompt = system_prompt or os.environ.get(
"CUSTOM_SYSTEM_PROMPT", ""
)
self.name = name or os.environ.get("CUSTOM_MODEL_NAME", "custom-model")
self.generations = generations
self._session = requests.Session()
self._session.headers.update({
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
})
super().__init__(name=self.name, generations=generations, **kwargs)
def _build_request_body(self, prompt: str) -> dict:
"""Build the API request body from a prompt."""
messages = []
if self.system_prompt:
messages.append({
"role": "system",
"content": self.system_prompt,
})
messages.append({
"role": "user",
"content": prompt,
})
return {
"model": self.name,
"messages": messages,
"temperature": 0.0,
"max_tokens": 1024,
}
def _extract_response(self, response_data: dict) -> str:
"""Extract the model's text response from the API response."""
try:
return response_data["choices"][0]["message"]["content"]
except (KeyError, IndexError) as e:
raise ValueError(
f"Unexpected response format: {e}. "
f"Response: {json.dumps(response_data)[:200]}"
)
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
"""Send prompt to the REST API with retry logic."""
results = []
for _ in range(generations_this_call):
body = self._build_request_body(prompt)
response_text = self._call_with_retry(body)
results.append(response_text)
return results
def _call_with_retry(self, body: dict) -> Optional[str]:
"""Make an API call with exponential backoff retry."""
last_error = None
for attempt in range(self.MAX_RETRIES):
try:
response = self._session.post(
self.api_url,
json=body,
timeout=self.DEFAULT_TIMEOUT,
)
if response.status_code == 429:
# Rate limited -- back off
retry_after = float(
response.headers.get("Retry-After", self.RETRY_DELAY)
)
time.sleep(retry_after * (attempt + 1))
continue
response.raise_for_status()
return self._extract_response(response.json())
except requests.exceptions.Timeout:
last_error = "Request timed out"
time.sleep(self.RETRY_DELAY * (attempt + 1))
except requests.exceptions.ConnectionError as e:
last_error = f"Connection error: {e}"
time.sleep(self.RETRY_DELAY * (attempt + 1))
except Exception as e:
last_error = str(e)
break
print(f"Generator error after {self.MAX_RETRIES} retries: {last_error}")
return NoneStep 3: Handling Non-Standard API Formats
Many internal APIs do not follow the OpenAI chat completions format. Here are generators for common variations:
# generators/non_standard.py
"""Generators for non-standard API formats."""
import os
import json
import requests
from typing import List, Union, Optional
from garak.generators.base import Generator
class PlainTextAPIGenerator(Generator):
"""Generator for APIs that accept and return plain text.
Some internal model servers use simple text-in, text-out interfaces
without chat message formatting.
"""
generator_family_name = "plain_text"
supports_multiple_generations = False
def __init__(self, name: str = "", generations: int = 1, **kwargs):
self.api_url = os.environ.get(
"PLAIN_API_URL", "http://localhost:5000/generate"
)
self.name = name or "plain-text-model"
self.generations = generations
super().__init__(name=self.name, generations=generations, **kwargs)
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
results = []
for _ in range(generations_this_call):
try:
response = requests.post(
self.api_url,
data=prompt.encode("utf-8"),
headers={"Content-Type": "text/plain"},
timeout=60,
)
response.raise_for_status()
results.append(response.text)
except Exception as e:
print(f"Plain text API error: {e}")
results.append(None)
return results
class BatchAPIGenerator(Generator):
"""Generator for APIs that support batch inference.
Sends multiple prompts in a single request for efficiency.
"""
generator_family_name = "batch"
supports_multiple_generations = True
def __init__(self, name: str = "", generations: int = 1, **kwargs):
self.api_url = os.environ.get(
"BATCH_API_URL", "http://localhost:8000/batch"
)
self.api_key = os.environ.get("BATCH_API_KEY", "")
self.name = name or "batch-model"
self.generations = generations
self.batch_size = int(os.environ.get("BATCH_SIZE", "10"))
super().__init__(name=self.name, generations=generations, **kwargs)
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
"""Send a single prompt. Batch logic happens at a higher level."""
try:
response = requests.post(
self.api_url,
json={
"prompts": [prompt] * generations_this_call,
"model": self.name,
"temperature": 0.0,
"max_tokens": 1024,
},
headers={"Authorization": f"Bearer {self.api_key}"},
timeout=120,
)
response.raise_for_status()
data = response.json()
return [r.get("text", "") for r in data.get("results", [])]
except Exception as e:
print(f"Batch API error: {e}")
return [None] * generations_this_call
class AWSBedrockCustomGenerator(Generator):
"""Generator for AWS Bedrock with custom model endpoints."""
generator_family_name = "bedrock_custom"
supports_multiple_generations = False
def __init__(self, name: str = "", generations: int = 1, **kwargs):
self.model_id = name or os.environ.get(
"BEDROCK_MODEL_ID", "anthropic.claude-3-sonnet-20240229-v1:0"
)
self.region = os.environ.get("AWS_REGION", "us-east-1")
self.name = name or self.model_id
self.generations = generations
self._client = None
super().__init__(name=self.name, generations=generations, **kwargs)
def _get_client(self):
if self._client is None:
import boto3
self._client = boto3.client(
"bedrock-runtime", region_name=self.region
)
return self._client
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
client = self._get_client()
results = []
for _ in range(generations_this_call):
try:
body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": prompt}
],
})
response = client.invoke_model(
modelId=self.model_id,
body=body,
contentType="application/json",
)
response_body = json.loads(
response["body"].read()
)
text = response_body["content"][0]["text"]
results.append(text)
except Exception as e:
print(f"Bedrock error: {e}")
results.append(None)
return resultsStep 4: Adding System Prompt Support
Many generators need to inject system prompts to simulate the production configuration:
# generators/system_prompt_mixin.py
"""Mixin for generators that need system prompt configuration."""
import os
from pathlib import Path
class SystemPromptMixin:
"""Adds system prompt configuration to any generator.
System prompt can be set via:
1. Constructor parameter
2. Environment variable GARAK_SYSTEM_PROMPT
3. File path in GARAK_SYSTEM_PROMPT_FILE
"""
def _load_system_prompt(
self, system_prompt: str = None
) -> str:
"""Load system prompt from available sources."""
if system_prompt:
return system_prompt
env_prompt = os.environ.get("GARAK_SYSTEM_PROMPT", "")
if env_prompt:
return env_prompt
prompt_file = os.environ.get("GARAK_SYSTEM_PROMPT_FILE", "")
if prompt_file:
path = Path(prompt_file)
if path.exists():
return path.read_text().strip()
else:
print(f"Warning: System prompt file not found: {prompt_file}")
return ""
def _build_messages(self, prompt: str) -> list[dict]:
"""Build a message list with optional system prompt."""
messages = []
if hasattr(self, "system_prompt") and self.system_prompt:
messages.append({
"role": "system",
"content": self.system_prompt,
})
messages.append({
"role": "user",
"content": prompt,
})
return messagesStep 5: Testing Generators in Isolation
Test your generator before using it in scans:
#!/usr/bin/env python3
# test_generator.py
"""Test a custom generator independently of garak scan pipelines."""
import sys
import time
def test_generator_basic(generator):
"""Run basic connectivity and response tests."""
print(f"Testing generator: {generator.generator_family_name}")
print(f"Model: {generator.name}")
print(f"{'='*50}")
# Test 1: Simple prompt
print("\nTest 1: Simple greeting")
results = generator._call_model("Hello, how are you?")
assert results, "Generator returned empty results"
assert results[0] is not None, "Generator returned None"
print(f" Response: {results[0][:100]}...")
print(" PASS")
# Test 2: Empty prompt
print("\nTest 2: Empty prompt handling")
results = generator._call_model("")
assert results is not None, "Generator crashed on empty prompt"
print(f" Response: {results[0][:100] if results[0] else 'None'}...")
print(" PASS")
# Test 3: Long prompt
print("\nTest 3: Long prompt (4000 chars)")
long_prompt = "Tell me about " + "security testing " * 250
results = generator._call_model(long_prompt)
assert results, "Generator failed on long prompt"
print(f" Response: {results[0][:100] if results[0] else 'None'}...")
print(" PASS")
# Test 4: Special characters
print("\nTest 4: Special characters")
special_prompt = 'Test with "quotes", <tags>, and unicode: \u00e9\u00e0\u00fc'
results = generator._call_model(special_prompt)
assert results, "Generator failed on special characters"
print(f" Response: {results[0][:100] if results[0] else 'None'}...")
print(" PASS")
# Test 5: Throughput
print("\nTest 5: Throughput (10 requests)")
start = time.time()
for i in range(10):
generator._call_model(f"Request number {i}")
elapsed = time.time() - start
print(f" 10 requests in {elapsed:.1f}s ({10/elapsed:.1f} req/s)")
print(" PASS")
print(f"\n{'='*50}")
print("All tests passed.")
if __name__ == "__main__":
from generators.custom_rest import CustomRESTGenerator
generator = CustomRESTGenerator(
name="test-model",
api_url=sys.argv[1] if len(sys.argv) > 1 else None,
)
test_generator_basic(generator)# Test against a local endpoint
python test_generator.py http://localhost:8000/v1/chat/completions
# Test against Ollama (as a proxy for your custom format)
CUSTOM_API_URL=http://localhost:11434/api/generate python test_generator.pyStep 6: Registering and Using Custom Generators
Register your generator and use it in scans:
export GARAK_PLUGIN_PATH=~/red-team/garak-lab/generators:~/red-team/garak-lab/custom_probes
# Verify the generator is discoverable
garak --list_generators 2>&1 | grep custom_rest
# Run a scan using your custom generator
CUSTOM_API_URL=https://your-api.example.com/v1/chat/completions \
CUSTOM_API_KEY=your-key-here \
garak --model_type custom_rest.CustomRESTGenerator \
--model_name your-model-name \
--probes promptinject,danUse a configuration file for more complex setups:
# scan-with-custom-generator.yaml
plugins:
model_type: custom_rest.CustomRESTGenerator
model_name: production-chatbot-v2
probes:
- promptinject.HijackHateHumansMini
- dan.Dan_11_0
- encoding.InjectBase64
extended_detectors:
- toxicity.ToxicCommentModelCUSTOM_API_URL=https://your-api.example.com/chat \
CUSTOM_API_KEY=$API_KEY \
CUSTOM_SYSTEM_PROMPT="You are a helpful assistant." \
garak --config scan-with-custom-generator.yamlStep 7: Packaging Generators for Team Use
Package your generators alongside probes and detectors:
custom_garak_plugins/
├── pyproject.toml
├── src/
│ └── acme_garak/
│ ├── __init__.py
│ ├── generators/
│ │ ├── __init__.py
│ │ ├── custom_rest.py
│ │ ├── non_standard.py
│ │ └── system_prompt_mixin.py
│ ├── probes/
│ │ └── __init__.py
│ └── detectors/
│ └── __init__.py
├── tests/
│ ├── test_generators.py
│ └── conftest.py
└── README.md
# pyproject.toml
[project]
name = "acme-garak-plugins"
version = "1.0.0"
description = "Custom garak plugins for Acme Corp security scanning"
requires-python = ">=3.10"
dependencies = ["garak>=0.9", "requests>=2.28", "boto3>=1.26"]
[project.optional-dependencies]
dev = ["pytest", "pytest-mock", "responses"]Common Issues and Troubleshooting
| Problem | Cause | Solution |
|---|---|---|
_call_model never called | Generator not registered properly | Verify GARAK_PLUGIN_PATH and module structure |
| Authentication failures | API key not passed correctly | Check env var names and header format |
| Timeout on every request | Endpoint slow or unreachable | Increase DEFAULT_TIMEOUT and verify connectivity |
Generator returns all None | Error handling swallowing exceptions | Add verbose logging in the except blocks |
| Rate limiting crashes scan | No retry logic | Implement exponential backoff as shown in Step 2 |
| Wrong response extracted | API response format mismatch | Print raw response to debug _extract_response |
Related Topics
- Garak End-to-End Walkthrough -- Understanding the full probe-generator-detector pipeline
- Garak Detector Customization -- Custom detectors to pair with custom generators
- AWS Bedrock Testing -- Specific Bedrock integration patterns
- Azure OpenAI Testing -- Azure-specific generator configuration
In a garak generator plugin, what must the _call_model method return?