diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index 76291c8a057f9..6e00f718be25a 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -27,6 +27,7 @@ from airflow.models.taskinstance import PAST_DEPENDS_MET from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.state import TaskInstanceState +from airflow.utils.task_group import MappedTaskGroup from airflow.utils.trigger_rule import TriggerRule as TR if TYPE_CHECKING: @@ -63,8 +64,7 @@ def calculate(cls, finished_upstreams: Iterator[TaskInstance]) -> _UpstreamTISta ``counter`` is inclusive of ``setup_counter`` -- e.g. if there are 2 skipped upstreams, one of which is a setup, then counter will show 2 skipped and setup counter will show 1. - :param ti: the ti that we want to calculate deps for - :param finished_tis: all the finished tasks of the dag_run + :param finished_upstreams: all the finished upstreams of the dag_run """ counter: dict[str, int] = Counter() setup_counter: dict[str, int] = Counter() @@ -143,6 +143,19 @@ def _get_expanded_ti_count() -> int: return ti.task.get_mapped_ti_count(ti.run_id, session=session) + def _iter_expansion_dependencies(task_group: MappedTaskGroup) -> Iterator[str]: + from airflow.models.mappedoperator import MappedOperator + + if isinstance(ti.task, MappedOperator): + for op in ti.task.iter_mapped_dependencies(): + yield op.task_id + if task_group and task_group.iter_mapped_task_groups(): + yield from ( + op.task_id + for tg in task_group.iter_mapped_task_groups() + for op in tg.iter_mapped_dependencies() + ) + @functools.lru_cache def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: """ @@ -156,6 +169,13 @@ def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: assert ti.task assert isinstance(ti.task.dag, DAG) + if isinstance(ti.task.task_group, MappedTaskGroup): + is_fast_triggered = ti.task.trigger_rule in (TR.ONE_SUCCESS, TR.ONE_FAILED, TR.ONE_DONE) + if is_fast_triggered and upstream_id not in set( + _iter_expansion_dependencies(task_group=ti.task.task_group) + ): + return None + try: expanded_ti_count = _get_expanded_ti_count() except (NotFullyPopulated, NotMapped): @@ -217,7 +237,7 @@ def _iter_upstream_conditions(relevant_tasks: dict) -> Iterator[ColumnOperators] for upstream_id in relevant_tasks: map_indexes = _get_relevant_upstream_map_indexes(upstream_id) if map_indexes is None: # All tis of this upstream are dependencies. - yield (TaskInstance.task_id == upstream_id) + yield TaskInstance.task_id == upstream_id continue # At this point we know we want to depend on only selected tis # of this upstream task. Since the upstream may not have been @@ -237,11 +257,9 @@ def _iter_upstream_conditions(relevant_tasks: dict) -> Iterator[ColumnOperators] def _evaluate_setup_constraint(*, relevant_setups) -> Iterator[tuple[TIDepStatus, bool]]: """ - Evaluate whether ``ti``'s trigger rule was met. + Evaluate whether ``ti``'s trigger rule was met as part of the setup constraint. - :param ti: Task instance to evaluate the trigger rule of. - :param dep_context: The current dependency context. - :param session: Database session. + :param relevant_setups: Relevant setups for the current task instance. """ if TYPE_CHECKING: assert ti.task @@ -327,13 +345,7 @@ def _evaluate_setup_constraint(*, relevant_setups) -> Iterator[tuple[TIDepStatus ) def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: - """ - Evaluate whether ``ti``'s trigger rule was met. - - :param ti: Task instance to evaluate the trigger rule of. - :param dep_context: The current dependency context. - :param session: Database session. - """ + """Evaluate whether ``ti``'s trigger rule in direct relatives was met.""" if TYPE_CHECKING: assert ti.task @@ -433,7 +445,7 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: ) if not past_depends_met: yield self._failing_status( - reason=("Task should be skipped but the past depends are not met") + reason="Task should be skipped but the past depends are not met" ) return changed = ti.set_state(new_state, session) diff --git a/newsfragments/44937.bugfix.rst b/newsfragments/44937.bugfix.rst new file mode 100644 index 0000000000000..d50da4de82fc9 --- /dev/null +++ b/newsfragments/44937.bugfix.rst @@ -0,0 +1 @@ +Fix pre-mature evaluation of tasks in mapped task group. The origins of the bug are in ``TriggerRuleDep``, when dealing with ``TriggerRule`` that is fastly triggered (i.e, ``ONE_FAILED``, ``ONE_SUCCESS`, or ``ONE_DONE``). Please note that at time of merging, this fix has been applied only for Airflow version > 2.10.4 and < 3, and should be ported to v3 after merging PR #40460. diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index cf547912fb924..d1e896200c7e6 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -37,7 +37,7 @@ from airflow.models.taskmap import TaskMap from airflow.models.xcom_arg import XComArg from airflow.operators.python import PythonOperator -from airflow.utils.state import TaskInstanceState +from airflow.utils.state import State, TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.trigger_rule import TriggerRule @@ -1784,3 +1784,87 @@ def group(n: int) -> None: "group.last": {0: "success", 1: "skipped", 2: "success"}, } assert states == expected + + +def test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete(dag_maker, session): + """Test that one failed trigger rule works well in mapped task group""" + with dag_maker() as dag: + + @dag.task + def t1(): + return [1, 2, 3] + + @task_group("tg1") + def tg1(a): + @dag.task() + def t2(a): + return a + + @dag.task(trigger_rule=TriggerRule.ONE_FAILED) + def t3(a): + return a + + t2(a) >> t3(a) + + t = t1() + tg1.expand(a=t) + + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance(task_id="t1") + ti.run() + dr.task_instance_scheduling_decisions() + ti3 = dr.get_task_instance(task_id="tg1.t3") + assert not ti3.state + + +def test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete__mapped_skip_with_all_success( + dag_maker, session +): + with dag_maker(): + + @task + def make_list(): + return [4, 42, 2] + + @task + def double(n): + if n == 42: + raise AirflowSkipException("42") + return n * 2 + + @task + def last(n): + print(n) + + @task_group + def group(n: int) -> None: + last(double(n)) + + list = make_list() + group.expand(n=list) + + dr = dag_maker.create_dagrun() + + def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]: + decision = dr.task_instance_scheduling_decisions(session=session) + return {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} + + tis = _one_scheduling_decision_iteration() + tis["make_list", -1].run() + assert tis["make_list", -1].state == State.SUCCESS + + tis = _one_scheduling_decision_iteration() + tis["group.double", 0].run() + tis["group.double", 1].run() + tis["group.double", 2].run() + + assert tis["group.double", 0].state == State.SUCCESS + assert tis["group.double", 1].state == State.SKIPPED + assert tis["group.double", 2].state == State.SUCCESS + + tis = _one_scheduling_decision_iteration() + tis["group.last", 0].run() + tis["group.last", 2].run() + assert tis["group.last", 0].state == State.SUCCESS + assert dr.get_task_instance("group.last", map_index=1, session=session).state == State.SKIPPED + assert tis["group.last", 2].state == State.SUCCESS