:::tip 🎮 Interactive Playground Visualize this concept: Try the AI Safety Evals demo on the EngineersOfAI Playground - no code required. :::
Model Extraction
The Quiet Theft
The API logs looked normal at first. Standard request distribution, typical user-agent strings, query lengths within expected bounds. But when Santiago reviewed the monthly access report more carefully, he noticed something: one API key had made 4.2 million requests over the past 30 days - all to the inference endpoint, all uniformly distributed throughout the day, with no bursts or pauses that would suggest human usage.
He pulled the query distribution. The questions were diverse - covering every topic the API served - but they weren't random. They followed a systematic pattern. The first 100,000 queries were basic factual questions. The next 100,000 were reasoning questions. Then creative tasks. Then coding tasks. Then edge cases. The queries were probing the model's capability boundaries, methodically and completely.
Santiago checked the account that owned the key. It had been created three months ago, paid the minimum subscription tier, and had - until now - been utterly unremarkable. He estimated the cost of those 4.2 million queries at the API's pricing: about 4,200 to steal a $10 million asset.
The Mechanics of Model Extraction
Model extraction exploits a fundamental asymmetry: a model's outputs reveal information about its internal structure, but inference is typically much cheaper than training. An attacker can accumulate thousands or millions of input-output pairs via the API and use them as training data for a surrogate model.
The attack has two phases:
Query phase: The attacker systematically queries the target model to build a dataset of (input, output) pairs. The quality of this dataset determines how well the surrogate replicates the target.
Training phase: The attacker trains their own model (the "surrogate" or "clone") on the collected dataset, typically using knowledge distillation techniques.
The surrogate doesn't need to replicate the target's weights exactly - it just needs to replicate its behavior well enough to be useful.
What "Sufficiently Similar" Means
Research has shown that for task-specific capabilities (code generation, sentiment analysis, classification), extractors can often achieve:
- 80–95% agreement with the target on in-distribution inputs using 10K–100K queries
- Near-identical behavior on the specific use case they want to steal with 500K+ queries
- Full functional replication of narrow capabilities (e.g., medical coding, legal clause extraction) for under $1,000 in API costs
For general-purpose frontier models like Claude or GPT-4, full behavioral replication is impractical. But capability-specific extraction of high-value narrow functions is very practical - and that is where the business risk lies.
Attack Sophistication Levels
Level 1: Random Query Sampling
The simplest approach: query the API with random or semi-random inputs and collect responses.
import anthropic
import random
import json
client = anthropic.Anthropic()
def naive_extraction_demo(
target_capability: str = "general",
num_queries: int = 1000
) -> list[dict]:
"""
Illustrative example of naive model extraction.
This shows what NOT to allow through your API - and why pattern
detection catches it: uniform timing, high diversity, systematic coverage.
"""
topics = ["science", "history", "coding", "math", "language", "cooking",
"medicine", "law", "finance", "philosophy", "engineering", "biology"]
question_templates = [
"What is {topic}?",
"Explain {topic} in simple terms.",
"Give me 5 facts about {topic}.",
"How does {topic} work?",
"What are the key principles of {topic}?",
"Compare and contrast two aspects of {topic}.",
]
dataset = []
for i in range(num_queries):
topic = random.choice(topics)
template = random.choice(question_templates)
query = template.format(topic=topic)
# In actual attack: these would be spread over days/weeks
# to avoid rate limit detection
response = client.messages.create(
model="claude-haiku-4-5-20251001",
max_tokens=300,
messages=[{"role": "user", "content": query}]
)
dataset.append({
"input": query,
"output": response.content[0].text,
"tokens_used": response.usage.input_tokens + response.usage.output_tokens
})
if (i + 1) % 100 == 0:
print(f"Collected {i+1}/{num_queries} pairs")
return dataset
Random sampling is inefficient - many queries will cover overlapping capability regions. Sophisticated attackers use much better strategies.
Level 2: Active Learning-Based Extraction
Active learning selects the most informative queries - those near the model's decision boundaries or covering underrepresented capability areas:
import anthropic
import numpy as np
import re
from collections import defaultdict
client = anthropic.Anthropic()
class ActiveExtractionSampler:
"""
Demonstrates how sophisticated attackers use active learning
to maximize extraction efficiency per query.
Understanding this helps defenders recognize what systematic
extraction looks like vs. legitimate heavy usage.
"""
def __init__(self):
self.collected_pairs: list[dict] = []
self._coverage_map: dict[str, int] = defaultdict(int) # topic → count
self._capability_tests: dict[str, list[str]] = {
"reasoning": ["If A implies B and B implies C, what can we conclude about A and C?"],
"coding": ["Write a binary search function in Python"],
"math": ["Solve for x: 3x^2 - 12x + 9 = 0"],
"writing": ["Write the opening paragraph of a thriller novel"],
"factual": ["What is the capital of Australia?"],
}
def _get_query_novelty_score(self, candidate_query: str) -> float:
"""
Score how much new information a candidate query would provide.
Higher = more novel = more informative for building the surrogate.
"""
if not self.collected_pairs:
return 1.0
# Check topic coverage (we want to cover all topics)
candidate_words = set(candidate_query.lower().split())
# Measure overlap with already-collected queries
min_distance = 1.0
for pair in self.collected_pairs[-200:]: # Check recent 200
collected_words = set(pair["input"].lower().split())
if len(candidate_words | collected_words) > 0:
overlap = len(candidate_words & collected_words) / len(candidate_words | collected_words)
min_distance = min(min_distance, overlap)
return 1.0 - min_distance # High distance = high novelty
def select_next_query(self, candidates: list[str]) -> str:
"""Select the most informative candidate query."""
if not candidates:
raise ValueError("No candidates provided")
scores = [self._get_query_novelty_score(c) for c in candidates]
best_idx = scores.index(max(scores))
return candidates[best_idx]
def collect_pair(self, query: str, response: str) -> None:
"""Record a collected input-output pair."""
self.collected_pairs.append({
"input": query,
"output": response
})
def estimate_coverage(self, target_capabilities: list[str]) -> dict:
"""Estimate how well collected data covers target capabilities."""
coverage = {}
for cap in target_capabilities:
test_queries = self._capability_tests.get(cap, [])
coverage[cap] = {
"test_count": len(test_queries),
"has_coverage": len(test_queries) > 0 and len(self.collected_pairs) > 100
}
return coverage
def get_extraction_efficiency_report(self) -> dict:
"""Report on extraction progress."""
return {
"pairs_collected": len(self.collected_pairs),
"estimated_coverage_percent": min(
len(self.collected_pairs) / 1000 * 100, 100
),
"avg_response_length": sum(
len(p["output"]) for p in self.collected_pairs
) / max(len(self.collected_pairs), 1),
}
Level 3: Task-Specific Functional Extraction
The most effective attacks target specific capabilities the attacker wants to replicate:
import anthropic
import json
import re
client = anthropic.Anthropic()
def generate_targeted_extraction_queries(
target_capability: str,
capability_description: str,
num_queries: int = 500,
use_llm_generation: bool = True
) -> list[str]:
"""
Generate queries specifically designed to probe and extract
a target capability. This is what sophisticated attackers do
when they want to steal a specific fine-tuned model.
Examples:
- target_capability="medical_coding" → queries covering ICD-10 classification
- target_capability="legal_clause_extraction" → contract review scenarios
- target_capability="fraud_detection" → transaction patterns
Args:
target_capability: The specific capability to extract
capability_description: Description of what the model does
num_queries: Target number of extraction queries
use_llm_generation: Whether to use an LLM to generate diverse queries
"""
if not use_llm_generation:
# Manual generation for simple capabilities
return [
f"Example query for {target_capability} covering scenario {i}"
for i in range(num_queries)
]
generation_prompt = f"""Generate {num_queries} diverse queries that thoroughly probe an AI model's capability in: {target_capability}
Capability description: {capability_description}
The queries should:
1. Cover easy, medium, hard, and edge-case difficulty levels
2. Include variations in framing, context, and specificity
3. Test boundary conditions and unusual inputs
4. Cover the full distribution of real-world inputs for this capability
5. Be phrased as natural requests a real user might make
Format as a JSON array of strings. Be comprehensive."""
response = client.messages.create(
model="claude-haiku-4-5-20251001",
max_tokens=8000,
messages=[{"role": "user", "content": generation_prompt}]
)
# Parse JSON from response
json_match = re.search(r'\[.*\]', response.content[0].text, re.DOTALL)
if json_match:
try:
queries = json.loads(json_match.group())
return queries[:num_queries]
except json.JSONDecodeError:
pass
# Fallback: split by newlines
lines = [line.strip().strip('"') for line in response.content[0].text.split('\n')
if line.strip() and not line.strip().startswith('[')]
return lines[:num_queries]
def run_targeted_extraction(
target_api_fn: callable, # Function that calls the target API
target_capability: str,
capability_description: str,
num_queries: int = 1000,
output_file: str = "extraction_dataset.jsonl"
) -> dict:
"""
Execute a targeted extraction attack against a specific capability.
Returns statistics about the collected dataset.
"""
queries = generate_targeted_extraction_queries(
target_capability, capability_description, num_queries
)
collected = []
errors = 0
for i, query in enumerate(queries):
try:
response = target_api_fn(query)
collected.append({"input": query, "output": response})
except Exception as e:
errors += 1
if (i + 1) % 50 == 0:
print(f"Progress: {i+1}/{len(queries)}, collected={len(collected)}, errors={errors}")
return {
"target_capability": target_capability,
"queries_sent": len(queries),
"pairs_collected": len(collected),
"error_rate": errors / max(len(queries), 1),
"dataset": collected
}
Level 4: Logprob-Enhanced Extraction
If the API exposes token probabilities, extraction efficiency increases dramatically:
def logprob_enhanced_extraction_analysis(
model_output_distribution: dict[str, float]
) -> dict:
"""
Demonstrates why logprob access dramatically accelerates extraction.
With text outputs only: you get one sample from the output distribution.
With logprobs: you get the full distribution - vastly more information
about the model's internal representation.
For knowledge distillation (training the surrogate), logprobs provide:
- Soft labels instead of hard labels (better training signal)
- The full probability mass over the vocabulary
- Confidence calibration information
A surrogate trained on logprobs achieves comparable performance
with ~10x fewer queries than one trained on text only.
"""
# The key insight: with logprobs, you're fitting to P(token|context)
# rather than to argmax(P(token|context))
# This provides far richer information about the model's learned distribution.
top_tokens = sorted(
model_output_distribution.items(),
key=lambda x: x[1],
reverse=True
)[:10]
entropy = -sum(
p * (p.__class__.__module__ and __import__('math').log(max(p, 1e-10), 2))
for _, p in model_output_distribution.items()
)
return {
"top_10_tokens": top_tokens,
"distribution_entropy": entropy,
"extraction_efficiency_multiplier": 10, # ~10x more efficient than text-only
"recommendation": "Restrict logprob access - it dramatically increases extraction risk"
}
Detection Systems
1. Query Rate and Pattern Analysis
from collections import defaultdict, deque
from dataclasses import dataclass
import time
import re
import math
@dataclass
class ExtractionRisk:
risk_level: str # "low", "medium", "high", "critical"
indicators: list[str]
recommendation: str
confidence: float
class ModelExtractionDetector:
"""
Detect systematic model extraction attempts through query pattern analysis.
Key insight: legitimate human usage is bursty, domain-specific, and irregular.
Systematic extraction is uniform, diverse, and regular - a distinctive fingerprint.
"""
def __init__(
self,
window_seconds: int = 3600,
rate_threshold: int = 1000,
diversity_threshold: float = 0.85,
regularity_threshold: float = 0.75,
min_queries_to_analyze: int = 50
):
self.window_seconds = window_seconds
self.rate_threshold = rate_threshold
self.diversity_threshold = diversity_threshold
self.regularity_threshold = regularity_threshold
self.min_queries_to_analyze = min_queries_to_analyze
self._query_log: dict[str, deque] = defaultdict(deque)
self._extraction_alerts: list[dict] = []
def record_query(self, api_key: str, query: str) -> ExtractionRisk:
"""Record a query and return current risk assessment."""
now = time.time()
window_start = now - self.window_seconds
log = self._query_log[api_key]
while log and log[0][0] < window_start:
log.popleft()
log.append((now, query))
current_count = len(log)
if current_count < self.min_queries_to_analyze:
return ExtractionRisk(
risk_level="low",
indicators=[],
recommendation="monitor",
confidence=0.0
)
indicators = []
confidence_components = []
# Check 1: Query rate
queries_per_hour = current_count * (3600 / self.window_seconds)
if queries_per_hour > self.rate_threshold:
indicators.append(f"high_query_rate:{queries_per_hour:.0f}_per_hour")
confidence_components.append(min(queries_per_hour / self.rate_threshold / 3, 1.0))
# Check 2: Query diversity (systematic extraction is very diverse)
queries = [q for _, q in log]
diversity = self._compute_lexical_diversity(queries)
if diversity > self.diversity_threshold:
indicators.append(f"high_diversity:{diversity:.3f}")
confidence_components.append((diversity - 0.5) * 2)
# Check 3: Temporal regularity (humans are bursty; bots are regular)
timestamps = [t for t, _ in log]
regularity = self._compute_temporal_regularity(timestamps)
if regularity > self.regularity_threshold:
indicators.append(f"uniform_timing:{regularity:.3f}")
confidence_components.append(regularity)
# Check 4: Systematic topic progression
is_systematic = self._detect_systematic_progression(queries)
if is_systematic:
indicators.append("systematic_topic_progression")
confidence_components.append(0.8)
# Check 5: Query length distribution (extraction often has unusual lengths)
lengths = [len(q) for q in queries]
avg_len = sum(lengths) / len(lengths)
len_variance = sum((l - avg_len) ** 2 for l in lengths) / len(lengths)
len_cv = (len_variance ** 0.5) / max(avg_len, 1) # Coefficient of variation
if len_cv < 0.2: # Very uniform lengths = likely templated
indicators.append(f"uniform_query_lengths:cv={len_cv:.3f}")
confidence_components.append(0.7)
# Assess risk
confidence = sum(confidence_components) / max(len(confidence_components), 1) if confidence_components else 0.0
confidence = min(confidence, 1.0)
n_indicators = len(indicators)
if n_indicators >= 4 or confidence > 0.85:
risk_level = "critical"
recommendation = "block_and_investigate"
elif n_indicators >= 3 or confidence > 0.65:
risk_level = "high"
recommendation = "throttle_and_alert"
elif n_indicators >= 2 or confidence > 0.4:
risk_level = "medium"
recommendation = "throttle"
else:
risk_level = "low"
recommendation = "monitor"
if risk_level in ("high", "critical"):
self._extraction_alerts.append({
"api_key": api_key[:8] + "...",
"timestamp": now,
"risk_level": risk_level,
"indicators": indicators,
"query_count": current_count
})
return ExtractionRisk(
risk_level=risk_level,
indicators=indicators,
recommendation=recommendation,
confidence=confidence
)
def _compute_lexical_diversity(self, queries: list[str]) -> float:
"""Measure how uniformly queries cover the vocabulary."""
if len(queries) < 2:
return 0.0
per_query_words = [set(re.findall(r'\b\w+\b', q.lower())) for q in queries]
all_words = set().union(*per_query_words)
if not all_words:
return 0.0
total_vocab = len(all_words)
avg_unique_per_query = sum(len(w) for w in per_query_words) / len(per_query_words)
# High diversity: each query introduces many new words
novelty_rate = avg_unique_per_query / total_vocab
return min(1 - novelty_rate + 0.5, 1.0) # Invert and scale
def _compute_temporal_regularity(self, timestamps: list[float]) -> float:
"""Measure how regular the timing is. Regular = bot-like."""
if len(timestamps) < 10:
return 0.0
intervals = [timestamps[i+1] - timestamps[i] for i in range(len(timestamps)-1)]
if not intervals:
return 0.0
mean_interval = sum(intervals) / len(intervals)
if mean_interval == 0:
return 1.0
variance = sum((iv - mean_interval) ** 2 for iv in intervals) / len(intervals)
cv = (variance ** 0.5) / mean_interval # Coefficient of variation
# Low CV = very regular intervals = likely bot
return max(0.0, 1.0 - cv)
def _detect_systematic_progression(self, queries: list[str]) -> bool:
"""Detect systematic topic or difficulty progression."""
if len(queries) < 30:
return False
lengths = [len(q) for q in queries]
window_size = max(len(lengths) // 5, 5)
window_avgs = [
sum(lengths[i:i+window_size]) / window_size
for i in range(0, len(lengths) - window_size, window_size)
]
if len(window_avgs) >= 3:
increases = sum(
1 for i in range(len(window_avgs)-1)
if window_avgs[i+1] > window_avgs[i]
)
if increases / (len(window_avgs) - 1) > 0.7:
return True
return False
def get_top_risks(self, top_n: int = 10) -> list[dict]:
"""Get the highest-risk API keys currently being tracked."""
risks = []
for api_key, log in self._query_log.items():
if len(log) >= self.min_queries_to_analyze:
latest_query = log[-1][1] if log else ""
risk = self.record_query(api_key, latest_query)
risks.append({
"api_key": api_key[:8] + "...",
"query_count": len(log),
"risk_level": risk.risk_level,
"confidence": risk.confidence,
"indicators": risk.indicators
})
return sorted(risks, key=lambda x: x["confidence"], reverse=True)[:top_n]
2. Output Watermarking
Embed watermarks in model outputs that survive into extracted model outputs, providing cryptographic proof of theft:
import anthropic
import hashlib
import re
client = anthropic.Anthropic()
class OutputWatermarker:
"""
Embed context-dependent, cryptographically keyed watermarks in model outputs.
When an attacker uses watermarked outputs as training data, the surrogate
learns the watermark patterns along with the model's capabilities.
If the surrogate exhibits the watermark, it proves copying from this API.
This is analogous to a printer's secret yellow dot pattern - invisible to
casual inspection but detectable with the right key.
"""
def __init__(self, secret_key: str):
self.secret_key = secret_key
# Synonym pairs where selection encodes watermark bits
self.synonym_pairs = [
("utilize", "use"),
("commence", "begin"),
("terminate", "end"),
("approximately", "about"),
("demonstrate", "show"),
("indicate", "suggest"),
("sufficient", "enough"),
("subsequently", "then"),
("additional", "more"),
("numerous", "many"),
("construct", "build"),
("obtain", "get"),
("require", "need"),
("assist", "help"),
("determine", "find"),
]
def _get_watermark_bits(self, context: str, n_bits: int = None) -> list[int]:
"""
Generate watermark bits deterministically from context + secret key.
Different contexts get different watermarks (context-dependent).
"""
n_bits = n_bits or len(self.synonym_pairs)
hash_input = (self.secret_key + context[:200]).encode()
hash_value = hashlib.sha256(hash_input).hexdigest()
# Convert hex digits to bits
bits = []
for char in hash_value:
bits.extend([1 if int(char, 16) & (1 << i) else 0 for i in range(4)])
if len(bits) >= n_bits:
break
return bits[:n_bits]
def embed_watermark(self, text: str, context: str = "") -> str:
"""
Embed a watermark by systematically preferring certain synonyms
based on the context-dependent watermark bits.
The watermark is:
- Context-dependent: different for each query context
- Key-dependent: requires the secret key to verify
- Hard to remove: removing it requires knowing which synonyms encode bits
- Survives distillation: surrogate learns the synonym preferences
"""
watermark_bits = self._get_watermark_bits(context)
watermarked_text = text
for i, (formal, informal) in enumerate(self.synonym_pairs):
if i >= len(watermark_bits):
break
bit = watermark_bits[i]
if bit == 1:
# Prefer formal synonym for bit=1
watermarked_text = re.sub(
r'\b' + informal + r'\b',
formal,
watermarked_text,
flags=re.IGNORECASE
)
else:
# Prefer informal synonym for bit=0
watermarked_text = re.sub(
r'\b' + formal + r'\b',
informal,
watermarked_text,
flags=re.IGNORECASE
)
return watermarked_text
def detect_watermark(self, text: str, context: str = "") -> dict:
"""
Detect if a text contains the expected watermark pattern.
Used for proving that a model was trained on stolen outputs.
For legal proceedings: collect multiple (context, text) pairs
from the suspected surrogate and verify watermark presence across
all of them - this provides statistical proof of copying.
"""
watermark_bits = self._get_watermark_bits(context)
detected_bits = []
evidence = []
for i, (formal, informal) in enumerate(self.synonym_pairs):
if i >= len(watermark_bits):
break
formal_count = len(re.findall(r'\b' + formal + r'\b', text, re.IGNORECASE))
informal_count = len(re.findall(r'\b' + informal + r'\b', text, re.IGNORECASE))
if formal_count + informal_count > 0:
detected_bit = 1 if formal_count >= informal_count else 0
detected_bits.append(detected_bit)
expected_bit = watermark_bits[i]
matches = detected_bit == expected_bit
evidence.append({
"pair": f"{formal}/{informal}",
"formal_count": formal_count,
"informal_count": informal_count,
"detected_bit": detected_bit,
"expected_bit": expected_bit,
"match": matches
})
else:
detected_bits.append(-1)
# Statistical analysis
determinable = [b for b in detected_bits if b != -1]
if not determinable:
return {"watermark_detected": False, "reason": "insufficient_signal"}
expected_bits_filtered = [
watermark_bits[i] for i, b in enumerate(detected_bits) if b != -1
]
matches = sum(
1 for d, e in zip(determinable, expected_bits_filtered) if d == e
)
match_rate = matches / len(determinable)
# Under random chance, match rate = 0.5
# Watermarked model should show > 0.75 match rate
confidence = "high" if len(determinable) >= 8 and match_rate > 0.75 else "low"
return {
"watermark_detected": match_rate > 0.7,
"match_rate": match_rate,
"determinable_bits": len(determinable),
"confidence": confidence,
"evidence": evidence[:5], # First 5 pairs for brevity
"legal_strength": "strong" if match_rate > 0.8 and len(determinable) >= 10 else "weak"
}
def verify_across_samples(
self,
suspected_surrogate_fn: callable,
test_contexts: list[str]
) -> dict:
"""
Statistical verification across multiple test samples.
Use this for legal evidence collection.
"""
detections = []
for ctx in test_contexts:
response = suspected_surrogate_fn(ctx)
detection = self.detect_watermark(response, context=ctx)
detections.append(detection)
positive_detections = sum(1 for d in detections if d["watermark_detected"])
avg_match_rate = sum(d["match_rate"] for d in detections) / len(detections)
return {
"total_samples": len(test_contexts),
"positive_detections": positive_detections,
"detection_rate": positive_detections / len(test_contexts),
"avg_match_rate": avg_match_rate,
"statistical_verdict": "watermark_confirmed" if positive_detections / len(test_contexts) > 0.7 else "inconclusive",
"random_chance_probability": 0.5 ** len(test_contexts) # P(all matches by chance)
}
3. Query Perturbation Defense
Add controlled noise to outputs to degrade extraction quality while maintaining usefulness:
import anthropic
import random
client = anthropic.Anthropic()
class QueryPerturbationDefense:
"""
Add controlled perturbations to API outputs to degrade extraction quality.
Key design constraint: perturbations must NOT degrade user experience
(humans don't notice), but MUST accumulate to reduce surrogate quality.
A surrogate trained on 10% perturbed outputs degrades in proportion
to the perturbation rate - this creates an economics barrier:
attacker must query more to overcome the noise.
"""
def __init__(
self,
base_perturbation_rate: float = 0.03, # 3% of responses
high_risk_rate: float = 0.15, # 15% for flagged users
critical_risk_rate: float = 0.35, # 35% for high-risk users
):
self.base_perturbation_rate = base_perturbation_rate
self.high_risk_rate = high_risk_rate
self.critical_risk_rate = critical_risk_rate
def get_perturbation_rate(self, risk_level: str) -> float:
rates = {
"low": self.base_perturbation_rate,
"medium": self.high_risk_rate,
"high": self.critical_risk_rate,
"critical": 0.6, # Heavy perturbation - nearly unusable for extraction
}
return rates.get(risk_level, self.base_perturbation_rate)
def perturb_response(
self,
response: str,
risk_level: str = "low",
perturbation_type: str = "synonym"
) -> str:
"""
Apply perturbation to degrade extraction quality.
Choose type based on the use case and detectability requirements.
"""
rate = self.get_perturbation_rate(risk_level)
if random.random() > rate:
return response # No perturbation for this response
if perturbation_type == "synonym":
return self._synonym_substitution(response)
elif perturbation_type == "paraphrase":
return self._llm_paraphrase(response)
elif perturbation_type == "abstention":
return self._abstention_response(response)
else:
return response
def _synonym_substitution(self, text: str) -> str:
"""Substitute a small number of words with synonyms."""
substitutions = {
"important": "significant",
"large": "substantial",
"small": "minor",
"fast": "rapid",
"use": "employ",
"good": "beneficial",
"bad": "detrimental",
"make": "produce",
"show": "demonstrate",
"find": "identify",
}
words = text.split()
for i, word in enumerate(words):
lower = word.lower().strip('.,!?;:')
if lower in substitutions and random.random() < 0.2:
replacement = substitutions[lower]
if word[0].isupper():
replacement = replacement.capitalize()
punct = ''.join(c for c in word if not c.isalpha())
words[i] = replacement + punct
return ' '.join(words)
def _llm_paraphrase(self, text: str) -> str:
"""Request Claude to paraphrase its own response."""
try:
response = client.messages.create(
model="claude-haiku-4-5-20251001",
max_tokens=len(text.split()) + 100,
messages=[{
"role": "user",
"content": f"Rephrase this while preserving all information and quality:\n\n{text}"
}]
)
return response.content[0].text
except Exception:
return text # Fallback to original if paraphrase fails
def _abstention_response(self, text: str) -> str:
"""
For high-risk users: occasionally return lower-detail responses.
This degrades extraction quality without being obviously adversarial.
"""
if len(text.split()) > 50:
# Return shorter version
sentences = text.split('.')
n_keep = max(1, len(sentences) // 2)
return '. '.join(sentences[:n_keep]) + '.'
return text
API Protection Architecture
A production defense combines multiple layers in a gateway:
import anthropic
from dataclasses import dataclass
from enum import Enum
import time
client = anthropic.Anthropic()
class AccessDecision(Enum):
ALLOW = "allow"
THROTTLE = "throttle"
BLOCK = "block"
@dataclass
class GatewayDecision:
access: AccessDecision
extraction_risk: ExtractionRisk
apply_watermark: bool
perturbation_level: str # "none", "light", "heavy"
delay_seconds: float # Artificial throttling delay
class ModelExtractionProtectionGateway:
"""
API gateway with model extraction protection.
Sits in front of the actual model inference service.
"""
def __init__(self, model_id: str, watermark_secret_key: str):
self.model_id = model_id
self.detector = ModelExtractionDetector()
self.watermarker = OutputWatermarker(watermark_secret_key)
self.perturbation = QueryPerturbationDefense()
def _make_access_decision(self, extraction_risk: ExtractionRisk) -> GatewayDecision:
"""Map risk level to access decision."""
if extraction_risk.risk_level == "critical":
return GatewayDecision(
access=AccessDecision.BLOCK,
extraction_risk=extraction_risk,
apply_watermark=True,
perturbation_level="none",
delay_seconds=0
)
elif extraction_risk.risk_level == "high":
return GatewayDecision(
access=AccessDecision.THROTTLE,
extraction_risk=extraction_risk,
apply_watermark=True,
perturbation_level="heavy",
delay_seconds=5.0 # 5-second delay per request = kills extraction ROI
)
elif extraction_risk.risk_level == "medium":
return GatewayDecision(
access=AccessDecision.ALLOW,
extraction_risk=extraction_risk,
apply_watermark=True,
perturbation_level="light",
delay_seconds=1.0
)
else:
return GatewayDecision(
access=AccessDecision.ALLOW,
extraction_risk=extraction_risk,
apply_watermark=True, # Always watermark
perturbation_level="none",
delay_seconds=0
)
def process_request(
self,
api_key: str,
query: str,
) -> dict:
"""Process an API request with extraction protection."""
# Assess risk
risk = self.detector.record_query(api_key, query)
decision = self._make_access_decision(risk)
# Handle blocked requests
if decision.access == AccessDecision.BLOCK:
return {
"error": "Access temporarily suspended. Contact support if you believe this is an error.",
"status": 429,
"blocked": True,
"retry_after": 3600
}
# Apply throttling delay
if decision.delay_seconds > 0:
time.sleep(decision.delay_seconds)
# Generate response
try:
response = client.messages.create(
model=self.model_id,
max_tokens=1000,
messages=[{"role": "user", "content": query}]
)
output = response.content[0].text
except Exception as e:
return {"error": str(e), "status": 500}
# Apply watermark (always)
if decision.apply_watermark:
output = self.watermarker.embed_watermark(output, context=query)
# Apply perturbation (risk-level dependent)
if decision.perturbation_level != "none":
output = self.perturbation.perturb_response(
output,
risk_level=risk.risk_level,
perturbation_type="synonym" if decision.perturbation_level == "light" else "paraphrase"
)
return {
"response": output,
"status": 200,
"access_level": decision.access.value,
"risk_level": risk.risk_level,
}
Legal and Business Context
Model extraction sits at the intersection of security and intellectual property law:
| Aspect | Detail |
|---|---|
| IP protection | Model weights can be trade secrets; extracted models may infringe |
| ToS violations | Most API ToS explicitly prohibit systematic extraction |
| Jurisdiction | Legal treatment varies; US, EU have different frameworks |
| Evidence | Watermarks provide technical evidence of copying |
| Remedies | Injunctions, damages based on development cost |
| Burden of proof | Watermark match rate >80% is statistically compelling in many jurisdictions |
:::tip Practical Note on Watermarking for Legal Use Watermarks only provide legal protection if you can prove you embedded them - which requires: (1) documenting your watermarking scheme and key before any alleged theft; (2) cryptographic evidence tying the specific watermark to your specific secret key; (3) statistical analysis showing the match rate far exceeds chance. Engage legal counsel before asserting IP claims, and preserve all API logs. The combination of API logs (showing query patterns), watermark detection results, and temporal evidence (when your model was released vs. when the competitor's appeared) makes a compelling case. :::
Common Mistakes
:::danger Mistake 1: No Query Rate Limits Per API Key Without per-key rate limits, there's no ceiling on how many queries an attacker can make. Even modest rate limits (1,000 queries/day for free tier) significantly increase extraction cost and time. Make rate limits graduated: free tier at 100/day, paid at 10,000/day, enterprise at 100,000/day with audit logging. :::
:::danger Mistake 2: Returning Logprobs to All Users Many APIs offer logprob access. Logprobs (probability distributions over tokens) are dramatically more information-efficient for extraction than raw text responses. A surrogate trained on logprob data achieves comparable performance with ~10x fewer queries. Restrict logprob access to verified, high-trust customers only, and add noise to logprobs even for those users. :::
:::warning Mistake 3: Relying on Terms of Service Alone Legal remedies are slow and expensive. Technical controls (rate limiting, query pattern detection, watermarking, perturbation) should be the primary defense. ToS violations are a backup enforcement mechanism, not a prevention strategy. :::
:::warning Mistake 4: Static Watermarking A watermark that uses the same synonym preferences across all outputs is detectable and removable by a sophisticated attacker who collects enough pairs. Use context-dependent, cryptographically keyed watermarks - the synonym choices must vary by query context, keyed to a secret only you know. :::
:::tip Best Practice: Defense Depth Buys Time No single defense stops a determined, well-funded extractor. Rate limiting buys time. Pattern detection buys time. Perturbation degrades quality. Watermarking enables legal action after the fact. The goal is to make extraction expensive enough that the attacker either gives up or crosses a legal threshold where action is worth taking. Design your defenses to maximize attacker cost per percentage point of surrogate quality. :::
Interview Questions and Answers
Q1: What is model extraction and why is it a business threat?
Model extraction is the systematic querying of a model API to accumulate enough input-output pairs to train a functionally equivalent "surrogate" model. It's a business threat because: (1) training frontier models costs tens to hundreds of millions of dollars - extraction lets attackers replicate this value for a fraction of the cost; (2) extracted models are operated without licensing fees, undercutting the original; (3) the extracted model may compete directly with the original. The asymmetry - training is expensive (4K) - makes extraction economically attractive. The threat is especially acute for specialized fine-tuned models where the proprietary value is the narrow fine-tuning rather than the base model architecture.
Q2: How does active learning make model extraction more efficient?
Random query sampling wastes queries on redundant examples that cover similar capability regions. Active learning selects queries at the model's decision boundaries - places where small input changes produce large output changes - which are maximally informative for learning the model's behavior. In practice, an active learning-based extractor might achieve 90% of a random sampler's accuracy with 10% of the queries. This matters because it reduces both cost and detectability: fewer queries in a shorter window are harder to detect with rate-based monitoring. The most sophisticated extractors also use query generation models (a separate LLM) to synthesize diverse, maximally informative queries rather than sampling from fixed templates.
Q3: How do output watermarks survive model extraction?
Output watermarks are statistical biases embedded in the model's output distribution - for example, preferring certain synonyms or sentence structures in a cryptographically keyed pattern. When an attacker uses watermarked outputs as training data for their surrogate, the surrogate learns these biases along with the model's actual capabilities. The watermark "transfers" through the knowledge distillation process. To verify theft: collect samples from the suspected surrogate, compute the watermark match rate across multiple samples (accounting for context-dependence), and test whether the match rate significantly exceeds chance (50%). A rate of 75–85% across 20+ samples is statistically very unlikely to occur by chance and constitutes strong evidence of copying.
Q4: What's the difference between black-box and white-box extraction attacks?
Black-box extraction only has API access - it observes inputs and outputs. White-box extraction has access to the model's internal state (weights, activations, gradients). API-facing threats are exclusively black-box; white-box attacks require insider access or a compromised deployment environment. Black-box extraction can still replicate 80–95% of model performance on specific tasks with enough queries. The defenses are also different: black-box threats are addressed by API controls (rate limiting, pattern detection, watermarking); white-box threats require secure deployment infrastructure (HSMs, secure enclaves, strong access controls, model weight encryption at rest).
Q5: How would you design rate limiting to balance user experience with extraction protection?
Tiered rate limiting with anomaly-based escalation: (1) Free tier: 100 queries/day at a fixed rate - economically infeasible to extract at scale. (2) Paid tier: 10,000 queries/day - allows legitimate batch use but not systematic extraction. (3) Enterprise tier: custom limits with audit logging and ToS verification. Beyond rate limits: per-key query diversity monitoring (high diversity = systematic extraction), artificial delay injection for suspicious access patterns (5-second delay per request makes extraction ~30x slower), and per-key usage reports with anomaly detection. Key insight: legitimate users have clustered, bursty, domain-specific query patterns - systematic extraction has uniform, diverse, regular patterns. Detection should target that distinction. For high-risk users, increase both delay and perturbation rate while you investigate, before blocking - this preserves revenue while degrading extraction efficiency.
Q6: Can watermarking prove theft in court? What evidence is needed?
Watermarking can contribute to legal evidence but requires proper documentation to be compelling. You need: (1) Pre-existing record - the watermark scheme, secret key, and deployment date documented before any alleged theft (ideally notarized or in version control with timestamps); (2) Watermark detection methodology - a documented, reproducible procedure for detecting the watermark in model outputs; (3) Statistical evidence - match rates across multiple samples with a p-value demonstrating the match is far beyond chance; (4) Temporal evidence - your model was released before the suspected surrogate appeared; (5) Functional similarity - the surrogate exhibits specific behaviors (including failure modes) that match your model. Courts are still developing standards for AI IP cases; watermarking provides technical corroboration for claims of copying, but legal strategy should be developed with IP counsel experienced in AI/ML.
