:::tip 🎮 Interactive Playground Visualize this concept: Try the Stream Pipeline Viz demo on the EngineersOfAI Playground - no code required. :::
Stream Processing Patterns for ML Pipelines
import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem';
Reading time: 40 minutes | Interview relevance: Very High | Target roles: Data Engineer, ML Engineer, AI Platform Engineer
The Production Scenario
A senior data engineer at ShipFast, a last-mile logistics company, is handed a requirements document with a single sentence at the top: "Build a real-time delivery ETA prediction system. P99 accuracy within 8 minutes." Below that sentence are 30 feature definitions that the model needs for every shipment status update arriving at 20,000 events per second.
She reads through the features. Some are straightforward - package weight, declared value, destination zip code. Some require database lookups - the assigned driver's historical on-time rate, the vehicle's current capacity utilization. Some need stream-stream joins - joining the shipment event with a weather data stream to get current route conditions. Some need windowed aggregations - average delivery time on this route over the last 7 days. And at least three features can only be computed after delayed ground truth arrives (the actual delivery time), which feeds back into the model retraining pipeline.
Without a pattern vocabulary, this becomes a pile of ad-hoc code. One engineer writes a custom database poller. Another hard-codes a static driver profile lookup. A third tries to join two Kafka topics with a hand-rolled time-based buffer that occasionally loses events. Within a month, the pipeline has seven different approaches to the same class of problems, none of them correct under all failure modes.
Pattern literacy is what separates a senior stream processing engineer from a junior one. The patterns in this lesson are not academic - they are the solutions to the seven recurring problems that appear in every non-trivial streaming ML pipeline. By the time you finish reading, you will be able to look at any ML feature specification and identify which pattern it requires.
Why Patterns Exist
Stream processing for ML has the same recurring problem structure that all distributed systems have: data arrives from multiple sources, at different rates, out of order, with varying latency, and needs to be joined and transformed before it is useful for a model. The challenge is that unbounded streams do not have the nice properties of finite datasets - you cannot sort the entire stream, you cannot wait for all data to arrive before joining, and you cannot retry a computation on the same data twice without careful idempotency.
The patterns in this lesson are the solutions that have emerged from production deployments at companies like Uber (Michelangelo), LinkedIn (Feathr), Airbnb (Chronon), and Lyft (Amundsen). They represent the industry consensus on how to solve the seven core stream-ML integration problems correctly.
The Pattern Catalog
Pattern 1: Stream Enrichment
Problem: A high-volume event stream (shipment updates at 20K events/sec) needs to be enriched with data from a slowly-changing reference dataset (driver profiles, updated every few hours). Joining every event against the full reference database would make every event require a synchronous database read - adding 5–50ms of latency per event.
Three approaches exist, and the right choice depends on the reference dataset size and update frequency.
Approach A: Broadcast State (Flink)
When the reference dataset fits in memory (tens of MB to a few GB) and updates infrequently, broadcast it to every TaskManager. Every operator instance holds a full copy of the reference data in memory. Lookups are local - sub-millisecond. No network call.
from pyflink.datastream import StreamExecutionEnvironment, BroadcastStream
from pyflink.datastream.state import MapStateDescriptor
from pyflink.datastream.functions import BroadcastProcessFunction
from pyflink.common import Types
import json
# Broadcast state descriptor: driver_id → driver profile dict
DRIVER_STATE_DESC = MapStateDescriptor(
"driver_profiles",
Types.STRING(), # key: driver_id
Types.STRING(), # value: JSON-encoded driver profile
)
class EnrichShipmentWithDriver(BroadcastProcessFunction):
"""
Enriches shipment events with driver profile data.
Driver profiles are broadcast to all operator instances.
"""
def process_element(self, shipment_raw: str, ctx, out):
"""Called for every event on the main (shipment) stream."""
shipment = json.loads(shipment_raw)
driver_id = shipment.get("driver_id")
# Read from broadcast state - local in-memory lookup
broadcast_state = ctx.get_broadcast_state(DRIVER_STATE_DESC)
driver_json = broadcast_state.get(driver_id)
if driver_json:
driver = json.loads(driver_json)
shipment["driver_on_time_rate"] = driver.get("on_time_rate", 0.0)
shipment["driver_avg_speed_mph"] = driver.get("avg_speed_mph", 0.0)
shipment["driver_experience_yr"] = driver.get("experience_years", 0)
else:
# Driver not found in broadcast state - use defaults
shipment["driver_on_time_rate"] = 0.75
shipment["driver_avg_speed_mph"] = 35.0
shipment["driver_experience_yr"] = 0
out.collect(json.dumps(shipment))
def process_broadcast_element(self, driver_update_raw: str, ctx, out):
"""Called for every event on the broadcast (driver profile) stream."""
driver_update = json.loads(driver_update_raw)
driver_id = driver_update["driver_id"]
broadcast_state = ctx.get_broadcast_state(DRIVER_STATE_DESC)
broadcast_state.put(driver_id, json.dumps(driver_update))
def build_enrichment_pipeline(env: StreamExecutionEnvironment):
# Main stream: high-volume shipment events
shipment_stream = env.from_source(
shipment_kafka_source, watermark_strategy, "Shipment Source"
)
# Control stream: driver profile updates (low volume, from Kafka or CDC)
driver_stream = env.from_source(
driver_kafka_source, watermark_strategy, "Driver Profile Source"
)
# Broadcast the driver stream to all TaskManagers
broadcast_driver_stream = driver_stream.broadcast(DRIVER_STATE_DESC)
# Connect main stream with broadcast stream and apply enrichment
enriched = (
shipment_stream
.connect(broadcast_driver_stream)
.process(EnrichShipmentWithDriver())
)
return enriched
Approach B: Async I/O (Feature Store Lookup)
When the reference dataset is too large for broadcast (gigabytes or larger), or when you need the freshest data on every lookup, use Async I/O. This pattern makes non-blocking lookups from an external store (Redis, DynamoDB, Feast online store) without blocking the Flink operator.
from pyflink.datastream.functions import AsyncFunction, ResultFuture
from pyflink.datastream import AsyncDataStream
import asyncio
import aioredis
import json
from concurrent.futures import ThreadPoolExecutor
class AsyncDriverProfileLookup(AsyncFunction):
"""
Non-blocking lookup of driver profiles from Redis feature store.
Async I/O allows up to `capacity` concurrent in-flight requests.
"""
def open(self, runtime_context):
# Redis connection pool - initialized once per TaskManager slot
self.redis = None
self.loop = asyncio.new_event_loop()
self.executor = ThreadPoolExecutor(max_workers=4)
async def _lookup_redis(self, driver_id: str) -> dict:
if self.redis is None:
self.redis = await aioredis.from_url("redis://redis:6379")
raw = await self.redis.hgetall(f"driver:{driver_id}")
if raw:
return {k.decode(): v.decode() for k, v in raw.items()}
return {}
def async_invoke(self, shipment_raw: str, result_future: ResultFuture):
shipment = json.loads(shipment_raw)
driver_id = shipment.get("driver_id", "")
async def do_lookup():
try:
driver = await self._lookup_redis(driver_id)
shipment["driver_on_time_rate"] = float(driver.get("on_time_rate", 0.75))
shipment["driver_avg_speed_mph"] = float(driver.get("avg_speed_mph", 35.0))
result_future.complete([json.dumps(shipment)])
except Exception as e:
result_future.complete_exceptionally(e)
asyncio.run_coroutine_threadsafe(do_lookup(), self.loop)
def apply_async_enrichment(stream, capacity: int = 100, timeout_ms: int = 500):
"""
Apply async I/O enrichment to a stream.
capacity: max concurrent in-flight async requests per operator instance
timeout_ms: fail the request if Redis doesn't respond within this window
"""
return AsyncDataStream.unordered_wait(
stream,
AsyncDriverProfileLookup(),
timeout=timeout_ms,
time_unit=TimeUnit.MILLISECONDS,
capacity=capacity,
)
Use unordered_wait (not ordered_wait) when event ordering is not critical. ordered_wait buffers results until they can be emitted in order, which increases memory usage and latency. For feature enrichment where each event is independent, unordered_wait gives better throughput.
Approach C: Cache-Aside
For reference data that is small per-key but has too many distinct keys for broadcast, cache the most-recently-accessed entries in an in-operator LRU cache. On cache miss, make a synchronous lookup (acceptable for low miss rates). This is the simplest approach but requires careful cache size tuning.
from functools import lru_cache
from pyflink.datastream.functions import MapFunction
import redis
import json
class CachedDriverEnrichment(MapFunction):
"""
In-operator LRU cache for driver profile lookups.
Best when: key cardinality is moderate, miss rate is low,
and broadcast state is too large.
"""
def open(self, runtime_context):
self.redis_client = redis.Redis(host="redis", port=6379)
self._cache = {} # Simple dict cache - use lru_cache for production
self.MAX_CACHE_SIZE = 10_000
def _get_driver(self, driver_id: str) -> dict:
if driver_id in self._cache:
return self._cache[driver_id]
# Cache miss: fetch from Redis
raw = self.redis_client.hgetall(f"driver:{driver_id}")
if raw:
profile = {k.decode(): v.decode() for k, v in raw.items()}
else:
profile = {}
# Evict oldest entry if cache is full (simple FIFO for illustration)
if len(self._cache) >= self.MAX_CACHE_SIZE:
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]
self._cache[driver_id] = profile
return profile
def map(self, shipment_raw: str) -> str:
shipment = json.loads(shipment_raw)
driver = self._get_driver(shipment.get("driver_id", ""))
shipment["driver_on_time_rate"] = float(driver.get("on_time_rate", 0.75))
return json.dumps(shipment)
Pattern 2: Stream-Stream Join
Problem: Joining a shipment event stream with a weather data stream to produce a combined event with real-time route conditions. Both streams are unbounded and out-of-order. You cannot buffer all of stream A waiting for matching events from stream B - you will run out of memory.
The fundamental constraint: you must define a bounded time window within which a match is expected. An interval join in Flink says: "join event at time with events from stream B where ."
Interval Join
from pyflink.datastream.functions import ProcessJoinFunction
from pyflink.common import Duration
import json
class JoinShipmentWithWeather(ProcessJoinFunction):
"""
Joins shipment events with weather observations within a 10-minute window.
For every shipment event, finds weather observations from -5min to +5min.
"""
def process_element(self, shipment_raw: str, weather_raw: str, out):
shipment = json.loads(shipment_raw)
weather = json.loads(weather_raw)
# Only join if same geographic region
if shipment.get("route_zone") != weather.get("zone"):
return
shipment["weather_temp_f"] = weather.get("temp_f", 70)
shipment["weather_precip_in"] = weather.get("precip_in", 0.0)
shipment["weather_visibility_mi"] = weather.get("visibility_mi", 10.0)
shipment["weather_wind_mph"] = weather.get("wind_mph", 0.0)
out.collect(json.dumps(shipment))
def build_interval_join(shipment_stream, weather_stream):
"""
Interval join: match each shipment event with weather events
within a 10-minute window (±5 minutes).
State: Flink buffers both streams' events within the join window.
Memory cost: proportional to event rate × window size.
"""
return (
shipment_stream
.key_by(lambda e: json.loads(e).get("route_zone", ""))
.interval_join(
weather_stream.key_by(lambda e: json.loads(e).get("zone", ""))
)
.between(Duration.of_minutes(-5), Duration.of_minutes(5))
.process(JoinShipmentWithWeather())
)
Windowed Join
For cases where both streams share the same window boundaries (e.g., "join click events with purchase events in the same 5-minute window"), use a windowed join. Unlike interval join, both streams must have events in the same window for a match - there is no time-relative matching.
from pyflink.datastream.window import TumblingEventTimeWindows
from pyflink.common.time import Time
from pyflink.datastream.functions import JoinFunction
import json
class JoinClickWithPurchase(JoinFunction):
def join(self, click_raw: str, purchase_raw: str) -> str:
click = json.loads(click_raw)
purchase = json.loads(purchase_raw)
return json.dumps({
"user_id": click["user_id"],
"product_id": click["product_id"],
"click_ts": click["event_ts"],
"purchase_ts": purchase["event_ts"],
"time_to_purchase_s": (purchase["event_ts"] - click["event_ts"]) / 1000,
"purchase_amount": purchase["amount_usd"],
})
def build_windowed_join(click_stream, purchase_stream):
"""
Join click events with purchase events in the same 5-minute window.
Generates conversion features: time-to-purchase, conversion rate by product.
"""
return (
click_stream
.join(purchase_stream)
.where(lambda e: json.loads(e)["user_id"])
.equal_to(lambda e: json.loads(e)["user_id"])
.window(TumblingEventTimeWindows.of(Time.minutes(5)))
.apply(JoinClickWithPurchase())
)
Stream-stream joins are memory-intensive. Each stream's events are buffered for the join window duration. For an interval join with a 10-minute window at 10K events/sec per stream, you are buffering roughly 10K × 60 × 10 = 6 million events at any given time. Right-size your TaskManager memory based on your event rate and window size.
Pattern 3: Aggregation and Windowing
Problem: Computing features like "average delivery time on this route over the last 7 days" or "number of deliveries in the last hour from this hub." These require aggregating events within time-bounded windows.
Window Type Decision Matrix
| Window Type | Use Case | State Cost | Latency |
|---|---|---|---|
| Tumbling | Fixed-period summaries (hourly, daily KPIs) | Low | Period duration |
| Sliding | Rolling windows, "in the last N minutes" | Higher (window overlap) | Slide interval |
| Session | User sessions, activity bursts | Variable | Gap timeout |
Computing Approximate Percentiles
Standard SQL percentile computation (PERCENTILE_CONT) on a stream requires sorting all values in the window - impossible for large windows. For streaming percentiles, use probabilistic data structures:
T-Digest - accurate percentile estimation with bounded error at the tails. Used by Elasticsearch, InfluxDB. Supports add() and quantile() operations. Memory: where is the compression parameter.
Count-Min Sketch - frequency estimation for heavy-hitter detection. Not for percentiles, but useful for "top K delivery zones by event count."
HyperLogLog - cardinality estimation ("how many distinct drivers active in the last hour"). Memory: O(1), configurable error rate.
from pyflink.datastream.functions import AggregateFunction
from pyflink.common import Types
import json
# Lightweight online percentile approximation using reservoir sampling
# In production: use the tdigest library or a native Flink sketches operator
import random
class RunningStatsAccumulator:
"""Tracks count, sum, mean, and approximate 95th percentile via reservoir."""
def __init__(self):
self.count = 0
self.total = 0.0
self.reservoir = [] # Fixed-size sample for percentile approximation
self.RESERVOIR_SIZE = 1000
def add(self, value: float):
self.count += 1
self.total += value
if len(self.reservoir) < self.RESERVOIR_SIZE:
self.reservoir.append(value)
else:
# Reservoir sampling: replace with decreasing probability
idx = random.randint(0, self.count - 1)
if idx < self.RESERVOIR_SIZE:
self.reservoir[idx] = value
def p95(self) -> float:
if not self.reservoir:
return 0.0
sorted_sample = sorted(self.reservoir)
idx = int(0.95 * len(sorted_sample))
return sorted_sample[min(idx, len(sorted_sample) - 1)]
def mean(self) -> float:
return self.total / self.count if self.count > 0 else 0.0
class RouteDeliveryTimeAggregator(AggregateFunction):
"""
Computes delivery time statistics per route in a tumbling window.
Outputs: mean delivery time, p95 delivery time, total deliveries.
"""
def create_accumulator(self) -> RunningStatsAccumulator:
return RunningStatsAccumulator()
def add(self, event_raw: str, acc: RunningStatsAccumulator) -> RunningStatsAccumulator:
event = json.loads(event_raw)
delivery_time_min = event.get("delivery_time_minutes", 0.0)
acc.add(delivery_time_min)
return acc
def get_result(self, acc: RunningStatsAccumulator) -> str:
return json.dumps({
"count": acc.count,
"mean_delivery_min": round(acc.mean(), 2),
"p95_delivery_min": round(acc.p95(), 2),
})
def merge(self, a: RunningStatsAccumulator, b: RunningStatsAccumulator):
# Merging two reservoirs via union sampling
merged = RunningStatsAccumulator()
merged.count = a.count + b.count
merged.total = a.total + b.total
merged.reservoir = (a.reservoir + b.reservoir)[:merged.RESERVOIR_SIZE]
return merged
Pattern 4: CDC to Feature Store
Problem: The transactional database holds the ground truth for driver availability, vehicle capacity, and hub status. The feature pipeline needs this data in real-time with sub-second latency. Polling is too slow and too expensive. The solution is CDC.
import json
import redis
from confluent_kafka import Consumer
REDIS = redis.Redis(host="redis", port=6379, decode_responses=True)
def process_driver_cdc(debezium_msg: dict) -> dict | None:
"""
Transform a Debezium CDC envelope for the drivers table into ML features.
Returns None for delete operations.
"""
op = debezium_msg.get("op")
after = debezium_msg.get("after")
if op == "d" or after is None:
driver_id = debezium_msg.get("before", {}).get("driver_id")
if driver_id:
REDIS.delete(f"driver:{driver_id}")
return None
# Compute derived features from raw database columns
total_deliveries = after.get("total_deliveries", 1)
on_time_count = after.get("on_time_deliveries", 0)
features = {
"driver_id": after["driver_id"],
"on_time_rate": round(on_time_count / max(total_deliveries, 1), 4),
"avg_speed_mph": float(after.get("avg_speed_mph", 35.0)),
"experience_years": int(after.get("tenure_days", 0)) // 365,
"current_capacity": float(after.get("vehicle_capacity_lb", 0)),
"is_active": after.get("status") == "active",
"updated_ts": debezium_msg.get("ts_ms", 0),
}
# Write to Redis for sub-millisecond online lookup
REDIS.hset(f"driver:{after['driver_id']}", mapping={
k: str(v) for k, v in features.items()
})
REDIS.expire(f"driver:{after['driver_id']}", 86400 * 7) # 7-day TTL
return features
Pattern 5: Streaming ML Inference
Problem: Running model inference inside the Flink operator graph - scoring events as they arrive rather than batching them for a separate scoring service. This eliminates the network round-trip to a model serving endpoint, reducing latency from 5–50ms (network + serving overhead) to sub-millisecond.
Loading an ONNX Model into a MapFunction
ONNX (Open Neural Network Exchange) is the standard interchange format for production ML models. You export a scikit-learn, XGBoost, PyTorch, or TensorFlow model to ONNX, then load it in the Flink operator with onnxruntime. The model is loaded once when the operator is initialized (in open()) and reused for every event.
from pyflink.datastream.functions import MapFunction
import onnxruntime as ort
import numpy as np
import json
class ETAPredictionFunction(MapFunction):
"""
Loads an ONNX gradient boosting model and scores every enriched shipment event.
The model is loaded once per operator instance (in open()), not once per event.
"""
MODEL_PATH = "/opt/models/eta_prediction_v3.onnx"
FEATURE_NAMES = [
"package_weight_lb",
"distance_miles",
"driver_on_time_rate",
"driver_avg_speed_mph",
"driver_experience_yr",
"weather_temp_f",
"weather_precip_in",
"weather_visibility_mi",
"weather_wind_mph",
"route_avg_speed_mph", # from windowed aggregation
"hour_of_day",
"day_of_week",
]
def open(self, runtime_context):
# Initialize ONNX Runtime session - CPU provider, 2 intra-op threads
session_options = ort.SessionOptions()
session_options.intra_op_num_threads = 2
session_options.inter_op_num_threads = 1
self.session = ort.InferenceSession(
self.MODEL_PATH,
sess_options=session_options,
providers=["CPUExecutionProvider"],
)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
def map(self, enriched_event_raw: str) -> str:
event = json.loads(enriched_event_raw)
# Build feature vector from event fields
features = np.array([[
event.get(f, 0.0) for f in self.FEATURE_NAMES
]], dtype=np.float32)
# Run ONNX inference - ~0.1–0.5ms for GBM models
predictions = self.session.run(
[self.output_name],
{self.input_name: features}
)
eta_minutes = float(predictions[0][0])
# Emit enriched event with prediction attached
event["eta_minutes"] = round(eta_minutes, 1)
event["model_version"] = "v3"
event["scored_at_ts"] = event.get("event_ts", 0)
return json.dumps(event)
Model Hot-Reload
When you deploy a new model version, you want to update the model in the running Flink job without restarting it (which would cause a brief outage and require a checkpoint restore). The hot-reload pattern uses a control stream to signal version changes.
from pyflink.datastream.functions import BroadcastProcessFunction
from pyflink.datastream.state import ValueStateDescriptor, MapStateDescriptor
from pyflink.common import Types
import onnxruntime as ort
import json
import os
# Broadcast state: model metadata (version, path)
MODEL_META_DESC = MapStateDescriptor(
"model_metadata",
Types.STRING(), # key: "current_model"
Types.STRING(), # value: JSON with {version, path, loaded_at}
)
class HotReloadInferenceFunction(BroadcastProcessFunction):
"""
Runs ML inference on the main stream.
Listens on the broadcast stream for model version changes.
Reloads the model in-place when a new version is signaled.
"""
def open(self, runtime_context):
self.current_session = None
self.current_version = None
self._load_model("/opt/models/eta_prediction_v1.onnx", "v1")
def _load_model(self, model_path: str, version: str):
if not os.path.exists(model_path):
return
opts = ort.SessionOptions()
opts.intra_op_num_threads = 2
self.current_session = ort.InferenceSession(
model_path, sess_options=opts,
providers=["CPUExecutionProvider"]
)
self.current_version = version
print(f"[TaskManager] Loaded model {version} from {model_path}")
def process_element(self, event_raw: str, ctx, out):
"""Score events using the currently loaded model."""
if self.current_session is None:
return # No model loaded yet
event = json.loads(event_raw)
features = np.array([[
event.get(f, 0.0) for f in ETAPredictionFunction.FEATURE_NAMES
]], dtype=np.float32)
input_name = self.current_session.get_inputs()[0].name
output_name = self.current_session.get_outputs()[0].name
result = self.current_session.run([output_name], {input_name: features})
event["eta_minutes"] = round(float(result[0][0]), 1)
event["model_version"] = self.current_version
out.collect(json.dumps(event))
def process_broadcast_element(self, model_update_raw: str, ctx, out):
"""
Called when a new model version is broadcast.
Triggers hot-reload on this operator instance.
"""
update = json.loads(model_update_raw)
new_version = update.get("version")
new_path = update.get("path")
if new_version != self.current_version:
self._load_model(new_path, new_version)
ctx.get_broadcast_state(MODEL_META_DESC).put(
"current_model",
json.dumps({"version": new_version, "path": new_path})
)
Pattern 6: Feedback Loops
Problem: The model predicts ETA at shipment pickup. The actual delivery time is known only hours later. To generate labeled training examples, you need to join the prediction (emitted at pickup time) with the ground truth label (emitted at delivery time) - across a time gap of 2–8 hours.
This is the delayed label join problem. It requires buffering predictions for up to 8 hours waiting for the corresponding ground truth.
from pyflink.datastream.functions import CoProcessFunction
from pyflink.datastream.state import MapStateDescriptor, ValueStateDescriptor
from pyflink.common import Types, Duration
import json
import time
class DelayedLabelJoin(CoProcessFunction):
"""
Joins prediction events (stream 1) with ground truth labels (stream 2).
Predictions may arrive up to 8 hours before their ground truth.
State:
- pending_predictions: shipment_id → prediction dict (buffered, waiting for GT)
Emits: complete training example {features, prediction, actual_eta_min} when GT arrives.
"""
# TTL: expire predictions after 12 hours (covers the max delivery window + buffer)
STATE_TTL_HOURS = 12
def open(self, runtime_context):
from pyflink.datastream.state import StateTtlConfig
ttl_config = (
StateTtlConfig
.new_builder(Duration.of_hours(self.STATE_TTL_HOURS))
.set_update_type(StateTtlConfig.UpdateType.OnCreateAndWrite)
.set_state_visibility(StateTtlConfig.StateVisibility.NeverReturnExpired)
.build()
)
# Buffer: shipment_id → prediction event (waits for ground truth)
pred_desc = ValueStateDescriptor("pending_prediction", Types.STRING())
pred_desc.enable_time_to_live(ttl_config)
self.pending_prediction = runtime_context.get_state(pred_desc)
# Buffer: shipment_id → ground truth (in case GT arrives before prediction)
gt_desc = ValueStateDescriptor("pending_ground_truth", Types.STRING())
gt_desc.enable_time_to_live(ttl_config)
self.pending_ground_truth = runtime_context.get_state(gt_desc)
def process_element1(self, prediction_raw: str, ctx, out):
"""Called for every prediction event."""
prediction = json.loads(prediction_raw)
shipment_id = prediction["shipment_id"]
# Check if ground truth already arrived (early delivery)
gt = self.pending_ground_truth.value()
if gt:
ground_truth = json.loads(gt)
self._emit_training_example(prediction, ground_truth, out)
self.pending_ground_truth.clear()
else:
# Buffer prediction, wait for ground truth
self.pending_prediction.update(prediction_raw)
def process_element2(self, ground_truth_raw: str, ctx, out):
"""Called for every ground truth label event."""
ground_truth = json.loads(ground_truth_raw)
shipment_id = ground_truth["shipment_id"]
# Check if prediction is buffered
pred = self.pending_prediction.value()
if pred:
prediction = json.loads(pred)
self._emit_training_example(prediction, ground_truth, out)
self.pending_prediction.clear()
else:
# Buffer ground truth, wait for prediction (race condition)
self.pending_ground_truth.update(ground_truth_raw)
def _emit_training_example(self, prediction: dict, ground_truth: dict, out):
"""Emit a complete labeled training example."""
predicted_eta = prediction.get("eta_minutes", 0)
actual_eta = ground_truth.get("actual_delivery_minutes", 0)
training_example = {
"shipment_id": prediction["shipment_id"],
"features": prediction.get("features", {}),
"predicted_eta_min": predicted_eta,
"actual_eta_min": actual_eta,
"prediction_error_min": abs(actual_eta - predicted_eta),
"model_version": prediction.get("model_version"),
"created_at_ts": int(time.time() * 1000),
}
out.collect(json.dumps(training_example))
def build_feedback_pipeline(env, prediction_stream, ground_truth_stream):
"""
Connect prediction and ground truth streams for label joining.
Key both streams by shipment_id so matching events go to the same operator.
"""
keyed_predictions = prediction_stream.key_by(
lambda e: json.loads(e)["shipment_id"]
)
keyed_ground_truth = ground_truth_stream.key_by(
lambda e: json.loads(e)["shipment_id"]
)
training_examples = (
keyed_predictions
.connect(keyed_ground_truth)
.process(DelayedLabelJoin())
)
return training_examples
The feedback loop is what closes the ML training cycle. Without it, you must manually join prediction logs with outcome databases in batch - introducing hours of delay between new outcomes and updated training data. With the streaming feedback loop, new labeled examples are available for training within minutes of delivery completion. This compresses the model update cycle from weekly batch retraining to continuous fine-tuning.
Pattern 7: Exactly-Once End-to-End
Problem: Ensuring that every event is processed exactly once through the entire pipeline - from Kafka source to Flink operators to Kafka sink - with no duplicates and no missing events, even in the presence of failures.
The Two-Phase Commit Protocol
Flink's Kafka sink implements exactly-once using a two-phase commit (2PC) protocol:
Phase 1 (Pre-commit): When Flink initiates a checkpoint, the Kafka sink producer pre-commits its buffered messages to Kafka. In Kafka's transactional model, this means opening a transaction and writing the messages, but not committing the transaction yet. The sink notifies Flink that pre-commit succeeded.
Phase 2 (Commit or Abort): When the checkpoint completes successfully (all operators have snapshotted their state), the Kafka sink commits the transaction. The messages become visible to consumers. If the job fails before checkpoint completion, the sink aborts the transaction - no messages are written.
from pyflink.datastream.connectors.kafka import KafkaSink, DeliveryGuarantee, KafkaRecordSerializationSchema
from pyflink.datastream import StreamExecutionEnvironment
from pyflink.common.serialization import SimpleStringSchema
def build_exactly_once_pipeline(env: StreamExecutionEnvironment, source_stream):
"""
End-to-end exactly-once: Kafka source → Flink (exactly-once state) → Kafka sink.
Requirements:
- Kafka source: read committed isolation
- Flink: exactly-once checkpointing enabled
- Kafka sink: EXACTLY_ONCE delivery guarantee (transactional producer)
- Consumers: isolation.level=read_committed
"""
# Enable checkpointing - required for exactly-once sink
env.enable_checkpointing(30_000) # Checkpoint every 30 seconds
# Process stream with stateful feature computation
feature_stream = source_stream.map(compute_features_stateful)
# Exactly-once Kafka sink using transactional producer
exactly_once_sink = (
KafkaSink.builder()
.set_bootstrap_servers("kafka:9092")
.set_record_serializer(
KafkaRecordSerializationSchema.builder()
.set_topic("features.eta_inputs")
.set_key_serialization_schema(SimpleStringSchema())
.set_value_serialization_schema(SimpleStringSchema())
.build()
)
# EXACTLY_ONCE: uses Kafka transactions + Flink 2PC
.set_delivery_guarantee(DeliveryGuarantee.EXACTLY_ONCE)
# Transactional ID prefix - must be unique per Flink job
.set_transactional_id_prefix("eta-feature-pipeline-sink")
.build()
)
feature_stream.sink_to(exactly_once_sink)
return feature_stream
What Breaks Exactly-Once
Exactly-once is fragile. Several patterns silently break it:
# --- PATTERNS THAT BREAK EXACTLY-ONCE ---
# DANGER 1: Writing to Redis inside the operator without idempotency
# If the operator replays after a checkpoint, Redis gets written twice
class DangerousRedisWriter(MapFunction):
def map(self, event_raw: str) -> str:
event = json.loads(event_raw)
# DANGER: Redis INCR is not idempotent - replay causes double-counting
self.redis.incr(f"user:{event['user_id']}:event_count")
return event_raw
# SAFE alternative: use HSET with an absolute value (idempotent)
class SafeRedisWriter(MapFunction):
def map(self, event_raw: str) -> str:
event = json.loads(event_raw)
# SAFE: HSET is idempotent - replaying sets the same value again
self.redis.hset(
f"user:{event['user_id']}",
"last_event_ts", event["event_ts"]
)
return event_raw
# DANGER 2: Non-deterministic operators (current time, random numbers)
class DangerousTimestampAdder(MapFunction):
def map(self, event_raw: str) -> str:
event = json.loads(event_raw)
# DANGER: time.time() returns different values on replay
# This means the output on re-execution differs from the original
import time
event["processed_at"] = time.time()
return json.dumps(event)
# SAFE: use event time from the message, not wall clock
class SafeTimestampAdder(MapFunction):
def map(self, event_raw: str) -> str:
event = json.loads(event_raw)
# SAFE: deterministic - same input always produces same output on replay
event["processed_at"] = event.get("event_ts", 0)
return json.dumps(event)
The silent exactly-once killer: external writes without idempotency
Exactly-once in Kafka (source and sink) does not protect external writes (Redis, PostgreSQL, Elasticsearch). If your Flink operator writes to Redis with a non-idempotent operation (INCR, RPUSH, non-conditional writes), a replay after checkpoint restore will double-write. Always use idempotent write patterns for external stores: SET key value (overwrites), HSET with absolute values, or writes that include a version check (SET key value NX with a processed event ID).
Production Notes
Pattern selection guide:
- Reference data fits in memory (less than 1 GB total), updates infrequently (hourly) → Broadcast State
- Reference data is large or updates continuously → Async I/O with Redis/DynamoDB
- Reference data is huge but per-key is small, access pattern is hot/cold → Cache-Aside
- Join two streams within a known time window → Interval Join
- Join two streams within the same time window → Windowed Join
- Database changes need to propagate to the feature store in real-time → CDC to Feature Store
- Score events without a network call to a model server → Streaming Inference
- Labels arrive hours after predictions → Delayed Label Join (Feedback Loop)
- Cannot tolerate duplicate feature writes → Exactly-Once with 2PC
Async I/O capacity tuning:
The capacity parameter in AsyncDataStream.unordered_wait() is the maximum number of concurrent in-flight async requests. Too low: the operator blocks waiting for responses, reducing throughput. Too high: you overwhelm the external store with concurrent requests. Start with capacity = parallelism × 10 (e.g., 4 slots × 10 = 40 concurrent requests), then tune based on observed latency and external store saturation metrics.
Interval join memory budget: For an interval join with window size seconds at rate events/sec on each stream, the memory requirement per operator instance is roughly:
For events/sec, seconds (10 min), 500 bytes/event:
Size your TaskManager memory accordingly before deploying.
Common Mistakes
Mistake: Using Broadcast State for large, frequently-updated reference data Broadcast state is replicated to every operator instance. If your driver profile table has 500,000 rows at 1KB each, that is 500MB of broadcast state per operator instance. With 12 operator instances, you are holding 6GB of redundant data. For large or frequently-updated reference data, use Async I/O with an external store instead.
Mistake: Interval join with an unbounded window Setting the interval join window to a very large value (hours or days) to "be safe" will cause your job to OOM. Every event from both streams must be buffered for the entire join window. Set the interval to the minimum reasonable window for your use case, and route events that miss the join window to a side output for batch reconciliation.
Mistake: Loading the ONNX model in map() instead of open()
If you load the model on every call to map() (once per event), you pay the model deserialization cost on every event - typically 50–500ms. Load in open() once per operator instance lifetime. The model is reused for all events processed by that instance.
Mistake: Feedback loop without TTL on buffered predictions
If a shipment is lost or canceled, the ground truth event never arrives. Without TTL on the pending_prediction state, the prediction is buffered forever, slowly accumulating unmatched predictions until state storage is exhausted. Always configure State TTL on both sides of a delayed join, set to the maximum expected label delay plus a safety buffer.
Interview Q&A
What is the broadcast state pattern and when would you use it?
The broadcast state pattern is a Flink feature for enriching a high-volume main stream with a small, slowly-changing reference dataset. The reference dataset (e.g., driver profiles, product metadata) is published to a separate "control" stream and broadcast to every operator instance. Each operator holds a full in-memory copy of the reference data, making enrichment lookups local and sub-millisecond with no network I/O.
Use it when: the reference dataset fits in memory (typically under a few GB), updates are infrequent (hourly or less), and you need maximum enrichment throughput. Don't use it when the reference dataset is large (use Async I/O instead), or when you need the absolute latest value on every lookup (broadcast replication has some delay).
How do you join two unbounded streams in Flink?
Two main approaches: the interval join and the windowed join. The interval join (streamA.intervalJoin(streamB).between(lower, upper)) matches each event from stream A with events from stream B where B's timestamp falls within [A.timestamp + lower, A.timestamp + upper]. Both streams are keyed by the join key first. Flink buffers both streams' events within the join window. This is the natural choice for time-correlated events like matching shipment events with weather events within ±5 minutes.
The windowed join (streamA.join(streamB).window(...)) joins events that fall within the same window boundary. Both streams must have events in the same window for a match.
The key constraint for both: you must define a bounded time window. Flink must know when it is safe to discard buffered events. Without a bound, you buffer forever and run out of memory.
What is Async I/O and why is it important for feature store lookups?
Async I/O (AsyncDataStream.unorderedWait()) is a Flink API for making non-blocking external calls from a stream operator. Without Async I/O, a synchronous external call (e.g., Redis HGET for a feature lookup) blocks the operator thread until the response arrives. If the call takes 5ms and you process 10,000 events/sec, the operator is blocked 50 seconds per second - clearly impossible. You would need to set parallelism very high to compensate.
With Async I/O, the operator issues up to capacity concurrent requests simultaneously. While waiting for one response, it processes the next event's request. If Redis responds in 5ms on average with capacity=100, one operator instance can handle 20,000 requests/sec (100 / 0.005). Async I/O is the essential pattern for any stream operator that needs to look up data in an external store.
How do you run ML inference inside a Flink job?
Load the model in the operator's open() method (called once per operator instance lifecycle, not once per event). Use ONNX Runtime as the inference engine - it accepts models exported from scikit-learn, XGBoost, PyTorch, and TensorFlow, and runs efficiently on CPU in a JVM process.
For the MapFunction or ProcessFunction, extract features from the event into a numpy array matching the model's input shape, call session.run(), and attach the prediction to the output event. Throughput is typically 2,000–20,000 events/sec per operator instance for GBM models, and 500–5,000 for small neural networks.
For model hot-reload without job restart, use the broadcast state pattern: a control stream carries model version updates, and each operator instance reloads the ONNX session when a new version is signaled.
How do you implement a feedback loop for streaming ML?
The feedback loop uses CoProcessFunction (connected streams) to join predictions with delayed ground truth labels. Both streams are keyed by the entity ID (shipment ID, transaction ID). The CoProcessFunction maintains two state entries per key: the buffered prediction (waiting for its label) and the buffered label (if it arrives before the prediction, though this is rare).
When a prediction arrives, check if its label is already buffered - emit the training example immediately. If not, buffer the prediction. When a label arrives, check if the matching prediction is buffered - emit and clear. Configure State TTL on both buffers set to the maximum expected label delay (e.g., 12 hours for a same-day delivery system) to prevent unbounded state growth from unmatched events.
What breaks exactly-once semantics in a streaming pipeline?
Three categories of breaks:
Non-idempotent external writes: writing to Redis with INCR or RPUSH inside a Flink operator. On replay after checkpoint restore, the same event is processed again and the external write happens twice. Fix: use idempotent writes (absolute value sets) or include the processed event offset as a conditional version check.
Non-deterministic operators: operators that produce different output for the same input on replay - using time.time(), random(), or reading from an external system inside map(). Fix: compute all non-deterministic values from event fields (use event time, not wall clock; use deterministic hash functions, not random).
Sinks without transactional support: writing to a system that does not support 2PC (Elasticsearch, plain HTTP endpoints) inside a checkpoint-enabled Flink job. Flink's exactly-once guarantee applies to Kafka sinks (which support 2PC). For non-transactional sinks, at-least-once is the best you can achieve - make the downstream consumer idempotent instead.
