diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 2c1f52fc40296..f84f898f3384c 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -247,7 +247,7 @@ def _run_raw_task( ti.hostname = get_hostname() ti.pid = os.getpid() if not test_mode: - TaskInstance.save_to_db(ti=ti, session=session) + TaskInstance.save_to_db(ti=ti, session=session, refresh_dag=False) actual_start_date = timezone.utcnow() Stats.incr(f"ti.start.{ti.task.dag_id}.{ti.task.task_id}", tags=ti.stats_tags) # Same metric with tagging @@ -1241,7 +1241,7 @@ def _handle_failure( ) if not test_mode: - TaskInstance.save_to_db(failure_context["ti"], session) + TaskInstance.save_to_db(task_instance, session) with Trace.start_span_from_taskinstance(ti=task_instance) as span: # ---- error info ---- @@ -3395,7 +3395,11 @@ def fetch_handle_failure_context( @staticmethod @internal_api_call @provide_session - def save_to_db(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION): + def save_to_db( + ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION, refresh_dag: bool = True + ): + if refresh_dag and isinstance(ti, TaskInstance): + ti.get_dagrun().refresh_from_db() ti = _coalesce_to_orm_ti(ti=ti, session=session) ti.updated_at = timezone.utcnow() session.merge(ti) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 9b7b5d957122d..bb877704339cb 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -26,6 +26,7 @@ import pickle import signal import sys +import time import urllib from traceback import format_exception from typing import cast @@ -34,6 +35,7 @@ from uuid import uuid4 import pendulum +import psutil import pytest import time_machine from sqlalchemy import select @@ -83,7 +85,7 @@ from airflow.sensors.base import BaseSensorOperator from airflow.sensors.python import PythonSensor from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG -from airflow.settings import TIMEZONE, TracebackSessionForTests +from airflow.settings import TIMEZONE, TracebackSessionForTests, reconfigure_orm from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS @@ -3587,6 +3589,43 @@ def test_handle_failure(self, create_dummy_dag, session=None): assert "task_instance" in context_arg_3 mock_on_retry_3.assert_not_called() + @provide_session + def test_handle_failure_does_not_push_stale_dagrun_model(self, dag_maker, create_dummy_dag, session=None): + session = settings.Session() + with dag_maker(): + + def method(): ... + + task = PythonOperator(task_id="mytask", python_callable=method) + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance(task.task_id) + ti.state = State.RUNNING + + assert dr.state == DagRunState.RUNNING + + session.merge(ti) + session.flush() + session.commit() + + pid = os.fork() + if pid: + process = psutil.Process(pid) + time.sleep(1) + + dr.state = DagRunState.SUCCESS + session.merge(dr) + session.flush() + session.commit() + process.wait(timeout=7) + else: + reconfigure_orm(disable_connection_pool=True) + time.sleep(2) + ti.handle_failure("should not update related models") + os._exit(0) + + dr.refresh_from_db() + assert dr.state == DagRunState.SUCCESS + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_handle_failure_updates_queued_task_updates_state(self, dag_maker): session = settings.Session()