:::tip 🎮 Interactive Playground Visualize this concept: Try the Spark Batch Processing demo on the EngineersOfAI Playground - no code required. :::
Apache Spark Architecture
The Nightly Job That Couldn't Finish
Marcus is a data engineer at a fintech company. His team owns the feature pipeline for the fraud detection model - the model that decides, in real time, whether to approve or decline a transaction. The model is good. The features it depends on are not the problem. The problem is computing those features every night before the model is retrained.
The features are computed from transaction history: rolling 7-day and 30-day aggregations per user, per merchant category, per geographic region. Things like "how many transactions did this user make in the last 7 days," "what was the average transaction amount for this merchant category in the last 30 days," and "what percentage of this user's transactions in the last month were declined." These are standard fraud features. The underlying data: 500 gigabytes of raw transaction records, updated nightly, covering the last 90 days of activity across 12 million active users.
Marcus's first attempt uses Pandas. He loads the data, applies groupby aggregations, computes rolling windows with .rolling(). It works perfectly on a 5% sample. When he points it at the full dataset, the process runs for 20 minutes, hits the 64 GB RAM limit on the computation server, and dies with a MemoryError. He doubles the RAM allocation. 45 minutes, MemoryError. He tries Dask, which distributes the computation across multiple processes. Better - it doesn't crash - but after 3 hours it still hasn't finished, and the nightly training window is 4 hours long. If feature computation takes 3 hours, the model retraining has an hour, which isn't enough.
A senior engineer on the team points Marcus to Spark. The same pipeline, rewritten in PySpark, runs in 18 minutes on a 10-node cluster. The reason is not that Spark is just "faster" - it's that Spark's architecture is designed from the ground up for this exact problem: distributed, in-memory computation over datasets that don't fit on a single machine. Understanding why Spark works requires understanding how it works. That's what this lesson covers.
This is not a lesson you should rush through. Spark's architecture is the foundation for everything in the next seven lessons of this module. Engineers who understand it deeply can diagnose failures, tune performance, and design pipelines that scale. Engineers who treat it as a black box spend weeks chasing mysterious slowdowns and out-of-memory errors.
Why This Exists
The Problem with Single-Node Processing
A single machine has two hard constraints: CPU cores and RAM. Pandas is single-threaded by default, meaning it uses one core. Even with parallelism via Dask or multiprocessing, you are still bounded by the RAM of one machine. At 500 GB of data, a machine with 128 GB of RAM cannot hold the dataset in memory. It spills to disk. Disk I/O is 100–1000x slower than memory access. The computation becomes disk-bound, not compute-bound.
The only solution is to distribute: break the data into chunks, process each chunk on a separate machine, and combine the results. But distributing computation introduces a new set of problems. How do you split the data? How do machines coordinate? What happens if one machine fails mid-computation? How do you handle computations that require combining data from multiple partitions (like a join or a global sort)? These are hard engineering problems. Spark solves all of them.
The Problem with MapReduce
Before Spark, the dominant distributed computing framework was Hadoop MapReduce (2004, Google's MapReduce paper). MapReduce worked: it could process petabytes of data across thousands of machines. But it had a fundamental performance problem for iterative algorithms - exactly the kind of computation ML requires.
In MapReduce, every stage of computation writes its output to HDFS (Hadoop Distributed File System) on disk. The next stage reads from disk. For a pipeline with 10 stages, data moves between disk and memory 10 times. Each disk write and read is slow. For batch ETL jobs that run once, this is acceptable. For iterative ML algorithms - training a linear model for 100 iterations, each iteration requiring a pass over the entire dataset - MapReduce is catastrophically slow. An algorithm that needs 100 iterations does 200 disk I/O operations per data partition. On large datasets, training time measured in days.
Matei Zaharia's 2012 Berkeley AMPLab paper (the paper that became Spark) had a single key insight: keep data in memory across iterations. Instead of writing intermediate results to disk, keep them in the cluster's RAM. For iterative algorithms, this produces 10–100x speedups over MapReduce. Spark was designed around this insight, and by 2014 it had become the dominant batch processing framework for data engineering and ML.
Core Concept: How Spark Actually Works
RDDs: The Foundation
The fundamental abstraction in Spark is the Resilient Distributed Dataset (RDD). An RDD is a collection of elements partitioned across the nodes of the cluster, which can be operated on in parallel.
Three words in that definition matter:
Resilient - RDDs can be recomputed if a partition is lost due to node failure. Spark tracks the computation history that produced each RDD (called the lineage), so if a partition is lost, Spark reruns only the steps needed to recompute that specific partition. This is Spark's fault tolerance mechanism.
Distributed - The data physically lives on multiple machines. Each machine holds a subset of the total data (a partition). Spark runs operations on all partitions in parallel.
Dataset - It is a collection of records. Each partition is an ordered collection of records. The records can be any Python/Scala/Java objects (in RDD API) or structured rows (in DataFrame API).
The key property of RDDs is the distinction between transformations and actions:
- Transformations are lazy. They define what computation will happen but do not execute anything.
map(),filter(),groupByKey(),join()are all transformations. Calling them returns a new RDD representing the computation, but no data moves. - Actions trigger execution.
collect(),count(),save(),take()are actions. When you call an action, Spark looks at the chain of transformations that produced the RDD, builds an execution plan (the DAG), optimizes it, and executes it.
This lazy evaluation is critical for optimization. Spark can see the entire computation graph before executing, which means it can reorder, combine, and optimize operations in ways that eager execution cannot.
The DAG Execution Model
When you call an action, Spark builds a Directed Acyclic Graph (DAG) of the computation. The DAG represents every transformation from the source data to the final result.
Spark then divides this DAG into stages. A stage boundary occurs wherever a shuffle is required - an operation that requires redistributing data across partitions (more on this shortly). Within a stage, all operations can be performed on a partition without data leaving that machine. Across stage boundaries, data must move between machines.
Each stage is divided into tasks - one task per partition. Tasks are the unit of work sent to executors. If you have 200 partitions and 4 executors with 10 cores each (40 total cores), Spark will run 40 tasks concurrently, cycling through all 200 tasks in 5 rounds.
Driver vs Executor Architecture
Spark has a master-worker architecture:
The Driver is the process running your Python script. It contains the SparkSession, constructs the DAG, communicates with the cluster manager, schedules tasks on executors, and collects results. The driver is the brain. It does not process data - it orchestrates the workers that do.
The Cluster Manager (YARN, Kubernetes, Mesos, or Spark Standalone) manages cluster resources. When the driver needs executors, it requests them from the cluster manager. The cluster manager allocates nodes and starts executor processes.
Executors are JVM processes running on worker nodes. Each executor has a fixed number of CPU cores and a fixed amount of RAM (configured at session startup). Executors receive task assignments from the driver, execute the tasks (processing their assigned partitions), and report results back. Executors are long-lived - they persist for the entire duration of the application, holding cached data and intermediate results in memory.
DataFrame API vs RDD API
Spark originally exposed only the RDD API. In 2015, Spark 1.3 introduced DataFrames (and later Datasets in Scala/Java). The DataFrame API is now the recommended way to write Spark code. Here is why:
The RDD API is flexible - any Python function can be applied to any RDD. But this flexibility has a cost: Spark has no visibility into what your function does. It cannot reorder operations, merge transformations, or push predicates down to the data source. Every Python UDF applied to an RDD crosses the JVM-Python boundary, incurring serialization overhead.
The DataFrame API expresses operations in terms of structured columns and SQL-like operations. Spark can inspect these operations. The Catalyst optimizer (Spark's query planner) can then apply a battery of logical and physical optimizations before any data moves. This is why the same computation expressed in the DataFrame API is frequently 3–10x faster than the equivalent RDD code.
For ML feature pipelines, always use the DataFrame API. Reserve the RDD API for cases where you need full control over partitioning or are doing custom serialization.
The Catalyst Optimizer
Catalyst is the query optimizer that transforms your DataFrame operations into an efficient execution plan. It operates in four phases:
Phase 1 - Unresolved Logical Plan: Spark parses your DataFrame operations into a tree of logical operators. Column names are not yet resolved to actual types.
Phase 2 - Analyzed Logical Plan: Spark resolves column references against the schema catalog. If you reference a column that doesn't exist, you get an AnalysisException here, before any data moves.
Phase 3 - Optimized Logical Plan: Catalyst applies rule-based optimizations to the logical plan. Key optimizations include:
- Predicate pushdown: Move filter conditions as early in the plan as possible - ideally down to the data source, so Parquet files can skip irrelevant row groups.
- Projection pruning: Drop columns not needed for the final result as early as possible, reducing data movement.
- Constant folding: Evaluate constant expressions at planning time (e.g.,
WHERE year > 2020 + 1becomesWHERE year > 2021). - Join reordering: Move smaller tables earlier in join chains.
Phase 4 - Physical Plan: Catalyst selects physical execution strategies (e.g., SortMergeJoin vs BroadcastHashJoin for joins), applies cost-based optimization (CBO) if statistics are available, and generates the final execution plan.
The physical plan is then passed to the Tungsten engine for code generation and execution.
Tungsten: Off-Heap Memory and Code Generation
Tungsten is Spark's execution engine, introduced in Spark 1.5. It addresses two performance bottlenecks that the JVM introduces:
Memory management: The JVM's garbage collector adds unpredictable latency and overhead when operating on large object graphs. Tungsten bypasses the GC by allocating memory directly in off-heap binary format. Data is stored as raw bytes, not as Java objects. This reduces GC pressure dramatically and enables memory-efficient binary sorting and hashing.
Whole-stage code generation (WSCG): The traditional Volcano model of query execution involves a virtual function call per row per operator - a significant overhead for billions of rows. Tungsten's WSCG fuses multiple operators into a single compiled function. Instead of row → filter → project → aggregate (three function calls per row), Tungsten generates a single tight loop that applies all operations together. The resulting code is close to what you would write by hand in C.
The combination of off-heap memory and WSCG typically delivers 2–5x speedups over non-Tungsten execution on CPU-bound operations.
Shuffle: The Most Expensive Operation
A shuffle redistributes data across partitions, typically across the network. It is triggered by any operation that requires data from multiple partitions to be combined: groupBy, join, distinct, repartition, sort.
During a shuffle:
- The map phase: Each executor writes its output records to local shuffle files, grouped by the target partition they belong to (determined by a hash of the key).
- The reduce phase: Executors fetch the relevant partition files from all other executors over the network.
- The data is reassembled into new partitions for the next stage.
Shuffle is expensive because it involves network I/O (potentially reading gigabytes across machines), disk I/O (writing intermediate shuffle files), and serialization/deserialization. In a poorly designed pipeline, a single shuffle on a large dataset can take longer than all the non-shuffle operations combined.
Operations that trigger a shuffle
groupBy().agg()- requires all rows with the same key on the same executorjoin()- most joins require aligning matching keys on the same executordistinct()- requires global deduplicationrepartition(n)- explicitly redistributes datasort()/orderBy()- global sort requires a shuffle
Reducing shuffle cost
- Use
reduceByKeyinstead ofgroupByKey(RDD API):groupByKeyshuffles all values for a key, then aggregates.reduceByKeyaggregates locally first, then shuffles only the partial aggregates. Dramatically less data moved. - Broadcast joins: If one side of a join is small enough to fit in memory on each executor, broadcast it rather than shuffling both sides. This eliminates the shuffle entirely.
- Pre-partition your data: If you know you will frequently join on
user_id, partition your tables byuser_idat write time. Reads that filter or join on this key read only the relevant partitions. - Tune
spark.sql.shuffle.partitions: The default is 200, which is too low for large datasets and too high for small ones. Target 128–256 MB per partition post-shuffle.
Memory Management
Spark's executor memory is divided into regions:
Reserved memory (~300 MB): internal overhead, not configurable.
User memory (the rest, after Spark's fraction): Python objects, data structures in your UDFs. Controlled by spark.memory.fraction (default 0.6, meaning Spark gets 60% of the remaining heap).
Within Spark's memory fraction (spark.memory.fraction), memory is shared between two pools via the unified memory model:
Execution memory: Used for shuffle buffers, sort buffers, hash tables for joins and aggregations. If an executor runs out of execution memory, it spills intermediate data to disk - dramatically slowing the job.
Storage memory: Used for caching DataFrames (.cache() or .persist()). If storage memory is needed for execution, Spark will evict cached data to make room.
Key insight: execution memory and storage memory share the same pool and compete with each other. If you cache many large DataFrames, you leave less room for execution memory, increasing spill risk during aggregations and joins.
Partitioning Strategy
Partitions are the unit of parallelism in Spark. Every executor works on one partition at a time (per core). The number of partitions determines how much parallelism is available.
Too few partitions: Not enough work to keep all cores busy. If you have 40 cores and 20 partitions, half your cluster is idle.
Too many partitions: Overhead from task scheduling, shuffle file management, and small file writes at the end. Each task has fixed overhead - thousands of tiny tasks can be slower than hundreds of right-sized tasks.
Target partition size: 128–256 MB of uncompressed data per partition. This is a rule of thumb, not a law. IO-heavy workloads can handle larger partitions; memory-intensive operations (wide aggregations) benefit from smaller ones.
repartition vs coalesce
repartition(n) performs a full shuffle to redistribute data into exactly n partitions. Use it when you need to increase partitions or when your data is heavily skewed and needs balanced redistribution.
coalesce(n) reduces the number of partitions with minimal data movement by combining existing partitions on the same executor. It does not do a full shuffle. Use it to reduce partition count after a filter that removed most of your data (avoiding the shuffle cost of repartition). Never use coalesce to increase partition count - it does not shuffle and will not distribute data to new partitions.
Code Examples
1. Production PySpark Session Initialization
from pyspark.sql import SparkSession
def create_spark_session(app_name: str, num_executors: int = 10) -> SparkSession:
"""
Initialize a SparkSession with production-appropriate configurations.
Adjust executor memory and cores based on your cluster sizing.
"""
spark = (
SparkSession.builder
.appName(app_name)
# Executor sizing: target 4–8 cores, 16–32 GB RAM per executor
.config("spark.executor.instances", str(num_executors))
.config("spark.executor.cores", "5")
.config("spark.executor.memory", "20g")
.config("spark.executor.memoryOverhead", "4g") # For Python workers
# Driver sizing: only orchestrates, doesn't process data
.config("spark.driver.memory", "8g")
.config("spark.driver.maxResultSize", "2g")
# Shuffle configuration
# Rule of thumb: total data size (bytes) / 128MB = num shuffle partitions
.config("spark.sql.shuffle.partitions", "400")
# Adaptive Query Execution (Spark 3.0+) - let Spark tune partitions at runtime
.config("spark.sql.adaptive.enabled", "true")
.config("spark.sql.adaptive.coalescePartitions.enabled", "true")
# Broadcast join threshold: tables smaller than this are broadcast automatically
.config("spark.sql.autoBroadcastJoinThreshold", "50mb")
# Parquet optimization
.config("spark.sql.parquet.filterPushdown", "true")
.config("spark.sql.parquet.mergeSchema", "false")
# Kryo serialization is faster than Java default for RDD operations
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.getOrCreate()
)
spark.sparkContext.setLogLevel("WARN")
return spark
spark = create_spark_session("fraud-feature-pipeline", num_executors=10)
2. RDD vs DataFrame API - Same Computation, Different Performance
# Both compute average transaction amount per user.
# DataFrame API is ~5x faster due to Catalyst optimization.
from pyspark.sql import functions as F
# --- RDD API (slower) ---
# Each lambda crosses the Python-JVM boundary per row
rdd = spark.sparkContext.textFile("s3://data/transactions/")
def parse_line(line):
parts = line.split(",")
return (parts[0], float(parts[2])) # (user_id, amount)
avg_by_user_rdd = (
rdd
.map(parse_line)
.groupByKey() # Triggers shuffle - moves ALL values
.mapValues(lambda amounts: sum(amounts) / len(amounts))
)
# avg_by_user_rdd.collect() # Triggers execution
# --- DataFrame API (faster) ---
# Catalyst can optimize this; Tungsten generates efficient bytecode
df = spark.read.parquet("s3://data/transactions/")
avg_by_user_df = (
df
.groupBy("user_id")
.agg(F.avg("amount").alias("avg_amount"))
)
# avg_by_user_df.show() # Triggers execution
# Key difference: groupByKey in RDD shuffles all values then aggregates.
# groupBy().agg() in DataFrame uses partial aggregation before shuffle
# (equivalent to combineByKey or reduceByKey in RDD), moving far less data.
3. Reading Parquet, Transforming, and Writing Back
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType
def compute_transaction_features(spark: SparkSession, input_path: str, output_path: str):
"""
Read raw transaction Parquet, compute user-level features,
write partitioned Parquet for downstream training jobs.
"""
# Read - Spark reads only columns referenced downstream (projection pruning)
# and skips Parquet row groups that don't match filters (predicate pushdown)
raw = (
spark.read
.parquet(input_path)
.filter(F.col("status") == "completed") # Pushed down to Parquet reader
.filter(F.col("transaction_date") >= "2024-01-01")
.select("user_id", "merchant_category", "amount", "transaction_date", "is_declined")
)
# Compute user-level aggregations
user_features = (
raw
.groupBy("user_id", "merchant_category")
.agg(
F.count("*").alias("tx_count"),
F.avg("amount").alias("avg_amount"),
F.stddev("amount").alias("std_amount"),
F.sum(F.col("is_declined").cast(DoubleType())).alias("decline_count"),
F.max("transaction_date").alias("last_tx_date"),
)
.withColumn(
"decline_rate",
F.col("decline_count") / F.col("tx_count")
)
.withColumn(
"days_since_last_tx",
F.datediff(F.current_date(), F.col("last_tx_date"))
)
)
# Write - partitioned by merchant_category for efficient downstream reads
# Training jobs filtering on merchant_category will read only relevant partitions
(
user_features
.repartition("merchant_category") # One shuffle to co-locate by partition key
.write
.mode("overwrite")
.partitionBy("merchant_category")
.parquet(output_path)
)
return user_features
features_df = compute_transaction_features(
spark,
input_path="s3://raw-data/transactions/",
output_path="s3://feature-store/user-merchant-features/"
)
4. Inspecting and Controlling Partitioning
# How many partitions does your DataFrame have?
print(f"Partition count: {features_df.rdd.getNumPartitions()}")
# Check partition size distribution (run on a sample to avoid full scan)
partition_sizes = (
features_df
.withColumn("partition_id", F.spark_partition_id())
.groupBy("partition_id")
.count()
.orderBy("partition_id")
)
partition_sizes.show(20)
# Diagnose skew: are most rows in a few partitions?
partition_sizes.agg(
F.min("count").alias("min_rows"),
F.max("count").alias("max_rows"),
F.avg("count").alias("avg_rows"),
F.stddev("count").alias("std_rows"),
).show()
# Repartition to increase parallelism before a heavy aggregation
features_repartitioned = features_df.repartition(400, "user_id")
print(f"After repartition: {features_repartitioned.rdd.getNumPartitions()}")
# Coalesce to reduce small file count before writing
# (No shuffle - much cheaper than repartition for reducing partition count)
before_write = features_df.coalesce(50)
print(f"After coalesce: {before_write.rdd.getNumPartitions()}")
5. Reading the Execution Plan via explain()
# Read the execution plan before running expensive jobs.
# This lets you verify Catalyst is applying the optimizations you expect.
query = (
spark.read.parquet("s3://raw-data/transactions/")
.filter(F.col("amount") > 100)
.groupBy("user_id")
.agg(F.count("*").alias("tx_count"))
)
# Simple plan - just shows the logical plan
query.explain()
# Extended plan - shows unresolved, analyzed, optimized, and physical plans
query.explain(mode="extended")
# Formatted plan - easier to read for complex queries (Spark 3.0+)
query.explain(mode="formatted")
# What to look for in the plan:
# - "FileScan parquet ... PushedFilters" → predicate pushdown is working
# - "BroadcastHashJoin" → small table is being broadcast (no shuffle)
# - "SortMergeJoin" → both sides shuffled (expensive, check if broadcast is possible)
# - "Exchange" → a shuffle is happening here
# - "HashAggregate ... partial" → partial aggregation before shuffle (good)
# - "AQEShuffleRead" → Adaptive Query Execution coalesced shuffle partitions (good)
# Read the physical plan for the fraud feature query
fraud_features = (
spark.read.parquet("s3://raw-data/transactions/")
.filter(F.col("status") == "completed")
.filter(F.col("amount") > 10)
.groupBy("user_id")
.agg(
F.count("*").alias("tx_count"),
F.avg("amount").alias("avg_amount"),
)
)
fraud_features.explain(mode="formatted")
# Expected output will show:
# == Physical Plan ==
# AdaptiveSparkPlan (if AQE enabled)
# HashAggregate (partial aggregate)
# FileScan parquet PushedFilters: [IsNotNull(status), EqualTo(status,completed), ...]
Reading Data Efficiently: Parquet and Predicate Pushdown
Spark's performance on real-world pipelines is often limited by how fast it can read data, not by computation speed. Choosing the right file format and understanding how Spark reads it is as important as optimizing your transformations.
Why Parquet
Parquet is the default format for Spark ML pipelines for three reasons:
Columnar storage: Parquet stores data column-by-column rather than row-by-row. Reading 10 columns from a 200-column table requires reading only 10/200 = 5% of the file bytes, because unneeded columns are never read. Row-oriented formats (CSV, JSON, Avro in row mode) must read every column to access any column.
Efficient compression: Parquet applies column-level compression (typically Snappy or Zstd). Columns with low cardinality (e.g., a status column with 3 possible values) compress 20–50x. Numeric columns compress well with dictionary encoding. A 500 GB CSV dataset often becomes 80–120 GB as Parquet.
Statistics for predicate pushdown: Each Parquet file stores per-column statistics (min/max values, null count) in its footer, and per-row-group statistics (row groups are chunks of ~128 MB within a file). Spark's Parquet reader uses these statistics to skip entire row groups that cannot satisfy a filter predicate - without reading the row group's data.
# Parquet predicate pushdown in action
# This filter: status == "completed" AND amount > 1000
# Spark passes these predicates to the Parquet reader.
# Row groups whose max(amount) < 1000 are skipped entirely.
# Row groups where status is only "declined" are skipped entirely.
# Only row groups that COULD contain matching rows are read.
transactions = (
spark.read
.parquet("s3://data/transactions/")
.filter(F.col("status") == "completed")
.filter(F.col("amount") > 1000)
)
# Confirm predicate pushdown is working:
transactions.explain(mode="formatted")
# Look for: "FileScan parquet ... PushedFilters: [IsNotNull(status),
# EqualTo(status,completed), IsNotNull(amount), GreaterThan(amount,1000.0)]"
# NOT working if you see: "Filter (isnotnull(status#1) AND (status#1 = completed))"
# ABOVE the FileScan - that means filter is applied after reading all data.
# Predicate pushdown does NOT work for:
# - Computed columns: .filter(F.upper(F.col("status")) == "COMPLETED")
# (Parquet can't evaluate functions, only compare raw column values)
# - After a join or aggregation (the predicate is on a derived column)
# - When reading CSV, JSON (no column statistics in these formats)
Partitioned Parquet directories
Hive-style partitioning stores data in subdirectories named by partition key values:
s3://data/transactions/
status=completed/
date=2024-11-01/
part-0000.parquet
part-0001.parquet
date=2024-11-02/
part-0000.parquet
status=declined/
date=2024-11-01/
...
Spark's Parquet reader treats the directory names as column values and uses them for partition pruning - completely skipping directories whose partition key values cannot satisfy the filter. This is more aggressive than row-group-level predicate pushdown: Spark never even opens the files in pruned partitions.
# Partition pruning - Spark skips the entire "status=declined/" directory
# No files in that partition are opened or read
recent_completed = (
spark.read
.parquet("s3://data/transactions/")
.filter(F.col("status") == "completed") # Partition pruning on status
.filter(F.col("date") >= "2024-10-01") # Partition pruning on date
.filter(F.col("amount") > 500) # Row-group predicate pushdown
)
# Confirm both pruning levels work:
recent_completed.explain(mode="formatted")
# Look for: "PartitionFilters: [isnotnull(status#1), EqualTo(status#1,completed), ...]"
# AND: "PushedFilters: [IsNotNull(amount), GreaterThan(amount,500.0)]"
# Design principle: choose your Parquet partition keys based on how the data will
# be queried. For ML training sets, partition by date - training jobs always
# filter on a date range. For feature stores, partition by (date, entity type).
Join Strategies
Joins are the most complex and expensive operations in Spark. Understanding how Spark executes joins helps you choose the right strategy and avoid common performance traps.
Broadcast Hash Join
When one side of a join is small enough to fit in the executor's memory, Spark broadcasts the small table to every executor as a hash map. Each executor then probes the hash map for every row in its partition of the large table. No shuffle required on the large table.
This is the fastest join strategy. Use it whenever the smaller table is under ~50–200 MB (tunable via spark.sql.autoBroadcastJoinThreshold).
from pyspark.sql import functions as F
# Explicit broadcast hint - forces broadcast even if Catalyst wouldn't choose it
merchant_lookup = spark.read.parquet("s3://data/merchant-categories/") # ~10 MB
transactions = spark.read.parquet("s3://data/transactions/") # 500 GB
# F.broadcast() tells Catalyst to broadcast merchant_lookup
enriched = transactions.join(
F.broadcast(merchant_lookup),
on="merchant_id",
how="left",
)
# In the query plan: BroadcastHashJoin - no shuffle on the transactions side
# Check the plan to confirm:
enriched.explain(mode="formatted")
# Look for: BroadcastHashJoin, BuildLeft (or BuildRight)
# NOT: SortMergeJoin, Exchange (which would indicate a shuffle)
Sort Merge Join
When both sides of the join are too large to broadcast, Spark falls back to Sort Merge Join. Both sides are shuffled so matching keys end up on the same executor, then each executor merges its two sorted partitions to find matching rows.
Sort Merge Join is due to sorting, and it involves two shuffles (one per side). It is the default for large-large joins. On 500 GB × 100 GB joins, this produces significant network I/O.
# Sort Merge Join - when neither side can be broadcast
# This happens automatically when both sides exceed autoBroadcastJoinThreshold
user_events = spark.read.parquet("s3://data/user-events/") # 200 GB
user_profiles = spark.read.parquet("s3://data/user-profiles/") # 50 GB
# Both too large for broadcast - Spark chooses SortMergeJoin
# In the plan: SortMergeJoin, two Exchange nodes (shuffles)
joined = user_events.join(user_profiles, on="user_id", how="left")
joined.explain()
# To force a broadcast if you know the profile table fits in memory:
# Either raise the threshold:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "60mb")
# Or use explicit F.broadcast():
joined_fast = user_events.join(F.broadcast(user_profiles), on="user_id", how="left")
Shuffle Hash Join
Shuffle Hash Join is an alternative to Sort Merge Join: both sides are shuffled by join key, but instead of sorting and merging, one side builds a hash table in memory and the other side probes it. It avoids the sort step ( vs ) but requires the build side to fit in executor memory.
Spark 3.0+ uses AQE to switch between join strategies at runtime based on actual data sizes after a shuffle, even if the pre-execution statistics were wrong.
Skewed Joins
Data skew in joins is a severe performance problem: if 80% of rows have user_id = NULL or user_id = "default", a single executor receives 80% of the data during the shuffle. That executor runs for hours while the other 49 executors finish in minutes. The job's wall-clock time equals the slowest task's time.
# Check for skew: examine key distribution before joining
user_event_counts = (
user_events
.groupBy("user_id")
.count()
.orderBy(F.col("count").desc())
)
user_event_counts.show(10)
# If the top key has 100M rows and the median has 50, you have severe skew.
# AQE Skew Join Optimization (Spark 3.0+)
# Automatically detects skewed partitions after shuffle and splits them
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256mb")
# With these settings, Spark splits partitions larger than 256MB * 5 = 1.28GB
# and duplicates the matching partition from the non-skewed side.
# Manual salting for cases AQE can't handle (e.g., non-equi joins)
SALT_BUCKETS = 16
transactions_salted = transactions.withColumn(
"salt", (F.rand() * SALT_BUCKETS).cast("int")
)
# Replicate the lookup table across all salt values
from pyspark.sql import functions as F
import itertools
merchant_replicated = merchant_lookup.withColumn(
"salt",
F.explode(F.array([F.lit(i) for i in range(SALT_BUCKETS)]))
)
# Join on both the original key and the salt
enriched_salted = transactions_salted.join(
merchant_replicated,
on=["merchant_id", "salt"],
how="left",
).drop("salt")
Adaptive Query Execution (AQE)
Introduced in Spark 3.0, Adaptive Query Execution (AQE) is one of the most impactful features for production pipelines. Traditional query planning happens once, before execution, using estimated statistics that are often wrong. AQE re-optimizes the query plan at runtime using actual statistics collected after each shuffle.
AQE addresses three major problems:
1. Suboptimal shuffle partition count: Before execution, Spark uses spark.sql.shuffle.partitions (default: 200). This may produce thousands of tiny partitions or a handful of huge ones depending on the actual data volume. AQE coalesces small post-shuffle partitions into fewer, larger ones based on actual bytes written.
2. Skewed join partitions: AQE detects partitions significantly larger than the median and automatically splits them across multiple tasks, duplicating the matching partition from the other side. This eliminates the "stragglers" problem in skewed joins without manual salting.
3. Wrong join strategy selection: If the pre-execution statistics estimated a table as large but the actual filtered result is small, AQE can switch from SortMergeJoin to BroadcastHashJoin mid-query.
# Enable AQE - recommended for all production Spark 3.0+ jobs
spark.conf.set("spark.sql.adaptive.enabled", "true")
# Coalesce small shuffle partitions (default: enabled when AQE is enabled)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128mb")
# Skew join optimization
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
# With AQE, you can often start with a high shuffle partition count
# and let AQE coalesce down to the right number at runtime
spark.conf.set("spark.sql.shuffle.partitions", "1000")
# AQE will coalesce 1000 shuffle partitions into fewer if they're small
# Verify AQE is active in the query plan
query = (
spark.read.parquet("s3://data/transactions/")
.groupBy("user_id")
.agg(F.count("*").alias("tx_count"))
)
query.explain(mode="formatted")
# Look for: AdaptiveSparkPlan in the physical plan
# Look for: AQEShuffleRead (coalesced partitions) after execution in Spark UI
AQE is not magic - it cannot fix fundamental design problems like reading 1 TB when you only need 10 GB, or using Python UDFs on every row. But for correctly structured pipelines, enabling AQE is a free performance improvement that typically reduces job time by 20–40% on workloads with unpredictable data distributions.
The Spark UI: Reading the Evidence
The Spark UI (default port 4040 during execution, or the history server after completion) is the primary tool for diagnosing Spark performance problems. Engineers who can read the Spark UI can diagnose issues that would otherwise take days of guesswork.
What to look for in the Stages tab
The Stages tab shows every stage in the job, with the duration of each stage and the distribution of task times across partitions.
Stragglers (a few tasks much longer than the median) indicate data skew. If 199 out of 200 tasks complete in 2 seconds and one task takes 45 minutes, that partition has all the data. The fix is either AQE skew join optimization or manual salting.
Long stages with many small tasks indicate too many partitions - the scheduling overhead dominates useful computation. Coalesce partitions or raise spark.sql.shuffle.partitions.
Spill columns in the stage summary show how much data was spilled to disk. Non-zero spill means executors are running out of execution memory. Fix by: increasing executor memory, reducing cores per executor (more memory per core), or reducing partition size.
What to look for in the SQL tab
The SQL tab shows the physical query plan as a visual DAG with timing information per node. This is the most useful view for understanding why a specific query is slow.
- Exchange nodes are shuffles. The number of bytes written to shuffle files is shown. Excessive shuffle bytes indicate either too many large aggregations or missing broadcast joins.
- Sort nodes before SortMergeJoin are necessary overhead. If you see them, ask whether the join could be a broadcast join instead.
- HashAggregate (partial) before an Exchange means Spark is correctly doing partial aggregation before the shuffle - good.
- Filter nodes close to FileScan nodes mean predicate pushdown is working - predicates are applied at the file scan level, not after reading all data.
# Programmatic access to job metrics (useful for pipeline monitoring)
# Spark doesn't expose a direct Python API for stage metrics, but you can
# access them via the REST API of the running Spark application
import requests
def get_stage_metrics(spark_ui_url: str = "http://localhost:4040"):
"""
Fetch stage metrics from the Spark UI REST API.
Useful for logging pipeline performance in monitoring systems.
"""
stages = requests.get(f"{spark_ui_url}/api/v1/applications").json()
if not stages:
return None
app_id = stages[0]["id"]
stage_data = requests.get(
f"{spark_ui_url}/api/v1/applications/{app_id}/stages"
).json()
for stage in stage_data:
print(f"Stage {stage['stageId']}: {stage['name']}")
print(f" Duration: {stage.get('executorRunTime', 0) / 1000:.1f}s")
print(f" Input: {stage.get('inputBytes', 0) / 1e9:.2f} GB")
print(f" Shuffle write: {stage.get('shuffleWriteBytes', 0) / 1e9:.2f} GB")
print(f" Spill (disk): {stage.get('diskBytesSpilled', 0) / 1e9:.2f} GB")
print(f" Spill (mem): {stage.get('memoryBytesSpilled', 0) / 1e9:.2f} GB")
# Call after submitting a job (before it completes) or via History Server
get_stage_metrics()
Speculative Execution
Speculative execution is Spark's mechanism for handling slow tasks (stragglers) that are slow due to hardware issues rather than data skew. When a task runs significantly longer than the median for its stage, Spark launches a duplicate copy of the task on a different executor. Whichever finishes first wins; the other is killed.
# Enable speculative execution
spark.conf.set("spark.speculation", "true")
spark.conf.set("spark.speculation.multiplier", "3") # Launch speculation when
# task is 3x the median
spark.conf.set("spark.speculation.quantile", "0.9") # Wait until 90% of tasks
# in stage are complete
:::warning Speculative execution and non-idempotent writes Speculative execution runs the same task twice on different executors. If your task writes to an external system (a database, an API, a non-transactional file sink) and is not idempotent, speculative execution can cause duplicate writes. It is safe for pure read-transform-write-to-Parquet pipelines, but disable it if your tasks have non-idempotent side effects. :::
Spark on Kubernetes vs YARN
Production Spark clusters run on either YARN (the traditional Hadoop resource manager) or Kubernetes (the modern container orchestration platform). For teams building on cloud-native infrastructure, Kubernetes is increasingly the default.
YARN is well-understood and deeply integrated with Hadoop ecosystems. It manages resource negotiation, executor placement, and multi-tenant job queues. If your organization runs Hadoop or EMR, YARN is the natural choice.
Kubernetes treats each Spark executor as a pod. Benefits: fine-grained resource control via pod specifications, integration with existing Kubernetes RBAC and network policies, and better autoscaling. The driver runs as a pod; executors are dynamically launched and terminated. Dynamic allocation on Kubernetes is well-supported in Spark 3.1+.
# Spark on Kubernetes - session configuration
spark = (
SparkSession.builder
.appName("fraud-feature-pipeline")
.master("k8s://https://my-k8s-cluster:6443")
.config("spark.kubernetes.container.image", "my-registry/spark:3.5.0-python3.11")
.config("spark.kubernetes.namespace", "ml-platform")
# Dynamic allocation: executors scale up/down based on workload
.config("spark.dynamicAllocation.enabled", "true")
.config("spark.dynamicAllocation.minExecutors", "2")
.config("spark.dynamicAllocation.maxExecutors", "50")
.config("spark.dynamicAllocation.initialExecutors", "5")
# Shuffle service for dynamic allocation (or use shuffle storage on PVCs)
.config("spark.dynamicAllocation.shuffleTracking.enabled", "true")
.config("spark.executor.instances", "5")
.config("spark.executor.cores", "4")
.config("spark.executor.memory", "16g")
.getOrCreate()
)
Dynamic allocation is particularly valuable for ML feature pipelines that run once per night: the cluster starts small, scales up to handle the heavy shuffle stages, and scales back down for the write phase. Combined with spot/preemptible instances, this significantly reduces compute costs.
Production Engineering Notes
:::tip Cache strategically
Calling .cache() on a DataFrame stores it in Spark's storage memory after the first action triggers computation. Only cache DataFrames that are reused multiple times in your pipeline. Caching everything wastes storage memory and leaves less room for execution memory during aggregations and joins, increasing spill.
:::
:::note Checkpoint for long lineage chains
If you have a pipeline with 50+ transformations, the RDD lineage becomes very long. In case of failure, Spark recomputes from the beginning. Use df.checkpoint() at logical midpoints to truncate the lineage. Checkpointing materializes the DataFrame to disk and breaks the lineage chain, making recomputation cheaper on failure.
:::
:::warning Driver memory and collect()
df.collect() brings all data from all executors back to the driver as a Python list. On a 500 GB dataset, this will kill your driver with an OOM error. Only collect small DataFrames - aggregate results, samples, or small lookup tables. For writing large DataFrames, always use df.write, which writes in parallel from executors without touching the driver.
:::
:::danger Implicit collect() in loops
Iterating over a DataFrame in a Python for loop (for row in df) implicitly calls collect(). This is one of the most common Spark performance bugs. All data flows through the driver. Use df.foreach() or df.foreachPartition() if you need to iterate, keeping computation on the executors.
:::
Common Mistakes
:::danger Using groupByKey instead of reduceByKey (RDD API)
groupByKey collects all values for each key on the reducer before aggregating, shuffling all the raw data. reduceByKey applies a commutative/associative function locally on each partition first, then shuffles only the partial results. For an aggregation like sum or count, reduceByKey can be 10x faster on large datasets. In the DataFrame API, groupBy().agg() already uses the partial-aggregation pattern automatically.
:::
:::danger repartition to reduce partition count
repartition(n) always performs a full shuffle, even when reducing partition count. Use coalesce(n) when you want to reduce partition count - it merges partitions on the same executor without network I/O. Reserve repartition for when you need to increase partition count or redistribute skewed data.
:::
:::warning spark.sql.shuffle.partitions = 200 (the default)
The default of 200 shuffle partitions made sense for Spark's early use cases. On modern datasets of tens or hundreds of GB, 200 partitions means each partition could be gigabytes - far larger than the 128–256 MB target. This causes memory spill and slow execution. Set this based on your data volume: total_shuffle_data_bytes / 128MB. Enable AQE (spark.sql.adaptive.enabled=true) on Spark 3.0+ to let Spark adjust this at runtime.
:::
:::warning Calling explain() on an uncached DataFrame references
If you call explain() and see Scan ExistingRDD instead of FileScan parquet, your DataFrame was derived from a collected Python list or a cached RDD. This means Catalyst cannot push predicates down to the source. Always read from Parquet/Delta/ORC files directly to benefit from predicate pushdown.
:::
Interview Q&A
Q: What is the difference between a transformation and an action in Spark?
A transformation defines a computation but does not execute it - Spark builds a DAG node for it and returns a new RDD/DataFrame. Examples: filter, select, groupBy, join, withColumn. An action triggers execution of the DAG to produce a result or write output. Examples: collect, count, show, write, take. The lazy evaluation model of transformations allows Spark's Catalyst optimizer to see the entire computation graph before executing, enabling optimizations like predicate pushdown, projection pruning, and join reordering. If Spark evaluated eagerly (like Pandas), it could not apply these cross-operator optimizations.
Q: What triggers a shuffle in Spark and how do you minimize it?
A shuffle is triggered by any operation that requires data from multiple partitions to be co-located: groupBy, most joins (SortMergeJoin), distinct, repartition, and sort. To minimize shuffle cost: (1) Use broadcast joins for small tables - this eliminates the shuffle on the large side entirely. (2) Pre-partition data at write time on join keys, so downstream reads on that key avoid a shuffle. (3) Use reduceByKey (RDD) or partial aggregation patterns (DataFrame) to reduce data volume before shuffling. (4) Tune spark.sql.shuffle.partitions to keep post-shuffle partitions in the 128–256 MB range. (5) Enable AQE (spark.sql.adaptive.enabled=true) so Spark can dynamically coalesce small shuffle partitions at runtime.
Q: How does the Catalyst optimizer work?
Catalyst is a rule-based and cost-based query optimizer. When you submit a DataFrame operation, Catalyst first parses it into an unresolved logical plan, then resolves column references into an analyzed logical plan. It then applies a set of logical optimization rules: predicate pushdown (move filters as early as possible, ideally to the data source), projection pruning (drop unused columns early), constant folding, and join reordering. The result is an optimized logical plan. Catalyst then selects physical execution strategies - for example, choosing BroadcastHashJoin over SortMergeJoin when one table is small - producing one or more candidate physical plans. If column statistics are available (from ANALYZE TABLE), Catalyst uses cost-based optimization (CBO) to choose between candidates. The winner is passed to Tungsten for code generation and execution.
Q: What is the difference between repartition and coalesce?
repartition(n) performs a full shuffle to create exactly n new partitions, distributing data evenly. It can both increase and decrease the partition count. Use it when you need to increase parallelism or fix a skewed data distribution. coalesce(n) reduces partitions without a full shuffle by merging existing partitions, minimizing data movement. It can only decrease the partition count. Use it to reduce small files before writing to avoid the overhead of a full shuffle. Never use coalesce to increase partitions - it does not shuffle, so data does not move to new partitions.
Q: How does Spark handle fault tolerance without replication?
Spark uses lineage-based fault tolerance. Instead of replicating data like HDFS, Spark remembers the sequence of transformations (the lineage DAG) that produced each RDD or DataFrame partition. If an executor fails and loses a partition, Spark reruns only the tasks needed to recompute that specific partition on another executor - tracing back through the lineage to find the last materialized point (source data or a checkpoint). This approach is efficient for read-only workloads: recomputing a partition is often cheaper than the overhead of synchronous replication. For very long lineage chains, checkpoint() truncates the lineage by materializing the DataFrame to stable storage, preventing expensive full recomputation on failure.
Q: What is whole-stage code generation and why does it matter?
Whole-stage code generation (WSCG) is a Tungsten optimization where Spark fuses multiple query operators into a single compiled Java function instead of calling a separate function per operator per row. In the traditional Volcano (iterator) model, processing a row through filter → project → aggregate requires three virtual function calls per row. At a billion rows, that is three billion function calls - each with indirection overhead. WSCG generates a tight loop that applies all operations to a row in a single compiled function, eliminating the overhead. The resulting code is close to hand-written Java or C. On CPU-bound workloads (aggregations, projections, sort), WSCG typically delivers 2–5x speedups. You can see it in the query plan as WholeStageCodegen wrapping multiple operators.
Q: How would you size executors for a Spark job processing 500 GB of Parquet data?
Start with the "YARN executor sizing" heuristic: 5 cores per executor, 20 GB executor RAM (with 4 GB overhead), as many executors as your cluster allows. For 500 GB of data, target 128–256 MB per partition, so 2000–4000 partitions. With 50 executors at 5 cores each (250 cores), 2000 partitions runs in ~8 rounds - reasonable. Set spark.sql.shuffle.partitions to 500 (500 GB / 128 MB ≈ 3900, but reduce since Parquet compression means on-disk data is smaller than in-memory). Enable AQE to let Spark coalesce shuffle partitions if the data is small enough. Monitor the Spark UI for spill - if executors are spilling to disk during aggregations, increase executor memory or reduce cores per executor to give each core more memory.
