:::tip š® Interactive Playground Visualize this concept: Try the LLM Caching Layers demo on the EngineersOfAI Playground - no code required. :::
Caching for ML Serving
The FAQ Bot That Cost $40,000 Per Monthā
The internal FAQ bot was a success. Employees loved it. Adoption was through the roof. And then the API bill arrived: $40,000 in GPT-4 API calls in a single month. The ML team pulled up the query logs with dread.
They found something stunning. 78% of all queries were semantically identical to queries seen in the previous 30 days. "What is our PTO policy?" asked 847 times with slight variations. "How do I file an expense report?" in 623 different phrasings. "What are the office hours?" 412 times. The same 200 questions, asked over and over in slightly different words, each triggering a full GPT-4 API call at $0.06 per 1K tokens.
The fix was semantic caching. For each incoming query, compute its embedding vector. Look up the nearest neighbor in the cache. If the similarity is above 0.92, return the cached response - skip the GPT-4 call entirely. If below 0.92, run GPT-4, cache the result for next time.
After one week: 71% cache hit rate. API costs dropped from 11,600. Response latency improved - cached responses return in 15ms instead of 1.2 seconds. Users could not tell which responses were cached. The semantic caching layer saved $340,000 per year without changing model quality.
Caching in ML is not just a performance optimization - it is often a fundamental economic enabler.
Why This Exists - The ML Serving Cost Structureā
Traditional API caching assumes exact-match inputs: same URL, same parameters ā return cached response. This works for deterministic REST APIs. It does not work for ML inputs, where "What's the weather?" and "How's the weather today?" are semantically identical but textually different.
ML serving has a unique cost structure that makes caching unusually valuable:
- High per-call compute cost: Running GPT-4 costs $0.06 per 1K tokens; even a local LLM call takes 50-200ms of GPU time
- High redundancy: Real applications see significant query repetition, especially in enterprise/consumer FAQ scenarios
- Determinism: Given the same input and sampling parameters, models produce the same or very similar outputs
- Expensive feature computation: Real-time feature lookup (user embeddings, item features) often costs more than the model inference itself
Multiple distinct caching layers exist in the ML serving stack, each solving a different problem. Understanding all of them - and when to apply each - is the mark of a production ML systems engineer.
Historical Contextā
Exact-match result caching for web applications has existed since the 1990s. The first notable application to ML serving was recommendation systems in the mid-2010s: computed item embeddings and user preference vectors were expensive to recompute, so they were cached in Redis with TTLs aligned to model update cycles.
KV cache for transformers is an architectural optimization that emerged with attention mechanisms (Vaswani et al., 2017). During autoregressive generation, each new token's attention must attend to all previous tokens. Without KV caching, this requires recomputing the key and value projections for all previous tokens on every generation step - O(n²) work for a sequence of length n. KV caching stores computed K and V matrices, reducing generation to O(n) work.
Semantic caching emerged as a practical production technique around 2022-2023, driven by the economics of LLM APIs. GPTCache (2023) was the first open-source library specifically designed for LLM semantic caching. The approach of using embedding similarity for cache lookup was directly inspired by the success of semantic search.
Prefix caching (also called prompt caching or context caching) was introduced by vLLM in 2023 and adopted by Anthropic's API and OpenAI's API. The insight: many LLM requests share a common system prompt. Computing attention KV values for that prefix is expensive and redundant. Prefix caching stores the KV activations for common prefixes, reducing time-to-first-token for requests that share prefixes.
Layer 1: Result Caching (Exact Match)ā
The simplest form: hash the input, look up the output, return if found. Works when inputs are truly discrete (classification labels, embeddings for fixed text).
# exact_match_cache.py - Redis-backed result cache
import hashlib
import json
import redis
import pickle
from typing import Optional, Any
from functools import wraps
import time
class InferenceResultCache:
"""Exact-match cache for deterministic ML model outputs."""
def __init__(
self,
redis_url: str = "redis://localhost:6379",
ttl_seconds: int = 3600,
key_prefix: str = "ml_result"
):
self.redis = redis.from_url(redis_url, decode_responses=False)
self.ttl = ttl_seconds
self.prefix = key_prefix
# Prometheus metrics
self._hits = 0
self._misses = 0
def _make_key(self, model_name: str, inputs: dict) -> str:
"""Create a stable cache key from model name + inputs."""
# Sort keys for deterministic hashing regardless of input order
canonical = json.dumps(
{"model": model_name, "inputs": inputs},
sort_keys=True,
ensure_ascii=False
)
digest = hashlib.sha256(canonical.encode()).hexdigest()
return f"{self.prefix}:{model_name}:{digest}"
def get(self, model_name: str, inputs: dict) -> Optional[Any]:
"""Return cached result or None."""
key = self._make_key(model_name, inputs)
raw = self.redis.get(key)
if raw is not None:
self._hits += 1
return pickle.loads(raw)
self._misses += 1
return None
def set(self, model_name: str, inputs: dict, result: Any):
"""Cache a result with TTL."""
key = self._make_key(model_name, inputs)
self.redis.setex(key, self.ttl, pickle.dumps(result))
@property
def hit_rate(self) -> float:
total = self._hits + self._misses
return self._hits / total if total > 0 else 0.0
def invalidate_model(self, model_name: str):
"""Invalidate all cached results for a model version (on model update)."""
pattern = f"{self.prefix}:{model_name}:*"
cursor = 0
deleted = 0
while True:
cursor, keys = self.redis.scan(cursor, match=pattern, count=1000)
if keys:
self.redis.delete(*keys)
deleted += len(keys)
if cursor == 0:
break
print(f"Invalidated {deleted} cache entries for {model_name}")
# Usage in FastAPI serving
cache = InferenceResultCache(redis_url="redis://cache:6379", ttl_seconds=86400)
async def classify_image(model_name: str, image_hash: str, image_tensor) -> dict:
"""Classify with exact-match caching."""
cache_inputs = {"image_hash": image_hash} # use perceptual hash as key
# Cache lookup
cached = cache.get(model_name, cache_inputs)
if cached is not None:
return {**cached, "_cached": True}
# Cache miss - run inference
result = await run_inference(model_name, image_tensor)
# Store result
cache.set(model_name, cache_inputs, result)
return {**result, "_cached": False}
Layer 2: Semantic Similarity Cachingā
For LLM queries, exact match is too strict. Semantic caching uses embedding similarity to find "close enough" cached answers.
# semantic_cache.py - embedding-based semantic cache
import numpy as np
import redis
import pickle
import hashlib
from typing import Optional, Tuple, List
from dataclasses import dataclass
@dataclass
class CacheEntry:
query: str
response: str
embedding: np.ndarray
timestamp: float
hit_count: int = 0
class SemanticCache:
"""
Semantic cache using embedding similarity for LLM query matching.
Uses Redis as backend and in-process FAISS index for fast ANN search.
"""
def __init__(
self,
embedding_model, # e.g., SentenceTransformer
redis_url: str,
similarity_threshold: float = 0.92,
max_cache_size: int = 10000,
ttl_seconds: int = 86400 * 7, # 1 week
):
self.embedder = embedding_model
self.redis = redis.from_url(redis_url)
self.threshold = similarity_threshold
self.max_size = max_cache_size
self.ttl = ttl_seconds
# FAISS index for fast nearest-neighbor search
import faiss
embedding_dim = self._get_embedding_dim()
self.index = faiss.IndexFlatIP(embedding_dim) # Inner product = cosine sim for normalized vectors
self.index_to_key: List[str] = [] # maps FAISS index position to Redis key
def _get_embedding_dim(self) -> int:
"""Infer embedding dimension from the model."""
test_emb = self.embedder.encode(["test"])
return test_emb.shape[1]
def _embed(self, text: str) -> np.ndarray:
"""Embed and normalize a string."""
embedding = self.embedder.encode([text], normalize_embeddings=True)
return embedding.astype(np.float32)
def lookup(self, query: str) -> Optional[Tuple[str, float]]:
"""
Find the most similar cached query.
Returns (cached_response, similarity_score) or None.
"""
if self.index.ntotal == 0:
return None
query_embedding = self._embed(query)
# Search for nearest neighbor
similarities, indices = self.index.search(query_embedding, k=1)
best_similarity = float(similarities[0][0])
best_index = int(indices[0][0])
if best_similarity < self.threshold:
return None # Not similar enough
# Retrieve from Redis
cache_key = self.index_to_key[best_index]
raw = self.redis.get(cache_key)
if raw is None:
return None # Expired from Redis
entry: CacheEntry = pickle.loads(raw)
# Update hit count
entry.hit_count += 1
self.redis.setex(cache_key, self.ttl, pickle.dumps(entry))
return entry.response, best_similarity
def store(self, query: str, response: str):
"""Cache a query-response pair."""
import time
embedding = self._embed(query)
cache_key = f"sem_cache:{hashlib.sha256(query.encode()).hexdigest()}"
entry = CacheEntry(
query=query,
response=response,
embedding=embedding,
timestamp=time.time()
)
# Store in Redis
self.redis.setex(cache_key, self.ttl, pickle.dumps(entry))
# Add to FAISS index
if self.index.ntotal >= self.max_size:
self._evict_oldest()
self.index.add(embedding)
self.index_to_key.append(cache_key)
def _evict_oldest(self):
"""Remove oldest 10% of entries when cache is full."""
n_evict = self.max_size // 10
# Remove first n_evict entries from FAISS
# (approximation - real eviction needs more sophisticated tracking)
self.index_to_key = self.index_to_key[n_evict:]
# Rebuild index without evicted entries
remaining_entries = []
for key in self.index_to_key:
raw = self.redis.get(key)
if raw:
remaining_entries.append(pickle.loads(raw))
self.index.reset()
if remaining_entries:
embeddings = np.vstack([e.embedding for e in remaining_entries])
self.index.add(embeddings)
# Integration with LLM API
async def answer_question(
question: str,
semantic_cache: SemanticCache,
llm_client,
system_prompt: str
) -> dict:
"""Answer a question using semantic cache to avoid redundant LLM calls."""
import time
# Step 1: Check semantic cache
cache_result = semantic_cache.lookup(question)
if cache_result:
cached_response, similarity = cache_result
return {
"answer": cached_response,
"cache_hit": True,
"similarity": similarity,
"latency_ms": 15, # embedding lookup only
}
# Step 2: Call LLM
start = time.perf_counter()
response = await llm_client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question}
]
)
latency_ms = (time.perf_counter() - start) * 1000
answer = response.choices[0].message.content
# Step 3: Cache for future
semantic_cache.store(question, answer)
return {
"answer": answer,
"cache_hit": False,
"similarity": 0.0,
"latency_ms": latency_ms,
}
Choosing the Similarity Thresholdā
The threshold is the most important parameter. Too high (0.99) ā almost no cache hits; too low (0.80) ā wrong answers returned for dissimilar questions.
# threshold_calibration.py - find optimal similarity threshold
import numpy as np
from sklearn.metrics import precision_recall_curve
def calibrate_threshold(
query_pairs: list, # [(q1, q2, is_equivalent: bool)]
embedding_model
) -> float:
"""
Find optimal threshold using labeled equivalence pairs.
Returns threshold that maximizes F1 on is_equivalent classification.
"""
similarities = []
labels = []
for q1, q2, is_equiv in query_pairs:
e1 = embedding_model.encode([q1], normalize_embeddings=True)
e2 = embedding_model.encode([q2], normalize_embeddings=True)
similarity = float(np.dot(e1, e2.T))
similarities.append(similarity)
labels.append(int(is_equiv))
similarities = np.array(similarities)
labels = np.array(labels)
precision, recall, thresholds = precision_recall_curve(labels, similarities)
f1_scores = 2 * precision * recall / (precision + recall + 1e-8)
best_idx = np.argmax(f1_scores)
best_threshold = thresholds[best_idx]
print(f"Optimal threshold: {best_threshold:.3f}")
print(f"At this threshold: precision={precision[best_idx]:.3f}, "
f"recall={recall[best_idx]:.3f}, F1={f1_scores[best_idx]:.3f}")
return float(best_threshold)
Layer 3: KV Cache for Transformer Inferenceā
During autoregressive generation, each token must attend to all previous tokens. Without KV caching, computing attention for token requires:
Where and are the key/value projections of all previous tokens. Without caching, you recompute and at every step - total work.
With KV caching, you store the K and V tensors from previous steps. Each new token's attention step only computes one new K and V pair and appends to the cache:
# kv_cache_demo.py - showing KV cache behavior with Hugging Face
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
prompt = "The transformer architecture was introduced in"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
# Without KV cache (use_cache=False): recomputes all K/V at each step
# With KV cache (use_cache=True, default): stores and reuses K/V
# Benchmark the difference
import time
def generate_with_options(use_cache: bool, max_new_tokens: int = 50) -> float:
start = time.perf_counter()
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
use_cache=use_cache,
do_sample=False,
)
return time.perf_counter() - start
# Warmup
for _ in range(3):
generate_with_options(True)
generate_with_options(False)
cached_time = np.mean([generate_with_options(True) for _ in range(10)])
uncached_time = np.mean([generate_with_options(False) for _ in range(10)])
print(f"With KV cache: {cached_time*1000:.1f}ms")
print(f"Without KV cache: {uncached_time*1000:.1f}ms")
print(f"Speedup: {uncached_time/cached_time:.1f}Ć")
# Typical output for 50 new tokens:
# With KV cache: 45ms
# Without KV cache: 380ms
# Speedup: 8.4Ć
KV Cache Memory Sizingā
The KV cache grows linearly with sequence length. For long contexts, KV cache memory can dominate GPU memory consumption:
Where: = batch size, = sequence length, = layers, = KV heads, = head dim, bytes = 2 for FP16.
def kv_cache_memory_gb(
batch_size: int, seq_len: int, num_layers: int,
num_kv_heads: int, head_dim: int, fp16: bool = True
) -> float:
bytes_per_element = 2 if fp16 else 4
# Factor of 2: one for K, one for V
total_bytes = (2 * batch_size * seq_len * num_layers *
num_kv_heads * head_dim * bytes_per_element)
return total_bytes / (1024 ** 3)
# LLaMA-70B: 80 layers, 8 GQA heads, 128 head dim
mem = kv_cache_memory_gb(1, 4096, 80, 8, 128)
print(f"LLaMA-70B KV cache (bs=1, seq=4096): {mem:.2f} GB")
# Output: 0.52 GB per sequence at 4K context
# For batch=32, seq=8192: ~33 GB - consumes half an A100 80GB
Layer 4: Prefix Caching for LLMsā
Many LLM applications share a common system prompt across thousands of requests. Computing KV values for that prompt is expensive and redundant. Prefix caching stores the KV activations for commonly used prompt prefixes.
# prefix_cache.py - prefix caching with vLLM
from vllm import LLM, SamplingParams
# Enable prefix caching in vLLM
llm = LLM(
model="meta-llama/Llama-2-7b-chat-hf",
enable_prefix_caching=True, # key parameter
gpu_memory_utilization=0.9,
max_num_seqs=256,
)
sampling_params = SamplingParams(temperature=0.7, max_tokens=200)
# This system prompt is shared across all customer-service requests
SYSTEM_PROMPT = """You are a helpful customer service assistant for Acme Corp.
Our products include: Product A ($99), Product B ($199), Product C ($299).
Our return policy: 30 days, no questions asked.
Our hours: Monday-Friday 9am-6pm EST.
Always be professional and concise."""
def answer_customer_query(user_question: str) -> str:
# vLLM detects the shared prefix across requests
# and reuses cached KV activations for SYSTEM_PROMPT
prompt = f"{SYSTEM_PROMPT}\n\nCustomer: {user_question}\nAssistant:"
output = llm.generate([prompt], sampling_params)[0]
return output.outputs[0].text
# When 1000 requests share the same 500-token system prompt,
# prefix caching computes the prompt KV once and reuses it.
# Time-to-first-token drops from ~200ms to ~30ms for the prefix portion.
# Memory cost: one copy of the prefix KV cache (~30MB for 500 tokens on LLaMA-7B)
Prefix Cache Hit Rate Optimizationā
Structure your prompts to maximize prefix sharing:
# Good: shared prefix is maximally long
SYSTEM_PROMPT = "You are a helpful assistant. [500 tokens of context...]"
# Request 1: SYSTEM_PROMPT + "Explain photosynthesis"
# Request 2: SYSTEM_PROMPT + "What is machine learning?"
# Both share the entire SYSTEM_PROMPT ā 100% prefix hit
# Bad: user-specific content at the start breaks prefix sharing
BAD_PROMPT = "User ID: {user_id}. Your history: {history}. Be helpful."
# Request 1: "User ID: 12345. Your history: [ordered item 1]. Be helpful."
# Request 2: "User ID: 67890. Your history: [returned item 3]. Be helpful."
# Zero shared prefix ā no cache benefit
# Solution: move user-specific content to the end
GOOD_PROMPT = "You are a helpful assistant. [shared context]. User context: {user_context}"
# Shared prefix: "You are a helpful assistant. [shared context]."
# ā all requests share this prefix ā prefix cache hits
Layer 5: Feature Caching with Redisā
ML serving often requires features that are expensive to compute: user embeddings, item recommendation scores, entity features from a feature store. Caching these in Redis near the serving layer eliminates round-trips to the feature store on every request.
# feature_cache.py - Redis feature cache for ML serving
import redis
import pickle
import time
from typing import List, Optional, Dict, Any
class FeatureCache:
"""
Cache for pre-computed ML features.
Reduces feature store latency from 10-50ms to 0.1-0.5ms.
"""
def __init__(
self,
redis_url: str,
default_ttl: int = 300, # 5 minutes - balance freshness vs hit rate
key_prefix: str = "features"
):
# Use connection pool for thread safety
pool = redis.ConnectionPool.from_url(
redis_url,
max_connections=50,
decode_responses=False
)
self.redis = redis.Redis(connection_pool=pool)
self.default_ttl = default_ttl
self.prefix = key_prefix
def get_user_features(self, user_id: str) -> Optional[Dict[str, Any]]:
key = f"{self.prefix}:user:{user_id}"
raw = self.redis.get(key)
if raw:
return pickle.loads(raw)
return None
def set_user_features(
self, user_id: str, features: Dict[str, Any], ttl: int = None
):
key = f"{self.prefix}:user:{user_id}"
self.redis.setex(key, ttl or self.default_ttl, pickle.dumps(features))
def get_item_embedding(self, item_id: str) -> Optional[list]:
"""Item embeddings change slowly - cache for 1 hour."""
key = f"{self.prefix}:item_emb:{item_id}"
raw = self.redis.get(key)
if raw:
return pickle.loads(raw)
return None
def mget_user_features(self, user_ids: List[str]) -> Dict[str, Optional[Dict]]:
"""Batch get multiple user features in one Redis round-trip."""
keys = [f"{self.prefix}:user:{uid}" for uid in user_ids]
raw_values = self.redis.mget(keys)
result = {}
for uid, raw in zip(user_ids, raw_values):
result[uid] = pickle.loads(raw) if raw else None
return result
def pipeline_set_features(self, user_features: Dict[str, Dict]):
"""Batch write features using Redis pipeline (single round-trip)."""
pipe = self.redis.pipeline(transaction=False)
for user_id, features in user_features.items():
key = f"{self.prefix}:user:{user_id}"
pipe.setex(key, self.default_ttl, pickle.dumps(features))
pipe.execute()
# Integration in serving
async def get_recommendation_features(
user_id: str,
item_ids: List[str],
feature_cache: FeatureCache,
feature_store_client
) -> dict:
"""Fetch features with cache-first strategy."""
# Check cache
user_features = feature_cache.get_user_features(user_id)
cache_miss_items = []
item_embeddings = {}
for item_id in item_ids:
emb = feature_cache.get_item_embedding(item_id)
if emb is not None:
item_embeddings[item_id] = emb
else:
cache_miss_items.append(item_id)
# Fetch misses from feature store
if user_features is None or cache_miss_items:
start = time.perf_counter()
fetched = await feature_store_client.get_features(
user_id=user_id if user_features is None else None,
item_ids=cache_miss_items
)
fetch_latency_ms = (time.perf_counter() - start) * 1000
if user_features is None and fetched.get("user_features"):
user_features = fetched["user_features"]
feature_cache.set_user_features(user_id, user_features, ttl=300)
for item_id, emb in fetched.get("item_embeddings", {}).items():
item_embeddings[item_id] = emb
feature_cache.redis.setex(
f"features:item_emb:{item_id}",
3600,
pickle.dumps(emb)
)
return {"user_features": user_features, "item_embeddings": item_embeddings}
Cache Invalidation Strategiesā
Cache invalidation is the hardest problem in ML serving caching. Models update; feature data becomes stale; cached responses become incorrect.
# cache_invalidation.py - event-driven cache invalidation
import redis
import json
from typing import Callable
class CacheInvalidationConsumer:
"""
Listens to model update and data change events,
invalidates affected cache entries.
"""
def __init__(self, redis_client: redis.Redis):
self.redis = redis_client
self.pubsub = redis_client.pubsub()
def on_model_update(self, model_name: str, old_version: str, new_version: str):
"""Flush all cached results from the old model version."""
pattern = f"ml_result:{model_name}:*"
cursor = 0
deleted = 0
while True:
cursor, keys = self.redis.scan(cursor, match=pattern, count=5000)
if keys:
self.redis.delete(*keys)
deleted += len(keys)
if cursor == 0:
break
print(f"Model update {old_version} ā {new_version}: invalidated {deleted} entries")
def on_user_activity(self, user_id: str):
"""User took an action - their features may have changed."""
# Invalidate user features and any personalized results
keys_to_delete = [
f"features:user:{user_id}",
f"ml_result:recommendations:{user_id}:*", # pattern - needs SCAN
]
# Delete exact keys directly
self.redis.delete(f"features:user:{user_id}")
# Pattern-based deletion for result cache
cursor = 0
pattern = f"ml_result:recommendations:*{user_id}*"
while True:
cursor, keys = self.redis.scan(cursor, match=pattern, count=1000)
if keys:
self.redis.delete(*keys)
if cursor == 0:
break
def listen_for_events(self):
"""Subscribe to invalidation events from Kafka or Redis Streams."""
self.pubsub.subscribe("model_updates", "user_activity")
for message in self.pubsub.listen():
if message['type'] != 'message':
continue
data = json.loads(message['data'])
channel = message['channel'].decode()
if channel == "model_updates":
self.on_model_update(
data['model_name'],
data['old_version'],
data['new_version']
)
elif channel == "user_activity":
self.on_user_activity(data['user_id'])
Production Engineering Notesā
CDN Caching for Model Artifactsā
Model files (weights, ONNX files, TRT engines) are large static artifacts that should be distributed via CDN - not downloaded from central storage on every pod startup.
Cache model artifacts in S3 + CloudFront (or GCS + Cloud CDN). Use versioned keys: models/resnet50/v1.2.3/model.onnx. Pod startup fetches from CDN; cache-hit is a nearby edge node instead of cross-region object storage. This reduces model loading time from 30-120 seconds to 2-8 seconds and eliminates the object storage cost of N pods all downloading simultaneously.
Redis Memory Managementā
For production feature caches serving large user bases, Redis memory management matters. Use maxmemory-policy allkeys-lru (evict least recently used keys when memory is full) and set maxmemory to 80% of available RAM. Monitor redis-cli info memory for used_memory_human and mem_fragmentation_ratio (above 1.5 indicates fragmentation requiring restart).
Common Mistakesā
:::danger Caching Stochastic Model Outputs
If your model uses temperature > 0 or has dropout enabled during inference, cached results may not represent the current model's output. For RAG systems where retrieved documents change, result caches can return outdated answers. Always disable temperature/sampling for responses you intend to cache, or explicitly only cache deterministic model components.
:::
:::danger Returning Wrong Answers from Semantic Cache Setting the similarity threshold too low causes the cache to return responses for semantically different queries. "What is our vacation policy?" and "What is our vacation package pricing?" might have similarity 0.88 but need different answers. Test your threshold with adversarial examples from your domain before deploying. Monitor cache hit rate AND human evaluation of cached response quality. :::
:::warning Not Invalidating on Model Update
Deploying a new model version without flushing the result cache means users may receive responses from the old model for hours (until TTL expiry). Always integrate cache invalidation into your model deployment pipeline. The safest pattern: version-key your cache entries (model_v1.2:hash), deploy new model version, flush old version's keys, then let the new cache populate organically.
:::
:::warning Prefix Cache Breaks for User-Specific Prompts
Prefix caching only helps when prefixes are actually shared. If you build prompts like f"User {user_id}, your history: {history}. Be helpful.", every user has a unique prefix - zero cache hits. Restructure prompts to put shared content first, user-specific content last. A 100-token system prompt shared by all users generates far more cache value than a 500-token personalized preamble.
:::
Interview Q&Aā
Q1: What is semantic caching and when is it more appropriate than exact-match caching for LLM serving?
A: Semantic caching uses embedding similarity to match incoming queries against a cache of previous queries. If the cosine similarity between the embedding of the new query and a cached query exceeds a threshold, the cached response is returned without calling the LLM. Semantic caching is appropriate whenever queries are in natural language and semantic equivalence matters more than textual equivalence - FAQ systems, customer service bots, document QA. Exact-match caching is appropriate for discrete structured inputs like IDs, fixed-format API calls, or deterministic embedding computation (where the same text always produces the same embedding). The tradeoff: semantic caching adds embedding computation cost per request (a small model is faster than the full LLM) and risks returning slightly mismatched responses if the threshold is wrong. Calibrate the threshold on labeled equivalence pairs from your specific domain.
Q2: Explain KV caching in transformer inference. Why is it so critical for autoregressive generation?
A: Autoregressive generation produces tokens one at a time. Each step's attention computation requires the key and value projections of every previous token: . Without caching, step recomputes all previous K and V projections - the total work is for a sequence of length . With KV caching, K and V projections from steps 1 through are stored in GPU memory. Step only computes one new K and one new V vector and appends them to the cache. Total work becomes . For a 200-token generation on GPT-2, this is roughly an 8Ć speedup as measured empirically. The cost is memory: the cache grows linearly with sequence length, and for large models at long contexts (LLaMA-70B, 8K context), the KV cache can be larger than the model weights themselves - requiring careful memory planning.
Q3: What is prefix caching and how does it improve time-to-first-token for LLMs?
A: Prefix caching stores the KV activations computed for a prompt prefix so that requests sharing the same prefix can skip recomputing it. For a system prompt of 500 tokens shared by 10,000 requests, each request without prefix caching pays the full 500-token prefill cost. With prefix caching, the first request computes and caches the KV activations; subsequent requests load from cache and only compute the user-specific suffix tokens. Time-to-first-token for requests with cached prefixes drops dramatically - from computing 600 tokens (500 system + 100 user) to computing 100 tokens. vLLM implements this via its "radix attention" or "automatic prefix caching" feature, which tracks hash-keyed prefix blocks in its PagedAttention memory pool. The key engineering requirement: prompts must be structured with shared content first and request-specific content last - the prefix must be bytewise identical across requests.
Q4: How do you handle cache invalidation when deploying a new model version?
A: The most robust approach is version-keyed cache entries: include the model version in the cache key (f"ml_result:{model_name}:{model_version}:{input_hash}"). When you deploy v1.3, the cache keys from v1.2 are simply never accessed - they expire via TTL naturally. New requests populate v1.3 keys. This is simpler than active invalidation but wastes Redis memory during the transition period. For faster cleanup, trigger active invalidation as part of deployment: scan and delete all keys matching the old version pattern. For feature caches, event-driven invalidation works well: when a user takes an action that changes their features, publish an event that immediately deletes their feature cache key. Always pair TTLs (as a safety net) with active invalidation (for freshness on model updates).
Q5: You are asked to design a caching strategy to reduce inference costs for a FAQ chatbot. How do you approach it?
A: I would implement a multi-layer caching strategy. First, exact-match caching for any inputs that can be canonicalized - normalize whitespace, lowercase, remove punctuation, then hash. This handles "PTO Policy" and "pto policy" as the same input. Second, semantic caching for natural language queries: embed incoming queries with a lightweight sentence transformer (all-MiniLM-L6-v2 is fast and accurate), search a FAISS index of previous query embeddings, return cached response if similarity exceeds ~0.92 (calibrate on labeled pairs). Third, prefix caching at the vLLM layer: structure the system prompt to be identical across all sessions, moving any session-specific context to the end. Fourth, monitor cache hit rate by category. For FAQ bots, I would expect 60-80% hit rate after warm-up, reducing LLM calls proportionally. Set TTL to 1-7 days (FAQ answers do not change frequently) and trigger active invalidation on knowledge base updates.
