Skip to content

Commit

Permalink
[v2-10-test] Fix premature evaluation in mapped task group (apache#44937
Browse files Browse the repository at this point in the history
)

* Fix docstrings and warnings in trigger_rule_dep.py

* Fix pre-mature evaluation of tasks in mapped task group

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
Co-authored-by: Ephraim Anierobi <splendidzigy24@gmail.com>

* Add newsfragment

---------

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
Co-authored-by: Ephraim Anierobi <splendidzigy24@gmail.com>
  • Loading branch information
3 people authored Dec 17, 2024
1 parent 0ad24cf commit 4b27c3f
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 16 deletions.
42 changes: 27 additions & 15 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from airflow.models.taskinstance import PAST_DEPENDS_MET
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.trigger_rule import TriggerRule as TR

if TYPE_CHECKING:
Expand Down Expand Up @@ -63,8 +64,7 @@ def calculate(cls, finished_upstreams: Iterator[TaskInstance]) -> _UpstreamTISta
``counter`` is inclusive of ``setup_counter`` -- e.g. if there are 2 skipped upstreams, one
of which is a setup, then counter will show 2 skipped and setup counter will show 1.
:param ti: the ti that we want to calculate deps for
:param finished_tis: all the finished tasks of the dag_run
:param finished_upstreams: all the finished upstreams of the dag_run
"""
counter: dict[str, int] = Counter()
setup_counter: dict[str, int] = Counter()
Expand Down Expand Up @@ -143,6 +143,19 @@ def _get_expanded_ti_count() -> int:

return ti.task.get_mapped_ti_count(ti.run_id, session=session)

def _iter_expansion_dependencies(task_group: MappedTaskGroup) -> Iterator[str]:
from airflow.models.mappedoperator import MappedOperator

if isinstance(ti.task, MappedOperator):
for op in ti.task.iter_mapped_dependencies():
yield op.task_id
if task_group and task_group.iter_mapped_task_groups():
yield from (
op.task_id
for tg in task_group.iter_mapped_task_groups()
for op in tg.iter_mapped_dependencies()
)

@functools.lru_cache
def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None:
"""
Expand All @@ -156,6 +169,13 @@ def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None:
assert ti.task
assert isinstance(ti.task.dag, DAG)

if isinstance(ti.task.task_group, MappedTaskGroup):
is_fast_triggered = ti.task.trigger_rule in (TR.ONE_SUCCESS, TR.ONE_FAILED, TR.ONE_DONE)
if is_fast_triggered and upstream_id not in set(
_iter_expansion_dependencies(task_group=ti.task.task_group)
):
return None

try:
expanded_ti_count = _get_expanded_ti_count()
except (NotFullyPopulated, NotMapped):
Expand Down Expand Up @@ -217,7 +237,7 @@ def _iter_upstream_conditions(relevant_tasks: dict) -> Iterator[ColumnOperators]
for upstream_id in relevant_tasks:
map_indexes = _get_relevant_upstream_map_indexes(upstream_id)
if map_indexes is None: # All tis of this upstream are dependencies.
yield (TaskInstance.task_id == upstream_id)
yield TaskInstance.task_id == upstream_id
continue
# At this point we know we want to depend on only selected tis
# of this upstream task. Since the upstream may not have been
Expand All @@ -237,11 +257,9 @@ def _iter_upstream_conditions(relevant_tasks: dict) -> Iterator[ColumnOperators]

def _evaluate_setup_constraint(*, relevant_setups) -> Iterator[tuple[TIDepStatus, bool]]:
"""
Evaluate whether ``ti``'s trigger rule was met.
Evaluate whether ``ti``'s trigger rule was met as part of the setup constraint.
:param ti: Task instance to evaluate the trigger rule of.
:param dep_context: The current dependency context.
:param session: Database session.
:param relevant_setups: Relevant setups for the current task instance.
"""
if TYPE_CHECKING:
assert ti.task
Expand Down Expand Up @@ -327,13 +345,7 @@ def _evaluate_setup_constraint(*, relevant_setups) -> Iterator[tuple[TIDepStatus
)

def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
"""
Evaluate whether ``ti``'s trigger rule was met.
:param ti: Task instance to evaluate the trigger rule of.
:param dep_context: The current dependency context.
:param session: Database session.
"""
"""Evaluate whether ``ti``'s trigger rule in direct relatives was met."""
if TYPE_CHECKING:
assert ti.task

Expand Down Expand Up @@ -433,7 +445,7 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
)
if not past_depends_met:
yield self._failing_status(
reason=("Task should be skipped but the past depends are not met")
reason="Task should be skipped but the past depends are not met"
)
return
changed = ti.set_state(new_state, session)
Expand Down
1 change: 1 addition & 0 deletions newsfragments/44937.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix pre-mature evaluation of tasks in mapped task group. The origins of the bug are in ``TriggerRuleDep``, when dealing with ``TriggerRule`` that is fastly triggered (i.e, ``ONE_FAILED``, ``ONE_SUCCESS`, or ``ONE_DONE``). Please note that at time of merging, this fix has been applied only for Airflow version > 2.10.4 and < 3, and should be ported to v3 after merging PR #40460.
86 changes: 85 additions & 1 deletion tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from airflow.models.taskmap import TaskMap
from airflow.models.xcom_arg import XComArg
from airflow.operators.python import PythonOperator
from airflow.utils.state import TaskInstanceState
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.task_instance_session import set_current_task_instance_session
from airflow.utils.trigger_rule import TriggerRule
Expand Down Expand Up @@ -1784,3 +1784,87 @@ def group(n: int) -> None:
"group.last": {0: "success", 1: "skipped", 2: "success"},
}
assert states == expected


def test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete(dag_maker, session):
"""Test that one failed trigger rule works well in mapped task group"""
with dag_maker() as dag:

@dag.task
def t1():
return [1, 2, 3]

@task_group("tg1")
def tg1(a):
@dag.task()
def t2(a):
return a

@dag.task(trigger_rule=TriggerRule.ONE_FAILED)
def t3(a):
return a

t2(a) >> t3(a)

t = t1()
tg1.expand(a=t)

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task_id="t1")
ti.run()
dr.task_instance_scheduling_decisions()
ti3 = dr.get_task_instance(task_id="tg1.t3")
assert not ti3.state


def test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete__mapped_skip_with_all_success(
dag_maker, session
):
with dag_maker():

@task
def make_list():
return [4, 42, 2]

@task
def double(n):
if n == 42:
raise AirflowSkipException("42")
return n * 2

@task
def last(n):
print(n)

@task_group
def group(n: int) -> None:
last(double(n))

list = make_list()
group.expand(n=list)

dr = dag_maker.create_dagrun()

def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]:
decision = dr.task_instance_scheduling_decisions(session=session)
return {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}

tis = _one_scheduling_decision_iteration()
tis["make_list", -1].run()
assert tis["make_list", -1].state == State.SUCCESS

tis = _one_scheduling_decision_iteration()
tis["group.double", 0].run()
tis["group.double", 1].run()
tis["group.double", 2].run()

assert tis["group.double", 0].state == State.SUCCESS
assert tis["group.double", 1].state == State.SKIPPED
assert tis["group.double", 2].state == State.SUCCESS

tis = _one_scheduling_decision_iteration()
tis["group.last", 0].run()
tis["group.last", 2].run()
assert tis["group.last", 0].state == State.SUCCESS
assert dr.get_task_instance("group.last", map_index=1, session=session).state == State.SKIPPED
assert tis["group.last", 2].state == State.SUCCESS

0 comments on commit 4b27c3f

Please sign in to comment.