:::tip 🎮 Interactive Playground Visualize this concept: Try the Event-Driven ML Architecture demo on the EngineersOfAI Playground - no code required. :::
Streaming Inference
Fifty Thousand Posts Per Second
The content moderation team has a problem that did not exist five years ago. The platform now processes 50,000 new posts per second at peak - 3 billion posts per day. A static list of banned words catches 12% of harmful content. Human moderators catch 90% but can review at most 1,000 posts per hour each. At 50K posts per second, you would need 180,000 full-time human moderators working 24/7 just to review every post.
The solution is ML-based content classification at scale. A fine-tuned BERT model achieves 94% accuracy on harmful content detection in 18ms per post. But 18ms per post × 50,000 posts/second requires 900,000ms of processing per second - 900 seconds of GPU time each second. On a single GPU. Which is impossible.
The engineering challenge is not the model. It is the pipeline: getting 50,000 posts per second into the model efficiently, batching them for GPU throughput, processing each one within an acceptable window (content must be moderated before it can be widely amplified - typically 30-60 seconds), handling failures without losing posts or double-processing them, and maintaining state across posts (a user who posts one harmful message in a session has elevated risk for subsequent posts in the same session).
This is streaming inference: the architecture discipline of running ML models on continuous high-volume event streams.
Why This Exists - The Gap Between Events and Predictions
Real-time data arrives as streams: user interactions, transactions, sensor readings, log events. Most ML systems are designed for request-response: a client submits one record, the model returns one prediction. Connecting streams to models naively - one HTTP request per event - is disastrously inefficient:
- 50,000 HTTP requests per second per process: ~25% of CPU is request/response overhead
- No batching: GPU at 2% utilization, model latency 18ms vs 0.5ms batched
- No stateful context: each post processed independently, no session awareness
- No fault tolerance: a model server restart loses all in-flight events
- No flow control: traffic spikes directly reach the model server
Streaming inference solves these by introducing a message broker (Kafka) as a durable buffer between event producers and inference workers, with stream processing logic (Flink, Kafka Streams) for stateful windowed features, batch assembly, and exactly-once processing guarantees.
Historical Context
Apache Kafka (LinkedIn, 2011) was originally built for log aggregation at LinkedIn - moving terabytes of metrics data between systems. Its pub-sub model, partitioned topics, and log-based storage made it ideal as a high-throughput event bus. By 2015, real-time ML pipelines were using Kafka as the backbone for streaming inference at Twitter, Uber, and Netflix.
Apache Flink (ASF project, 2014) provided stateful stream processing with exactly-once semantics, making it the engine of choice for complex ML feature computation over streams. Flink's ability to maintain per-key state across millions of users and compute sliding window aggregations with watermarks solved the stateful feature problem that simpler systems like Kafka Streams could not handle.
Kafka Streams (Confluent, 2016) offered a simpler alternative for lighter workloads: a library embedded in the application process, no separate cluster needed, but with Kafka's durability and exactly-once guarantees. For inference pipelines that do not need complex state, Kafka Streams reduces operational overhead significantly.
The content moderation use case became canonical around 2020-2022 as platforms faced regulatory pressure to moderate content at scale. The combination of transformer-based classifiers (accurate enough for production) and mature stream processing infrastructure (reliable enough for continuous operation) made streaming inference at billions of events per day a standard pattern.
The Streaming Inference Architecture
Kafka Integration for ML Inference
The inference worker is a Kafka consumer that reads batches of events, runs the model, and writes results back to Kafka:
# kafka_inference_worker.py - GPU batch inference from Kafka stream
import asyncio
import json
import time
import torch
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer
from typing import List, Dict
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class StreamingInferenceWorker:
"""
Kafka consumer that batches events for GPU inference.
Each worker handles one Kafka partition.
"""
def __init__(
self,
model_name: str,
input_topic: str,
output_topic: str,
kafka_bootstrap: str,
max_batch_size: int = 64,
max_batch_wait_ms: float = 50.0, # max time to accumulate a batch
device: str = "cuda",
):
self.input_topic = input_topic
self.output_topic = output_topic
self.kafka_bootstrap = kafka_bootstrap
self.max_batch_size = max_batch_size
self.max_batch_wait_ms = max_batch_wait_ms
self.device = device
# Load model
print(f"Loading {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model.to(device)
self.model.eval()
async def run(self):
"""Main loop: consume, batch, infer, produce."""
consumer = AIOKafkaConsumer(
self.input_topic,
bootstrap_servers=self.kafka_bootstrap,
group_id="content-moderation-inference",
# Disable auto-commit - we commit manually after successful inference
enable_auto_commit=False,
auto_offset_reset="earliest",
# Fetch large batches from Kafka for efficiency
max_poll_records=200,
fetch_max_bytes=10 * 1024 * 1024, # 10MB per fetch
)
producer = AIOKafkaProducer(
bootstrap_servers=self.kafka_bootstrap,
# Async batching for high throughput
linger_ms=5,
batch_size=32768,
)
await consumer.start()
await producer.start()
buffer: List[Dict] = []
offsets_to_commit = {}
batch_start = time.monotonic()
try:
async for message in consumer:
# Deserialize event
event = json.loads(message.value.decode())
event["_partition"] = message.partition
event["_offset"] = message.offset
buffer.append(event)
# Track offset for manual commit
tp = (message.topic, message.partition)
offsets_to_commit[tp] = message.offset + 1
# Check if we should process the batch
batch_age_ms = (time.monotonic() - batch_start) * 1000
should_flush = (
len(buffer) >= self.max_batch_size or
batch_age_ms >= self.max_batch_wait_ms
)
if should_flush:
# Run inference on batch
results = await self._run_batch_inference(buffer)
# Write results to output topic
await self._produce_results(producer, results)
# Commit offsets after successful inference + produce
await consumer.commit()
buffer = []
offsets_to_commit = {}
batch_start = time.monotonic()
finally:
await consumer.stop()
await producer.stop()
async def _run_batch_inference(self, events: List[Dict]) -> List[Dict]:
"""Tokenize and run batch inference on GPU."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._sync_inference, events)
def _sync_inference(self, events: List[Dict]) -> List[Dict]:
"""Synchronous GPU inference - runs in thread pool."""
texts = [event["text"] for event in events]
# Tokenize batch
inputs = self.tokenizer(
texts,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Inference
with torch.no_grad():
logits = self.model(**inputs).logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()
# Build results
results = []
for event, prob_row in zip(events, probs):
harmful_score = float(prob_row[1]) # class 1 = harmful
results.append({
"post_id": event["post_id"],
"harmful_score": harmful_score,
"decision": "remove" if harmful_score > 0.85 else
"review" if harmful_score > 0.5 else "allow",
"inference_timestamp": time.time(),
"model_version": "bert-v2.1",
})
return results
async def _produce_results(
self, producer: AIOKafkaProducer, results: List[Dict]
):
"""Write inference results to output topic."""
tasks = []
for result in results:
tasks.append(
producer.send(
self.output_topic,
key=result["post_id"].encode(),
value=json.dumps(result).encode()
)
)
await asyncio.gather(*tasks)
Stateful Stream Processing for ML
Many ML features require state across events: "how many times has this user posted in the last hour?" This requires stateful stream processing with Flink or Kafka Streams.
# flink_ml_features.py - stateful feature computation with PyFlink
from pyflink.datastream import StreamExecutionEnvironment
from pyflink.datastream.functions import KeyedProcessFunction, RuntimeContext
from pyflink.datastream.state import ValueStateDescriptor, ListStateDescriptor
from pyflink.common.typeinfo import Types
from pyflink.common.watermark_strategy import WatermarkStrategy, Duration
import json
import time
class UserSessionFeatureFunction(KeyedProcessFunction):
"""
Computes per-user session features for content moderation.
State: post count, harmful score history, recent post texts.
"""
def open(self, runtime_context: RuntimeContext):
# Per-user state: persisted in Flink's state backend (RocksDB)
self.post_count_state = runtime_context.get_state(
ValueStateDescriptor("post_count_1h", Types.INT())
)
self.harmful_history = runtime_context.get_list_state(
ListStateDescriptor("harmful_scores", Types.FLOAT())
)
self.last_post_time = runtime_context.get_state(
ValueStateDescriptor("last_post_ms", Types.LONG())
)
def process_element(self, post_event, ctx):
"""Called for each post, keyed by user_id."""
# Update post count (approximate - timer clears per hour)
current_count = self.post_count_state.value() or 0
current_count += 1
self.post_count_state.update(current_count)
# Record timestamp
now_ms = int(time.time() * 1000)
last_post = self.last_post_time.value() or 0
time_since_last_post_s = (now_ms - last_post) / 1000.0
self.last_post_time.update(now_ms)
# Collect harmful score history (last 10 posts)
scores = list(self.harmful_history.get() or [])
scores.append(post_event.get("harmful_score", 0.0))
if len(scores) > 10:
scores = scores[-10:] # keep last 10
self.harmful_history.update(scores)
# Compute session features
avg_harmful_score = sum(scores) / len(scores) if scores else 0.0
max_harmful_score = max(scores) if scores else 0.0
posting_velocity = min(current_count / 3600.0, 1.0) # posts per second
# Emit enriched event with session features
enriched = {
**post_event,
"features": {
"user_post_count_1h": current_count,
"avg_harmful_score_session": avg_harmful_score,
"max_harmful_score_session": max_harmful_score,
"time_since_last_post_s": time_since_last_post_s,
"posting_velocity": posting_velocity,
}
}
yield enriched
# Register timer to clear hourly count
ctx.timer_service().register_event_time_timer(
ctx.timestamp() + 3600 * 1000 # 1 hour from now
)
def on_timer(self, timestamp, ctx):
"""Called by timer - reset hourly post count."""
self.post_count_state.update(0)
def build_streaming_pipeline():
env = StreamExecutionEnvironment.get_execution_environment()
env.set_parallelism(32) # 32 parallel operators
# Kafka source
kafka_source = create_kafka_source("raw-posts", "localhost:9092")
posts_stream = env.from_source(
kafka_source,
WatermarkStrategy.for_bounded_out_of_orderness(Duration.of_seconds(5)),
"Kafka Posts Source"
)
# Parse and key by user_id
keyed_stream = posts_stream.map(
lambda msg: json.loads(msg.decode())
).key_by(lambda post: post["user_id"])
# Apply stateful feature extraction
enriched_stream = keyed_stream.process(UserSessionFeatureFunction())
# Write enriched events to Kafka for inference workers
enriched_stream.map(
lambda event: json.dumps(event).encode()
).sink_to(create_kafka_sink("post-features", "localhost:9092"))
env.execute("Content Moderation Feature Pipeline")
Exactly-Once Inference
Kafka consumers process messages "at least once" by default - if a worker crashes after processing a message but before committing the offset, the message is reprocessed. For idempotent operations (feature computation) this is fine; for non-idempotent inference (sending moderation actions) it is not.
Exactly-once processing requires coordinating Kafka offset commits with output writes atomically:
# exactly_once.py - exactly-once inference with Kafka transactions
from aiokafka import AIOKafkaProducer, AIOKafkaConsumer
from aiokafka.structs import TopicPartition
import json
class ExactlyOnceInferenceWorker:
"""
Exactly-once semantics using Kafka transactions.
Reads from input, writes to output, commits offset atomically.
"""
async def run(self):
# Transactional producer - enables atomic write + offset commit
producer = AIOKafkaProducer(
bootstrap_servers="localhost:9092",
transactional_id="inference-worker-0", # unique per worker instance
enable_idempotence=True,
)
consumer = AIOKafkaConsumer(
"post-features",
bootstrap_servers="localhost:9092",
group_id="inference-workers",
enable_auto_commit=False,
isolation_level="read_committed", # only read committed messages
)
await producer.start()
await producer.init_transactions()
await consumer.start()
batch_size = 64
records = []
async for msg in consumer:
records.append(msg)
if len(records) >= batch_size:
# Run inference
events = [json.loads(r.value.decode()) for r in records]
results = run_inference(events) # your model here
# Atomic transaction: write results + commit offsets
async with producer.transaction():
# Write inference results
for result in results:
await producer.send(
"moderation-results",
value=json.dumps(result).encode()
)
# Commit consumer offsets within the transaction
offsets = {
TopicPartition(r.topic, r.partition): r.offset + 1
for r in records
}
await producer.send_offsets_to_transaction(
offsets, group_id="inference-workers"
)
records = []
Windowed Aggregations as Features
Sliding window aggregations (sum/count/average over last N minutes) are common ML features that require careful time semantics in stream processing:
# windowed_features.py - sliding window feature computation
from collections import deque
from typing import Dict, Any
import time
class SlidingWindowFeatureStore:
"""
Computes sliding window aggregations for streaming ML features.
Maintains per-key deques of (timestamp, value) pairs.
Efficient: O(1) amortized updates, O(k) queries for k active keys.
"""
def __init__(self, windows_seconds: list = [60, 300, 3600]):
self.windows = windows_seconds
# Per-key event buffers: user_id -> deque[(timestamp_s, value)]
self._buffers: Dict[str, deque] = {}
self._max_window = max(windows_seconds)
def record(self, key: str, value: float, timestamp_s: float = None):
"""Record a new event for a key."""
ts = timestamp_s or time.time()
if key not in self._buffers:
self._buffers[key] = deque()
self._buffers[key].append((ts, value))
self._evict_old(key, ts)
def _evict_old(self, key: str, current_ts: float):
"""Remove events older than the largest window."""
cutoff = current_ts - self._max_window
buf = self._buffers[key]
while buf and buf[0][0] < cutoff:
buf.popleft()
def get_features(self, key: str, current_ts: float = None) -> Dict[str, Any]:
"""
Compute window aggregations for a key.
Returns dict of feature_name -> value for each window size.
"""
ts = current_ts or time.time()
self._evict_old(key, ts)
buf = list(self._buffers.get(key, []))
features = {}
for window_s in self.windows:
cutoff = ts - window_s
window_events = [(t, v) for t, v in buf if t >= cutoff]
features[f"count_{window_s}s"] = len(window_events)
if window_events:
values = [v for _, v in window_events]
features[f"sum_{window_s}s"] = sum(values)
features[f"mean_{window_s}s"] = sum(values) / len(values)
features[f"max_{window_s}s"] = max(values)
else:
features[f"sum_{window_s}s"] = 0.0
features[f"mean_{window_s}s"] = 0.0
features[f"max_{window_s}s"] = 0.0
return features
# Usage in content moderation pipeline
user_harmful_windows = SlidingWindowFeatureStore(windows_seconds=[60, 300, 3600])
def enrich_post_with_user_history(post: dict, user_id: str) -> dict:
"""Add user's recent harmful content history as features."""
features = user_harmful_windows.get_features(user_id)
return {
**post,
"user_harmful_count_1m": features["count_60s"],
"user_harmful_count_5m": features["count_300s"],
"user_harmful_count_1h": features["count_3600s"],
"user_avg_harmful_score_5m": features["mean_300s"],
}
def record_harmful_prediction(user_id: str, harmful_score: float):
"""Record inference result for future user history features."""
if harmful_score > 0.5: # only record significant scores
user_harmful_windows.record(user_id, harmful_score)
Ordering Guarantees and Time Semantics
Streaming inference must handle out-of-order events - a post sent at 14:00:00 may arrive at the processor at 14:00:05 due to network delay. Flink's watermark mechanism handles this:
The watermark lateness tolerance parameter directly impacts feature accuracy vs latency: a 5-second watermark means you wait 5 seconds before closing a window, giving late events time to arrive. For real-time content moderation where low latency matters, a 2-5 second tolerance is typical. For financial fraud detection where accuracy is paramount, 30-60 second tolerances are acceptable.
Production Engineering Notes
Consumer Group Scaling
Each Kafka partition is handled by exactly one consumer in a consumer group. To scale inference workers horizontally, increase partition count. Rule of thumb: number of partitions = max number of consumers you expect. For 50K events/second with 32-event batches at 50ms windows, you need roughly 50000 * 0.05 / 32 = 78 batches/second. With 64 events per batch and 5ms inference time, one worker handles 1000/5 = 200 batches/second. So ceil(78/200) = 1 worker would suffice for throughput, but you need ≥ 3 for redundancy.
Handling Model Updates in Streaming Pipelines
When you deploy a new model version to streaming inference workers, rolling restarts create a period where different partitions use different model versions. For content moderation (where consistency matters), use blue-green deployment: stand up a new consumer group with the new model, verify it is consuming correctly, then drain the old consumer group.
Common Mistakes
:::danger Not Committing Offsets After Successful Produce The common Kafka consumer bug: process message → produce result to output topic → commit offset. If the process crashes between producing and committing, the message is reprocessed on restart - output topic gets a duplicate. Mitigation: use Kafka transactions (exactly-once) for critical workflows where duplicates are unacceptable (financial decisions, content removal actions). For analytics where duplicates are tolerable, commit-before-produce is safe but means possible message loss on crash. :::
:::danger Using Wall-Clock Time Instead of Event Time for Windows
If your windowed aggregations use time.now() rather than the event's timestamp, events that arrive late are counted in the wrong window. A fraud model that counts "transactions in the last 5 minutes" using processing time will miscount during batch replay (replaying historical data for testing gives wrong feature values) and during network delays. Always propagate event timestamps through the pipeline and use Flink's event-time with watermarks.
:::
:::warning Not Handling Backpressure
When inference workers are slower than the producer, the Kafka consumer lag grows. Without backpressure handling, the consumer buffer fills, memory spills to disk, and latency increases unboundedly. Monitor consumer lag (Kafka metric consumer_lag) and alert when it exceeds a threshold (e.g., 60 seconds of lag). Scale up consumers proactively before lag becomes user-visible.
:::
Interview Q&A
Q1: How do you design a streaming ML inference pipeline for 50K events per second?
A: The architecture has three layers. First, Kafka as the durable event bus - all events write to Kafka before any ML processing, providing backpressure handling, replay capability, and decoupling between producers and inference workers. Second, optional stateful feature enrichment via Flink or Kafka Streams - if features require windowed aggregations (post count last hour, user session history), a stateful Flink operator computes these and writes enriched events to a second Kafka topic. Third, GPU inference workers - each worker consumes from one Kafka partition, accumulates events up to max_batch_size OR max_wait_ms, runs batched GPU inference, and writes results to an output topic. For 50K events/second with 64-event batches and 5ms inference per batch: throughput per worker = 12,800 events/second. You need ceil(50000/12800) = 4 workers minimum, deploy 8 for redundancy. Scale partitions to match worker count.
Q2: What is the difference between exactly-once, at-least-once, and at-most-once semantics in streaming ML?
A: These describe what happens when a worker crashes mid-processing. At-most-once: commit offset before processing. If crash after commit, the message is skipped - data loss. Unacceptable for content moderation (some posts never scored). At-least-once: process message, write result, commit offset. If crash after write but before commit, message is reprocessed - duplicate output. Acceptable for analytics where duplicates can be de-duped downstream. Exactly-once: use Kafka transactions to atomically write result + commit offset. If crash, the transaction is rolled back - message is reprocessed but output is not duplicated because the previous transaction was not committed. Exactly-once is most expensive (coordination overhead) but necessary when duplicates cause problems: double-removing content, double-charging a user, double-counting a fraud signal.
Q3: How do watermarks help with out-of-order events in stream processing?
A: Watermarks are timestamps embedded in the data stream that tell the processor "all events with timestamp earlier than X have now arrived." A processor using event-time windows cannot close a window until it knows all events for that window have arrived - which requires knowing no more late events will come. Watermarks define the cutoff. A watermark at time W means: events with timestamp before W will not arrive in the future. The processor can safely close and emit windows with end time before W. The watermark lateness tolerance (in Flink: for_bounded_out_of_orderness(Duration.of_seconds(5))) means the watermark lags the maximum observed event time by 5 seconds, giving events up to 5 seconds late time to arrive. Higher tolerance = more accurate windows but higher latency (you wait longer before closing each window). For real-time ML, 2-10 seconds is typical; for batch reprocessing, much higher tolerances are acceptable.
Q4: How do you handle stateful features (like "user activity in the last hour") in a streaming ML pipeline without training/serving skew?
A: The key is using the same logic for feature computation in both training and serving. For sliding window features: (1) In serving, maintain per-user state in Flink with event-time windows - each post enriches the next post with the user's history. (2) For training data generation, replay historical events through the same Flink topology in event-time order. Since Flink uses the same state management logic regardless of whether data is live or replayed, the features computed at training time match serving time exactly. The alternative - computing features from a feature store (Redis) at serving time and from batch SQL queries at training time - creates skew whenever the SQL query semantics differ from the Redis computation. Using Flink for both eliminates this divergence.
Q5: What is consumer lag and why does it matter for streaming ML inference?
A: Consumer lag is the difference between the latest offset produced to a Kafka topic and the latest offset committed by a consumer group. In seconds, it represents how far behind the inference workers are from the live event stream. Lag of 0 means inference is real-time. Lag of 60 seconds means events are processed 60 seconds after they occur. For content moderation, 60 seconds of lag means potentially harmful posts are visible for 60 seconds before being moderated - likely unacceptable for regulatory compliance. For fraud detection, 60 seconds of lag means transactions in the last 60 seconds receive no ML scoring - high-risk window. Monitor consumer lag as a primary SLA metric (kafka_consumer_group_lag in Prometheus via the Kafka exporter). Alert when lag exceeds your application's latency budget. Scale consumers (increase partition count + consumer count) or optimize inference (larger batches, faster model) to maintain acceptable lag.
