:::tip 🎮 Interactive Playground Visualize this concept: Try the Spark Batch Processing demo on the EngineersOfAI Playground - no code required. :::
Advanced Spark Performance Tuning for ML Workloads
Reading time: ~45 min | Production relevance: Critical | Roles: Data Engineer, ML Engineer, MLOps Engineer
The $2,700/Day Problem
It is a Monday morning at a financial institution running compliance analytics. The data engineering team responsible for the transaction risk feature pipeline has a performance review meeting scheduled because last month's Databricks bill came back 3,000 per run, 90,000 per month.
A senior data engineer named Isabel spends the next week applying systematic performance tuning. She starts with the Spark UI, not guesswork. She finds three things: a data skew on the merchant_category join that is causing 2 tasks to process 40% of the data while 398 tasks complete instantly, a shuffle producing 800GB of intermediate data from a join that could be eliminated entirely with bucketing, and 47% of executor time spent on garbage collection due to oversized in-memory caching of DataFrames that are only read once.
After salting the skewed join, implementing bucket joins for the three most expensive joins, fixing the caching strategy, and enabling Adaptive Query Execution, the job runs in 45 minutes. The cluster cost drops from 300 per night. The annual savings: $985,500. The time spent: one week of one engineer.
This lesson covers exactly what changed and how to apply the same methodology to any Spark job.
Why Performance Tuning Matters for ML Workloads
ML feature pipelines are among the most demanding Spark workloads because of their unique characteristics:
High fan-out joins: Feature computation joins a central entity table (customers, transactions, vehicles) to dozens of feature source tables. Each join is expensive. Multiplicative inefficiency - a 2x slowdown on each of 20 joins - produces a 1,000,000x total slowdown.
Aggregation at multiple granularities: A feature like "transaction count in the last 7, 14, 30 days" requires computing window aggregations at multiple time horizons. These are shuffle-heavy operations.
Skewed entity distributions: In financial data, power users make 10,000 transactions per day while most users make 5. In logistics, a hub city has 100x more vehicle events than a rural depot. This skew concentrates work on a few tasks, causing the entire job to wait on the slowest one.
Iterative access patterns: ML training often re-reads feature tables multiple times (multiple epochs in some workflows, or multiple feature extraction steps). Caching these tables can eliminate redundant I/O, but over-caching causes memory pressure and GC pauses that are slower than re-reading.
Reading the Spark UI: Finding the Bottleneck
Before tuning anything, use the Spark UI to find the actual bottleneck. Optimization based on guesswork wastes time and sometimes makes things worse.
The Spark UI hierarchy: Jobs → Stages → Tasks
What to look for in the Spark UI:
| Symptom | Where to find it | Likely cause |
|---|---|---|
| Most tasks finish in 1s, a few take 10+ minutes | Stage detail → Task duration histogram | Data skew |
| Large "Shuffle Read" column for a stage | Stage list → shuffle read bytes | Too many or too large shuffles |
| High "GC Time" relative to task duration | Task metrics | Over-caching, large objects in memory |
| "Spill (Memory)" or "Spill (Disk)" non-zero | Stage detail → task metrics | Too many rows per partition, insufficient executor memory |
| Low CPU utilization (40-60%) with many tasks | Executor tab | Too many small tasks, scheduling overhead |
Reading task duration histograms: The single most useful view for detecting data skew. If 95% of tasks complete in under 10 seconds but 5 tasks take 15 minutes, you have skew. The entire stage waits for the slowest task - those 5 tasks determine stage duration.
Extracting Metrics Programmatically
from pyspark.sql import SparkSession
from pyspark import SparkContext
from py4j.java_gateway import java_import
def get_stage_metrics(sc: SparkContext) -> list[dict]:
"""
Extract stage-level metrics from the Spark status API.
Useful for automated performance regression tests or dashboards.
"""
java_import(sc._jvm, "org.apache.spark.SparkConf")
status_api = sc._jvm.org.apache.spark.SparkContext.getOrCreate().statusStore()
stages = status_api.stageList()
metrics = []
for stage in stages:
attempt = stage.latestAttemptId()
stage_data = status_api.stage(stage.stageId(), attempt)
metrics.append({
"stage_id": stage.stageId(),
"name": stage.name(),
"num_tasks": stage_data.numTasks(),
"duration_ms": stage_data.executorRunTime(),
"shuffle_read_bytes": stage_data.shuffleReadBytes(),
"shuffle_write_bytes": stage_data.shuffleWriteBytes(),
"gc_time_ms": stage_data.jvmGcTime(),
"spill_memory_bytes": stage_data.memoryBytesSpilled(),
"spill_disk_bytes": stage_data.diskBytesSpilled(),
"input_records": stage_data.inputRecords(),
"output_records": stage_data.outputRecords(),
})
return metrics
def identify_bottleneck_stage(metrics: list[dict]) -> dict:
"""Return the stage consuming the most executor time."""
return max(metrics, key=lambda s: s["duration_ms"])
Data Skew: The Silent Killer
Data skew is the single most common cause of unexpectedly slow Spark jobs. Understanding it requires understanding how Spark distributes work.
When Spark performs a shuffle (for a groupBy, join, or window operation), it partitions data by hash of the join/group key. If certain keys appear far more often than others - a "hot key" - those partitions receive far more data. The tasks processing hot-key partitions take 100x longer than others. The stage does not complete until the slowest task finishes. All other executor cores sit idle.
Common hot keys in ML workloads:
merchant_category = "GROCERY"- represents 30% of all transactionscity = "New York"- represents 15% of all GPS eventscustomer_segment = "UNKNOWN"- represents 25% of all customer records (a data quality artifact)null- if your join key has nulls, all null-key records hash to the same partition
Detecting Data Skew
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
def detect_key_skew(df: DataFrame, key_column: str, top_n: int = 10) -> None:
"""
Analyze the distribution of a join/group key.
Prints the top-N most frequent key values and their percentage of total records.
If the top key accounts for > 10% of records, you likely have a skew problem.
"""
total_rows = df.count()
top_keys = (
df
.groupBy(key_column)
.agg(F.count("*").alias("count"))
.withColumn("pct_of_total", F.round(F.col("count") / total_rows * 100, 2))
.orderBy(F.desc("count"))
.limit(top_n)
)
print(f"\nKey distribution for '{key_column}' (total rows: {total_rows:,}):")
top_keys.show(truncate=False)
# Flag if top key is more than 10% of data
top_pct = top_keys.first()["pct_of_total"]
if top_pct > 10.0:
print(
f"\nWARNING: Top key accounts for {top_pct}% of data - "
f"this will cause significant skew on shuffle operations."
)
# Usage: check before running the expensive join
detect_key_skew(transactions_df, "merchant_category")
Fixing Skew: Salting
Salting adds a random prefix to the skewed join key, distributing a single hot-key partition across multiple partitions:
import random
from pyspark.sql import functions as F
def skewed_join_with_salting(
large_df: DataFrame,
small_df: DataFrame,
join_key: str,
salt_buckets: int = 100,
) -> DataFrame:
"""
Join two DataFrames where large_df has skewed join key distribution.
Strategy:
1. Add random salt (0..N) to the hot key in large_df
2. Explode small_df to have N copies (one per salt bucket)
3. Join on (original_key, salt)
4. The join is now evenly distributed
salt_buckets: higher = more even distribution, more replication of small_df
Choose based on: hot_key_fraction * total_large_rows / target_partition_size
"""
# Step 1: Add random salt to the large table
salted_large = large_df.withColumn(
"salt",
(F.rand() * salt_buckets).cast("int")
).withColumn(
"salted_key",
F.concat(F.col(join_key).cast("string"), F.lit("_"), F.col("salt").cast("string"))
)
# Step 2: Replicate the small table across all salt buckets
# Create an array of salt values [0, 1, 2, ..., N-1], then explode
salt_array = F.array([F.lit(i) for i in range(salt_buckets)])
salted_small = (
small_df
.withColumn("salt_array", salt_array)
.withColumn("salt", F.explode(F.col("salt_array")))
.drop("salt_array")
.withColumn(
"salted_key",
F.concat(F.col(join_key).cast("string"), F.lit("_"), F.col("salt").cast("string"))
)
)
# Step 3: Join on the salted key (now evenly distributed)
result = salted_large.join(salted_small, on="salted_key", how="inner")
# Step 4: Clean up - drop the salt columns from the result
return result.drop("salt", "salted_key")
# Example: joining 1TB transactions with merchant metadata
# merchant_category is skewed - "GROCERY" appears in 30% of rows
result_df = skewed_join_with_salting(
large_df=transactions_df,
small_df=merchant_metadata_df,
join_key="merchant_category",
salt_buckets=50, # Distribute hot key across 50 partitions
)
Salting is the right fix when: (1) you cannot use a broadcast join (small table too large), (2) AQE skew join handling is not fixing it (skew threshold not met), and (3) you can identify the skewed key in advance. Salting has overhead - it replicates the small table N times. For salt_buckets=50, your small table is replicated 50x in memory. Make sure the small table actually fits in memory multiplied by the salt factor.
Broadcast Joins: Eliminating Shuffle Entirely
When one side of a join is small enough to fit in executor memory, Spark can broadcast it to all executors, eliminating the shuffle entirely. This is the highest-impact join optimization available.
from pyspark.sql import functions as F
# ── Method 1: Automatic broadcast (configure threshold) ───────────────────────
# Default: 10MB. Increase to cover larger dimension tables.
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", str(100 * 1024 * 1024)) # 100MB
# Now any join where one side < 100MB is automatically broadcast
result = large_transactions.join(merchant_metadata, on="merchant_id", how="left")
# ── Method 2: Explicit broadcast hint ─────────────────────────────────────────
# Use when you know the table is small even if Spark doesn't estimate correctly
# (Spark's size estimator can be wrong for non-parquet sources)
result = large_transactions.join(
F.broadcast(merchant_metadata), # Explicit broadcast hint
on="merchant_id",
how="left",
)
# ── Method 3: Check whether a join is using broadcast ─────────────────────────
def check_join_plan(df: DataFrame) -> None:
"""Print the physical plan and highlight join strategy."""
plan = df._jdf.queryExecution().executedPlan().toString()
if "BroadcastHashJoin" in plan:
print("JOIN STRATEGY: BroadcastHashJoin (efficient)")
elif "SortMergeJoin" in plan:
print("JOIN STRATEGY: SortMergeJoin (shuffle required - consider broadcast)")
elif "BroadcastNestedLoopJoin" in plan:
print("JOIN STRATEGY: BroadcastNestedLoopJoin (cross-join - check conditions)")
print("\nFull plan:")
print(plan)
check_join_plan(result)
Size thresholds in practice:
| Table size | Recommendation |
|---|---|
| Under 50MB | Always broadcast (well within default threshold after config change) |
| 50MB–500MB | Explicitly broadcast - increase threshold or use broadcast() hint |
| 500MB–2GB | Use AQE + salting if skewed; sort-merge join otherwise |
| Over 2GB | Sort-merge join; consider bucketing to eliminate shuffle permanently |
Adaptive Query Execution (AQE): Spark's Self-Tuning
Spark 3.0 introduced Adaptive Query Execution - the optimizer re-plans the query at runtime based on actual data statistics collected during execution. It solves three common problems automatically:
# Enable AQE (enabled by default in Spark 3.2+, but confirm)
spark.conf.set("spark.sql.adaptive.enabled", "true")
# ── AQE Feature 1: Dynamic partition coalescing ───────────────────────────────
# After a shuffle, Spark knows the actual partition sizes.
# Instead of 200 shuffle partitions of 5MB each (too small, too many tasks),
# AQE merges them into larger partitions.
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.minPartitionSize", "64mb")
spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "256mb")
# ── AQE Feature 2: Skew join handling ─────────────────────────────────────────
# AQE detects skewed partitions and splits them automatically.
# A partition is "skewed" if it is 5x the median size AND > 256MB.
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes",
str(256 * 1024 * 1024))
# ── AQE Feature 3: Dynamic join strategy switching ────────────────────────────
# If AQE determines one side of a join is small after execution,
# it can switch from sort-merge to broadcast join at runtime.
spark.conf.set("spark.sql.adaptive.localShuffleReader.enabled", "true")
# Verify AQE is working: check the physical plan for "AdaptiveSparkPlan"
def is_aqe_active(df: DataFrame) -> bool:
plan = df._jdf.queryExecution().executedPlan().toString()
return "AdaptiveSparkPlan" in plan
AQE is excellent at fixing skew after it detects it, but it has limits. Skew must exceed the threshold (5x median, >256MB per partition) to trigger AQE skew handling. For extreme skew (a single key representing 30% of 1TB = 300GB in one partition), AQE helps but may not fully solve it. Manual salting is still needed in the most extreme cases. Think of AQE as a safety net that catches moderate skew automatically, not a replacement for understanding your data distribution.
Shuffle Optimization
Shuffle is the most expensive operation in Spark - data is serialized, written to disk, transferred over the network, and deserialized on the receiving executor. Minimizing shuffle is the highest-leverage optimization after addressing data skew.
# ── Setting 1: Number of shuffle partitions ────────────────────────────────────
# Default is 200 - correct for small datasets, too few for 1TB jobs.
# Rule of thumb: target 100-200MB per partition after shuffle.
# For 800GB shuffle: 800GB / 128MB = ~6250 partitions
spark.conf.set("spark.sql.shuffle.partitions", "6000")
# With AQE enabled, this is the MAXIMUM - AQE coalesces down to larger partitions.
# Without AQE, set this precisely based on your data volume.
# ── Setting 2: Max partition bytes (input reading) ────────────────────────────
# Controls the target size of each input partition when reading files.
# Default 128MB works for most cases. For small files (many Parquet files < 1MB),
# increasing this reduces the number of tasks and scheduling overhead.
spark.conf.set("spark.sql.files.maxPartitionBytes", str(256 * 1024 * 1024)) # 256MB
# ── Detecting shuffle spill ────────────────────────────────────────────────────
# Spill = data too large to fit in executor memory, written to disk during shuffle.
# Spill causes 10-100x slowdown for affected tasks.
# Detect: Spark UI → Stage details → "Shuffle Spill (Memory)" and "Shuffle Spill (Disk)"
# Fix: increase executor memory, reduce partition size, or add repartition before shuffle
def analyze_shuffle_overhead(df: DataFrame) -> None:
"""
Analyze the logical plan for shuffle-inducing operations.
Helps identify which transformations cause the most expensive shuffles.
"""
from pyspark.sql.functions import explain
print("=== Physical Plan (looking for Exchange nodes = shuffles) ===")
df.explain(mode="formatted")
# Look for: Exchange (shuffle), Sort (often precedes SortMergeJoin), HashAggregate
Memory Tuning
Spark's memory is divided between execution (joins, aggregations, shuffles) and storage (caching). Understanding this split eliminates GC problems and spill.
# ── Executor memory configuration ─────────────────────────────────────────────
# Total executor memory allocated by YARN/Kubernetes
# spark.executor.memory = "16g"
# ── Memory fractions ──────────────────────────────────────────────────────────
# spark.memory.fraction: fraction of heap used for execution + storage combined
# Default: 0.6 (60% of heap for Spark; 40% reserved for user code, OS)
spark.conf.set("spark.memory.fraction", "0.7") # Increase if heap is clean
# spark.memory.storageFraction: fraction of the Spark memory pool reserved for storage
# Default: 0.5 (so 30% of heap for storage = caching)
# When execution needs more memory, it can evict storage - but not vice versa
spark.conf.set("spark.memory.storageFraction", "0.3") # Reduce for join-heavy jobs
# ── Off-heap memory (advanced) ────────────────────────────────────────────────
# Moves data out of JVM heap - reduces GC pressure for large-memory workloads
spark.conf.set("spark.memory.offHeap.enabled", "true")
spark.conf.set("spark.memory.offHeap.size", "8g")
# ── GC tuning ─────────────────────────────────────────────────────────────────
# If GC time > 10% of task time: executor heap is too full.
# Options:
# 1. Increase executor memory
# 2. Reduce cached data (cache only what is reused 2+ times)
# 3. Switch to G1GC (better for large heaps)
spark.conf.set(
"spark.executor.extraJavaOptions",
"-XX:+UseG1GC -XX:InitiatingHeapOccupancyPercent=35"
)
Caching Strategy: Cache Only What You Reuse
Caching is the most commonly misused optimization in Spark. Caching a DataFrame tells Spark to keep it in memory (or on disk) after its first computation, avoiding recomputation on subsequent accesses.
from pyspark import StorageLevel
# ── Storage levels ─────────────────────────────────────────────────────────────
# MEMORY_ONLY: Fastest. Fails if data doesn't fit in memory (evicts, recomputes).
# MEMORY_AND_DISK: Spills to disk if memory is full. Most common production choice.
# DISK_ONLY: Saves memory but adds disk I/O on every access. Use for large, rarely reused DataFrames.
# MEMORY_AND_DISK_SER: Serialized in memory (less space, more CPU). Good for large DataFrames.
# .cache() is shorthand for MEMORY_AND_DISK (in practice, depends on Spark version)
large_feature_table = (
spark.read.parquet("s3://features/base_features/")
.filter(F.col("date") == processing_date)
.cache() # Will be reused in 5 subsequent joins
)
# Trigger the cache (lazy evaluation means cache is not filled until an action)
large_feature_table.count() # Force materialization
# Use the cached DataFrame in multiple operations
gps_enriched = large_feature_table.join(gps_features, on="vehicle_id")
route_enriched = large_feature_table.join(route_features, on="vehicle_id")
weather_enriched = large_feature_table.join(weather_features, on="vehicle_id")
# Explicit persist with custom storage level
important_intermediate = (
base_df
.groupBy("customer_id")
.agg(F.count("*").alias("total_transactions"))
.persist(StorageLevel.MEMORY_AND_DISK)
)
# Always unpersist when done - releases memory for other operations
large_feature_table.unpersist()
important_intermediate.unpersist()
When to cache:
- The DataFrame is used in 3 or more downstream operations
- Recomputation is expensive (involves a long chain of transformations or a large shuffle)
- The DataFrame fits comfortably in executor memory (less than 50% of total storage memory)
When NOT to cache:
- The DataFrame is only used once - caching adds overhead with no benefit
- The DataFrame is too large for available memory - it will spill to disk and be slower than just recomputing from Parquet
- The source data is already in Parquet with column pruning and predicate pushdown - reading from Parquet is often faster than reading from a spilled cache
Bucketing: Eliminating Shuffle Permanently
Bucketing pre-sorts data by join key when writing to disk. When two tables bucketed on the same key and with the same number of buckets are joined, Spark performs a "bucket join" - no shuffle required, ever.
# ── Write tables with bucketing ────────────────────────────────────────────────
# This is a one-time cost when writing the table.
# The benefit is paid back on every join thereafter.
# Write the transactions table bucketed by customer_id
(
transactions_df
.write
.mode("overwrite")
.bucketBy(256, "customer_id") # 256 buckets - must match the join partner
.sortBy("customer_id") # Optional but improves merge performance
.saveAsTable("transactions_bucketed")
)
# Write customer features bucketed by customer_id with same bucket count
(
customer_features_df
.write
.mode("overwrite")
.bucketBy(256, "customer_id") # Must use same number of buckets!
.sortBy("customer_id")
.saveAsTable("customer_features_bucketed")
)
# ── Join without shuffle ────────────────────────────────────────────────────────
# Read the bucketed tables
txn = spark.table("transactions_bucketed")
features = spark.table("customer_features_bucketed")
# This join produces NO Exchange (shuffle) node in the physical plan
result = txn.join(features, on="customer_id", how="left")
# Verify: the plan should show BucketedScan + SortMergeJoin with no Exchange
result.explain("formatted")
# ── Checking if bucket join is being used ──────────────────────────────────────
def check_bucket_join_used(df: DataFrame) -> None:
"""
Verify the join is using bucket join (no shuffle).
Look for: BucketedScan in plan (good), no Exchange before SortMergeJoin (good).
If you see Exchange: buckets may not match or bucketing is being ignored.
"""
plan = df._jdf.queryExecution().executedPlan().toString()
has_bucket_scan = "BucketedScan" in plan
# Count Exchange nodes (each = one shuffle)
exchange_count = plan.count("Exchange")
print(f"BucketedScan present: {has_bucket_scan}")
print(f"Exchange (shuffle) nodes: {exchange_count}")
if has_bucket_scan and exchange_count == 0:
print("OPTIMAL: Bucket join with no shuffle.")
elif has_bucket_scan and exchange_count > 0:
print("SUBOPTIMAL: Bucket scan present but shuffles still occurring. "
"Check bucket counts match and tables are in same catalog.")
else:
print("NOT USING BUCKET JOIN: Falls back to sort-merge with shuffle.")
If you bucket transactions by 256 buckets and customer_features by 128 buckets, Spark cannot use a bucket join - it falls back to sort-merge with a full shuffle. The bucket count is fixed at write time and must match on both sides of every join you want to optimize. Standardize on a few bucket counts (64, 128, 256, 512) across your feature tables and document this as a team convention.
Z-Ordering (Delta Lake): Optimizing Range Queries
For ML feature tables stored in Delta Lake, Z-ordering co-locates related data - data with similar values in the Z-ordered column is physically adjacent in the same Parquet files:
from delta.tables import DeltaTable
# Z-order by the columns most frequently used in predicates
# This is run after writing the table (or as part of OPTIMIZE)
delta_table = DeltaTable.forPath(spark, "s3://features/ml_feature_table/")
(
delta_table
.optimize()
.executeZOrderBy("date", "region") # Most common filter columns
)
# Practical impact: reading features for a specific date and region
# Before Z-ordering: scans all Parquet files in the table (full scan)
# After Z-ordering: skips ~85-95% of files via file-level statistics
filtered = (
spark.read
.format("delta")
.load("s3://features/ml_feature_table/")
.filter(
(F.col("date") == "2024-01-15") &
(F.col("region") == "NORTHEAST")
)
)
# Spark UI: "Files Pruned" should be >> "Files Read" after Z-ordering
Columnar Execution: Parquet Predicate Pushdown
Parquet stores data in columnar format with column-level and row-group-level statistics. Spark uses these for predicate pushdown (skipping files and row groups that cannot satisfy the filter) and column pruning (reading only columns referenced in the query).
# Parquet predicate pushdown - happens automatically for simple predicates
# on indexed columns. Check: Spark UI → SQL tab → "Files Scanned" vs "Files Skipped"
# Column pruning - read only what you need
efficient_read = (
spark.read
.parquet("s3://data/transactions/")
.select("customer_id", "amount", "merchant_category", "transaction_date")
# Spark reads ONLY these 4 columns from the Parquet files
# Not all 50 columns in the table
)
# Partition pruning - for partitioned tables, filter on partition key first
partitioned_read = (
spark.read
.parquet("s3://features/ml_features/")
.filter(F.col("date") == "2024-01-15") # Partition key filter
# Spark reads only the date=2024-01-15 directory
)
# Verify pushdown is working
partitioned_read.explain("extended")
# Look for: "PushedFilters: [IsNotNull(date), EqualTo(date,2024-01-15)]"
# and: "PartitionFilters: [isnotnull(date#0), (date#0 = 2024-01-15)]"
Kryo Serialization
Java's default serialization is verbose and slow. Kryo is 3-5x faster and produces 2-3x smaller serialized objects:
spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
spark.conf.set("spark.kryo.registrationRequired", "false") # Permissive mode
# For best performance, register frequently serialized classes
spark.conf.set(
"spark.kryo.classesToRegister",
"com.example.TransactionRecord,com.example.VehicleEvent"
)
Performance Tuning Checklist
A systematic checklist organized by category. Work through these in order - address data distribution first, then joins, then memory, then configuration.
Data Distribution
- Run
detect_key_skew()on all join keys before optimizing anything else - If any key > 10% of data: apply salting or AQE skew join handling
- Check partition sizes: target 100-256MB per partition after shuffle
- Repartition by join key before expensive joins to co-locate data
- Use
repartitionByRange()for range-based aggregations (time windows)
Joins
- Check if any small tables (under 100MB) can be broadcast - biggest win possible
- Increase
spark.sql.autoBroadcastJoinThresholdto at least 100MB - For large joins on stable tables: implement bucket joins (eliminates shuffle permanently)
- Enable AQE for automatic skew handling and dynamic broadcast switching
- Review join order: put the most selective filter first to reduce data volume early
Memory
- Check GC time in Spark UI - if > 10% of task time, reduce cached data
- Only cache DataFrames used in 3+ downstream operations
- Always call
.unpersist()when a cached DataFrame is no longer needed - For large executors (>32GB): enable G1GC for better garbage collection
- Consider off-heap memory for workloads with many large intermediate DataFrames
Configuration
- Set
spark.sql.shuffle.partitionsbased on data volume (target 128-256MB/partition) - Enable AQE:
spark.sql.adaptive.enabled = true - Enable Kryo serialization for 2-3x faster task serialization
- Set
spark.dynamicAllocation.enabled = truefor automatic executor scaling - Enable Delta Lake Z-ordering on frequently filtered columns
Monitoring
- Check Spark UI stage list: identify the slowest stage first
- Review task duration histogram for each slow stage: uniform = good, bimodal = skew
- Check spill columns: if nonzero, executor memory is too small or partitions too large
- Track shuffle read bytes: primary indicator of join and aggregation cost
Performance Tuning Decision Tree
Common Mistakes
# WRONG: Increasing executor memory or shuffle partitions without looking at Spark UI first
spark.conf.set("spark.executor.memory", "64g") # Won't help if problem is data skew
spark.conf.set("spark.sql.shuffle.partitions", "2000") # Won't help if problem is GC
# RIGHT: Open Spark UI → identify the slowest stage → examine task duration histogram
# → find the actual bottleneck → apply the targeted fix
Premature optimization based on intuition is often wrong. The Spark UI shows you what is actually happening. Use it before writing a single line of tuning code.
.unpersist()# WRONG - DataFrame stays cached indefinitely, consuming memory
df_features = base_df.join(feature_source, on="id").cache()
df_features.count() # Materializes cache
# ... use df_features in joins ...
# Never called df_features.unpersist()
# RIGHT - release cache when done
df_features = base_df.join(feature_source, on="id").cache()
df_features.count()
# ... use df_features ...
df_features.unpersist() # Frees executor memory for the next stage
A long-running job with many cached DataFrames will exhaust executor memory, causing eviction, recomputation, and GC pressure - ultimately slower than not caching at all.
Spark's default spark.sql.shuffle.partitions = 200 is correct for datasets around 20-50GB. For 1TB jobs, 200 partitions means each post-shuffle partition is ~5GB - far too large for a single task, causing massive spill and OOM errors. Always calculate target partition count as: data_size_bytes / (128MB * 1024 * 1024). With AQE enabled, set this high (2000-10000) and let AQE coalesce down to optimal size.
# WRONG - bucket counts differ; Spark cannot use bucket join
(large_df.write.bucketBy(256, "customer_id").saveAsTable("large_bucketed"))
(small_df.write.bucketBy(128, "customer_id").saveAsTable("small_bucketed")) # Different!
# Spark falls back to sort-merge with a full shuffle - bucket optimization wasted
spark.table("large_bucketed").join(spark.table("small_bucketed"), on="customer_id")
# RIGHT - same bucket count on both sides
(large_df.write.bucketBy(256, "customer_id").saveAsTable("large_bucketed"))
(small_df.write.bucketBy(256, "customer_id").saveAsTable("small_bucketed")) # Same
Interview Q&A
Q: What is data skew in Spark and how do you fix it?
A: Data skew occurs when some partitions contain far more data than others due to uneven key distribution. In a groupBy or join, Spark partitions data by hash of the key. If a key like "GROCERY" appears in 30% of all rows, the task processing that partition receives 30% of the data while all other tasks process tiny fractions. Since a stage waits for all tasks to complete, those few overloaded tasks determine total stage duration.
Fixes in order of preference: first, try broadcast() if the small side of the join is under 100-200MB - eliminates the shuffle entirely. Second, enable AQE skew join handling, which detects and splits skewed partitions automatically at runtime. Third, apply salting manually - add a random prefix (0..N) to the hot key in the large table, explode the small table to have N copies, join on the salted key. The salt distributes the hot-key work across N partitions.
Q: What is Adaptive Query Execution and what problems does it solve?
A: AQE re-plans the query at runtime using actual data statistics collected during execution, rather than relying on the optimizer's pre-execution estimates (which can be wrong by orders of magnitude for skewed data or poor table statistics).
Three specific problems it solves: (1) Dynamic partition coalescing - after a shuffle produces 200 small partitions, AQE merges them into fewer, larger partitions, reducing task scheduling overhead. (2) Skew join handling - detects partitions that are significantly larger than the median and splits them into sub-partitions, processing the hot-key work in parallel. (3) Dynamic join strategy switching - if runtime statistics show one side of a join is much smaller than estimated, AQE can switch from a sort-merge join to a broadcast join mid-query.
Enable with spark.sql.adaptive.enabled=true (default in Spark 3.2+).
Q: When should you use a broadcast join vs a sort-merge join?
A: Use a broadcast join when one side of the join is small enough to fit in executor memory - typically under 100-200MB. The entire small table is sent to every executor, and each executor performs the join locally against its partition of the large table. No shuffle occurs. This is the most impactful join optimization possible.
Use a sort-merge join (the default) when both sides are too large to broadcast. Spark shuffles both tables by the join key, sorts them, and merges. This produces a network shuffle proportional to the size of both tables.
The threshold is configurable: spark.sql.autoBroadcastJoinThreshold. For ML pipelines with large dimension tables (region mappings, merchant metadata, customer segments), I typically set this to 256MB and explicitly use the broadcast() hint for tables I know are small even if Spark's statistics underestimate them.
Q: How do you debug a Spark job that's running slow?
A: Always start with the Spark UI, not code changes. The systematic approach:
- Open Spark UI → Jobs tab → find the slowest-running job.
- Click into the job → Stages tab → find the slowest stage (highest duration).
- Click into the stage → scroll to the task duration histogram. If tasks are uniformly distributed, the problem is total data volume. If a few tasks take 100x longer than the median, the problem is data skew.
- Check the stage metrics: if "Shuffle Spill (Memory)" or "Shuffle Spill (Disk)" is nonzero, you have partitions too large for executor memory.
- Check the Executors tab: if GC time is more than 10% of task time, you have memory pressure from over-caching or too many objects in executor heap.
Then apply the targeted fix: skew → salting or AQE; spill → increase partition count or executor memory; GC pressure → reduce cached data.
Q: What is bucketing in Spark and how does it eliminate shuffle?
A: Bucketing pre-partitions data by the join key when writing to disk. Data with the same key always ends up in the same bucket file. When two tables are bucketed on the same key with the same number of buckets, Spark knows that row customer_id=X from table A is in the same numbered bucket as row customer_id=X from table B. It can read the corresponding bucket files from each table and join them locally without any shuffle.
The cost: bucketing must be done at write time. The benefit: every subsequent join on that key is shuffle-free, permanently. For ML feature pipelines that join the same dimension tables on every nightly run, bucket joins are extremely high-value - write once, benefit on every future run. The critical constraint: both tables must be bucketed on the same key with the same bucket count, and tables must be in the same catalog (Hive metastore or Glue).
Q: What is the difference between .cache() and .persist(DISK_ONLY)?
A: .cache() is equivalent to .persist(MEMORY_AND_DISK) - data is stored in executor memory if it fits, and spills to disk if memory is exhausted. First access reads from source; subsequent accesses read from memory (fast) or disk cache (slower than memory, faster than recomputation).
.persist(DISK_ONLY) stores exclusively on disk - no in-memory copy. Every access goes through disk I/O. This sounds slow and usually is for hot DataFrames. But it is useful when: (1) the DataFrame is too large to fit in any reasonable memory configuration but too expensive to recompute, (2) the DataFrame is accessed infrequently and you want to save memory for other operations, or (3) you have fast local NVMe storage where disk I/O is faster than recomputing complex aggregations from S3.
In practice, MEMORY_AND_DISK covers most cases well. Use DISK_ONLY specifically when you need to preserve a checkpoint in the execution graph without consuming executor heap memory.
Module 02 complete. Next: Module 03 - Stream Processing for ML: real-time feature computation, Kafka, Flink, and online feature serving.
