Skip to main content

:::tip 🎮 Interactive Playground Visualize this concept: Try the Spark Batch Processing demo on the EngineersOfAI Playground - no code required. :::

Spark for ML Pipelines

Three Billion Rows and a Deadline

Elena is a senior data engineer on the ML Platform team at a ride-sharing company. The company's driver ETA model - the model that tells riders how long their driver will take to arrive - is one of the highest-value models in the entire product. A 15-second improvement in ETA accuracy meaningfully reduces cancellation rates and increases rider satisfaction scores. The model is retrained nightly. It trains on three years of trip history: three billion records.

The model team has given Elena a feature specification document. It lists 218 features. The features include driver-level statistics (average trips per hour in the last 7 days, average rating over the last 30 days, percentage of trips completed in under 5 minutes), city-level context (how many trips are currently active in this geohash cell, historical average wait time for this cell at this time of day), and complex temporal features (rolling standard deviation of trip duration for this driver over the last 14 days, lag of the previous trip's tip amount). Many of these features require window functions over time-ordered trip data. Several require joining the trip table with the driver profile table and the city-level heatmap table. The entire computation must finish within a 3-hour window before the training job starts.

This is not a Pandas problem. A Pandas DataFrame of three billion rows at modest 100 bytes per row is 300 GB of RAM - before any intermediate computation. Window functions in Pandas require sorting the entire frame in memory. The groupBy aggregations produce intermediate frames that can be as large as the original data. Elena knows from experience that Pandas on this dataset would either run for 12 hours or crash. She needs Spark.

But knowing that you need Spark is different from knowing how to use it well for ML features. The PySpark API has dozens of ways to compute a rolling average. Not all of them are efficient. The Window function is the right tool, but only if you understand which window operations trigger shuffles and which do not. UDFs for custom computations seem like an obvious solution - until you discover they are 10–20x slower than equivalent built-in functions. And the most insidious problem: if you compute features naively, you leak future information into training features, producing a model that looks excellent offline and fails in production. Point-in-time correctness is not optional. It is the difference between a useful model and a fraudulent one.

This lesson covers the full stack of PySpark techniques for ML feature engineering: the DataFrame API in depth, window functions for temporal features, UDFs and why to avoid them, Pandas UDFs as the efficient alternative, Spark MLlib Pipelines, point-in-time join patterns, and writing ML-ready datasets to Delta Lake. By the end, you will be able to write the pipeline Elena needs.


Why Spark for ML Features

At scales above 50–100 GB of input data, Spark is the standard tool for ML feature computation in production. The reasons are practical:

Memory model: Spark processes data partition-by-partition, never holding the entire dataset in driver memory. A pipeline that reads 1 TB of Parquet, applies window aggregations, and writes 50 GB of features never materializes more than a few GB per executor at any given time.

Expressiveness for ML patterns: The DataFrame API has native support for the operations that dominate feature engineering - window aggregations, joins, pivots, conditional expressions, type casting. What would require custom loops in Pandas is a single DataFrame operation in Spark.

Columnar execution via Tungsten: Spark's execution engine operates on columnar data in batches, benefiting from CPU cache locality and SIMD vectorization. Aggregations over billions of rows of numeric data run at near-hardware speed.

Native integration: Spark reads from and writes to S3, GCS, HDFS, Delta Lake, Hive Metastore, and most cloud data warehouses. Feature pipelines that need to read from multiple sources and write to a feature store have native connectors without custom serialization code.


PySpark DataFrame API in Depth

The DataFrame API is the foundation of every PySpark ML pipeline. Before covering window functions and UDFs, you need to be fluent in the core operations.

Column operations: select, withColumn, filter

from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import DoubleType, StringType, TimestampType

spark = SparkSession.builder.appName("eta-features").getOrCreate()

trips = spark.read.parquet("s3://data/trips/")

# select - defines the output schema explicitly.
# Spark only reads these columns from Parquet (projection pruning).
trips_clean = trips.select(
"trip_id",
"driver_id",
"city_id",
F.col("start_time").cast(TimestampType()),
F.col("end_time").cast(TimestampType()),
F.col("actual_eta_seconds").cast(DoubleType()).alias("actual_eta"),
F.col("predicted_eta_seconds").cast(DoubleType()).alias("predicted_eta"),
"tip_amount",
"rating",
"status",
)

# filter - pushed down to the Parquet scanner when possible.
# Spark skips row groups in Parquet files that cannot satisfy the filter.
trips_completed = trips_clean.filter(
(F.col("status") == "completed") &
(F.col("actual_eta").isNotNull()) &
(F.col("actual_eta") > 0)
)

# withColumn - adds or overwrites a single column.
# Chain multiple withColumn calls to build up derived features.
trips_with_features = (
trips_completed
.withColumn(
"trip_duration_minutes",
(F.unix_timestamp("end_time") - F.unix_timestamp("start_time")) / 60
)
.withColumn("eta_error_seconds",
F.col("actual_eta") - F.col("predicted_eta")
)
.withColumn("eta_error_pct",
F.col("eta_error_seconds") / F.col("predicted_eta")
)
.withColumn("hour_of_day", F.hour("start_time"))
.withColumn("day_of_week", F.dayofweek("start_time"))
.withColumn("is_weekend",
F.when(F.col("day_of_week").isin([1, 7]), 1).otherwise(0)
)
)

Aggregations: groupBy and agg

# groupBy + agg uses partial aggregation before the shuffle.
# Spark computes partial sums/counts on each executor before moving data -
# equivalent to a combiner in MapReduce, much more efficient than groupByKey.
driver_stats = (
trips_with_features
.groupBy("driver_id")
.agg(
F.count("*").alias("total_trips"),
F.avg("actual_eta").alias("avg_eta"),
F.stddev("actual_eta").alias("std_eta"),
F.avg("trip_duration_minutes").alias("avg_duration"),
F.avg("rating").alias("avg_rating"),
F.sum(F.when(F.col("rating") >= 4.5, 1).otherwise(0)).alias("high_rating_count"),
F.avg("tip_amount").alias("avg_tip"),
F.max("start_time").alias("last_trip_time"),
)
.withColumn("high_rating_rate",
F.col("high_rating_count") / F.col("total_trips")
)
)

Window Functions for ML Features

Window functions are the most important tool in PySpark for ML feature engineering. They compute a value for each row based on a sliding window of rows defined by partition, ordering, and range. They are the correct way to compute rolling averages, lag features, ranks, and cumulative aggregations.

The Window specification

from pyspark.sql import Window

# A Window specification has three parts:
# 1. partitionBy - which rows are in the same group (like GROUP BY)
# 2. orderBy - how rows are ordered within the group
# 3. rowsBetween / rangeBetween - the extent of the window relative to current row

# cumulative window: all rows for this driver up to and including the current row
cumulative_window = Window.partitionBy("driver_id").orderBy("start_time")

# Rolling 7-day window based on time (rangeBetween with unix timestamp)
# rangeBetween uses the orderBy column's value - so we order by unix_timestamp
rolling_7d = (
Window
.partitionBy("driver_id")
.orderBy(F.col("start_time").cast("long")) # Unix timestamp (seconds)
.rangeBetween(-7 * 24 * 3600, 0) # 7 days in seconds, inclusive
)

# Rolling 30-day window
rolling_30d = (
Window
.partitionBy("driver_id")
.orderBy(F.col("start_time").cast("long"))
.rangeBetween(-30 * 24 * 3600, 0)
)

# Last 10 trips (row-count based - useful when time intervals are irregular)
last_10_trips = (
Window
.partitionBy("driver_id")
.orderBy("start_time")
.rowsBetween(-9, 0) # 9 preceding rows + current row = 10 rows
)

Computing rolling and lag features

trips_with_rolling = (
trips_with_features
# Rolling mean ETA - key feature for ETA model
.withColumn("avg_eta_7d", F.avg("actual_eta").over(rolling_7d))
.withColumn("avg_eta_30d", F.avg("actual_eta").over(rolling_30d))
# Rolling standard deviation - captures driver consistency
.withColumn("std_eta_7d", F.stddev("actual_eta").over(rolling_7d))
.withColumn("std_eta_30d", F.stddev("actual_eta").over(rolling_30d))
# Rolling trip counts
.withColumn("trip_count_7d", F.count("*").over(rolling_7d))
.withColumn("trip_count_30d", F.count("*").over(rolling_30d))
# Rolling rating and tip
.withColumn("avg_rating_7d", F.avg("rating").over(rolling_7d))
.withColumn("avg_tip_30d", F.avg("tip_amount").over(rolling_30d))
# Lag features - previous trip's values for the same driver
.withColumn("prev_eta", F.lag("actual_eta", 1).over(cumulative_window))
.withColumn("prev_duration", F.lag("trip_duration_minutes", 1).over(cumulative_window))
.withColumn("prev_tip", F.lag("tip_amount", 1).over(cumulative_window))
# Rank within the driver's history (how experienced is this driver?)
.withColumn("trip_rank", F.row_number().over(cumulative_window))
# Expanding (cumulative) mean - all trips up to current
.withColumn("cumulative_avg_eta", F.avg("actual_eta").over(cumulative_window))
# Percent rank - what percentile is this driver's current ETA within their history?
.withColumn("eta_percentile", F.percent_rank().over(cumulative_window))
)

:::tip Window functions and shuffle cost Each unique (partitionBy, orderBy) combination triggers one shuffle to redistribute data. Group all window functions that share the same specification into a single chain of withColumn calls - they share the same shuffle. If you use three different window specs (7d, 30d, cumulative), you pay three shuffles. That is unavoidable, but the computation within each shuffle is parallelized across partitions. :::


UDFs vs Built-in Functions

This is one of the most consequential performance decisions in PySpark development. Getting it wrong means a pipeline that takes 3 hours instead of 18 minutes.

Why Python UDFs are slow

A Python UDF bridges the JVM execution environment (where Spark runs natively) and the Python interpreter. For each row processed:

  1. Spark serializes the row from JVM binary format to Python-compatible format (pickle or Arrow)
  2. The serialized data crosses from the JVM process to the Python worker process via a socket
  3. The Python worker deserializes it and runs your Python function
  4. The result is serialized and sent back across the socket to the JVM

At one billion rows, this is one billion round-trips across an inter-process socket. The serialization and deserialization overhead completely dominates the actual computation time. Python UDFs are typically 10–20x slower than equivalent built-in Spark functions.

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

# SLOW: Python UDF - one row at a time, crosses JVM-Python boundary per row
@udf(returnType=StringType())
def categorize_eta_udf(eta_seconds):
if eta_seconds is None:
return "unknown"
elif eta_seconds < 120:
return "very_fast"
elif eta_seconds < 300:
return "fast"
elif eta_seconds < 600:
return "normal"
else:
return "slow"

df_slow = trips_with_features.withColumn(
"eta_category",
categorize_eta_udf(F.col("actual_eta"))
)

# FAST: Built-in F.when - stays in JVM, Tungsten whole-stage code generation
# Typically 10-20x faster on large datasets
df_fast = trips_with_features.withColumn(
"eta_category",
F.when(F.col("actual_eta").isNull(), "unknown")
.when(F.col("actual_eta") < 120, "very_fast")
.when(F.col("actual_eta") < 300, "fast")
.when(F.col("actual_eta") < 600, "normal")
.otherwise("slow")
)

Rule: before writing any UDF, check pyspark.sql.functions for a built-in equivalent. Spark has over 200 built-in functions covering string manipulation, date/time operations, math, conditional logic, array operations, and JSON parsing. If a built-in exists, use it.

Pandas UDFs: the efficient middle ground

When a computation genuinely cannot be expressed with built-in functions, use a Pandas UDF (vectorized UDF). Instead of being called once per row, a Pandas UDF receives an entire partition as one or more Pandas Series objects, applies a vectorized operation, and returns a Pandas Series.

The performance difference is fundamental: instead of 1 million individual Python function calls per million rows, a Pandas UDF makes one call with a Pandas Series of 1 million elements. Data is transferred using Apache Arrow (a columnar binary format), which is dramatically more efficient than row-by-row pickling. The Python function applies vectorized NumPy/Pandas operations across the batch.

import pandas as pd
import pygeohash as pgh
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StringType

# Pandas UDF: receives entire partition as Pandas Series objects
# Returns a Pandas Series of the same length
# Data transferred via Apache Arrow - efficient columnar serialization
@pandas_udf(StringType())
def compute_geohash_udf(lat: pd.Series, lon: pd.Series) -> pd.Series:
"""
Encode (lat, lon) coordinates to a geohash at precision 5.
Geohash precision 5 = ~4.9km x 4.9km cell.
Used as a location bucket feature for the ETA model.
There is no built-in Spark equivalent for geohash encoding.
"""
return pd.Series([
pgh.encode(la, lo, precision=5)
if pd.notna(la) and pd.notna(lo) else None
for la, lo in zip(lat, lon)
])

# Apply: partition is Arrow-serialized and sent once per partition, not per row
trips_with_geohash = trips_with_features.withColumn(
"pickup_geohash",
compute_geohash_udf(F.col("pickup_lat"), F.col("pickup_lon"))
)


# Iterator Pandas UDF: called once per partition with an iterator of DataFrames.
# Use this pattern when you need to initialize expensive state once per partition
# (e.g., loading a model artifact, opening a database connection).
from typing import Iterator
from pyspark.sql.types import DoubleType

@pandas_udf(DoubleType())
def score_with_local_model(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
"""
Load a scoring model once per partition.
Score every row using that model.
Amortizes model loading cost across thousands of rows.
"""
import pickle
with open("/shared/models/geohash_pop_model.pkl", "rb") as f:
model = pickle.load(f)

for batch in iterator:
features = batch[["hour_of_day", "day_of_week", "pickup_geohash_encoded"]]
scores = model.predict_proba(features.fillna(0))[:, 1]
yield pd.Series(scores, index=batch.index)

:::note When to use which type Use built-in functions first - always. Use Pandas UDFs when the logic requires a Python library (pygeohash, scipy, a custom model) or genuinely vectorized operations not available as built-ins. Use plain Python UDFs only in interactive exploration where performance does not matter. Never use plain Python UDFs in production feature pipelines. :::


Spark MLlib Pipelines

Spark MLlib provides a scikit-learn-inspired Pipeline API for composing preprocessing transformations and ML algorithms into reproducible, serializable workflows.

The Transformer / Estimator pattern

  • Transformer: Stateless - takes a DataFrame, returns a DataFrame. Examples: VectorAssembler, SQLTransformer, Tokenizer.
  • Estimator: Learns parameters from data during .fit(), produces a fitted Transformer. Examples: StringIndexer learns the vocabulary from training data; StandardScaler learns mean and standard deviation.
  • Pipeline: A sequence of Transformers and Estimators. .fit(train_df) fits each Estimator in sequence, producing a PipelineModel.
from pyspark.ml import Pipeline
from pyspark.ml.feature import (
StringIndexer,
OneHotEncoder,
VectorAssembler,
StandardScaler,
Bucketizer,
)

# 1. StringIndexer: city_id string -> integer index
# Learns the full vocabulary of city IDs from training data
city_indexer = StringIndexer(
inputCol="city_id",
outputCol="city_idx",
handleInvalid="keep", # Unseen cities at serving time get a dedicated index
)

# 2. OneHotEncoder: integer index -> binary sparse vector
city_encoder = OneHotEncoder(
inputCols=["city_idx"],
outputCols=["city_ohe"],
dropLast=True,
)

# 3. Bucketize ETA into meaningful bins
eta_bucketizer = Bucketizer(
splits=[0, 60, 120, 300, 600, float("inf")],
inputCol="actual_eta",
outputCol="eta_bucket",
handleInvalid="keep",
)

# 4. Assemble all numeric and encoded features into a single feature vector
assembler = VectorAssembler(
inputCols=[
"city_ohe",
"eta_bucket",
"hour_of_day",
"day_of_week",
"is_weekend",
"trip_count_7d",
"avg_eta_7d",
"std_eta_7d",
"avg_rating_7d",
"trip_rank",
"prev_eta",
"cumulative_avg_eta",
],
outputCol="features_raw",
handleInvalid="skip",
)

# 5. StandardScaler: zero mean, unit variance
# Learns mean and std from training data during .fit()
scaler = StandardScaler(
inputCol="features_raw",
outputCol="features",
withMean=True,
withStd=True,
)

# Build and fit the pipeline
feature_pipeline = Pipeline(stages=[
city_indexer,
city_encoder,
eta_bucketizer,
assembler,
scaler,
])

train_df, test_df = trips_with_rolling.randomSplit([0.8, 0.2], seed=42)
pipeline_model = feature_pipeline.fit(train_df)

# Transform - applies all fitted transformers in sequence
train_features = pipeline_model.transform(train_df)
test_features = pipeline_model.transform(test_df)

# Serialize the fitted pipeline - load at serving time for consistent transforms.
# The saved model includes the learned vocabulary (StringIndexer),
# mean/std (StandardScaler), and all other fitted state.
pipeline_model.write().overwrite().save("s3://models/eta-feature-pipeline-v2/")

# Load for serving
from pyspark.ml import PipelineModel
loaded_pipeline = PipelineModel.load("s3://models/eta-feature-pipeline-v2/")

The key value of MLlib Pipelines for production ML systems: the fitted PipelineModel guarantees that serving applies the exact same transformations - including the exact same learned vocabulary and scaling factors - as training. This eliminates the most common form of training-serving skew: inconsistent preprocessing logic between the training job and the serving layer.


Point-in-Time Correct Feature Computation

This is the most important concept in this lesson for ML correctness. Read it carefully.

When you train a model on historical data, each training example has a label timestamp - the moment in time when the label was observed. For the ETA model, the label timestamp is when the driver accepted the trip. The features for that training example must reflect only data that was available at or before that timestamp. Using data from after the label timestamp is called label leakage.

Why it fails silently

Label leakage is insidious because it produces excellent offline metrics and terrible production performance. If you compute a driver's "average trip duration in the last 30 days" using all trips up to today (e.g., 2024-11-01), and then use that feature for a training example from 6 months ago (e.g., 2024-05-01), you have used 6 months of future trips. The model learns to use information that is "good" in retrospect but unavailable at prediction time. Offline evaluation looks excellent. Production degrades.

The correct pattern is an as-of join (point-in-time join): for each label event at time tt, join the most recent feature value computed at or before tt.

feature(t)=argmaxtt  feature_snapshot(t)\text{feature}(t) = \text{argmax}_{t' \leq t} \; \text{feature\_snapshot}(t')

Implementing point-in-time joins in PySpark

def point_in_time_join(
spark,
labels_df,
features_df,
entity_key: str,
label_time_col: str,
feature_time_col: str,
feature_cols: list,
max_lookback_days: int = 90,
):
"""
For each row in labels_df, find the most recent row in features_df
for the same entity where feature_time <= label_time.

labels_df: DataFrame with [entity_key, label_time_col, label, ...]
features_df: DataFrame with [entity_key, feature_time_col, feature1, ...]
entity_key: Entity identifier column (e.g., "driver_id")
label_time_col: Timestamp of the label event in labels_df
feature_time_col: Timestamp when the feature was computed in features_df
feature_cols: Feature columns to pull into the result
max_lookback_days: Maximum age of a feature snapshot to consider valid
"""
from pyspark.sql import Window

# Step 1: Join on entity key (not on time yet)
# Limit the join to a lookback window to prevent an explosive cross-join
joined = labels_df.join(
features_df.select(entity_key, feature_time_col, *feature_cols),
on=entity_key,
how="left",
).filter(
# Only feature rows strictly before or at the label event timestamp
F.col(feature_time_col) <= F.col(label_time_col)
).filter(
# Discard feature snapshots older than the lookback window
F.col(feature_time_col) >= F.date_sub(F.col(label_time_col), max_lookback_days)
)

# Step 2: Rank candidate feature rows by recency (most recent = rank 1)
# Partition by both entity and label event to handle multiple labels per entity
pit_window = (
Window
.partitionBy(entity_key, label_time_col)
.orderBy(F.col(feature_time_col).desc())
)
ranked = joined.withColumn("_rank", F.row_number().over(pit_window))

# Step 3: Keep only the most recent feature row for each label event
result = (
ranked
.filter(F.col("_rank") == 1)
.drop("_rank", feature_time_col)
)

return result


# Usage: build a training set where each trip uses only driver features
# that were available when the trip started (no future leakage)
label_events = trips_completed.select(
"trip_id",
"driver_id",
F.col("start_time").alias("label_time"),
"actual_eta", # the label
)

# Historical driver feature snapshots - one row per (driver_id, feature_date)
driver_feature_snapshots = spark.read.parquet(
"s3://feature-store/driver-features-daily-snapshots/"
)

training_set = point_in_time_join(
spark=spark,
labels_df=label_events,
features_df=driver_feature_snapshots,
entity_key="driver_id",
label_time_col="label_time",
feature_time_col="feature_computed_at",
feature_cols=["avg_eta_7d", "trip_count_30d", "avg_rating_30d", "avg_tip_30d"],
max_lookback_days=30,
)

:::danger Never use F.lead() in training features F.lead("rating", 1).over(window) returns the value from the next row - a future row relative to the current one. Using lead features in training means the model sees future data it cannot see at serving time. This is the most common form of label leakage in window function-based feature pipelines. Only F.lag() (previous rows) is safe for training features. :::


Complete Feature Pipeline

Here is the complete PySpark pipeline that produces the 218 features the ETA model requires, integrating all the techniques covered above:

from pyspark.sql import SparkSession, Window, functions as F
from pyspark.sql.types import StringType
import pandas as pd
import pygeohash as pgh
from pyspark.sql.functions import pandas_udf

def build_eta_feature_pipeline(spark: SparkSession, run_date: str) -> None:
"""
Complete feature pipeline for the Driver ETA model.

Reads raw trip data, computes rolling window features, joins with driver
profiles using a broadcast join, and writes partitioned Parquet output.

run_date: "YYYY-MM-DD" - all features reflect data available as of this date.
"""
LOOKBACK_DAYS = 90

# ================================================================
# STEP 1: LOAD RAW DATA (with projection + predicate pushdown)
# ================================================================
trips_raw = (
spark.read
.parquet("s3://data/trips/")
.filter(F.col("status") == "completed")
.filter(F.col("start_time") >= F.date_sub(F.lit(run_date), LOOKBACK_DAYS))
.filter(F.col("start_time") <= run_date)
.select(
"trip_id", "driver_id", "city_id",
F.col("start_time").cast("timestamp"),
F.col("end_time").cast("timestamp"),
F.col("actual_eta_seconds").cast("double").alias("actual_eta"),
F.col("predicted_eta_seconds").cast("double").alias("predicted_eta"),
F.col("trip_distance_km").cast("double").alias("distance_km"),
F.col("tip_amount").cast("double"),
F.col("rating").cast("double"),
F.col("pickup_lat").cast("double"),
F.col("pickup_lon").cast("double"),
)
.filter(F.col("actual_eta") > 0)
.filter(F.col("actual_eta") < 7200) # Filter obvious sensor errors
)

# Cache: trips_raw is used for both window features and the join base
trips_raw.cache()

driver_profiles = (
spark.read
.parquet("s3://data/driver-profiles/")
.select("driver_id", "vehicle_type", "years_active", "home_city_id")
)
# Broadcast: ~50MB - eliminates shuffle on the trips side
driver_profiles_broadcast = F.broadcast(driver_profiles)

# ================================================================
# STEP 2: BASIC DERIVED COLUMNS
# ================================================================
@pandas_udf(StringType())
def encode_geohash(lat: pd.Series, lon: pd.Series) -> pd.Series:
return pd.Series([
pgh.encode(la, lo, precision=5)
if pd.notna(la) and pd.notna(lo) else None
for la, lo in zip(lat, lon)
])

trips = (
trips_raw
.withColumn("duration_minutes",
(F.unix_timestamp("end_time") - F.unix_timestamp("start_time")) / 60
)
.withColumn("speed_kmh",
F.col("distance_km") / (F.col("duration_minutes") / 60)
)
.withColumn("eta_error", F.col("actual_eta") - F.col("predicted_eta"))
.withColumn("hour_of_day", F.hour("start_time"))
.withColumn("day_of_week", F.dayofweek("start_time"))
.withColumn("is_weekend",
F.when(F.col("day_of_week").isin([1, 7]), 1).otherwise(0)
)
.withColumn("pickup_geohash", encode_geohash("pickup_lat", "pickup_lon"))
)

# ================================================================
# STEP 3: WINDOW FEATURES
# ================================================================
unix_ts = F.unix_timestamp("start_time")

w7 = Window.partitionBy("driver_id").orderBy(unix_ts).rangeBetween(-7*86400, 0)
w30 = Window.partitionBy("driver_id").orderBy(unix_ts).rangeBetween(-30*86400, 0)
w_c = Window.partitionBy("driver_id").orderBy("start_time") # cumulative

trips_windowed = (
trips
.withColumn("avg_eta_7d", F.avg("actual_eta").over(w7))
.withColumn("std_eta_7d", F.stddev("actual_eta").over(w7))
.withColumn("trip_count_7d", F.count("*").over(w7))
.withColumn("avg_speed_7d", F.avg("speed_kmh").over(w7))
.withColumn("avg_rating_7d", F.avg("rating").over(w7))
.withColumn("avg_eta_30d", F.avg("actual_eta").over(w30))
.withColumn("std_eta_30d", F.stddev("actual_eta").over(w30))
.withColumn("trip_count_30d", F.count("*").over(w30))
.withColumn("avg_tip_30d", F.avg("tip_amount").over(w30))
.withColumn("avg_rating_30d", F.avg("rating").over(w30))
.withColumn("trip_rank", F.row_number().over(w_c))
.withColumn("cumulative_avg_eta", F.avg("actual_eta").over(w_c))
.withColumn("prev_eta", F.lag("actual_eta", 1).over(w_c))
.withColumn("prev_duration", F.lag("duration_minutes", 1).over(w_c))
.withColumn("prev_tip", F.lag("tip_amount", 1).over(w_c))
)

# Handle nulls in lag features (first trip per driver has no lag)
trips_windowed = (
trips_windowed
.withColumn("prev_eta", F.coalesce(F.col("prev_eta"), F.avg("actual_eta").over(w_c)))
.withColumn("prev_duration", F.coalesce(F.col("prev_duration"), F.lit(0.0)))
.withColumn("prev_tip", F.coalesce(F.col("prev_tip"), F.lit(0.0)))
)

# ================================================================
# STEP 4: JOIN WITH DRIVER PROFILES (broadcast join - no shuffle)
# ================================================================
enriched = trips_windowed.join(
driver_profiles_broadcast,
on="driver_id",
how="left",
)

# ================================================================
# STEP 5: SELECT FINAL FEATURE SET
# ================================================================
feature_cols = [
"trip_id", "driver_id", "city_id",
F.col("start_time").alias("label_time"),
"actual_eta", # label
"hour_of_day", "day_of_week", "is_weekend",
"distance_km", "pickup_geohash",
"avg_eta_7d", "std_eta_7d", "trip_count_7d", "avg_speed_7d", "avg_rating_7d",
"avg_eta_30d", "std_eta_30d", "trip_count_30d", "avg_tip_30d", "avg_rating_30d",
"trip_rank", "cumulative_avg_eta",
"prev_eta", "prev_duration", "prev_tip",
"vehicle_type", "years_active",
]

features_final = enriched.select(*feature_cols)

# ================================================================
# STEP 6: WRITE PARTITIONED PARQUET
# ================================================================
features_final = features_final.withColumn(
"feature_date", F.to_date("label_time")
)

(
features_final
.repartition(200, "driver_id") # Co-locate by driver for training reads
.write
.mode("overwrite")
.partitionBy("feature_date")
.parquet(f"s3://feature-store/eta-features/run_date={run_date}/")
)

trips_raw.unpersist()
print(f"Pipeline complete: s3://feature-store/eta-features/run_date={run_date}/")


if __name__ == "__main__":
spark = SparkSession.builder.appName("eta-feature-pipeline").getOrCreate()
build_eta_feature_pipeline(spark, run_date="2024-11-01")

Writing Delta Lake Feature Tables

Delta Lake adds ACID transactions, schema enforcement, and time travel to Parquet files. For ML feature tables, these properties are essential:

  • Atomic writes: a feature refresh either fully commits or leaves the previous version intact - no partial writes that could corrupt a training run started concurrently.
  • Schema enforcement: Delta rejects writes that would silently drop or alter columns, catching upstream pipeline bugs before they reach model training.
  • Time travel: query the feature table as it existed at any past timestamp - essential for reproducing historical training runs or debugging model degradation.
from delta.tables import DeltaTable
from pyspark.sql import functions as F

def write_delta_feature_table(
spark,
features_df,
table_path: str,
merge_key: str = "trip_id",
partition_col: str = "feature_date",
):
"""
Write features to a Delta Lake table.
Uses MERGE (upsert) for idempotent re-runs:
- existing rows with the same merge_key are updated
- new rows are inserted
"""
if DeltaTable.isDeltaTable(spark, table_path):
delta_table = DeltaTable.forPath(spark, table_path)
(
delta_table.alias("existing")
.merge(
features_df.alias("new"),
f"existing.{merge_key} = new.{merge_key}"
)
.whenMatchedUpdateAll()
.whenNotMatchedInsertAll()
.execute()
)
else:
# First write - create the Delta table
(
features_df
.write
.format("delta")
.mode("overwrite")
.option("overwriteSchema", "true")
.partitionBy(partition_col)
.save(table_path)
)

# Vacuum: delete versions older than 7 days to reclaim storage
# (retain 7 days for time travel / debugging window)
DeltaTable.forPath(spark, table_path).vacuum(retentionHours=168)


# Time travel: read features as they existed 30 days ago
# (e.g., to reproduce a historical training run)
features_30d_ago = (
spark.read
.format("delta")
.option("timestampAsOf", "2024-10-01")
.load("s3://feature-store/delta/eta-features/")
)

# Schema enforcement in action:
# If an upstream pipeline drops the "avg_eta_7d" column and tries to write,
# Delta raises an AnalysisException before writing, protecting the table.
# Without Delta, the write silently succeeds and the training job fails later.

Mermaid: ML Feature Pipeline Architecture


Production Engineering Notes

:::tip Cache the base DataFrame before multiple window specs If trips is used as input for three different window specifications (7d, 30d, cumulative), call trips.cache() before the first window computation. Without caching, each window computation re-reads and re-filters the Parquet source - three Parquet scans instead of one. Cache after filtering and type-casting but before the first withColumn(...over(window)). Call trips.unpersist() after all window computations are complete. :::

:::tip Always broadcast the smaller side of a join Any driver or merchant profile table (hundreds of thousands of rows, tens of MB) qualifies for broadcast. Use F.broadcast(small_df) explicitly or raise spark.sql.autoBroadcastJoinThreshold above the table size. Broadcasting eliminates the shuffle on the large side of the join - often the single highest-impact optimization in a feature pipeline. :::

:::warning Null handling in lag features F.lag("value", 1) returns null for the first row in each driver's history. Always handle these nulls explicitly: either fill with a global or window-level average (F.coalesce(lag_col, avg_col)), or fill with a sentinel value that communicates "no previous trip." If nulls propagate unhandled into VectorAssembler, the row is silently dropped (with handleInvalid="skip") or causes a runtime error (with handleInvalid="error"). :::

:::danger collect_list inside window functions on large groups F.collect_list("amount").over(window) collects all values in the window into an array. For an active driver with 200 trips per day over a 30-day window, this creates an array of 6,000 elements per row. Across millions of drivers, the resulting DataFrame is 1000x larger than the input. Use scalar aggregation functions (F.avg, F.stddev, F.count, F.sum) in window functions. If you genuinely need the list, write a separate aggregation job and join the result. :::


Common Mistakes

:::danger Applying a global orderBy before window functions df.orderBy("start_time") before windowing performs a global sort - a full shuffle that sorts all data by a single key. The window function then re-partitions by driver_id and re-sorts within each partition. You have paid for an expensive global sort that the window function completely ignores. Remove the global orderBy. The Window.partitionBy("driver_id").orderBy("start_time") specification handles all necessary ordering internally. :::

:::warning Sharing a single window spec across incompatible aggregations If w7 uses rangeBetween(-7*86400, 0) ordered by unix_timestamp("start_time"), applying F.count("*").over(w7) and F.row_number().over(w7) in the same chain is valid - they share the same shuffle. But if you accidentally apply F.row_number() over a rangeBetween window, you get undefined results because row_number() requires a fully ordered frame. Use the ordering-only cumulative window for rank functions; use the range-based window for time-range aggregations. :::

:::danger Training on features computed over the entire dataset A common leakage pattern: computing global_avg_eta = df.agg(F.avg("actual_eta")).collect()[0][0] and then joining this back as a feature. This value uses all trips - including those after any given label timestamp. In production, you cannot know the future global average. Compute global statistics only on the training split (data before the train cutoff date), never on the full historical dataset before splitting. :::


Interview Q&A

Q: Why are Spark Python UDFs slow, and what is the alternative?

Python UDFs execute in a separate Python worker process from the JVM where Spark runs. For each row, Spark serializes the data (pickle format), sends it via an inter-process socket to the Python worker, runs the function, and serializes the result back. At one billion rows, this is one billion serialization round-trips. The overhead completely dominates the actual computation. The primary alternative is built-in Spark functions (pyspark.sql.functions) - over 200 functions for string, date, math, and conditional operations - which run entirely in the JVM under Tungsten's whole-stage code generation, with no Python overhead. When a built-in does not exist, use a Pandas UDF (vectorized UDF), which sends an entire partition as an Apache Arrow columnar buffer rather than one row at a time. Pandas UDFs reduce the number of serialization calls from millions to hundreds and use the efficient Arrow format, recovering 80–90% of the performance gap.


Q: How do you implement rolling window features in PySpark?

Use the Window class with partitionBy, orderBy, and rangeBetween (for time-based windows) or rowsBetween (for row-count-based windows). For a 7-day rolling average: first cast the timestamp to a Unix timestamp (seconds), then define Window.partitionBy("driver_id").orderBy(F.col("start_time").cast("long")).rangeBetween(-7*86400, 0). Apply aggregation functions over this window: F.avg("amount").over(w7), F.count("*").over(w7), F.stddev("amount").over(w7). For lag features, use the ordering-only cumulative window (no range specification): Window.partitionBy("driver_id").orderBy("start_time"), then F.lag("amount", 1).over(cumulative_window). The key performance consideration: each unique (partitionBy, orderBy) combination is a separate shuffle. Group all window functions that share the same spec into one chain of withColumn calls to pay for the shuffle only once.


Q: What is a Pandas UDF and when should you use it?

A Pandas UDF (vectorized UDF) is a Python function decorated with @pandas_udf(returnType) that receives one or more Pandas Series objects (one per input column) and returns a Pandas Series. It is called once per partition rather than once per row. Data is transferred between JVM and Python using Apache Arrow (a columnar binary format), which is far more efficient than row-by-row pickling. Use Pandas UDFs when: (1) the computation requires a Python library that has no Spark built-in equivalent, such as pygeohash, scipy, or a custom model; (2) the computation is naturally vectorized and can benefit from NumPy operations on the entire partition; (3) you need to load an expensive artifact (a model, a database client) once per partition rather than per row - use the iterator variant. Do not use Pandas UDFs when a built-in Spark function can do the same job. Built-ins always outperform Pandas UDFs because they never cross the JVM-Python boundary.


Q: How do you prevent label leakage in batch feature computation?

Label leakage in batch feature computation occurs when a training example's features include data from after the label's timestamp. Three defenses: (1) Point-in-time joins: when joining label events with feature tables, filter to only feature rows where feature_time <= label_time, then rank candidates by recency and keep the most recent. This is the as-of join pattern. (2) Avoid F.lead() in training features: lead() reads a future row's value - inherently leaky. Only lag() is safe. (3) Never compute global statistics over the full historical dataset and use them as features: statistics computed over all time include future data. Compute them only over data before the training cutoff. The practical test for any feature: "At the moment when this label was generated in production, could the model have accessed this feature value?" If the answer is no, it is leakage.


Q: How would you design a Spark pipeline for 200+ features at 1 billion rows per day?

Start with schema: partition the raw data by date so daily reads only scan the relevant partition, not the entire history. For the pipeline: (1) Project to only the columns needed (select before any transformation - Parquet projection pruning reduces scan size dramatically). (2) Cache the cleaned base DataFrame before the first window computation if it is used as input for multiple window specifications. (3) Group all window functions that share a (partitionBy, orderBy) combination into one chain - each unique combination is a shuffle. With 200 features, you may have 5–10 distinct window specs; structure the computation to minimize the number of distinct specs. (4) Broadcast all small lookup tables under 50 MB. (5) Write output partitioned by feature_date and pre-sorted by entity key for efficient training set reads. Size spark.sql.shuffle.partitions to keep post-shuffle partitions at 128–256 MB. Enable AQE to let Spark coalesce overly small shuffle partitions at runtime. Monitor the Spark UI for executor spill - if spill is occurring during window aggregations, reduce cores per executor to give each core more memory.


Q: What is the Spark MLlib Pipeline pattern and why is it useful for production ML systems?

The MLlib Pipeline composes preprocessing stages - Transformers (stateless, e.g., VectorAssembler) and Estimators (learn parameters from data, e.g., StringIndexer, StandardScaler) - into a single object. Calling pipeline.fit(train_df) fits all Estimators in sequence and returns a PipelineModel where every stage is a fitted Transformer. The fitted PipelineModel can be serialized and loaded at serving time, guaranteeing identical transformations - including learned vocabularies, scaling factors, and bin boundaries - between training and serving. This eliminates the most common cause of training-serving skew: inconsistent preprocessing code. In teams where the training pipeline (Python/Spark) and the serving layer (Java/Scala/REST) are owned by different engineers, a serialized MLlib PipelineModel is a contract: apply this artifact to a raw feature vector and you get the correctly preprocessed input the model expects.

© 2026 EngineersOfAI. All rights reserved.