Skip to main content

:::tip 🎮 Interactive Playground Visualize this concept: Try the Pipeline Orchestration demo on the EngineersOfAI Playground - no code required. :::

Airflow for ML Pipelines

The 3-Hour Manual Process

A data scientist owned the churn prediction model. Every Monday morning, she opened a Jupyter notebook, ran all the cells - 45 minutes of training on her laptop - downloaded the resulting pickle file, uploaded it to an S3 bucket, updated a configuration file in a Git repo with the new model path, and submitted a pull request. Another engineer reviewed it, merged it, and then manually triggered a deployment. The whole process took three to four hours if both people were available and nothing went wrong.

The model was supposed to be retrained weekly. Some weeks it happened on Monday. Some weeks the scientist was traveling, so it happened Thursday. One week it did not happen at all. The model silently served predictions based on a month-old training snapshot while customer churn patterns had shifted. The business noticed the prediction quality degrading, but attributed it to seasonal effects rather than a stale model.

When the team audited the process, they found five more failure modes they had not noticed: there was no validation of the training data quality before training started, so occasionally the model trained on a corrupted dataset and produced worse metrics than the previous version - but got deployed anyway because nobody checked. There was no comparison against the production model. There was no record of which model version was in production at what time.

The team automated the entire process in Airflow. The new pipeline ran unattended at 2am every Monday, validated the training data, ran training in an isolated Kubernetes pod with GPU, compared the new model against the current production model, deployed only if there was a measurable improvement, and sent a Slack summary regardless of outcome. The data scientist's Monday was freed. The model quality improved because training now ran every single week without exception.

This lesson explains how to build that pipeline.


Why ML Pipelines Need Orchestration Specifically

Standard ETL pipelines have a straightforward failure model: data arrives, gets transformed, gets loaded. If a step fails, retry it. The data either looks right or it does not.

ML pipelines have additional failure modes that require additional machinery:

Data quality failures. Training on corrupted or incomplete data is worse than not training at all. You need a gate before training starts that validates row counts, feature null rates, feature distribution drift, and label balance. If the data fails the gate, the pipeline must halt - not proceed to training with bad input.

Training is expensive and slow. An ETL task might run for two minutes. A training task might run for six hours, cost $40 in GPU compute, and produce an artifact (model weights) that is 500MB. The orchestration system must handle long-running tasks gracefully, report progress, and not time out.

Model quality is not binary. A SQL transformation either succeeds or fails. A model always trains - but the resulting model might be worse than what is already in production. You need an evaluation step that measures the new model against the current production model and makes a gated deployment decision.

Deployment is conditional. Unlike ETL where every run replaces the previous output, model deployment should only happen when the new model is demonstrably better. This requires branching logic in the pipeline.

Reproducibility and auditability. Regulatory and operational requirements often demand that you can reproduce any model version and explain exactly what data it was trained on, what hyperparameters were used, and when it was deployed. The pipeline must log this information systematically.

Airflow addresses all of these: ShortCircuitOperator for data quality gates, KubernetesPodOperator for isolated GPU training, PythonOperator for evaluation logic, BranchPythonOperator for conditional deployment, and the metadata database for a complete audit trail of every run.


The ML Pipeline Pattern

Every production ML retraining pipeline follows the same fundamental pattern. The names and specifics change, but the shape is constant:


Step 1 - Data Quality Gate with ShortCircuitOperator

The first and most important step is validating the training data before spending any compute on training. The ShortCircuitOperator halts the entire DAG (marks all downstream tasks as skipped) if its callable returns False. This is preferable to raising an exception because skipped is a distinct, non-alarming state - you might send a Slack notification rather than waking someone up at 3am.

from airflow.operators.python import ShortCircuitOperator
import pandas as pd
import boto3

def validate_training_data_fn(ds: str, **context) -> bool:
"""
Validate training data quality. Return False to skip the pipeline.
ds is the execution date - Airflow injects it automatically.
"""
s3 = boto3.client("s3")
path = f"s3://data/features/churn/date={ds}/features.parquet"

try:
df = pd.read_parquet(path)
except Exception as e:
print(f"QUALITY FAIL: Could not read training data: {e}")
return False

checks = {
"row_count": (len(df), 10_000, "rows"),
"max_null_rate": (df.isnull().mean().max(), 0.05, "null fraction"),
"positive_label_rate": (df["churned"].mean(), 0.03, "min label rate"),
}

all_passed = True
for check_name, (actual, threshold, unit) in checks.items():
if check_name == "max_null_rate" and actual > threshold:
print(f"QUALITY FAIL [{check_name}]: {actual:.3f} {unit} > threshold {threshold:.3f}")
all_passed = False
elif check_name != "max_null_rate" and actual < threshold:
print(f"QUALITY FAIL [{check_name}]: {actual} {unit} < threshold {threshold}")
all_passed = False
else:
print(f"QUALITY PASS [{check_name}]: {actual}")

# push metadata for downstream use, even if we are about to short-circuit
context["ti"].xcom_push(key="row_count", value=len(df))
context["ti"].xcom_push(key="s3_path", value=path)

return all_passed


validate_data = ShortCircuitOperator(
task_id="validate_training_data",
python_callable=validate_training_data_fn,
op_kwargs={"ds": "{{ ds }}"},
# ignore_downstream_trigger_rules=True ensures all downstream tasks are skipped
# not just direct children
ignore_downstream_trigger_rules=True,
)

:::tip Add feature distribution checks Row count and null rate checks catch the obvious problems. Production pipelines also check feature distribution drift - if the mean of a key feature has shifted by more than N standard deviations from its historical distribution, training on it may produce a model that performs well on the training period but is miscalibrated for the current distribution. The evidently library provides ready-made drift detectors you can integrate into this gate. :::


Step 2 - Feature Engineering

After data quality is confirmed, compute the final feature matrix for training. This step is often a separate DAG (a feature pipeline that runs on its own schedule) rather than part of the training DAG - separating concerns. When it is part of the training DAG, it should write its output to a stable S3 path and push that path via XCom.

from airflow.operators.python import PythonOperator

def run_feature_engineering(ds: str, **context) -> None:
"""Read raw validated data, compute features, write to feature store path."""
import pandas as pd
import numpy as np

ti = context["ti"]
raw_path = ti.xcom_pull(task_ids="validate_training_data", key="s3_path")

df = pd.read_parquet(raw_path)

# Example feature engineering
df["days_since_last_login"] = (pd.Timestamp(ds) - pd.to_datetime(df["last_login_date"])).dt.days
df["avg_session_length_7d"] = df["session_lengths"].apply(lambda x: np.mean(x[-7:]) if len(x) >= 7 else np.nan)
df["payment_failure_rate"] = df["payment_failures_30d"] / df["payment_attempts_30d"].clip(lower=1)

feature_cols = [
"days_since_last_login", "avg_session_length_7d",
"payment_failure_rate", "plan_tier", "support_tickets_90d",
]
output = df[feature_cols + ["churned", "user_id"]].dropna(subset=feature_cols)

output_path = f"s3://data/features/training/churn/date={ds}/train_features.parquet"
output.to_parquet(output_path, index=False)
ti.xcom_push(key="feature_path", value=output_path)
ti.xcom_push(key="feature_count", value=len(output))
print(f"Feature engineering complete: {len(output)} rows written to {output_path}")


feature_engineering_task = PythonOperator(
task_id="feature_engineering",
python_callable=run_feature_engineering,
op_kwargs={"ds": "{{ ds }}"},
)

Step 3 - KubernetesPodOperator for Training

Training should not run on an Airflow worker. Workers are shared resources - a training job consuming 32GB of RAM and a GPU would starve other tasks. The correct approach is to run training in an isolated Kubernetes pod via KubernetesPodOperator.

The training container handles all the ML logic. It reads from the feature path, trains the model, logs metrics and artifacts to MLflow, and writes its MLflow run ID to stdout so Airflow can retrieve it via XCom.

from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from kubernetes.client import models as k8s

# Volume and config map for MLflow configuration
mlflow_config_volume = k8s.V1Volume(
name="mlflow-config",
config_map=k8s.V1ConfigMapVolumeSource(name="mlflow-client-config"),
)
mlflow_config_mount = k8s.V1VolumeMount(
name="mlflow-config",
mount_path="/etc/mlflow",
read_only=True,
)

train_model = KubernetesPodOperator(
task_id="train_churn_model",
name="churn-training-{{ ds_nodash }}", # unique pod name per run
namespace="ml-pipelines",
image="registry.company.com/churn-trainer:v3.2",
image_pull_policy="Always",
env_vars={
"TRAIN_DATE": "{{ ds }}",
"FEATURE_PATH": "{{ task_instance.xcom_pull('feature_engineering', key='feature_path') }}",
"MLFLOW_TRACKING_URI": "https://mlflow.internal",
"MODEL_NAME": "churn_prediction",
"EXPERIMENT_NAME": "churn_weekly_training",
},
container_resources=k8s.V1ResourceRequirements(
requests={"memory": "32Gi", "cpu": "8"},
limits={"memory": "48Gi", "cpu": "16", "nvidia.com/gpu": "1"},
),
volumes=[mlflow_config_volume],
volume_mounts=[mlflow_config_mount],
do_xcom_push=True, # reads the last line of stdout as XCom return value
is_delete_operator_pod=True,
get_logs=True,
log_events_on_failure=True,
startup_timeout_seconds=300,
retries=1,
retry_delay=timedelta(minutes=15),
)

The training container script writes its output to stdout as a JSON object on the last line - this is how do_xcom_push=True retrieves it:

# Inside the Docker container (train.py)
import mlflow, json, sys, os
import pandas as pd
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

feature_path = os.environ["FEATURE_PATH"]
model_name = os.environ["MODEL_NAME"]
experiment_name = os.environ["EXPERIMENT_NAME"]
train_date = os.environ["TRAIN_DATE"]

mlflow.set_experiment(experiment_name)

df = pd.read_parquet(feature_path)
feature_cols = [c for c in df.columns if c not in ["churned", "user_id"]]
X_train, X_test, y_train, y_test = train_test_split(
df[feature_cols], df["churned"], test_size=0.2, stratify=df["churned"], random_state=42
)

with mlflow.start_run() as run:
mlflow.log_param("train_date", train_date)
mlflow.log_param("n_estimators", 300)
mlflow.log_param("learning_rate", 0.05)
mlflow.log_param("feature_count", len(feature_cols))

model = GradientBoostingClassifier(n_estimators=300, learning_rate=0.05, random_state=42)
model.fit(X_train, y_train)
y_prob = model.predict_proba(X_test)[:, 1]
auc = roc_auc_score(y_test, y_prob)

mlflow.log_metric("val_auc", auc)
mlflow.sklearn.log_model(model, artifact_path="model", registered_model_name=model_name)
run_id = run.info.run_id

# Write run_id as final JSON line - picked up by Airflow XCom
result = {"mlflow_run_id": run_id, "val_auc": auc, "train_date": train_date}
print(json.dumps(result)) # must be the last line of stdout

Step 4 - Champion/Challenger Evaluation

After training, evaluate whether the new model beats the current production model by a meaningful margin. The "champion" is the model currently serving predictions in production. The "challenger" is the newly trained model.

def evaluate_models(ds: str, **context) -> dict:
"""
Load challenger (new) and champion (production) model metrics.
Return evaluation results including whether to deploy.
"""
import mlflow

ti = context["ti"]

# Get the training run ID from the KubernetesPodOperator XCom
# do_xcom_push=True writes the last JSON line from the pod's stdout
training_output = ti.xcom_pull(task_ids="train_churn_model")
if isinstance(training_output, str):
import json
training_output = json.loads(training_output)

challenger_run_id = training_output["mlflow_run_id"]
challenger_auc = training_output["val_auc"]

# Fetch champion model from registry
client = mlflow.tracking.MlflowClient()
champion_versions = client.get_latest_versions("churn_prediction", stages=["Production"])

if not champion_versions:
# No champion exists yet - always deploy first model
return {
"challenger_auc": challenger_auc,
"champion_auc": 0.0,
"improvement": challenger_auc,
"should_deploy": True,
"challenger_run_id": challenger_run_id,
"reason": "no_champion_exists",
}

champion_version = champion_versions[0]
champion_run = client.get_run(champion_version.run_id)
champion_auc = float(champion_run.data.metrics.get("val_auc", 0.0))

improvement = challenger_auc - champion_auc
# Deploy only if challenger is meaningfully better (0.5pp threshold)
# Avoids deploying on noise - small improvements may not hold on production traffic
should_deploy = improvement >= 0.005

result = {
"challenger_auc": challenger_auc,
"champion_auc": champion_auc,
"improvement": improvement,
"should_deploy": should_deploy,
"challenger_run_id": challenger_run_id,
"champion_version": champion_version.version,
"reason": "improvement_above_threshold" if should_deploy else "improvement_below_threshold",
}

print(f"Champion AUC: {champion_auc:.4f}")
print(f"Challenger AUC: {challenger_auc:.4f}")
print(f"Improvement: {improvement:+.4f}")
print(f"Deploy: {should_deploy} ({result['reason']})")

ti.xcom_push(key="evaluation", value=result)
return result


evaluate_task = PythonOperator(
task_id="evaluate_models",
python_callable=evaluate_models,
op_kwargs={"ds": "{{ ds }}"},
)

Step 5 - Conditional Deployment with BranchPythonOperator

The BranchPythonOperator returns the task_id (or list of task_ids) that should run next. All other downstream tasks are marked skipped.

from airflow.operators.python import BranchPythonOperator
from airflow.utils.trigger_rule import TriggerRule

def choose_deployment_branch(**context) -> str:
ti = context["ti"]
evaluation = ti.xcom_pull(task_ids="evaluate_models", key="evaluation")
if evaluation["should_deploy"]:
return "deploy_challenger"
return "log_no_deployment"

deployment_branch = BranchPythonOperator(
task_id="deployment_decision",
python_callable=choose_deployment_branch,
)

def deploy_challenger(**context) -> None:
import mlflow
ti = context["ti"]
evaluation = ti.xcom_pull(task_ids="evaluate_models", key="evaluation")
client = mlflow.tracking.MlflowClient()

# Transition challenger to Production, archive previous champion
challenger_version = client.get_latest_versions("churn_prediction", stages=["Staging"])[0]
client.transition_model_version_stage(
name="churn_prediction",
version=challenger_version.version,
stage="Production",
archive_existing_versions=True, # moves old champion to Archived
)
print(f"Deployed churn_prediction v{challenger_version.version} - AUC {evaluation['challenger_auc']:.4f} (+{evaluation['improvement']:+.4f} vs champion)")

def log_no_deployment(**context) -> None:
ti = context["ti"]
evaluation = ti.xcom_pull(task_ids="evaluate_models", key="evaluation")
print(
f"Skipped deployment. Challenger AUC {evaluation['challenger_auc']:.4f} vs "
f"Champion {evaluation['champion_auc']:.4f} - improvement {evaluation['improvement']:+.4f} "
f"below 0.005 threshold."
)

deploy_task = PythonOperator(
task_id="deploy_challenger",
python_callable=deploy_challenger,
)

skip_task = PythonOperator(
task_id="log_no_deployment",
python_callable=log_no_deployment,
)

# Notification task runs regardless of which branch was taken
def send_pipeline_summary(**context) -> None:
import requests
ti = context["ti"]
evaluation = ti.xcom_pull(task_ids="evaluate_models", key="evaluation")
ds = context["ds"]
action = "DEPLOYED" if evaluation["should_deploy"] else "SKIPPED"
message = (
f":robot_face: *Churn Model Pipeline - {ds}*\n"
f"Challenger AUC: `{evaluation['challenger_auc']:.4f}` "
f"vs Champion: `{evaluation['champion_auc']:.4f}` "
f"(improvement: `{evaluation['improvement']:+.4f}`)\n"
f"Decision: *{action}*"
)
requests.post(
os.environ["SLACK_WEBHOOK_URL"],
json={"text": message},
)

summary_task = PythonOperator(
task_id="pipeline_summary",
python_callable=send_pipeline_summary,
trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, # run after either branch
)

# Wire the deployment section
deployment_branch >> [deploy_task, skip_task] >> summary_task

:::warning trigger_rule on post-branch tasks Any task downstream of a BranchPythonOperator that should run regardless of which branch was taken must set trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS (or ALL_DONE). The default trigger rule ALL_SUCCESS will cause the task to be skipped if any upstream task was skipped - which is exactly what happens to the non-chosen branch. :::


SLA Monitoring

SLAs (Service Level Agreements) define the maximum acceptable time for a pipeline to complete. If the pipeline takes longer than the SLA, Airflow calls the sla_miss_callback and marks the DAG run as having missed its SLA (the run continues - SLA miss does not cancel tasks).

from datetime import timedelta

def sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis):
"""
Called when any task misses its SLA.
dag: the DAG object
blocking_tis: task instances that caused the SLA miss
"""
import requests, os
slow_tasks = [ti.task_id for ti in blocking_tis]
message = (
f":warning: *SLA MISS - {dag.dag_id}*\n"
f"Tasks exceeding SLA: `{slow_tasks}`\n"
f"Expected completion within: 6 hours"
)
requests.post(os.environ["PAGERDUTY_WEBHOOK_URL"], json={"message": message})

# Apply SLA in default_args (applies to all tasks)
default_args = {
"sla": timedelta(hours=6),
}

# Or set per-task for granular alerting
train_model_task = KubernetesPodOperator(
task_id="train_churn_model",
sla=timedelta(hours=4), # training itself should finish in 4 hours
# ...
)

Sensors for ML Pipeline Dependencies

Sensors pause a DAG run, polling until a condition is met. For ML pipelines, the most useful sensors are:

from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
from airflow.sensors.external_task import ExternalTaskSensor

# Wait for the feature pipeline to finish writing today's features
wait_for_features = S3KeySensor(
task_id="wait_for_feature_data",
bucket_name="data",
bucket_key=f"features/churn/date={{{{ ds }}}}/features.parquet",
aws_conn_id="aws_default",
timeout=60 * 60 * 4, # wait at most 4 hours
poke_interval=60 * 5, # check every 5 minutes
mode="reschedule", # release the worker slot while waiting (not "poke")
)

# Wait for the upstream feature DAG to complete
wait_for_feature_dag = ExternalTaskSensor(
task_id="wait_for_feature_pipeline",
external_dag_id="daily_feature_engineering",
external_task_id=None, # None = wait for the entire DAG run to succeed
execution_date_fn=lambda dt: dt, # same logical date
timeout=60 * 60 * 4,
poke_interval=60 * 2,
mode="reschedule",
)

:::tip Use mode="reschedule" for all sensors mode="poke" (default) keeps the worker process alive while polling - each sensor occupies a worker slot for hours. mode="reschedule" releases the worker slot between polls and reschedules the sensor task as a lightweight heartbeat. For long-running sensors, reschedule mode can reduce your worker slot usage by 80-90%. :::


Train multiple model variants in parallel using Airflow's dynamic task mapping. Each variant gets its own pod, its own training run, and its own evaluation - and they all run simultaneously.

from airflow.decorators import dag, task

@dag(schedule="0 2 * * 0", start_date=datetime(2024, 1, 1), catchup=False)
def hyperparameter_search():

HYPERPARAMETER_CONFIGS = [
{"n_estimators": 100, "learning_rate": 0.1, "max_depth": 3},
{"n_estimators": 200, "learning_rate": 0.05, "max_depth": 4},
{"n_estimators": 300, "learning_rate": 0.01, "max_depth": 5},
{"n_estimators": 500, "learning_rate": 0.005, "max_depth": 6},
]

@task
def get_configs() -> list[dict]:
return HYPERPARAMETER_CONFIGS

@task
def train_variant(config: dict) -> dict:
"""Train one model variant. Runs as a separate task instance per config."""
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import roc_auc_score
import pandas as pd, mlflow

df = pd.read_parquet("s3://data/features/training/latest/train_features.parquet")
X, y = df.drop("churned", axis=1), df["churned"]
# ... train and evaluate ...
auc = 0.88 # placeholder
return {**config, "val_auc": auc, "config_id": hash(str(config))}

@task
def select_best(results: list[dict]) -> dict:
"""Pick the best variant from all parallel training results."""
best = max(results, key=lambda r: r["val_auc"])
print(f"Best config: {best}")
return best

configs = get_configs()
# .expand() creates one task instance per element in configs
all_results = train_variant.expand(config=configs)
best = select_best(all_results)

hyperparameter_search()

The Complete Production ML Training DAG

Here is the full DAG integrating all the above components:

from airflow.decorators import dag, task
from airflow.operators.python import BranchPythonOperator, ShortCircuitOperator
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.trigger_rule import TriggerRule
from datetime import datetime, timedelta
import os

def sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis):
import requests
requests.post(os.environ["SLACK_WEBHOOK_URL"], json={
"text": f":warning: SLA miss in {dag.dag_id}: {[t.task_id for t in blocking_tis]}"
})

default_args = {
"owner": "ml-platform",
"retries": 2,
"retry_delay": timedelta(minutes=10),
"email_on_failure": True,
"email": ["[email protected]"],
"sla": timedelta(hours=6),
}

@dag(
dag_id="weekly_churn_model_training",
schedule="0 2 * * 1", # 2am Monday
start_date=datetime(2024, 1, 1),
catchup=False,
default_args=default_args,
sla_miss_callback=sla_miss_callback,
tags=["ml", "churn", "production"],
max_active_runs=1, # only one retraining run at a time
)
def churn_training_dag():

wait_for_features = ExternalTaskSensor(
task_id="wait_for_feature_pipeline",
external_dag_id="daily_feature_engineering",
external_task_id=None,
timeout=14400,
poke_interval=120,
mode="reschedule",
)

@task
def validate_data(ds: str) -> dict:
import pandas as pd
path = f"s3://data/features/churn/date={ds}/features.parquet"
df = pd.read_parquet(path)
assert len(df) >= 10_000, f"Only {len(df)} rows - need 10,000"
assert df.isnull().mean().max() < 0.05, "Null rate exceeds 5%"
return {"row_count": len(df), "s3_path": path}

train = KubernetesPodOperator(
task_id="train_model",
name="churn-train-{{ ds_nodash }}",
namespace="ml-pipelines",
image="registry.company.com/churn-trainer:latest",
env_vars={"TRAIN_DATE": "{{ ds }}", "MLFLOW_URI": "https://mlflow.internal"},
container_resources={"requests": {"memory": "32Gi", "nvidia.com/gpu": "1"}},
do_xcom_push=True,
is_delete_operator_pod=True,
retries=1,
)

@task
def evaluate(train_output: str) -> dict:
import mlflow, json
output = json.loads(train_output) if isinstance(train_output, str) else train_output
challenger_auc = output["val_auc"]
client = mlflow.tracking.MlflowClient()
champions = client.get_latest_versions("churn_prediction", stages=["Production"])
if not champions:
return {**output, "should_deploy": True, "champion_auc": 0.0, "improvement": challenger_auc}
champion_run = client.get_run(champions[0].run_id)
champion_auc = float(champion_run.data.metrics.get("val_auc", 0.0))
improvement = challenger_auc - champion_auc
return {**output, "champion_auc": champion_auc, "improvement": improvement, "should_deploy": improvement >= 0.005}

def branch_fn(**context) -> str:
ev = context["ti"].xcom_pull(task_ids="evaluate")
return "deploy" if ev["should_deploy"] else "skip_deploy"

branch = BranchPythonOperator(task_id="deployment_decision", python_callable=branch_fn)

@task(task_id="deploy")
def deploy(**context) -> None:
import mlflow
ev = context["ti"].xcom_pull(task_ids="evaluate")
client = mlflow.tracking.MlflowClient()
versions = client.get_latest_versions("churn_prediction", stages=["Staging"])
client.transition_model_version_stage("churn_prediction", versions[0].version, "Production", archive_existing_versions=True)
print(f"Deployed v{versions[0].version} - AUC {ev['val_auc']:.4f}")

@task(task_id="skip_deploy")
def skip_deploy(**context) -> None:
ev = context["ti"].xcom_pull(task_ids="evaluate")
print(f"No deploy. Improvement {ev['improvement']:+.4f} below threshold.")

@task(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
def notify(**context) -> None:
import requests
ev = context["ti"].xcom_pull(task_ids="evaluate")
action = "DEPLOYED" if ev["should_deploy"] else "SKIPPED"
requests.post(os.environ["SLACK_WEBHOOK_URL"], json={
"text": f":robot_face: Churn Model Pipeline - {action}\nChallenger AUC: {ev['val_auc']:.4f} | Champion AUC: {ev.get('champion_auc', 0):.4f} | Δ {ev.get('improvement', 0):+.4f}"
})

# Wiring
validated = validate_data("{{ ds }}")
wait_for_features >> validated >> train
evaluation_result = evaluate(train.output)
deploy_task = deploy()
skip_task = skip_deploy()
notify_task = notify()

evaluation_result >> branch >> [deploy_task, skip_task] >> notify_task

churn_training_dag()

Production Engineering Notes

max_active_runs=1: For ML training DAGs, set this to 1. If Monday's run is still going when Tuesday's would start (rare but possible for long training), you do not want two training jobs competing for GPU resources and writing to the same MLflow experiment simultaneously.

Artifact versioning: Never overwrite model artifacts - always version them. The training container should write to a path that includes the training date or run ID. This preserves the ability to roll back to any previous model.

Staging before Production: Rather than promoting directly to Production, promote to Staging first, run a shadow evaluation (route a small percentage of production traffic to the new model), and promote to Production after validating live metrics. This is the safest pattern for high-stakes models.

GPU resource requests: Request the exact GPU type your training requires using Kubernetes node affinity rules. Without this, the pod may be scheduled on a machine with an incompatible GPU driver.

:::danger Non-idempotent training tasks If your training task writes to a fixed MLflow run name like churn_v1, running a backfill will overwrite previous training runs. Use the execution date in all artifact paths and MLflow run names: f"churn_{ds}". This makes every historical run independently addressable and reproducible. :::

:::danger PythonOperator for heavy training Never run model training inside a PythonOperator. PythonOperators run in the same process as the Airflow worker - if the training job needs 32GB of RAM, it will consume 32GB from the worker process, potentially crashing other tasks on the same worker. Always use KubernetesPodOperator or a similar isolation mechanism for compute-intensive tasks. :::


Interview Q&A

Q: How do you implement a data quality gate in an Airflow ML pipeline, and what happens when it fails?

Use ShortCircuitOperator. When its callable returns False, all downstream tasks are marked skipped (not failed). This is the correct signal - the pipeline did not error, it halted because preconditions were not met. The key parameter is ignore_downstream_trigger_rules=True, which ensures the short-circuit propagates through the entire downstream graph rather than stopping at direct children. In practice, a data quality gate checks: row count vs. historical baseline, null rates per feature, label class balance, and optionally distribution drift using tools like Evidently or Great Expectations. When the gate fires, you send a notification (Slack/PagerDuty) so the data engineering team can investigate the upstream pipeline, and the model remains on the last successfully deployed version.


Q: Why should ML training tasks use KubernetesPodOperator instead of PythonOperator?

Three reasons: resource isolation, dependency isolation, and reproducibility. Resource isolation - training may need 32GB RAM and a GPU; those resources cannot be shared with other worker tasks. Dependency isolation - training may need PyTorch 2.1 while other tasks need TensorFlow 2.12; running both in the same process is impossible, and even in the same worker pool it creates conflicts. Reproducibility - by pinning the container image tag, you guarantee that a historical backfill run uses exactly the same code, library versions, and CUDA version as the original run. This makes debugging model regressions straightforward: you can trace exactly what environment produced each model artifact.


Q: What is the champion/challenger pattern and why is it important for model deployment?

The champion model is the version currently serving predictions in production. The challenger is the newly trained model. Before deploying the challenger, you compare its validation metrics against the champion's. You only promote the challenger if it exceeds the champion by a meaningful margin (e.g., 0.5pp AUC). The margin threshold prevents deploying on statistical noise - a model that is 0.001 AUC better on the training-period validation set may not actually be better in production. Without this gate, every retraining run would automatically replace the production model, which could silently degrade prediction quality whenever the training data is noisy or the model architecture is suboptimal.


Q: How do you handle the case where the BranchPythonOperator's non-chosen branch causes downstream tasks to be skipped?

Downstream tasks that should run regardless of which branch was chosen must set trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS. The default ALL_SUCCESS trigger rule causes a task to be skipped if any upstream task was skipped - which is what happens to the non-executed branch. NONE_FAILED_MIN_ONE_SUCCESS runs the task as long as at least one upstream succeeded and none failed - exactly the right semantics for a post-branch aggregation or notification task. Alternatively, ALL_DONE runs after all upstream tasks complete regardless of their state, including failures - useful for cleanup tasks.


Q: Why should you use mode="reschedule" for sensors instead of mode="poke"?

mode="poke" (the default) keeps the worker process alive and sleeping between polls. If you have 10 sensors each polling every 5 minutes with a 4-hour timeout, that is 10 worker slots occupied for 4 hours - slots that could be used by other tasks. mode="reschedule" releases the worker slot between polls. The sensor task transitions to up_for_reschedule state, the worker finishes, and a few minutes later the Scheduler re-queues the sensor task for another poll. The only case where poke is justified is when the polling interval is very short (under 30 seconds) and the overhead of re-queueing and scheduling exceeds the cost of holding the slot.


Q: How do you make ML pipeline tasks idempotent, and why does it matter for backfill?

A task is idempotent if running it multiple times for the same logical date produces the same result without side effects. For ML pipelines, this means: (1) writing model artifacts to paths that include the execution date, never overwriting a fixed path; (2) creating MLflow runs with a deterministic name that includes the date - and checking if a run with that name already exists before starting a new one; (3) for data loading tasks, using upsert semantics rather than append. Idempotency matters for backfill because Airflow's backfill command runs the same DAG multiple times for historical dates. If your tasks are not idempotent, a backfill will create duplicate model versions, duplicate MLflow runs, and corrupted evaluation results. It also matters for retries - if a task fails halfway through and is retried, the retry should not create a partially-duplicate artifact.


Q: How would you design an Airflow pipeline that trains 12 model variants in parallel and picks the best one?

Use dynamic task mapping. Create a @task that returns a list of hyperparameter configurations. Use .expand() on the training task to create one task instance per configuration - all run in parallel (up to the worker pool concurrency limit). After all variants complete, a downstream aggregation task receives the list of all results (Airflow collects them automatically) and selects the best one by metric. This is cleaner than the older pattern of generating Operators dynamically at parse time, because the list of variants is computed at runtime rather than at DAG parse time - so you can generate it from a database query, a configuration file in S3, or an external service. The fan-out also respects Pools, so you can limit the number of simultaneous training pods to avoid overloading your GPU cluster.


Feature Store Integration

In mature ML platforms, features are not computed inside the training pipeline - they are read from a feature store. The training pipeline becomes a consumer of pre-computed features, not a producer. This separation of concerns reduces training pipeline complexity and ensures that the same feature definitions used for training are also used for online inference.

from airflow.decorators import task

@task(retries=2)
def fetch_features_from_store(entity_ids: list[str], feature_names: list[str], event_timestamp: str) -> str:
"""
Pull a training dataset from Feast (feature store).
Returns the path to the materialized training dataset in S3.
"""
from feast import FeatureStore
import pandas as pd

store = FeatureStore(repo_path="/opt/feast/feature_repo")

# Build the entity dataframe - tells Feast which entities and at what point in time
entity_df = pd.DataFrame({
"user_id": entity_ids,
"event_timestamp": pd.Timestamp(event_timestamp, tz="UTC"),
})

# Feast handles point-in-time correct feature retrieval
# No future leakage - features are fetched as of event_timestamp per entity
training_df = store.get_historical_features(
entity_df=entity_df,
features=[
"user_stats:days_since_last_login",
"user_stats:payment_failure_rate",
"user_stats:avg_session_length_7d",
"contract_features:plan_tier",
"support_features:tickets_90d",
],
).to_df()

output_path = f"s3://data/training/churn/{event_timestamp}/features.parquet"
training_df.to_parquet(output_path, index=False)
return output_path

The key advantage of point-in-time correct retrieval is that it prevents training-serving skew and data leakage - features are fetched exactly as they would have been available at the time of the prediction event, not as they exist today.


Model Registry Workflow in Detail

The MLflow model registry integrates tightly with Airflow ML pipelines. Understanding the full lifecycle - NoneStagingProductionArchived - prevents common mistakes:

@task(name="register-and-stage-model")
def register_and_stage_model(run_id: str, model_name: str) -> str:
"""
Register the trained model in MLflow and move it to Staging.
Returns the model version number.
"""
import mlflow

client = mlflow.tracking.MlflowClient()

# Register the model - this creates a new version in state "None"
model_uri = f"runs:/{run_id}/model"
registered = mlflow.register_model(model_uri, model_name)
version = registered.version

# Wait for model to finish registering (it runs async)
import time
for _ in range(30):
mv = client.get_model_version(model_name, version)
if mv.status == "READY":
break
time.sleep(2)
else:
raise TimeoutError(f"Model version {version} did not reach READY state")

# Transition to Staging for evaluation
client.transition_model_version_stage(
name=model_name,
version=version,
stage="Staging",
archive_existing_versions=False, # keep previous staging version for comparison
)

# Add descriptive tags
client.set_model_version_tag(model_name, version, "train_date", run_id[:10])
client.set_model_version_tag(model_name, version, "pipeline", "weekly_churn_training")

return version


@task(name="shadow-evaluate-model")
def shadow_evaluate_model(version: str, model_name: str) -> dict:
"""
Run the new model on the shadow traffic dataset - a held-out recent slice
that was not used during training or validation.
"""
import mlflow, pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score

client = mlflow.tracking.MlflowClient()
model = mlflow.sklearn.load_model(f"models:/{model_name}/{version}")

# Shadow dataset - last 7 days of actual outcomes (never seen during training)
shadow_df = pd.read_parquet("s3://data/evaluation/shadow_set/latest.parquet")
X = shadow_df.drop(["churned", "user_id"], axis=1)
y = shadow_df["churned"]

y_prob = model.predict_proba(X)[:, 1]
shadow_auc = roc_auc_score(y, y_prob)
shadow_ap = average_precision_score(y, y_prob)

# Log shadow evaluation to MLflow
with mlflow.start_run(run_id=client.get_model_version(model_name, version).run_id):
mlflow.log_metric("shadow_auc", shadow_auc)
mlflow.log_metric("shadow_ap", shadow_ap)

return {"shadow_auc": shadow_auc, "shadow_ap": shadow_ap, "version": version}

Handling Long-Running Tasks and Timeouts

Training jobs sometimes hang - a GPU runs out of memory and the process waits indefinitely, or a network connection to S3 stalls. Without timeouts, hung tasks occupy worker slots indefinitely.

from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from datetime import timedelta

# Task-level timeout - Airflow kills the task after this duration
train_task = KubernetesPodOperator(
task_id="train_model",
execution_timeout=timedelta(hours=5), # kill task if it runs more than 5 hours
# ...
)

# For PythonOperator - wrap long-running logic with a signal-based timeout
@task(execution_timeout=timedelta(hours=2))
def long_running_computation() -> dict:
# If this task runs for more than 2 hours,
# Airflow raises AirflowTaskTimeout and the task enters 'failed' state
result = run_expensive_algorithm()
return result

For KubernetesPodOperator, set both an Airflow execution_timeout AND a container-level timeout (via environment variable or training script argument). The Airflow timeout kills the pod from the orchestration side; the container timeout gives the training code a chance to save a checkpoint before being killed.


Trigger Rules Reference

Trigger rules control when a task is eligible to run based on the state of its upstream tasks. This is a frequent source of bugs in complex DAGs.

Trigger RuleRuns When
ALL_SUCCESS (default)All upstream tasks succeeded
ALL_FAILEDAll upstream tasks failed
ALL_DONEAll upstream tasks completed (success, failure, or skip)
ONE_SUCCESSAt least one upstream succeeded (does not wait for others)
ONE_FAILEDAt least one upstream failed
NONE_FAILEDNo upstream task failed - skipped upstream is OK
NONE_FAILED_MIN_ONE_SUCCESSNo failures, at least one success
DUMMYAlways runs regardless of upstream state

The most important rule to understand: ALL_SUCCESS + a BranchPythonOperator upstream = the task will be skipped if the non-chosen branch is in its ancestry. Always use NONE_FAILED_MIN_ONE_SUCCESS for tasks that should run after a branch regardless of which path was taken.

from airflow.utils.trigger_rule import TriggerRule

cleanup_task = PythonOperator(
task_id="cleanup_temp_files",
python_callable=cleanup,
trigger_rule=TriggerRule.ALL_DONE, # run cleanup even if upstream tasks failed
)

notify_task = PythonOperator(
task_id="send_notification",
python_callable=notify,
trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, # run after either deploy or skip branch
)
© 2026 EngineersOfAI. All rights reserved.