Skip to content

Commit

Permalink
do not update DR on TI update after task execution (apache#45348)
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
  • Loading branch information
mobuchowski authored Jan 8, 2025
1 parent 2cd40ca commit 586f1ea
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
10 changes: 7 additions & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _run_raw_task(
ti.hostname = get_hostname()
ti.pid = os.getpid()
if not test_mode:
TaskInstance.save_to_db(ti=ti, session=session)
TaskInstance.save_to_db(ti=ti, session=session, refresh_dag=False)
actual_start_date = timezone.utcnow()
Stats.incr(f"ti.start.{ti.task.dag_id}.{ti.task.task_id}", tags=ti.stats_tags)
# Same metric with tagging
Expand Down Expand Up @@ -1241,7 +1241,7 @@ def _handle_failure(
)

if not test_mode:
TaskInstance.save_to_db(failure_context["ti"], session)
TaskInstance.save_to_db(task_instance, session)

with Trace.start_span_from_taskinstance(ti=task_instance) as span:
# ---- error info ----
Expand Down Expand Up @@ -3395,7 +3395,11 @@ def fetch_handle_failure_context(
@staticmethod
@internal_api_call
@provide_session
def save_to_db(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION):
def save_to_db(
ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION, refresh_dag: bool = True
):
if refresh_dag and isinstance(ti, TaskInstance):
ti.get_dagrun().refresh_from_db()
ti = _coalesce_to_orm_ti(ti=ti, session=session)
ti.updated_at = timezone.utcnow()
session.merge(ti)
Expand Down
41 changes: 40 additions & 1 deletion tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pickle
import signal
import sys
import time
import urllib
from traceback import format_exception
from typing import cast
Expand All @@ -34,6 +35,7 @@
from uuid import uuid4

import pendulum
import psutil
import pytest
import time_machine
from sqlalchemy import select
Expand Down Expand Up @@ -83,7 +85,7 @@
from airflow.sensors.base import BaseSensorOperator
from airflow.sensors.python import PythonSensor
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
from airflow.settings import TIMEZONE, TracebackSessionForTests
from airflow.settings import TIMEZONE, TracebackSessionForTests, reconfigure_orm
from airflow.stats import Stats
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
Expand Down Expand Up @@ -3587,6 +3589,43 @@ def test_handle_failure(self, create_dummy_dag, session=None):
assert "task_instance" in context_arg_3
mock_on_retry_3.assert_not_called()

@provide_session
def test_handle_failure_does_not_push_stale_dagrun_model(self, dag_maker, create_dummy_dag, session=None):
session = settings.Session()
with dag_maker():

def method(): ...

task = PythonOperator(task_id="mytask", python_callable=method)
dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task.task_id)
ti.state = State.RUNNING

assert dr.state == DagRunState.RUNNING

session.merge(ti)
session.flush()
session.commit()

pid = os.fork()
if pid:
process = psutil.Process(pid)
time.sleep(1)

dr.state = DagRunState.SUCCESS
session.merge(dr)
session.flush()
session.commit()
process.wait(timeout=7)
else:
reconfigure_orm(disable_connection_pool=True)
time.sleep(2)
ti.handle_failure("should not update related models")
os._exit(0)

dr.refresh_from_db()
assert dr.state == DagRunState.SUCCESS

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
def test_handle_failure_updates_queued_task_updates_state(self, dag_maker):
session = settings.Session()
Expand Down

0 comments on commit 586f1ea

Please sign in to comment.