Skip to main content

:::tip 🎮 Interactive Playground Visualize this concept: Try the Data Quality Checks demo on the EngineersOfAI Playground - no code required. :::

Testing Data Pipelines for ML Correctness

Reading time: ~40 min | Production relevance: Critical | Roles: Data Engineer, ML Engineer, MLOps Engineer


The Label Leakage Incident

It is a quarterly business review at a credit scoring fintech. The VP of Risk puts up a slide: loan default rates on the "AI-approved" cohort are running 3.4 percentage points higher than the statistical model they replaced six months ago. The room goes quiet. The ML team has been citing 94% AUC on their holdout set. The risk team is pointing at real-world loss data. These two facts cannot both be true.

Three days of investigation later, a data engineer finds the bug. Three months ago, during a schema migration of the transactions table, a column called days_since_last_missed_payment was inadvertently computed using the loan outcome date rather than the loan application date. For borrowers who later defaulted, this column contained a value of 0 - effectively encoding the label into the feature. The model learned to predict from the future. It achieved 94% AUC on historical data because the answer was encoded in the features. In production, where the future is unknown, the model had learned nothing real.

This is label leakage. The feature pipeline silently introduced it during a schema change. The model evaluation passed because evaluation used the same leaky features. The model deployed. The company approved bad loans. Now they need to retroactively explain to regulators why their AI system approved borrowers it should have flagged.

A temporal ordering assertion - a single 10-line test - would have caught this immediately. It would have failed the first time the migrated pipeline ran, before any training data was written. The entire incident, its financial impact, and its regulatory exposure would not have happened.

Testing data pipelines is not optional for ML systems. It is how you prevent the most expensive class of bugs in existence: silent data corruption that trains models on wrong signals.


Why Pipeline Testing Is Different

Testing a data pipeline is fundamentally different from testing application software. Understanding the differences helps you design the right test strategy.

Non-determinism: Data pipelines process real-world data that changes. A transformation that worked correctly on January data may produce different behavior on February data due to new values, edge cases, or distribution shifts. You cannot test once and assume correctness forever - you need continuous validation in production.

Data volume: You cannot run full-scale integration tests on terabytes in CI. A Spark job that takes 45 minutes on the full dataset needs a test strategy that uses sampled data while still catching real bugs. The sampling strategy itself is a design decision.

Schema evolution: Upstream schemas change constantly - new columns added, types changed, columns dropped. Your pipeline may silently ignore a dropped column it depended on, producing null output with no error. Schema evolution tests catch this before the model trains on corrupted features.

Environment parity: Production runs against 1TB in a Databricks cluster. CI runs against 10MB on a laptop. Behavior can differ: data type inference on small files may succeed where it fails at scale, memory pressure does not appear in small tests, and distributed execution edge cases only appear with real parallelism.

Implicit assumptions: Pipeline correctness depends on data properties rarely stated explicitly - "customer_id is always present," "event_time is always before label_time," "amount is always positive." These assumptions need to become explicit assertions.


The Testing Pyramid for Data Pipelines

Layer 1 - Unit tests: Test individual transformation functions with tiny, hand-crafted DataFrames. These are fast (seconds), require no cluster, and catch logical errors in transformation code. Run on every commit.

Layer 2 - Integration tests: Run the full pipeline end-to-end on a small, representative sample of real data. Verify that stages connect correctly, outputs have the right shape, and the pipeline handles realistic edge cases. Slower (minutes), run on every PR.

Layer 3 - Data quality tests: Run Great Expectations or Soda Core checks on actual pipeline output data on every pipeline execution. These verify the data itself - not just the code - meets expectations. Catches issues correct code can still produce from bad upstream data.

Layer 4 - Distribution monitoring: Compare output statistics (null rates, cardinality, value distributions) against historical baselines. Catches gradual drift that passes all other tests because the code is correct and data is technically valid, but the distribution has shifted in ways that affect model performance.


Unit Testing PySpark Transformations

The key to unit testing PySpark is keeping transformation logic in pure functions that take DataFrames and return DataFrames. This makes them testable with no infrastructure.

Setting Up the SparkSession Fixture

# conftest.py - shared SparkSession fixture for all tests
import pytest
from pyspark.sql import SparkSession


@pytest.fixture(scope="session")
def spark():
"""Create a SparkSession shared across all tests in the session.

scope='session' means one SparkSession for the entire test run.
SparkSession startup takes 10-15 seconds - per-test startup makes
the suite unusably slow.
"""
spark = (
SparkSession.builder
.master("local[2]") # 2 local threads - enough for tests
.appName("pipeline-tests")
.config("spark.sql.shuffle.partitions", "4") # Low for test speed
.config("spark.ui.enabled", "false") # No UI overhead in CI
.config("spark.driver.memory", "2g")
.getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR") # Suppress noisy INFO logs
yield spark
spark.stop()


@pytest.fixture
def sample_gps_df(spark):
"""Reusable small GPS DataFrame for multiple tests."""
from pyspark.sql.types import (
StructType, StructField, StringType, DoubleType, TimestampType
)
from datetime import datetime

schema = StructType([
StructField("vehicle_id", StringType(), nullable=False),
StructField("latitude", DoubleType(), nullable=True),
StructField("longitude", DoubleType(), nullable=True),
StructField("speed_kmh", DoubleType(), nullable=True),
StructField("event_time", TimestampType(), nullable=False),
])

data = [
# VH-001: 2 moving events, 1 stop (speed=0)
("VH-001", 37.7749, -122.4194, 45.0, datetime(2024, 1, 15, 8, 0, 0)),
("VH-001", 37.7800, -122.4100, 60.0, datetime(2024, 1, 15, 8, 5, 0)),
("VH-001", 37.7850, -122.4050, 0.0, datetime(2024, 1, 15, 8, 10, 0)),
# VH-002: 1 moving event, 1 null-GPS event (counts as stop)
("VH-002", 34.0522, -118.2437, 35.0, datetime(2024, 1, 15, 9, 0, 0)),
("VH-002", None, None, None, datetime(2024, 1, 15, 9, 5, 0)),
]
return spark.createDataFrame(data, schema)

Testing the Feature Transformation

# src/features/gps_features.py - production transformation function
from pyspark.sql import DataFrame
from pyspark.sql import functions as F


def compute_vehicle_stop_features(df: DataFrame) -> DataFrame:
"""
Compute stop-related features per vehicle per day.

A stop: speed_kmh is 0 or null.
Output columns:
vehicle_id, date, total_events, stop_count,
stop_rate, has_high_stop_rate (stop_rate > 0.3)
"""
df_with_stop = df.withColumn(
"is_stop",
F.when(
F.col("speed_kmh").isNull() | (F.col("speed_kmh") == 0.0),
1
).otherwise(0)
)

return (
df_with_stop
.groupBy("vehicle_id", F.to_date("event_time").alias("date"))
.agg(
F.count("*").alias("total_events"),
F.sum("is_stop").alias("stop_count"),
(F.sum("is_stop") / F.count("*")).alias("stop_rate"),
)
.withColumn(
"has_high_stop_rate",
F.col("stop_rate") > 0.3
)
)
# tests/unit/test_gps_features.py
import pytest
from pyspark.sql import functions as F
from src.features.gps_features import compute_vehicle_stop_features


class TestComputeVehicleStopFeatures:

def test_stop_count_is_correct(self, spark, sample_gps_df):
"""VH-001 has 1 stop (speed=0), VH-002 has 1 stop (null GPS)."""
result = compute_vehicle_stop_features(sample_gps_df)
result_dict = {
row["vehicle_id"]: row["stop_count"]
for row in result.collect()
}

assert result_dict["VH-001"] == 1, "VH-001: exactly one speed=0 event"
assert result_dict["VH-002"] == 1, "VH-002: exactly one null-speed event"

def test_stop_rate_within_bounds(self, spark, sample_gps_df):
"""Stop rate must be between 0.0 and 1.0 for all vehicles."""
result = compute_vehicle_stop_features(sample_gps_df)

invalid = (
result
.filter((F.col("stop_rate") < 0) | (F.col("stop_rate") > 1))
.count()
)
assert invalid == 0, "stop_rate must be in [0, 1] for all vehicles"

def test_output_schema_has_required_columns(self, spark, sample_gps_df):
"""Output must contain the expected columns."""
result = compute_vehicle_stop_features(sample_gps_df)
actual = {f.name for f in result.schema.fields}

required = {
"vehicle_id", "date", "total_events",
"stop_count", "stop_rate", "has_high_stop_rate"
}
assert required.issubset(actual), (
f"Missing columns: {required - actual}"
)

def test_no_nulls_in_key_columns(self, spark, sample_gps_df):
"""vehicle_id and date must never be null in output."""
result = compute_vehicle_stop_features(sample_gps_df)
null_keys = (
result
.filter(F.col("vehicle_id").isNull() | F.col("date").isNull())
.count()
)
assert null_keys == 0

def test_one_row_per_vehicle_per_day(self, spark, sample_gps_df):
"""Output has exactly one row per (vehicle_id, date)."""
result = compute_vehicle_stop_features(sample_gps_df)
total = result.count()
distinct = result.select("vehicle_id", "date").distinct().count()
assert total == distinct, "Duplicate (vehicle_id, date) rows detected"

def test_vehicle_with_all_null_speed(self, spark):
"""Vehicle with all-null speed readings: stop_rate = 1.0."""
from pyspark.sql.types import (
StructType, StructField, StringType, DoubleType, TimestampType
)
from datetime import datetime

schema = StructType([
StructField("vehicle_id", StringType()),
StructField("latitude", DoubleType()),
StructField("longitude", DoubleType()),
StructField("speed_kmh", DoubleType()),
StructField("event_time", TimestampType()),
])
data = [
("VH-OFFLINE", None, None, None, datetime(2024, 1, 15, 10, 0, 0)),
("VH-OFFLINE", None, None, None, datetime(2024, 1, 15, 10, 5, 0)),
]
df = spark.createDataFrame(data, schema)
result = compute_vehicle_stop_features(df)

row = result.filter(F.col("vehicle_id") == "VH-OFFLINE").first()
assert row["stop_count"] == 2
assert row["stop_rate"] == 1.0
assert row["has_high_stop_rate"] is True

Great Expectations: Data Quality as Code

Great Expectations (GE) defines data quality expectations - assertions about what output data should look like - as code. These form a suite that runs against every pipeline execution.

# expectations/build_ml_feature_suite.py
import great_expectations as gx
from great_expectations.core import ExpectationSuite, ExpectationConfiguration

context = gx.get_context()

suite = context.add_expectation_suite(
expectation_suite_name="ml_features_nightly"
)


def build_ml_feature_expectations(suite: ExpectationSuite) -> ExpectationSuite:
"""Define the full expectation suite for the nightly ML feature table."""

# ── Row count: expected fleet-scale volume per day ─────────────────────────
suite.add_expectation(ExpectationConfiguration(
expectation_type="expect_table_row_count_to_be_between",
kwargs={"min_value": 500_000, "max_value": 2_000_000},
))

# ── Primary key completeness ───────────────────────────────────────────────
for column in ["vehicle_id", "date", "region"]:
suite.add_expectation(ExpectationConfiguration(
expectation_type="expect_column_values_to_not_be_null",
kwargs={"column": column},
))

# ── Feature completeness (allow GPS outage noise) ──────────────────────────
for column in ["avg_speed_kmh", "stop_rate", "route_deviation_km"]:
suite.add_expectation(ExpectationConfiguration(
expectation_type="expect_column_values_to_not_be_null",
kwargs={"column": column, "mostly": 0.95},
))

# ── Speed range: trucks max ~120 km/h; 150 allows GPS noise ───────────────
suite.add_expectation(ExpectationConfiguration(
expectation_type="expect_column_values_to_be_between",
kwargs={
"column": "avg_speed_kmh",
"min_value": 0.0,
"max_value": 150.0,
"mostly": 0.99,
},
))

# ── Stop rate: must be a valid fraction ────────────────────────────────────
suite.add_expectation(ExpectationConfiguration(
expectation_type="expect_column_values_to_be_between",
kwargs={"column": "stop_rate", "min_value": 0.0, "max_value": 1.0},
))

# ── Region: only known values (guards against upstream enum corruption) ────
suite.add_expectation(ExpectationConfiguration(
expectation_type="expect_column_values_to_be_in_set",
kwargs={
"column": "region",
"value_set": [
"NORTHEAST", "SOUTHEAST", "MIDWEST",
"SOUTHWEST", "NORTHWEST", "UNKNOWN"
],
},
))

# ── Temporal consistency: feature_date must match partition_date ───────────
# Catches cross-partition contamination bugs
suite.add_expectation(ExpectationConfiguration(
expectation_type="expect_column_pair_values_to_be_equal",
kwargs={
"column_A": "feature_date",
"column_B": "partition_date",
},
))

return suite


context.save_expectation_suite(build_ml_feature_expectations(suite))

Running GE in the Airflow DAG

# Called as PythonOperator in the Airflow DAG
import great_expectations as gx


def run_great_expectations_checkpoint(**context):
"""Run GE checkpoint; fail the task if any expectation fails."""
execution_date = context["ds"]
gx_context = gx.get_context()

result = gx_context.run_checkpoint(
checkpoint_name="ml_features_nightly",
batch_request={
"datasource_name": "spark_datasource",
"data_connector_name": "inferred_data_connector",
"data_asset_name": "ml_feature_table",
"batch_spec_passthrough": {
"reader_options": {"filter": f"date = '{execution_date}'"}
}
},
)

all_passed = all(vr.success for vr in result.list_validation_results())
failed = [
str(vr.results)
for vr in result.list_validation_results()
if not vr.success
]

context["ti"].xcom_push(key="validation_passed", value=all_passed)
context["ti"].xcom_push(key="failed_expectations", value=failed)

if not all_passed:
raise ValueError(
f"Data quality FAILED for {execution_date}. "
f"Failed checks: {failed}"
)

return all_passed

Testing Incremental Pipelines: Idempotency

An idempotent pipeline produces identical output when run once or multiple times with the same input. This is critical because Airflow retries failed tasks - a retry must be safe.

# tests/integration/test_idempotency.py
import pytest
from pyspark.sql import functions as F
from src.pipeline.feature_pipeline import run_feature_pipeline


class TestPipelineIdempotency:

def test_feature_table_idempotent(self, spark, tmp_path):
"""Running the pipeline twice for the same date produces identical output."""
date = "2024-01-15"
output_1 = str(tmp_path / "features_run1")
output_2 = str(tmp_path / "features_run2")

run_feature_pipeline(spark, date=date, output_path=output_1)
run_feature_pipeline(spark, date=date, output_path=output_2)

df1 = spark.read.parquet(output_1).sort("vehicle_id", "date")
df2 = spark.read.parquet(output_2).sort("vehicle_id", "date")

# Row count must be identical
assert df1.count() == df2.count(), (
f"Row counts differ: run1={df1.count()}, run2={df2.count()}"
)

# Schema must be identical
assert df1.schema == df2.schema, "Schemas differ between runs"

# Data must be identical (symmetric difference must be empty)
diff_count = df1.subtract(df2).union(df2.subtract(df1)).count()
assert diff_count == 0, (
f"{diff_count} rows differ between runs - pipeline is not idempotent"
)

def test_no_double_counting_on_retry(self, spark, tmp_path):
"""Re-running must overwrite, not append. Row count must stay stable."""
date = "2024-01-15"
output_path = str(tmp_path / "features")

run_feature_pipeline(spark, date=date, output_path=output_path)
count_after_first = spark.read.parquet(output_path).count()

run_feature_pipeline(spark, date=date, output_path=output_path)
count_after_second = spark.read.parquet(output_path).count()

assert count_after_first == count_after_second, (
f"Row count changed: {count_after_first}{count_after_second}. "
f"Pipeline is appending instead of overwriting on re-run."
)

Testing for Label Leakage

Label leakage: a feature is computed using information from after the label event. The most critical test in any ML feature pipeline.

# tests/leakage/test_label_leakage.py
import pytest
from datetime import datetime
from pyspark.sql import functions as F
from pyspark.sql.types import (
StructType, StructField, StringType, DoubleType, TimestampType
)
from src.features.credit_features import (
compute_credit_features,
compute_days_since_last_missed_payment,
)


class TestLabelLeakage:
"""
Temporal invariant: for every training row,
max(feature_event_time) < loan_application_date

Violation = label leakage.
"""

def test_no_features_use_post_application_data(self, spark):
"""Feature events must all precede the loan application date."""
application_date = datetime(2024, 1, 15)

transactions_schema = StructType([
StructField("customer_id", StringType()),
StructField("transaction_date", TimestampType()),
StructField("amount", DoubleType()),
StructField("missed_payment", DoubleType()),
])
# CUST-001: all transactions before application - valid
# CUST-002: one transaction after application - leakage
transactions_data = [
("CUST-001", datetime(2024, 1, 10), 500.0, 0.0),
("CUST-001", datetime(2024, 1, 12), 200.0, 0.0),
("CUST-001", datetime(2024, 1, 14), 100.0, 0.0),
("CUST-002", datetime(2024, 1, 10), 300.0, 0.0),
("CUST-002", datetime(2024, 1, 16), 150.0, 1.0), # FUTURE - leakage!
]
transactions_df = spark.createDataFrame(transactions_data, transactions_schema)

labels_schema = StructType([
StructField("customer_id", StringType()),
StructField("application_date", TimestampType()),
StructField("default_label", DoubleType()),
])
labels_data = [
("CUST-001", application_date, 0.0),
("CUST-002", application_date, 1.0),
]
labels_df = spark.createDataFrame(labels_data, labels_schema)

features_df = compute_credit_features(transactions_df, labels_df)

# Join and check: max feature event time must be before application date
check_df = features_df.join(labels_df, on="customer_id", how="inner")

leakage_rows = (
check_df
.filter(F.col("max_feature_event_time") >= F.col("application_date"))
.count()
)

assert leakage_rows == 0, (
f"LABEL LEAKAGE: {leakage_rows} rows have feature data from after "
f"the application date. This will cause the model to train on future "
f"information and fail silently in production."
)

def test_days_since_missed_payment_uses_application_date(self, spark):
"""
This exact feature caused the production incident.

Verify it computes days relative to application_date,
NOT loan_outcome_date (which is known only after the fact).
"""
application_date = datetime(2024, 1, 15)
# outcome_date would be ~90 days later - after loan decision
last_missed_payment = datetime(2024, 1, 10) # 5 days before application

schema = StructType([
StructField("customer_id", StringType()),
StructField("application_date", TimestampType()),
StructField("last_missed_payment_date", TimestampType()),
])
data = [("CUST-TEST", application_date, last_missed_payment)]
df = spark.createDataFrame(data, schema)

result = compute_days_since_last_missed_payment(df)
days = (
result
.filter(F.col("customer_id") == "CUST-TEST")
.select("days_since_last_missed_payment")
.first()[0]
)

# Correct: application_date - last_missed = 5 days
# Leaked: outcome_date - last_missed = 95 days (uses future outcome date)
assert days == 5, (
f"Expected 5 days (application-relative), got {days}. "
f"Feature may be using outcome_date instead of application_date - "
f"this is the production label leakage pattern."
)

Schema Evolution Testing

Upstream schemas change constantly. Frozen schema contracts prevent silent corruption:

# tests/schema/test_schema_evolution.py
import pytest
from pyspark.sql.types import (
StructType, StructField, StringType, DoubleType
)
from src.pipeline.feature_pipeline import compute_features_from_raw

# The expected schema - treated as a contract, reviewed in PRs like code
EXPECTED_FEATURE_SCHEMA = StructType([
StructField("vehicle_id", StringType(), nullable=False),
StructField("date", StringType(), nullable=False),
StructField("avg_speed_kmh", DoubleType(), nullable=True),
StructField("stop_rate", DoubleType(), nullable=True),
StructField("route_deviation_km", DoubleType(), nullable=True),
StructField("region", StringType(), nullable=True),
])


class TestSchemaEvolution:

def test_output_contains_all_expected_columns(self, spark, sample_gps_df):
"""No expected column may be missing from output."""
result = compute_features_from_raw(sample_gps_df)
actual_names = {f.name for f in result.schema.fields}
expected_names = {f.name for f in EXPECTED_FEATURE_SCHEMA}

missing = expected_names - actual_names
assert not missing, (
f"Missing columns that downstream models depend on: {missing}. "
f"If upstream removed a column, update feature computation to handle it."
)

def test_column_types_unchanged(self, spark, sample_gps_df):
"""Column types must not change - type changes silently corrupt downstream joins."""
result = compute_features_from_raw(sample_gps_df)
actual_types = {f.name: f.dataType for f in result.schema}
expected_types = {f.name: f.dataType for f in EXPECTED_FEATURE_SCHEMA}

mismatches = [
f"'{col}': expected {expected_types[col]}, got {actual_types[col]}"
for col in expected_types
if col in actual_types and actual_types[col] != expected_types[col]
]

assert not mismatches, (
"Type mismatches detected:\n" + "\n".join(mismatches)
)

def test_not_null_columns_have_no_nulls(self, spark, sample_gps_df):
"""NOT NULL contract columns must never contain nulls in output."""
result = compute_features_from_raw(sample_gps_df)
non_nullable = [f.name for f in EXPECTED_FEATURE_SCHEMA if not f.nullable]

for col in non_nullable:
null_count = result.filter(result[col].isNull()).count()
assert null_count == 0, (
f"Column '{col}' declared NOT NULL contains {null_count} null rows"
)

CI/CD Integration: Data Tests on Every PR

# .gitlab-ci.yml - data pipeline test stages
stages:
- lint
- unit-test
- integration-test

# Fast unit tests - every commit
pipeline-unit-tests:
stage: unit-test
image: python:3.11-slim
variables:
PYSPARK_PYTHON: python3
SPARK_LOCAL_IP: "127.0.0.1"
before_script:
- pip install pyspark==3.5.0 pytest pytest-cov great-expectations==0.18.0
- pip install -r requirements.txt
script:
- pytest tests/unit/ -v --tb=short --cov=src --cov-report=term-missing
- pytest tests/schema/ -v --tb=short
- pytest tests/leakage/ -v --tb=short
coverage: '/TOTAL.*\s+(\d+%)$/'
rules:
- if: $CI_PIPELINE_SOURCE == "merge_request_event"
- if: $CI_COMMIT_BRANCH == "main"

# Integration tests on sampled data (~5 min)
pipeline-integration-tests:
stage: integration-test
image: python:3.11-slim
variables:
TEST_DATA_SAMPLE_FRACTION: "0.01"
TEST_DATE: "2024-01-15"
script:
- python scripts/pull_test_sample.py
--date $TEST_DATE
--fraction $TEST_DATA_SAMPLE_FRACTION
--output tests/fixtures/sample
- pytest tests/integration/ tests/idempotency/ -v --tb=short
rules:
- if: $CI_PIPELINE_SOURCE == "merge_request_event"
allow_failure: false
# scripts/pull_test_sample.py - stratified sample for CI integration tests
import argparse
from pyspark.sql import SparkSession


def pull_stratified_sample(date: str, fraction: float, output_path: str):
"""
Pull a small stratified sample from production data.

Stratified by region ensures rare categories are present in tests.
seed=42 makes sampling deterministic - same sample every run.
"""
spark = SparkSession.builder.appName("test-sampler").getOrCreate()

df = spark.read.parquet(f"s3://data-lake/telemetry/{date}/")

sample = df.sampleBy(
col="region",
fractions={region: fraction for region in [
"NORTHEAST", "SOUTHEAST", "MIDWEST", "SOUTHWEST", "NORTHWEST"
]},
seed=42,
)

sample.write.mode("overwrite").parquet(output_path)
print(f"Sampled {sample.count():,} rows ({fraction*100:.1f}%) → {output_path}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--date", required=True)
parser.add_argument("--fraction", type=float, default=0.01)
parser.add_argument("--output", default="tests/fixtures/sample")
args = parser.parse_args()
pull_stratified_sample(args.date, args.fraction, args.output)

Soda Core: YAML-Based Quality Checks

Soda Core provides a more readable alternative to GE - checks are defined in YAML, reviewed in PRs like configuration:

# checks/ml_features.yml
checks for ml_feature_table:

# Volume - sanity check on fleet-scale data
- row_count between 500000 and 2000000

# Primary key completeness
- missing_count(vehicle_id) = 0:
name: vehicle_id must never be null
- missing_count(date) = 0:
name: date must never be null

# Feature completeness (5% null budget for GPS outages)
- missing_percent(avg_speed_kmh) < 5%:
name: avg_speed_kmh max 5% null

# Value ranges
- invalid_percent(avg_speed_kmh) < 1%:
name: avg_speed_kmh within valid range
valid min: 0
valid max: 150

- invalid_percent(stop_rate) = 0%:
name: stop_rate must be 0.0-1.0
valid min: 0.0
valid max: 1.0

# Data freshness - must be updated within last 25 hours
- freshness(feature_date) < 25h:
name: Feature table must be fresh

# Uniqueness - no duplicate (vehicle_id, date)
- duplicate_count(vehicle_id, date) = 0:
name: No duplicate vehicle-date combinations
# Running Soda checks programmatically in a pipeline
from soda.scan import Scan

def run_soda_checks(date: str) -> bool:
scan = Scan()
scan.set_data_source_name("spark_datasource")
scan.add_configuration_yaml_str("""
data_sources:
spark_datasource:
type: spark
method: hive
""")
scan.add_sodacl_yaml_files("checks/ml_features.yml")
scan.set_scan_definition_name(f"ml-features-{date}")
scan.execute()

if scan.has_check_failures():
failed = scan.get_scan_results()
print(f"Soda checks FAILED: {failed}")
return False

print(f"All Soda checks passed for {date}")
return True

Common Mistakes

Testing only the happy path
# WRONG - tests with clean, hand-curated data that never fails
def test_compute_features(spark):
# No nulls, no zeros, no edge cases, no realistic noise
# This test will always pass and catches nothing
result = compute_features(clean_df)
assert result.count() > 0

# RIGHT - test with realistic production data shapes
def test_compute_features_with_realistic_data(spark):
# Include: null GPS readings, zero-speed events, duplicate timestamps,
# unknown regions, extreme values, minimum valid fleet size
# Test the cases that actually cause production bugs
Not testing idempotency on every pipeline that writes data

The most common production bug in batch ML: a pipeline using mode("append") instead of mode("overwrite").partitionBy("date"). A single Airflow retry doubles the training data. The model trains on duplicate features. Every evaluation metric looks fine - precision, recall, AUC - but the model has learned twice as much from certain examples. Write an explicit idempotency test for every pipeline that writes to persistent storage.

Skipping integration tests because they are slow

Unit tests do not catch integration bugs - wrong join keys, incorrect partition filters, column name mismatches between stages, join fanout producing 10x expected rows. A 5-minute integration test on 1% sampled data catches all of these. The cost of finding them in production - corrupted features, retrained models, investigation time, incident postmortems - is measured in days and thousands of dollars.

Not version-controlling expectation suites

If your GE expectation suite lives only in the GE data docs server (not in source control), a wrong expectation change silently lowers your quality bar. Version-control expectation suites as Python or YAML files, review changes in PRs, and treat them as code. The expectations ARE the data contract - they deserve the same rigor as the transformation code.

Use seed=42 for all test sampling

Make test data samples reproducible. Use seed=42 (or any fixed seed) for all sample(), sampleBy(), and random operations in test data generation. Flaky tests - ones that fail occasionally depending on which rows were sampled - are worse than no tests. Engineers start ignoring test failures and the CI gate becomes meaningless.


Interview Q&A

Q: How do you test a PySpark transformation function?

A: Keep transformation logic in pure functions that accept DataFrames and return DataFrames. In tests, create small hand-crafted DataFrames using spark.createDataFrame() with explicit schemas that include every edge case: nulls in optional columns, zeros where division might occur, boundary values for range-limited features, empty DataFrames for the zero-record edge case.

Share one SparkSession across all tests using a scope="session" pytest fixture - SparkSession startup takes 10-15 seconds and per-test startup makes the suite unusably slow. Assert on: specific output values for known inputs, schema correctness (names and types), invariants (no nulls in key columns, values within valid ranges), and the cardinality of the output (one row per entity, not more). The goal is tests that detect when transformation logic changes, not tests that pass regardless of what the code does.

Q: What is an idempotency test and why is it important for batch pipelines?

A: An idempotency test verifies that running a pipeline twice with the same inputs produces the same output as running it once - no doubling, no duplication, no state accumulation from the first run.

It matters because Airflow retries failed tasks from the beginning. If task B fails after writing partial output, Airflow re-runs task B. If the pipeline uses mode("append"), the output table now has partial data from the first attempt plus full data from the second attempt. The ML model trains on corrupted data - some examples appear twice, biasing the model toward their distribution. This is one of the most common production bugs in batch ML systems and one of the hardest to debug because the pipeline logs always show "SUCCESS."

The fix: always write with mode("overwrite") using a deterministic output path that includes the processing date. Idempotency test: run the pipeline twice, assert the output table has the same row count and identical data.

Q: How do you detect label leakage in a feature pipeline?

A: Write a temporal ordering assertion. Join the feature table with the label table on the entity key. For each row, assert that the maximum feature event timestamp is strictly less than the label event timestamp - the application date, decision date, or whatever point in time the prediction must be made.

In PySpark: compute max_feature_event_time in the feature computation, join to labels, filter max_feature_event_time >= label_time, assert count is zero. Write per-feature tests for the highest-risk features - specifically any feature involving "days since," "count of prior," or aggregations of historical events. For each, trace the reference date used in the computation and write a test that verifies it is the prediction date, not any future-dated event.

Q: What is Great Expectations and how does it differ from dbt tests?

A: Great Expectations is a standalone data quality framework for validating data at any point in the stack - raw ingested data, Spark job outputs, warehouse tables, Pandas DataFrames in notebooks. It defines expectations (assertions about data properties) that run against real data and produce human-readable validation reports. It integrates with Airflow, Spark, Pandas, and any data source via datasource connectors.

dbt tests are SQL assertions that run against data in your warehouse as part of the dbt build process. They are tightly integrated with dbt's execution model and most naturally express relational constraints - uniqueness, not-null, accepted values, referential integrity between tables. They run after dbt models execute, not before or during.

The two complement each other. Use GE for raw data validation (before dbt), Spark output validation (outside the warehouse), and complex statistical assertions. Use dbt tests for the final transformed layer inside the warehouse. A mature data platform uses both: GE validates that the data entering the warehouse is correct; dbt tests validate that the transformations produced correct outputs.

Q: How do you balance test coverage with pipeline execution speed in CI?

A: Three strategies:

First, tier your tests by speed and run each tier at the right frequency. Unit tests run in seconds - run them on every commit push. Integration tests run in minutes - run them on every PR merge request. Full data quality checks on real production data run in hours - run them nightly or on main branch merges only.

Second, use sampled data for integration tests. A 1% stratified sample catches the vast majority of real bugs in a fraction of the time. The sampling must be stratified - random sampling under-represents rare categories and edge cases that are exactly what you need to test. Use sampleBy() with explicit fractions per category and a fixed seed for reproducibility.

Third, parallelize test execution with pytest-xdist. The -n auto flag uses all available cores. Separate your SparkSession fixture carefully in parallel mode - each worker needs its own SparkSession instance, not a shared one, or you get race conditions on the JVM.

Q: What are the most common bugs in ML feature pipelines and how do you catch them?

A: Six patterns account for the majority of production incidents:

  1. Label leakage: Feature uses post-prediction-date data. Caught by temporal ordering assertions per feature.
  2. Silent schema change: Upstream column renamed or dropped; pipeline produces nulls with no error. Caught by schema evolution tests comparing output to a frozen contract.
  3. Wrong join type: INNER drops records that should be preserved; LEFT creates nulls that should be errors. Caught by integration tests asserting on specific expected row counts for known inputs.
  4. Append-on-retry: Pipeline uses mode("append"), doubles data on retry. Caught by idempotency tests that run the pipeline twice and compare output.
  5. Off-by-one date partition: Pipeline processes yesterday's data thinking it is today's, or vice versa. Caught by freshness expectations in GE (check max feature date equals processing date) and temporal ordering tests.
  6. Null propagation through joins: A null in a join key causes a left join to produce unexpected nulls in downstream features. Caught by completeness expectations and by comparing null rates against historical baselines.

The pattern across all six: they are silent. The pipeline succeeds. The data looks plausible. The model trains without errors. The problem surfaces weeks or months later in production metrics. The solution is making implicit assumptions explicit as executable tests.


Next lesson: Performance Tuning Spark - turning a 6-hour, 3,000/dayjobintoa45minute,3,000/day job into a 45-minute, 300/day job.

© 2026 EngineersOfAI. All rights reserved.