Skip to main content

:::tip 🎮 Interactive Playground Visualize this concept: Try the Spark Batch Processing demo on the EngineersOfAI Playground - no code required. :::

Batch Processing with Spark for ML Pipelines

Spark is the shovel. The feature engineering logic is the gold mine. How well you use the shovel determines whether you find gold in 4 hours or 40.

The Production Moment

The Spark job had been running for 9 hours. The on-call engineer checked the Spark UI. Stage 47 of 52 had been running for 7 hours. Of the 200 tasks in that stage, 199 had completed in the first 45 minutes. One task was still running.

Classic data skew. One partition - the one containing all the events for a single viral item that had been viewed 40 million times that week - was 300× larger than the average partition. The job would not complete until that one task finished.

This is the Spark engineer's equivalent of trying to fill a bathtub where 199 of the 200 faucets work perfectly and one produces a trickle. The total time is determined by the slowest component.

The engineer who understands Spark can diagnose this in five minutes by looking at the task timing distribution in the Spark UI, apply a salting technique to distribute the skewed key across multiple partitions, and rerun the job - completing in 40 minutes instead of 9 hours. The engineer who doesn't understand Spark waits, or kills and retries, or adds more nodes hoping it helps (it won't).

This lesson is about understanding Spark well enough to never be that engineer.

Why Spark for ML

The question is worth asking explicitly: why Spark and not Python + Pandas?

Scale: Pandas loads data into a single machine's memory. A 16 GB MacBook Pro can handle maybe 6 GB of Pandas data. A single day of events for a large ML system might be 500 GB. Pandas fails. Spark distributes the data across hundreds of machines - each holding a partition - and processes them in parallel.

Integration: Spark integrates with the entire modern data ecosystem. It reads and writes Parquet, Delta Lake, Iceberg, Hudi, Hive, Kafka, S3, GCS, ADLS, JDBC databases. It is the universal processing engine for the data lake.

Optimization: Spark SQL's Catalyst optimizer automatically rewrites query plans for efficiency - pushing filters down to minimize data read, choosing join strategies based on table sizes, merging multiple operations into a single pass. Writing the same logic in Python would require manual optimization.

Ecosystem for ML: PySpark DataFrames integrate directly with MLlib (for distributed ML), TensorFlow Dataset readers (for training PyTorch/TF models on distributed data), and Delta Lake (for transactional writes). It is the standard tool for the "training data preparation" stage of the ML pipeline.

Spark Architecture

Before writing Spark code, you must understand the architecture. This determines how to diagnose failures and optimize performance.

Driver: The "brain" of the application. Runs your PySpark script, builds the logical execution plan, converts it to a physical plan, and distributes tasks to executors. Single point of failure - if the driver dies, the job dies.

Executor: A JVM process running on a worker node. Receives tasks from the driver, reads data from storage, processes partitions, and writes results. Multiple tasks can run concurrently in one executor.

Task: The smallest unit of work. Each task processes one partition of the data. If your data has 200 partitions, Spark has up to 200 tasks per stage.

Stage: A group of tasks that can run without a shuffle. A new stage begins whenever data must be redistributed across executors (shuffle operation).

Shuffle: The most expensive operation in Spark - all executors write their data to disk, and executors read from each other to redistribute data according to the new partitioning key. Triggers: groupBy, join (non-broadcast), repartition, distinct.

DataFrames and Spark SQL for Feature Engineering

The DataFrame API is the primary interface for ML feature engineering. Under the hood, DataFrames use the Catalyst optimizer - which means Spark SQL and the DataFrame API produce the same optimized physical execution plan.

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, LongType, TimestampType

# Initialize Spark session with ML-optimized configuration
spark = SparkSession.builder \
.appName("MLFeatureEngineering") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
.config("spark.sql.adaptive.skewJoin.enabled", "true") \
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
.config("spark.sql.shuffle.partitions", "400") \
.getOrCreate()

# Define schema explicitly - never infer from JSON (slow and error-prone)
events_schema = StructType([
StructField("user_id", StringType(), nullable=False),
StructField("item_id", StringType(), nullable=True),
StructField("event_type", StringType(), nullable=False),
StructField("event_timestamp", TimestampType(), nullable=False),
StructField("session_id", StringType(), nullable=True),
StructField("amount", DoubleType(), nullable=True),
StructField("device_type", StringType(), nullable=True),
])

# Read with explicit schema and partition pruning
# Spark reads only the specified date range - critical for large tables
events = spark.read \
.schema(events_schema) \
.parquet("s3://company-lake/events/") \
.filter("year = 2024 AND month = 11") # partition pruning

print(f"Loaded {events.count():,} events")

# Feature engineering with window functions
# Window spec: for each user, order by time, look back 30 days
user_window_30d = Window \
.partitionBy("user_id") \
.orderBy(F.col("event_timestamp").cast("long")) \
.rangeBetween(
-30 * 24 * 3600, # 30 days in seconds
Window.currentRow
)

user_features = events \
.withColumn("tx_count_30d",
F.count("*").over(user_window_30d)
) \
.withColumn("avg_amount_30d",
F.avg("amount").over(user_window_30d)
) \
.withColumn("sessions_30d",
F.approx_count_distinct("session_id").over(user_window_30d)
) \
.withColumn("hour_of_day", F.hour("event_timestamp")) \
.withColumn("day_of_week", F.dayofweek("event_timestamp")) \
.withColumn("is_weekend",
F.when(F.dayofweek("event_timestamp").isin(1, 7), 1).otherwise(0)
)

# Get the latest feature snapshot per user (one row per user)
user_window_latest = Window.partitionBy("user_id").orderBy(F.desc("event_timestamp"))
latest_features = user_features \
.withColumn("row_num", F.row_number().over(user_window_latest)) \
.filter("row_num = 1") \
.drop("row_num")

Partitioning and Parallelism

Partitioning is the single most important configuration decision in a Spark job. Too few partitions: some cores sit idle. Too many: task scheduling overhead dominates.

Rule of thumb: target partition sizes of 128–256 MB. For a 100 GB dataset, aim for 400–800 partitions.

# Check current partition statistics
def analyze_partitions(df, sample_fraction: float = 0.1):
"""Analyze partition size distribution to detect skew and inefficiency."""
# Sample-based partition statistics
partition_counts = df.rdd.mapPartitions(lambda it: [sum(1 for _ in it)]).collect()

import statistics
avg_count = statistics.mean(partition_counts)
median_count = statistics.median(partition_counts)
max_count = max(partition_counts)
min_count = min(partition_counts)

print(f"Number of partitions: {len(partition_counts)}")
print(f"Average rows per partition: {avg_count:,.0f}")
print(f"Median rows per partition: {median_count:,.0f}")
print(f"Max rows per partition: {max_count:,.0f}")
print(f"Min rows per partition: {min_count:,.0f}")
print(f"Skew ratio (max/avg): {max_count/avg_count:.1f}x")

if max_count / avg_count > 5:
print("WARNING: Significant data skew detected!")
print("Consider: salting, bucket joins, or AQE skew join optimization")

# Repartition strategies for ML workloads
# Option 1: repartition by number (round-robin, even distribution)
balanced_df = df.repartition(400)

# Option 2: repartition by column (groups related data, efficient for joins)
# Use when you'll join on this column - puts matching keys on same executor
user_partitioned = df.repartition(200, "user_id")

# Option 3: repartition by range (sorted, useful for ordered processing)
time_partitioned = df.repartitionByRange(100, "event_timestamp")

# Option 4: coalesce (reduce partitions without full shuffle - use when reducing)
small_result = large_df.coalesce(10) # Avoid shuffle when reducing partitions

Adaptive Query Execution (AQE)

Spark 3.0+ introduced Adaptive Query Execution, which automatically adjusts the execution plan at runtime based on actual data statistics. Critical for ML workloads with variable data distributions:

# Enable AQE - strongly recommended for production ML jobs
spark = SparkSession.builder \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
.config("spark.sql.adaptive.coalescePartitions.minPartitionNum", "1") \
.config("spark.sql.adaptive.coalescePartitions.initialPartitionNum", "400") \
.config("spark.sql.adaptive.skewJoin.enabled", "true") \
.config("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5") \
.config("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256m") \
.getOrCreate()

# With AQE enabled:
# - Coalesce: if result has many empty/small partitions, AQE merges them
# - SkewJoin: automatically detects and handles skewed partitions in joins
# - Stats: uses runtime statistics (actual row counts) instead of estimates

Caching and Persistence

Caching intermediate DataFrames avoids recomputation when the same data is used multiple times (e.g., computing multiple feature sets from the same filtered dataset).

from pyspark import StorageLevel

# Cache levels - choose based on memory vs. compute trade-off
# MEMORY_ONLY: fastest, but fails if data doesn't fit in RAM
# MEMORY_AND_DISK: safer, spills to disk when RAM full
# DISK_ONLY: avoids OOM at the cost of speed

# Best practice for ML pipelines:
filtered_events = events \
.filter("event_timestamp >= '2024-10-01'") \
.filter("event_type IN ('click', 'purchase', 'view')")

# Cache after expensive filter - both feature sets reuse this
filtered_events.persist(StorageLevel.MEMORY_AND_DISK)
filtered_events.count() # Trigger computation and cache

# Now compute multiple feature sets from the cached DataFrame
user_features = filtered_events.groupBy("user_id").agg(
F.count("*").alias("total_events"),
F.sum(F.when(F.col("event_type") == "purchase", 1).otherwise(0)).alias("purchase_count"),
F.countDistinct("session_id").alias("session_count")
)

item_features = filtered_events.groupBy("item_id").agg(
F.count("*").alias("view_count"),
F.countDistinct("user_id").alias("unique_users"),
F.avg("amount").alias("avg_purchase_price")
)

# Don't forget to unpersist when done - free up cluster memory
filtered_events.unpersist()

# Anti-pattern: don't cache everything
# Only cache DataFrames that are used multiple times in the same job
# Caching everything wastes memory and can cause spills that slow down the job

Join Strategies: The Critical Performance Decision

Joins are the most expensive operation in feature engineering (they trigger shuffles), and choosing the wrong join strategy is the most common performance mistake.

Broadcast Join: The Secret Weapon

If one side of a join fits in memory on each executor (typically under 200 MB), Spark can broadcast it - sending a copy to every executor and avoiding the shuffle entirely. A broadcast join can be 10–100× faster than a shuffle join.

from pyspark.sql.functions import broadcast

# Scenario: join 500 GB events table with 50 MB item metadata table
# WITHOUT broadcast: Spark shuffles BOTH tables by item_id (expensive!)
wrong_way = events.join(item_metadata, on="item_id", how="left")

# WITH broadcast: item_metadata is sent to every executor, no shuffle
fast_way = events.join(broadcast(item_metadata), on="item_id", how="left")
# Time difference: 2 hours vs. 8 minutes

# Also works with explicit hint
also_fast = events.join(item_metadata.hint("broadcast"), on="item_id", how="left")

# Set the auto-broadcast threshold (default: 10 MB, often too conservative)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "200m") # 200 MB

When to broadcast: Any small lookup table - item metadata, user demographics, product categories, currency conversion rates, geo lookup. Rule: if the table is less than a few hundred MB and used repeatedly, broadcast it.

When not to broadcast: If the "small" table is actually 5 GB, broadcasting it means 5 GB × N_executors of memory overhead. If you have 100 executors, that's 500 GB of memory for one table.

Shuffle Hash Join vs Sort-Merge Join

For large-large joins (both tables too big to broadcast):

# Sort-Merge Join (default for large joins): both sides sorted and merged
# Good when: join key is already sorted or data is pre-bucketed
# Bad when: data is highly skewed

# For highly skewed data, use bucket joins (if data is pre-bucketed on the same key)
# or use AQE's skew join optimization

# Pre-bucketing: saves shuffle time for repeated joins on the same key
# Useful when you join the same large table many times (e.g., user features)
events.write \
.bucketBy(200, "user_id") \ # 200 buckets by user_id
.sortBy("user_id") \
.saveAsTable("bucketed_events")

user_features.write \
.bucketBy(200, "user_id") \ # Must match events buckets
.sortBy("user_id") \
.saveAsTable("bucketed_user_features")

# Now this join avoids the sort-merge shuffle entirely!
result = spark.table("bucketed_events").join(
spark.table("bucketed_user_features"),
on="user_id"
)

Writing Efficient Spark Jobs: Production Patterns

Pattern 1: Avoid Wide Transformations (Shuffles) Where Possible

# AVOID: multiple separate groupBy operations (each triggers a shuffle)
feature1 = events.groupBy("user_id").agg(F.count("*").alias("event_count"))
feature2 = events.filter("event_type = 'purchase'").groupBy("user_id").agg(
F.sum("amount").alias("total_spend")
)
# Joining these triggers ANOTHER shuffle
combined = feature1.join(feature2, on="user_id") # 3 shuffles total

# PREFER: single pass with conditional aggregation
combined_efficient = events.groupBy("user_id").agg(
F.count("*").alias("event_count"),
F.sum(F.when(F.col("event_type") == "purchase", F.col("amount"))).alias("total_spend"),
F.count(F.when(F.col("event_type") == "purchase", 1)).alias("purchase_count")
)
# 1 shuffle total

Pattern 2: Filter Early, Filter Hard

# WRONG: filter after expensive operations
result = events.join(user_features, on="user_id") \
.filter("event_timestamp >= '2024-11-01'") # filter AFTER join

# RIGHT: filter before the join to reduce join input size
filtered_events = events.filter("event_timestamp >= '2024-11-01'")
result = filtered_events.join(user_features, on="user_id") # smaller join

# Even better: push filters into the read operation (partition pruning)
result = spark.read \
.parquet("s3://lake/events/") \
.filter("year = 2024 AND month = 11") \ # partition pruning at read time
.join(user_features, on="user_id")

Pattern 3: Control Output File Size

def write_training_data(df, output_path: str, target_file_size_mb: int = 256):
"""Write training data with optimal file sizes for downstream ML reads."""
# Estimate number of output partitions based on data size
# (requires approximate row count - use count() or estimate from sample)
row_count = df.count()

# Rough bytes per row estimate - refine based on your schema
bytes_per_row = 500
total_mb = row_count * bytes_per_row / (1024 * 1024)
n_partitions = max(1, int(total_mb / target_file_size_mb))

print(f"Estimated size: {total_mb:.0f} MB, writing with {n_partitions} partitions")

df.repartition(n_partitions) \
.write \
.mode("overwrite") \
.option("compression", "zstd") \
.parquet(output_path)

Integration with Delta Lake for Versioned Training Data

from delta.tables import DeltaTable

def write_features_to_delta(df, delta_path: str, feature_date: str):
"""
Write features to Delta Lake with full ACID guarantees.
Supports time travel for reproducible training data.
"""
# Write with Delta format - adds transaction log
df.withColumn("feature_date", F.lit(feature_date)) \
.write \
.format("delta") \
.mode("overwrite") \
.option("replaceWhere", f"feature_date = '{feature_date}'") \
.save(delta_path)

print(f"Features written to {delta_path} for date {feature_date}")

def create_training_dataset_with_lineage(
events_path: str,
features_path: str,
training_start: str,
training_end: str,
output_path: str
) -> str:
"""
Create a training dataset with point-in-time correct features.
Returns the Delta table version for reproducibility tracking.
"""
# Read raw events
events = spark.read.format("delta").load(events_path) \
.filter(f"event_date BETWEEN '{training_start}' AND '{training_end}'")

# Read features - using Delta time travel to get historical snapshots
features = spark.read.format("delta") \
.option("timestampAsOf", training_end) \
.load(features_path)

# Point-in-time join (as described in Module 1)
window = Window.partitionBy("user_id").orderBy(F.desc("feature_date"))
pit_features = features \
.join(events.select("user_id", "event_timestamp").distinct(), on="user_id") \
.filter(F.col("feature_date") <= F.col("event_timestamp").cast("date")) \
.withColumn("rank", F.row_number().over(window)) \
.filter("rank = 1") \
.drop("rank")

# Write training dataset
training_data = events.join(pit_features, on="user_id", how="left")
training_data.write \
.format("delta") \
.mode("overwrite") \
.save(output_path)

# Return Delta version for tracking in MLflow
delta_table = DeltaTable.forPath(spark, output_path)
version = delta_table.history(1).select("version").collect()[0][0]
return f"delta_version_{version}"

A Full Example: Building a 1 TB Training Dataset from Raw Logs

def build_recommendation_training_data(
events_path: str, # s3://lake/events/ - 5 TB total
items_path: str, # s3://lake/items/ - 500 MB
output_path: str,
training_date_range: tuple[str, str] # ("2024-08-01", "2024-11-01")
) -> None:
"""
Full ML training data pipeline:
1. Load and filter raw events (partition pruning reduces 5 TB to ~1 TB)
2. Compute user features using window functions
3. Join with item metadata (broadcast join - 500 MB fits in memory)
4. Create positive/negative samples for training
5. Write in Parquet format with optimal partitioning
"""
start_date, end_date = training_date_range

# Step 1: Load events with partition pruning
# Converts date range to year/month/day filter Spark can push down
events = spark.read \
.schema(events_schema) \
.parquet(events_path) \
.filter(
(F.col("event_date") >= start_date) &
(F.col("event_date") <= end_date)
)
events.persist(StorageLevel.MEMORY_AND_DISK)

# Step 2: Compute user features
user_window = Window.partitionBy("user_id").orderBy(F.col("event_timestamp").cast("long")) \
.rangeBetween(-30 * 24 * 3600, Window.currentRow)

user_features = events \
.withColumn("view_count_30d", F.count("*").over(user_window)) \
.withColumn("purchase_count_30d",
F.sum(F.when(F.col("event_type") == "purchase", 1).otherwise(0)).over(user_window)
) \
.withColumn("unique_items_30d", F.approx_count_distinct("item_id").over(user_window))

# Step 3: Load and broadcast item metadata
item_metadata = spark.read.parquet(items_path) # 500 MB - safe to broadcast
item_features = item_metadata.select(
"item_id", "category", "price_tier", "avg_rating", "age_days"
)

# Broadcast join: no shuffle for 500 MB table
user_item_features = user_features.join(
broadcast(item_features),
on="item_id",
how="left"
)

# Step 4: Create training labels
# Positive: actual interactions; Negative: random non-interactions
positives = user_item_features \
.filter("event_type = 'purchase'") \
.withColumn("label", F.lit(1))

# Sample negatives (items the user viewed but didn't purchase)
negatives = user_item_features \
.filter("event_type = 'view'") \
.filter(~F.col("user_id").isin(
positives.select("user_id").distinct().rdd.flatMap(lambda x: x).collect()
)) \
.sample(fraction=0.1) \
.withColumn("label", F.lit(0))

training_data = positives.union(negatives)

# Step 5: Write optimized output
feature_columns = [
"user_id", "item_id", "label",
"view_count_30d", "purchase_count_30d", "unique_items_30d",
"category", "price_tier", "avg_rating", "age_days",
"hour_of_day", "day_of_week", "is_weekend"
]

training_data.select(feature_columns) \
.repartition(2000) \ # ~500 MB chunks for 1 TB output
.write \
.mode("overwrite") \
.option("compression", "zstd") \
.parquet(output_path)

events.unpersist()
print(f"Training data written to {output_path}")

Monitoring and Debugging ML Spark Jobs

Production Spark jobs for ML require observability built-in from the start. The three most important failure modes are: jobs that run indefinitely (data skew), jobs that OOM executors (memory pressure), and jobs that produce incorrect results silently (logical bugs in feature computation).

import time
from dataclasses import dataclass
from pyspark.sql import SparkSession, DataFrame

@dataclass
class SparkJobMetrics:
job_name: str
start_time: float
end_time: float
input_rows: int
output_rows: int
stages_count: int
shuffle_bytes_written_gb: float
peak_executor_memory_gb: float
status: str
error: str = ""

class MLSparkJobRunner:
"""
Wrapper for running ML Spark jobs with monitoring, retries, and lineage tracking.
"""
def __init__(self, spark: SparkSession, job_name: str, metrics_emitter=None):
self.spark = spark
self.job_name = job_name
self.metrics = metrics_emitter

def run_with_monitoring(self, job_fn, input_df: DataFrame,
output_path: str, **kwargs) -> SparkJobMetrics:
"""Run a Spark job function and collect execution metrics."""
start = time.time()
input_rows = 0

try:
# Count input (triggers a Spark action - consider sampling for large datasets)
input_rows = input_df.count()
print(f"[{self.job_name}] Processing {input_rows:,} input rows")

# Run the actual transformation
result_df = job_fn(input_df, **kwargs)

# Write output
result_df.write.mode("overwrite").parquet(output_path)

# Count output rows for validation
output_rows = self.spark.read.parquet(output_path).count()

end = time.time()
metrics = SparkJobMetrics(
job_name=self.job_name,
start_time=start,
end_time=end,
input_rows=input_rows,
output_rows=output_rows,
stages_count=len(self.spark.sparkContext.statusTracker().getActiveJobIds()),
shuffle_bytes_written_gb=0.0, # would read from Spark metrics
peak_executor_memory_gb=0.0,
status="success"
)

duration_min = (end - start) / 60
print(f"[{self.job_name}] Completed in {duration_min:.1f} min. "
f"Output: {output_rows:,} rows ({output_rows/input_rows:.1%} retention)")

if self.metrics:
self.metrics.emit("spark_job_duration_minutes", duration_min,
tags={"job": self.job_name})
self.metrics.emit("spark_job_output_rows", output_rows,
tags={"job": self.job_name})

return metrics

except Exception as e:
end = time.time()
print(f"[{self.job_name}] FAILED after {(end-start)/60:.1f} min: {e}")
if self.metrics:
self.metrics.emit("spark_job_failure", 1,
tags={"job": self.job_name, "error": type(e).__name__})
return SparkJobMetrics(
job_name=self.job_name, start_time=start, end_time=end,
input_rows=input_rows, output_rows=0, stages_count=0,
shuffle_bytes_written_gb=0, peak_executor_memory_gb=0,
status="failed", error=str(e)
)


def validate_feature_output(df: DataFrame, expected_entity_count: int,
required_columns: list[str],
null_threshold: float = 0.01) -> bool:
"""
Validate Spark output before writing to the feature store.
Returns True if validation passes, raises ValueError if it fails.
"""
# Check required columns exist
missing_cols = [c for c in required_columns if c not in df.columns]
if missing_cols:
raise ValueError(f"Missing required columns: {missing_cols}")

actual_count = df.count()
# Allow 10% variance from expected entity count (some users may have no activity)
if actual_count < expected_entity_count * 0.9:
raise ValueError(
f"Output has {actual_count:,} rows, expected ~{expected_entity_count:,}. "
f"Possible data pipeline failure upstream."
)

# Check null rates on critical feature columns
from pyspark.sql import functions as F
null_counts = df.select([
(F.sum(F.col(c).isNull().cast("int")) / F.count("*")).alias(c)
for c in required_columns
]).collect()[0].asDict()

high_null_columns = {col: rate for col, rate in null_counts.items()
if rate > null_threshold}
if high_null_columns:
raise ValueError(
f"High null rates in critical columns: {high_null_columns}. "
f"Check upstream data quality."
)

print(f"Validation passed: {actual_count:,} rows, all columns present, null rates acceptable")
return True

Orchestrating Multi-Step ML Feature Pipelines with Airflow

Production ML Spark jobs don't run in isolation - they are part of DAGs with dependencies, retries, and SLA monitoring.

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator
from datetime import datetime, timedelta

default_args = {
"owner": "ml-platform",
"depends_on_past": False,
"start_date": datetime(2024, 1, 1),
"email_on_failure": True,
"email": ["[email protected]"],
"retries": 2,
"retry_delay": timedelta(minutes=15),
"sla": timedelta(hours=4), # Alert if this DAG takes > 4 hours
}

with DAG(
"ml_feature_daily_pipeline",
default_args=default_args,
schedule_interval="0 2 * * *", # 2 AM daily, after upstream data lands
catchup=False,
tags=["ml", "features", "critical"],
) as dag:

def validate_upstream_complete(**context):
"""Check that bronze data has landed before starting feature computation."""
from pyspark.sql import SparkSession
ds = context["ds"] # execution date: "2024-11-15"

spark = SparkSession.builder.getOrCreate()
bronze_path = f"s3://company-lake/bronze/events/date={ds}/"

try:
count = spark.read.parquet(bronze_path).count()
min_expected_rows = 1_000_000 # fail if < 1M events (data pipeline issue)
if count < min_expected_rows:
raise ValueError(
f"Bronze data for {ds} has only {count:,} rows "
f"(expected >= {min_expected_rows:,}). Upstream may have failed."
)
print(f"Upstream validation passed: {count:,} rows for {ds}")
except Exception as e:
raise RuntimeError(f"Bronze data validation failed: {e}")

validate_upstream = PythonOperator(
task_id="validate_upstream_data",
python_callable=validate_upstream_complete,
)

compute_features = SparkSubmitOperator(
task_id="compute_user_features",
application="s3://company-code/spark/user_features_daily.py",
conf={
"spark.sql.adaptive.enabled": "true",
"spark.sql.shuffle.partitions": "800",
"spark.executor.memory": "16g",
"spark.executor.cores": "4",
"spark.dynamicAllocation.enabled": "true",
"spark.dynamicAllocation.maxExecutors": "50",
},
application_args=["--date", "{{ ds }}"],
conn_id="spark_default",
)

def validate_feature_output_task(**context):
"""Verify the feature output before marking the pipeline as successful."""
from pyspark.sql import SparkSession
ds = context["ds"]

spark = SparkSession.builder.getOrCreate()
output_path = f"s3://company-lake/gold/user_features/feature_date={ds}/"

feature_df = spark.read.parquet(output_path)
validate_feature_output(
df=feature_df,
expected_entity_count=5_000_000, # ~5M active users
required_columns=["user_id", "purchase_count_30d", "sessions_30d", "unique_items_viewed_30d"],
null_threshold=0.02,
)

validate_output = PythonOperator(
task_id="validate_feature_output",
python_callable=validate_feature_output_task,
)

validate_upstream >> compute_features >> validate_output

Common Mistakes

:::danger Setting spark.sql.shuffle.partitions Too Low The default spark.sql.shuffle.partitions is 200 - fine for small datasets, catastrophically slow for large ones. For a 1 TB dataset, you want 2,000–4,000 shuffle partitions. With 200, each partition is 5 GB - Spark tasks will spill to disk constantly, grinding the job to a halt. Set this based on your data size: target 128–256 MB per partition after shuffle. :::

:::danger Using Python UDFs for Performance-Critical Operations Python UDFs in PySpark serialize data to Python, run the function, and serialize back. This can be 10–100× slower than equivalent Spark SQL or built-in functions. Use pyspark.sql.functions (which uses Catalyst-optimized JVM code) wherever possible. If you must use Python, use Pandas UDFs (@pandas_udf) which use Arrow serialization and are 5–10× faster than regular Python UDFs. :::

:::warning Not Caching Strategically Over-caching (caching everything "just in case") wastes executor memory and can cause spills that hurt overall job performance. Under-caching (never caching) forces recomputation of expensive operations. Cache only DataFrames that are used multiple times in the same job, and always unpersist when done. :::

:::warning Collecting Large DataFrames to the Driver df.collect() pulls all data to the driver's memory. On a 100 GB DataFrame, this will OOM the driver immediately. Use df.take(N) for sampling, df.show() for inspection, and df.write to persist results. If you need Python-based processing on the full dataset, use Pandas UDFs to keep processing distributed. :::

Interview Q&A

Q1: Explain the Spark execution model: what are stages, tasks, and shuffles?

A Spark job is broken into stages. A stage is a sequence of transformations that can run without exchanging data between executors (no shuffle). When Spark needs to redistribute data - for a groupBy, join, or repartition - it starts a new stage. The stage boundary is the shuffle.

Within a stage, each partition of the data becomes one task. Tasks run in parallel across executors. If you have 400 partitions and 100 executor cores, Spark runs 100 tasks simultaneously (4 waves of 100 tasks).

A shuffle writes data to disk on each executor and reads from all executors. It is expensive: disk writes, network transfer, disk reads. Minimizing shuffles - through predicate pushdown, broadcast joins, and smart partitioning - is the primary optimization strategy for Spark ML pipelines.

Q2: What is a broadcast join and when should you use it?

A broadcast join sends a copy of the smaller table to every executor, avoiding the shuffle that would otherwise be required for a join. If table A has 500 GB and table B has 50 MB, Spark normally shuffles both tables by join key. With a broadcast join, Spark sends table B (50 MB × number of executors) to all executors and each executor performs the join locally against its own partitions of table A.

Use broadcast join when: the smaller table is under 200–500 MB (fits comfortably in executor memory), the join is used repeatedly (each executor only receives the broadcast once per job), and the data is lookup-style (item metadata, user demographics, geo lookups).

Do not broadcast when: the "small" table is actually several GB, the join produces a very large output (cartesian product), or you're already memory-constrained (broadcast adds N_executors × table_size memory overhead).

Q3: How do you diagnose and fix data skew in Spark?

Diagnosis: In the Spark UI, go to the Stages tab. Find the stage that's taking unusually long. Click it and look at the Task Duration distribution. If 99% of tasks complete in 2 minutes and 1 task takes 2 hours, you have a skewed partition. The "Tasks" section shows min/median/max duration - a max/median ratio greater than 5× signals skew.

Fixes in order of preference: (1) Enable AQE with spark.sql.adaptive.skewJoin.enabled=true - Spark 3.0+ will automatically detect and split skewed partitions during execution. (2) Key salting: if the skew comes from a hot key in a join (e.g., a viral item_id), add a random salt to the hot key and explode the small side to match. (3) Custom partitioning: repartition the data such that the skewed key is split across multiple partitions. (4) Separate processing: handle the hot key separately using a filter, process it independently, then union back with the results for normal keys.

Q4: What is the difference between repartition and coalesce?

repartition(N) redistributes data evenly across N partitions using a full shuffle. Use when: increasing partition count, or when you need evenly balanced partitions (after a filter that created empty/tiny partitions). Expensive - triggers a shuffle.

coalesce(N) combines partitions by moving data from multiple partitions to fewer executors without a full shuffle. Use when: reducing partition count after filtering (e.g., filtered from 1B rows to 10M rows, now have 200 huge partitions but want 20). Cheaper than repartition. But: can create uneven partitions if input partitions are uneven.

Rule of thumb: use coalesce to reduce, repartition to increase or when you need balanced output.

Q5: How would you build a PySpark pipeline to generate 1 TB of training data from raw event logs efficiently?

Key decisions:

  1. Partition pruning: Read only the date range needed. Use Hive-style partitioning (year/month/day) and filter on partition columns so Spark only reads relevant files.

  2. Schema specification: Always specify schemas explicitly - never infer from JSON. Schema inference requires Spark to read all files once to determine types, doubling your read cost.

  3. Broadcast joins: Load all small reference tables (item metadata, user demographics) and broadcast them before any joins with the large events table. This eliminates the most expensive shuffles.

  4. Single-pass aggregation: Compute all features in one groupBy pass using conditional aggregation (F.sum(F.when(...))) instead of multiple separate groupBy operations.

  5. Caching: After the initial filter/read (which is expensive), persist the filtered DataFrame if it's used multiple times. Unpersist when done.

  6. Output tuning: Write with repartition(N) where N is chosen to produce 256 MB files. Use ZSTD compression. Write to Delta Lake for ACID guarantees and time travel capability.

  7. Enable AQE: spark.sql.adaptive.enabled=true handles runtime skew, empty partition coalescing, and join strategy selection automatically.

Summary

Spark is the standard batch processing engine for ML feature engineering at scale because it combines distributed computing with a rich SQL and Python API, integrates with the entire data ecosystem, and provides automatic optimization through the Catalyst query planner and Adaptive Query Execution.

The performance principles: understand shuffles and minimize them, use broadcast joins for small tables, partition correctly (128–256 MB per partition), cache strategically, filter early. The common mistakes: too few shuffle partitions, Python UDFs instead of built-in functions, collecting large DataFrames to the driver, and ignoring data skew.

Master these patterns and you can process a terabyte of training data in minutes instead of hours. Get them wrong and you get the 9-hour job with one stuck task from the opening story.

:::tip Key Takeaway The single most valuable Spark optimization habit: look at the Spark UI after every job. The Stage timeline shows where time is spent. The Task metrics show data skew. The SQL tab shows the physical execution plan. Ten minutes of Spark UI analysis can identify 80% of performance problems that would otherwise take days of code debugging to find. :::

© 2026 EngineersOfAI. All rights reserved.