Настройка Apache Airflow для ML-пайплайнов
Apache Airflow — зрелый оркестратор DAG-пайплайнов. Для ML используется там, где нужна гибкость в смешивании data engineering и ML шагов, или где Airflow уже используется для других ETL задач.
Airflow vs Kubeflow для ML
Airflow предпочтительнее когда: уже используется в data engineering, нужна интеграция с ML и non-ML задачами в одном DAG, команда знает Airflow.
Kubeflow Pipelines предпочтительнее когда: ML-центричная команда, нужны нативные ML-примитивы (metrics, artifacts), Kubernetes-native workflow.
Установка с KubernetesExecutor
# Установка через Helm (рекомендуется)
helm repo add apache-airflow https://airflow.apache.org
helm upgrade --install airflow apache-airflow/airflow \
--namespace airflow \
--create-namespace \
--set executor=KubernetesExecutor \
--set config.logging.logging_level=INFO \
--values airflow-values.yaml
ML-пайплайн как Airflow DAG
from airflow import DAG
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.operators.python import PythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from datetime import datetime, timedelta
default_args = {
"owner": "ml-team",
"retries": 2,
"retry_delay": timedelta(minutes=5),
"on_failure_callback": notify_on_slack,
}
with DAG(
"fraud_detection_training",
default_args=default_args,
schedule="0 2 * * 1", # по понедельникам в 2:00
start_date=datetime(2024, 1, 1),
catchup=False,
tags=["ml", "fraud-detection"],
) as dag:
# Подготовка данных — на обычном поде
prepare_data = KubernetesPodOperator(
task_id="prepare_data",
image="ml-pipeline:latest",
cmds=["python", "prepare_data.py"],
arguments=["--date={{ ds }}", "--output=s3://bucket/features/{{ ds }}/"],
namespace="ml-pipelines",
resources={"request_memory": "4Gi", "request_cpu": "2"},
get_logs=True,
is_delete_operator_pod=True,
)
# Обучение — на GPU поде
train_model = KubernetesPodOperator(
task_id="train_model",
image="ml-pipeline-gpu:latest",
cmds=["python", "train.py"],
arguments=[
"--data=s3://bucket/features/{{ ds }}/",
"--run-name=fraud-{{ ds }}",
],
namespace="ml-pipelines",
resources={
"request_memory": "32Gi",
"request_cpu": "8",
"limit_gpu": "1",
},
annotations={"nvidia.com/gpu": "1"},
tolerations=[{"key": "nvidia.com/gpu", "operator": "Exists", "effect": "NoSchedule"}],
get_logs=True,
)
# Evaluation gate — Python оператор (дешево)
def check_model_quality(**context):
import mlflow
client = mlflow.tracking.MlflowClient()
run = client.search_runs(
experiment_ids=[EXPERIMENT_ID],
filter_string=f"tags.run_date = '{context['ds']}'",
order_by=["metrics.f1 DESC"],
max_results=1
)[0]
f1 = run.data.metrics.get("test_f1", 0)
if f1 < 0.90:
raise ValueError(f"Model quality too low: F1={f1:.3f} < 0.90")
context["ti"].xcom_push(key="run_id", value=run.info.run_id)
quality_gate = PythonOperator(
task_id="quality_gate",
python_callable=check_model_quality,
)
# Промоция — только если quality_gate прошёл
promote_model = KubernetesPodOperator(
task_id="promote_to_staging",
image="ml-pipeline:latest",
cmds=["python", "promote_model.py"],
arguments=["--run-id={{ ti.xcom_pull(task_ids='quality_gate', key='run_id') }}"],
namespace="ml-pipelines",
)
# Зависимости
prepare_data >> train_model >> quality_gate >> promote_model
TaskFlow API (современный подход)
from airflow.decorators import dag, task
@dag(schedule="0 2 * * 1", start_date=datetime(2024, 1, 1))
def ml_pipeline():
@task
def prepare_data(execution_date: str) -> str:
# Подготовка данных
return f"s3://bucket/features/{execution_date}/"
@task
def train_model(data_path: str) -> dict:
# Запуск обучения (или триггер внешнего job)
return {"run_id": "xxx", "f1": 0.924}
@task
def promote_if_good(metrics: dict) -> None:
if metrics["f1"] >= 0.90:
promote_to_staging(metrics["run_id"])
data = prepare_data()
metrics = train_model(data)
promote_if_good(metrics)
ml_pipeline()
Мониторинг Airflow DAG
Airflow UI показывает: статус каждого запуска, длительность каждого task, логи. Интеграция с Prometheus через airflow-exporter: airflow_dag_run_duration_seconds, airflow_task_fail_count. Алерт при failed task через Slack/PagerDuty через on_failure_callback.







