Apache Airflow for ML
The Pipeline That Worked Until It Didn't
For eight months, the recommendation team's Airflow pipeline ran every night without incident. Raw interaction data landed in S3. A sequence of PythonOperators featurized it, trained a collaborative filtering model, evaluated against the holdout set, and pushed the new model to the registry. Monday morning standup never included pipeline problems.
Then the data volume doubled. A product launch drove user engagement past all projections. The pipeline that used to complete in 2.5 hours started taking 4.5 hours. At 4 hours, Airflow's task timeout fired, killing the training step. The retry picked up immediately, consumed the same memory, and also timed out. The pipeline marked itself as failed at 5:30am.
But nobody was paged. The on-call rotation covered infrastructure alerts, not Airflow task failures. The failure email went to a shared inbox that nobody monitored overnight. The recommendation system served the previous day's model, then the model from two days ago, then three days ago, as each night's pipeline failed in the same way.
The data scientist who built the pipeline diagnosed the problem in ten minutes once they looked at it: the training task was loading the entire dataset into memory. As data volume grew past the worker's RAM limit, the process was being OOM-killed before the timeout even fired. The fix was straightforward - switch to chunked processing and a larger worker class. But the investigation into why nobody knew for three days took two weeks and changed the team's entire approach to Airflow monitoring.
This is the typical Airflow journey. The framework is powerful and production-proven, but it requires intentional engineering to make it reliable at scale. This lesson gives you the mental model and practical patterns to use Airflow correctly from the start.
:::tip 🎮 Interactive Playground Visualize this concept: Try the ML Pipeline Orchestration demo on the EngineersOfAI Playground - no code required. :::
Why Airflow - and Its Origins
Airflow was created by Maxime Beauchemin at Airbnb in 2014 and open-sourced in 2015. At the time, Airbnb's data team was managing dozens of ETL pipelines using cron and shell scripts. The coordination overhead was becoming untenable: engineers spent more time debugging pipeline ordering failures than building data products.
Beauchemin's insight was that pipelines are code, not configuration. Expressing DAGs as Python code (rather than XML, JSON, or a GUI) meant that pipelines could be version-controlled, reviewed, tested, and refactored using standard software engineering tools. This "pipelines as code" principle became Airflow's defining characteristic.
Airflow graduated to a top-level Apache Software Foundation project in 2019 and released the landmark Airflow 2.0 in December 2020, which brought a new scheduler architecture, TaskFlow API, and significant performance improvements. It is now the most widely deployed pipeline orchestrator in the industry, used at Google, LinkedIn, Robinhood, Twitter, and thousands of other organizations.
Airflow Architecture
Understanding Airflow's architecture is essential for diagnosing production issues.
Scheduler: The brain of Airflow. It continuously parses DAG files, resolves which tasks are ready to run (all dependencies met, no concurrency limit exceeded), and places them in the executor queue. The scheduler does NOT execute tasks - it only decides what should run.
Web Server: Serves the Airflow UI and REST API. Reads from and writes to the metadata database. Stateless - you can restart it without affecting running pipelines.
Metadata Database: PostgreSQL (strongly recommended) or MySQL. Stores DAG definitions, task instances, run states, XCom values, variables, and connections. This is the single source of truth for the system. Back it up.
Executor: Determines how tasks are actually run. The executor is a plugin point - the same DAG code runs on LocalExecutor (single machine), CeleryExecutor (distributed workers), or KubernetesExecutor (K8s pods).
DAG Authoring for ML Pipelines
The Basic ML DAG
from datetime import datetime, timedelta
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
# Default arguments inherited by all tasks
default_args = {
"owner": "ml-team",
"retries": 2,
"retry_delay": timedelta(minutes=5),
"email_on_failure": True,
"execution_timeout": timedelta(hours=4), # always set this
}
with DAG(
dag_id="recommendation_training_pipeline",
default_args=default_args,
description="Daily recommendation model training",
schedule_interval="0 2 * * *", # 2am daily
start_date=datetime(2024, 1, 1),
catchup=False, # don't backfill missed runs
tags=["ml", "recommendations"],
max_active_runs=1, # only one run at a time
) as dag:
def ingest_data(**context):
run_date = context["ds"] # logical date, not wall clock
# pull data for exactly this date
from src.data import ingest
ingest.run(date=run_date)
def validate_data(**context):
run_date = context["ds"]
from src.data import validate
report = validate.run(date=run_date)
if not report["passed"]:
raise ValueError(f"Validation failed: {report['failures']}")
def featurize(**context):
run_date = context["ds"]
from src.features import featurize
featurize.run(date=run_date)
def train_model(**context):
run_date = context["ds"]
from src.training import train
metrics = train.run(date=run_date)
# push metrics to XCom for downstream tasks
context["ti"].xcom_push(key="train_metrics", value=metrics)
def evaluate_model(**context):
run_date = context["ds"]
ti = context["ti"]
# pull metrics from training task
train_metrics = ti.xcom_pull(
task_ids="train_model", key="train_metrics"
)
from src.evaluation import evaluate
eval_report = evaluate.run(date=run_date, train_metrics=train_metrics)
if eval_report["accuracy"] < 0.82:
raise ValueError(
f"Model accuracy {eval_report['accuracy']:.3f} below threshold 0.82"
)
ti.xcom_push(key="eval_report", value=eval_report)
def register_model(**context):
run_date = context["ds"]
ti = context["ti"]
eval_report = ti.xcom_pull(task_ids="evaluate_model", key="eval_report")
from src.registry import register
register.push(date=run_date, eval_report=eval_report)
# Define tasks
t_ingest = PythonOperator(task_id="ingest_data", python_callable=ingest_data)
t_validate = PythonOperator(task_id="validate_data", python_callable=validate_data)
t_featurize = PythonOperator(task_id="featurize", python_callable=featurize)
t_train = PythonOperator(
task_id="train_model",
python_callable=train_model,
execution_timeout=timedelta(hours=8), # override default for training
)
t_evaluate = PythonOperator(task_id="evaluate_model", python_callable=evaluate_model)
t_register = PythonOperator(task_id="register_model", python_callable=register_model)
# Define dependencies
t_ingest >> t_validate >> t_featurize >> t_train >> t_evaluate >> t_register
The TaskFlow API (Airflow 2.0+)
Airflow 2.0 introduced the TaskFlow API, which makes Python-based DAGs much cleaner:
from airflow.decorators import dag, task
from datetime import datetime
@dag(
schedule_interval="@daily",
start_date=datetime(2024, 1, 1),
catchup=False,
tags=["ml", "recommendations"],
)
def recommendation_pipeline():
@task(retries=2, retry_delay=timedelta(minutes=5))
def ingest_data(run_date: str) -> str:
"""Returns path to ingested data."""
from src.data import ingest
return ingest.run(date=run_date)
@task
def validate_data(data_path: str) -> str:
from src.data import validate
report = validate.run(path=data_path)
if not report["passed"]:
raise ValueError(f"Validation failed: {report['failures']}")
return data_path
@task
def featurize(data_path: str) -> str:
from src.features import featurize
return featurize.run(input_path=data_path)
@task(execution_timeout=timedelta(hours=8))
def train_model(feature_path: str) -> dict:
from src.training import train
return train.run(feature_path=feature_path)
@task
def evaluate_and_register(metrics: dict, feature_path: str) -> None:
from src.evaluation import evaluate
from src.registry import register
if metrics["accuracy"] < 0.82:
raise ValueError(f"Accuracy {metrics['accuracy']:.3f} below threshold")
register.push(metrics=metrics)
# Build the pipeline
from airflow.models import Variable
run_date = "{{ ds }}" # Jinja template for logical date
raw_path = ingest_data(run_date)
validated_path = validate_data(raw_path)
feature_path = featurize(validated_path)
metrics = train_model(feature_path)
evaluate_and_register(metrics, feature_path)
dag_instance = recommendation_pipeline()
The TaskFlow API automatically handles XCom serialization - return values from tasks are passed as inputs to downstream tasks without explicit xcom_push / xcom_pull calls.
XCom: Passing Data Between Tasks
XCom (cross-communication) is Airflow's mechanism for passing data between tasks. It is stored in the metadata database, which means:
- It is persisted across task retries
- It is visible in the Airflow UI
- It has a size limit (default 48KB for PostgreSQL, configurable)
# Push a value from a task
def my_task(**context):
result = {"accuracy": 0.87, "f1": 0.84}
context["ti"].xcom_push(key="metrics", value=result)
# Pull a value in a downstream task
def downstream_task(**context):
metrics = context["ti"].xcom_pull(
task_ids="my_task",
key="metrics"
)
print(f"Received metrics: {metrics}")
:::warning XCom Size Limits Never pass large artifacts through XCom. Model files, DataFrames, and large arrays will either hit the 48KB limit or bloat the metadata database. The correct pattern: write the artifact to S3/GCS in the producing task, push only the artifact path through XCom, and read from storage in the consuming task. :::
# CORRECT pattern for large artifacts
@task
def train_model(feature_path: str) -> str:
model = fit_model(feature_path)
artifact_path = f"s3://ml-artifacts/{run_date}/model.pkl"
save_model(model, artifact_path)
return artifact_path # pass only the path, not the model object
@task
def evaluate_model(model_path: str) -> dict:
model = load_model(model_path) # load from storage
return compute_metrics(model)
Dynamic DAGs for Parameterized Training
Sometimes you need to run the same pipeline logic for multiple configurations - different hyperparameter sets, different data segments, different models. Dynamic DAGs generate task graphs programmatically:
from airflow import DAG
from airflow.operators.python import PythonOperator
HYPERPARAMETER_SETS = [
{"learning_rate": 0.01, "depth": 6, "n_estimators": 500},
{"learning_rate": 0.05, "depth": 4, "n_estimators": 300},
{"learning_rate": 0.1, "depth": 3, "n_estimators": 200},
]
with DAG("hyperparameter_sweep", schedule_interval=None, catchup=False) as dag:
def prepare_data(**context):
# shared featurization - runs once
from src.features import featurize
featurize.run()
prepare = PythonOperator(task_id="prepare_data", python_callable=prepare_data)
for i, params in enumerate(HYPERPARAMETER_SETS):
# capture params in closure
def train_with_params(params=params, **context):
from src.training import train
train.run(params=params)
def evaluate_params(i=i, params=params, **context):
from src.evaluation import evaluate
evaluate.run(config_id=i, params=params)
t_train = PythonOperator(
task_id=f"train_config_{i}",
python_callable=train_with_params,
)
t_eval = PythonOperator(
task_id=f"evaluate_config_{i}",
python_callable=evaluate_params,
)
prepare >> t_train >> t_eval
This creates a fan-out DAG where prepare_data runs once, then all three training runs execute in parallel (up to your worker concurrency limit), each followed by their own evaluation step.
Sensor Patterns for ML
Sensors are special Airflow operators that wait for a condition to be true before proceeding. They are essential for event-driven ML pipelines:
from airflow.sensors.s3_key_sensor import S3KeySensor
from airflow.sensors.python import PythonSensor
# Wait for data to arrive in S3 before starting pipeline
wait_for_data = S3KeySensor(
task_id="wait_for_raw_data",
bucket_name="ml-data-lake",
bucket_key="raw/{{ ds }}/data.parquet",
aws_conn_id="aws_default",
poke_interval=300, # check every 5 minutes
timeout=7200, # fail after 2 hours if data never arrives
mode="reschedule", # release the worker slot while waiting
)
# Wait for an upstream pipeline to complete
def check_feature_pipeline_done(**context):
from airflow.models import DagRun
runs = DagRun.find(
dag_id="feature_engineering_pipeline",
execution_date=context["logical_date"],
)
return any(run.state == "success" for run in runs)
wait_for_features = PythonSensor(
task_id="wait_for_feature_pipeline",
python_callable=check_feature_pipeline_done,
poke_interval=120,
timeout=3600,
mode="reschedule",
)
:::tip Use reschedule Mode for Sensors
With mode="poke" (default), a sensor occupies a worker slot while waiting - even though it is just sleeping. Use mode="reschedule" so the sensor releases its worker slot between checks. This prevents sensor tasks from starving real work tasks on limited worker pools.
:::
Airflow at Scale: Executors
LocalExecutor
Runs tasks as subprocesses on the same machine as the scheduler. Fine for development and small teams. Cannot scale beyond one machine.
# airflow.cfg
executor = LocalExecutor
CeleryExecutor
Distributes tasks to a fleet of Celery workers via a message broker (Redis or RabbitMQ). The standard choice for scaling Airflow to dozens of concurrent tasks.
executor = CeleryExecutor
broker_url = redis://redis:6379/0
result_backend = db+postgresql://airflow:password@postgres/airflow
Workers are stateless Python processes - you scale by adding more worker machines or pods. Each worker registers its queue and picks up tasks matching its queues.
KubernetesExecutor
Runs each task in a dedicated Kubernetes pod. When a task is queued, the scheduler creates a K8s pod, the pod runs the task, and the pod terminates. Advantages: perfect isolation between tasks, per-task resource allocation, no idle workers wasting resources.
# Values for helm chart
executor: KubernetesExecutor
workers:
resources:
requests:
cpu: "500m"
memory: "1Gi"
limits:
cpu: "2"
memory: "4Gi"
The KubernetesExecutor is increasingly the right default for ML teams on Kubernetes. Each training task can request the exact GPU resources it needs; evaluation tasks can request CPU-only pods. No shared worker pool to manage.
Connection Management
Airflow Connections store credentials for external services (databases, cloud providers, APIs). Never hardcode credentials in DAG files:
from airflow.hooks.base import BaseHook
# Access a connection in a task
def upload_to_s3(**context):
conn = BaseHook.get_connection("aws_s3_conn")
# conn.login = access key, conn.password = secret key, conn.extra = JSON extras
import boto3
session = boto3.Session(
aws_access_key_id=conn.login,
aws_secret_access_key=conn.password,
region_name=conn.extra_dejson.get("region", "us-east-1"),
)
s3 = session.client("s3")
# ... upload logic
For Kubernetes-deployed Airflow, use Kubernetes Secrets as the backend for connections:
# Set connection via environment variable (overrides DB connection)
AIRFLOW_CONN_AWS_DEFAULT=aws://AKID:secret@/?region_name=us-east-1
Common Airflow ML Pitfalls
:::danger DAG File Import Side Effects
Airflow's scheduler imports every DAG file frequently (every min_file_process_interval seconds, default 30s). Any code at module level runs on every import. Never put expensive operations (database queries, API calls, model loading) at module level in a DAG file. All such code must be inside task functions.
:::
# BAD: loads model on every scheduler parse cycle - kills performance
import tensorflow as tf
model = tf.keras.models.load_model("s3://models/latest") # at module level
with DAG(...) as dag:
@task
def predict():
return model.predict(...) # model already loaded
# GOOD: load inside task, runs only when task executes
with DAG(...) as dag:
@task
def predict():
import tensorflow as tf
model = tf.keras.models.load_model("s3://models/latest")
return model.predict(...)
:::danger catchup=True with start_date Far in the Past
If catchup=True (the default) and your start_date is months ago, Airflow will try to run every scheduled interval since start_date on the first deployment. A daily pipeline with a start_date 6 months ago will immediately queue 180 DAG runs. Set catchup=False for ML training pipelines unless you explicitly need historical backfill.
:::
:::warning Zombie Tasks
When a worker process crashes (OOM kill, node failure), its task becomes a "zombie" - still marked as running in the metadata database, never completing. Airflow's zombie detection cleans these up on a configurable interval (zombie_detection_interval, default 10 minutes). During that window, downstream tasks are blocked. Keep task timeouts shorter than your zombie detection interval to catch stuck tasks quickly.
:::
Monitoring Airflow in Production
# airflow.cfg: enable StatsD metrics export
statsd_on = True
statsd_host = localhost
statsd_port = 8125
statsd_prefix = airflow
# Key metrics to alert on:
# airflow.scheduler.tasks.starving - tasks queued but no worker available
# airflow.scheduler.dag_runs.duration - how long DAG runs take
# airflow.ti_failures - task instance failures
# airflow.executor.open_slots - available worker capacity
The Airflow StatsD integration exports these metrics to Prometheus (via StatsD exporter) or Datadog. Set up dashboards for:
- Task failure rate by DAG and task
- Task duration (p50, p95, p99) - catch slow degradation before it becomes critical
- Scheduler lag - time between task becoming ready and task starting execution
- Executor queue depth - are workers keeping up with scheduled tasks?
Interview Q&A
Q1: Explain Airflow's architecture. What does the scheduler actually do?
Airflow's architecture has three main components: the scheduler, the web server, and the metadata database, plus an executor layer. The scheduler continuously parses DAG files to detect new DAGs and changes, scans the metadata database to identify tasks whose dependencies are all met, and places ready tasks into the executor queue. Critically, the scheduler does not execute tasks - it only decides what should run and when. The executor (LocalExecutor, CeleryExecutor, or KubernetesExecutor) is responsible for actually running the task code. The web server is a stateless read/write layer over the metadata database that serves the Airflow UI and REST API.
Q2: What is XCom and when should you NOT use it?
XCom (cross-communication) is Airflow's mechanism for passing data between tasks. Values are serialized and stored in the metadata database, making them available to downstream tasks via xcom_pull. You should not use XCom for large data - the metadata database is not designed to store model files, DataFrames, or large arrays. The size limit is typically 48KB for PostgreSQL. The correct pattern for large artifacts is to write to external storage (S3, GCS) in the producing task, push only the storage path through XCom, and load from storage in the consuming task.
Q3: What is the difference between CeleryExecutor and KubernetesExecutor?
CeleryExecutor distributes tasks to a pool of always-running Celery worker processes via a message broker (Redis/RabbitMQ). Workers are stateful and persistent - they consume tasks from the queue as long as they are running. KubernetesExecutor creates a new Kubernetes pod for each task and destroys it when the task completes. KubernetesExecutor provides perfect isolation (each task gets its own environment), per-task resource allocation (training tasks can request GPUs, evaluation tasks can request CPU-only), and no idle resource cost (pods only exist while tasks run). The tradeoff is pod startup time (10-30 seconds) vs Celery worker startup time (near zero). For ML workloads with long-running tasks, KubernetesExecutor is usually the better choice.
Q4: A DAG that used to run in 2 hours now consistently takes 5 hours. How do you investigate?
First, check the Airflow UI's Gantt chart for the affected DAG run to identify which specific tasks are taking longer. Compare with historical runs to see when the slowdown started. Common causes for ML pipeline slowdowns: (1) data volume growth causing tasks to process more data, (2) a task loading an increasingly large artifact into memory (e.g., a model file that grows each run), (3) a worker running multiple tasks concurrently when it used to run one at a time (resource contention), (4) a downstream API or database becoming slower (connection timeouts), (5) a task not properly parallelizing computation that used to use a faster machine type. Check StatsD/Prometheus metrics for task duration trends - a gradual increase points to data growth; a sudden jump points to an infrastructure or code change.
Q5: How do you handle a training pipeline that needs to run for 12 hours? What Airflow configuration is needed?
Set execution_timeout to at least 12 hours on the training task specifically (override the default). Ensure dagrun_timeout on the DAG itself allows sufficient time for the full pipeline (training task + other tasks). For CeleryExecutor, ensure the Celery task soft_time_limit and time_limit on the worker are also set appropriately - by default, Celery may kill long-running tasks. For KubernetesExecutor, set appropriate pod resource requests (CPU, memory, GPU) to ensure the training pod is not OOM-killed. Also consider checkpointing: if the training is interrupted at hour 10, can it resume from a checkpoint rather than starting over? This requires saving checkpoints to persistent storage and checking for existing checkpoints at the start of the task.
