From 04806231e4411f37faa3d97f7b9e9fe2c0409303 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 9 Jan 2025 12:45:53 +0530 Subject: [PATCH] AIP-72: Add support for `get_current_context` in Task SDK (#45486) closes https://github.com/apache/airflow/issues/45234 I am putting the logic for `set_current_context` in `execution_time/context.py`. I didn't want to put `_CURRENT_CONTEXT` in `task_sdk/src/airflow/sdk/definitions/contextmanager.py` to avoid execution logic in a user-facing module but I couldn't think of another way to store it from execution & allow retrieving (via `get_current_context` in the Standard Provider) in their Task. Upcoming PRs: - Move most of the internal stuff in Task SDK to a separate module. - Use `create_runtime_ti` fixture more widely in tests --- Tested with the following DAG: ```py import pendulum from airflow.decorators import dag, task from airflow.providers.standard.operators.python import get_current_context @dag( schedule=None, start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, ) def x_get_context(): @task def template_test(data_interval_end): context = get_current_context() # Will print `2024-10-10 00:00:00+00:00`. # Note how we didn't pass this value when calling the task. Instead # it was passed by the decorator from the context print(f"data_interval_end: {data_interval_end}") # Will print the full context dict print(f"context: {context}") template_test() x_get_context() ``` image --- airflow/models/taskinstance.py | 2 +- .../providers/standard/operators/python.py | 25 +++-- .../tests/standard/operators/test_python.py | 6 +- task_sdk/src/airflow/sdk/__init__.py | 3 + .../airflow/sdk/definitions/contextmanager.py | 46 +++++++- .../src/airflow/sdk/execution_time/context.py | 27 ++++- .../airflow/sdk/execution_time/task_runner.py | 65 ++++++----- .../tests/defintions/test_contextmanager.py | 39 +++++++ task_sdk/tests/execution_time/conftest.py | 101 ++++++++++++++++++ task_sdk/tests/execution_time/test_context.py | 42 ++++++++ .../tests/execution_time/test_task_runner.py | 60 +++++------ 11 files changed, 338 insertions(+), 78 deletions(-) create mode 100644 task_sdk/tests/defintions/test_contextmanager.py diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 27293fa2d022e..387ea9122e0cf 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -108,6 +108,7 @@ from airflow.plugins_manager import integrate_macros_plugins from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef from airflow.sdk.definitions.templater import SandboxedEnvironment +from airflow.sdk.execution_time.context import _CURRENT_CONTEXT from airflow.sentry import Sentry from airflow.settings import task_instance_mutation_hook from airflow.stats import Stats @@ -142,7 +143,6 @@ TR = TaskReschedule -_CURRENT_CONTEXT: list[Context] = [] log = logging.getLogger(__name__) diff --git a/providers/src/airflow/providers/standard/operators/python.py b/providers/src/airflow/providers/standard/operators/python.py index 25de405a80c9b..1207d349d10b1 100644 --- a/providers/src/airflow/providers/standard/operators/python.py +++ b/providers/src/airflow/providers/standard/operators/python.py @@ -43,14 +43,10 @@ ) from airflow.models.baseoperator import BaseOperator from airflow.models.skipmixin import SkipMixin -from airflow.models.taskinstance import _CURRENT_CONTEXT from airflow.models.variable import Variable from airflow.operators.branch import BranchMixIn from airflow.providers.standard.utils.python_virtualenv import prepare_virtualenv, write_python_script -from airflow.providers.standard.version_compat import ( - AIRFLOW_V_2_10_PLUS, - AIRFLOW_V_3_0_PLUS, -) +from airflow.providers.standard.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS from airflow.utils import hashlib_wrapper from airflow.utils.context import context_copy_partial, context_merge from airflow.utils.file import get_unique_dag_module_name @@ -1122,7 +1118,7 @@ def execute(self, context: Context) -> Any: return self.do_branch(context, super().execute(context)) -def get_current_context() -> Context: +def get_current_context() -> Mapping[str, Any]: """ Retrieve the execution context dictionary without altering user method's signature. @@ -1149,9 +1145,22 @@ def my_task(): Current context will only have value if this method was called after an operator was starting to execute. """ + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import get_current_context + + return get_current_context() + else: + return _get_current_context() + + +def _get_current_context() -> Mapping[str, Any]: + # Airflow 2.x + # TODO: To be removed when Airflow 2 support is dropped + from airflow.models.taskinstance import _CURRENT_CONTEXT + if not _CURRENT_CONTEXT: - raise AirflowException( + raise RuntimeError( "Current context was requested but no context was found! " - "Are you running within an airflow task?" + "Are you running within an Airflow task?" ) return _CURRENT_CONTEXT[-1] diff --git a/providers/tests/standard/operators/test_python.py b/providers/tests/standard/operators/test_python.py index 240815644895a..e0cbf9e3c2d15 100644 --- a/providers/tests/standard/operators/test_python.py +++ b/providers/tests/standard/operators/test_python.py @@ -1069,7 +1069,7 @@ def f(): with pytest.raises( AirflowException, match="Current context was requested but no context was found! " - "Are you running within an airflow task?", + "Are you running within an Airflow task?", ): self.run_as_task(f, return_ti=True, use_airflow_context=False) @@ -1890,7 +1890,7 @@ def default_kwargs(*, python_version=DEFAULT_PYTHON_VERSION, **kwargs): class TestCurrentContext: def test_current_context_no_context_raise(self): - with pytest.raises(AirflowException): + with pytest.raises(RuntimeError): get_current_context() def test_current_context_roundtrip(self): @@ -1904,7 +1904,7 @@ def test_context_removed_after_exit(self): with set_current_context(example_context): pass - with pytest.raises(AirflowException): + with pytest.raises(RuntimeError): get_current_context() def test_nested_context(self): diff --git a/task_sdk/src/airflow/sdk/__init__.py b/task_sdk/src/airflow/sdk/__init__.py index 1117e946f8ea6..a71ab7b2dd893 100644 --- a/task_sdk/src/airflow/sdk/__init__.py +++ b/task_sdk/src/airflow/sdk/__init__.py @@ -26,6 +26,7 @@ "TaskGroup", "dag", "Connection", + "get_current_context", "__version__", ] @@ -34,6 +35,7 @@ if TYPE_CHECKING: from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.connection import Connection + from airflow.sdk.definitions.contextmanager import get_current_context from airflow.sdk.definitions.dag import DAG, dag from airflow.sdk.definitions.edges import EdgeModifier, Label from airflow.sdk.definitions.taskgroup import TaskGroup @@ -47,6 +49,7 @@ "Label": ".definitions.edges", "Connection": ".definitions.connection", "Variable": ".definitions.variable", + "get_current_context": ".definitions.contextmanager", } diff --git a/task_sdk/src/airflow/sdk/definitions/contextmanager.py b/task_sdk/src/airflow/sdk/definitions/contextmanager.py index ee08bd19c908a..3880bb6e35700 100644 --- a/task_sdk/src/airflow/sdk/definitions/contextmanager.py +++ b/task_sdk/src/airflow/sdk/definitions/contextmanager.py @@ -19,6 +19,7 @@ import sys from collections import deque +from collections.abc import Mapping from types import ModuleType from typing import Any, Generic, TypeVar @@ -27,10 +28,47 @@ T = TypeVar("T") -__all__ = [ - "DagContext", - "TaskGroupContext", -] +__all__ = ["DagContext", "TaskGroupContext", "get_current_context"] + +# This is a global variable that stores the current Task context. +# It is used to push the Context dictionary when Task starts execution +# and it is used to retrieve the current context in PythonOperator or Taskflow API via +# the `get_current_context` function. +_CURRENT_CONTEXT: list[Mapping[str, Any]] = [] + + +def get_current_context() -> Mapping[str, Any]: + """ + Retrieve the execution context dictionary without altering user method's signature. + + This is the simplest method of retrieving the execution context dictionary. + + **Old style:** + + .. code:: python + + def my_task(**context): + ti = context["ti"] + + **New style:** + + .. code:: python + + from airflow.providers.standard.operators.python import get_current_context + + + def my_task(): + context = get_current_context() + ti = context["ti"] + + Current context will only have value if this method was called after an operator + was starting to execute. + """ + if not _CURRENT_CONTEXT: + raise RuntimeError( + "Current context was requested but no context was found! Are you running within an Airflow task?" + ) + return _CURRENT_CONTEXT[-1] # In order to add a `@classproperty`-like thing we need to define a property on a metaclass. diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index 72ac2af225e85..c5a1e9dbee47a 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -16,10 +16,13 @@ # under the License. from __future__ import annotations +import contextlib +from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any import structlog +from airflow.sdk.definitions.contextmanager import _CURRENT_CONTEXT from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType from airflow.sdk.types import NOTSET @@ -28,6 +31,8 @@ from airflow.sdk.definitions.variable import Variable from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult +log = structlog.get_logger(logger_name="task") + def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection: from airflow.sdk.definitions.connection import Connection @@ -55,7 +60,6 @@ def _get_connection(conn_id: str) -> Connection: from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - log = structlog.get_logger(logger_name="task") SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id)) msg = SUPERVISOR_COMMS.get_message() if isinstance(msg, ErrorResponse): @@ -75,7 +79,6 @@ def _get_variable(key: str, deserialize_json: bool) -> Variable: from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - log = structlog.get_logger(logger_name="task") SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key)) msg = SUPERVISOR_COMMS.get_message() if isinstance(msg, ErrorResponse): @@ -157,3 +160,23 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, MacrosAccessor): return False return True + + +@contextlib.contextmanager +def set_current_context(context: Mapping[str, Any]) -> Generator[Mapping[str, Any], None, None]: + """ + Set the current execution context to the provided context object. + + This method should be called once per Task execution, before calling operator.execute. + """ + _CURRENT_CONTEXT.append(context) + try: + yield context + finally: + expected_state = _CURRENT_CONTEXT.pop() + if expected_state != context: + log.warning( + "Current context is not equal to the state at context stack.", + expected=context, + got=expected_state, + ) diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 610556ce005e1..fd6155d91f84b 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -44,13 +44,21 @@ ToTask, XComResult, ) -from airflow.sdk.execution_time.context import ConnectionAccessor, MacrosAccessor, VariableAccessor +from airflow.sdk.execution_time.context import ( + ConnectionAccessor, + MacrosAccessor, + VariableAccessor, + set_current_context, +) if TYPE_CHECKING: import jinja2 from structlog.typing import FilteringBoundLogger as Logger +# TODO: Move this entire class into a separate file: +# `airflow/sdk/execution_time/task_instance.py` +# or `airflow/sdk/execution_time/runtime_ti.py` class RuntimeTaskInstance(TaskInstance): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -426,37 +434,18 @@ def run(ti: RuntimeTaskInstance, log: Logger): # TODO: Get a real context object ti.task = ti.task.prepare_for_execution() context = ti.get_template_context() - jinja_env = ti.task.dag.get_template_env() - ti.task = ti.render_templates(context=context, jinja_env=jinja_env) + with set_current_context(context): + jinja_env = ti.task.dag.get_template_env() + ti.task = ti.render_templates(context=context, jinja_env=jinja_env) + result = _execute_task(context, ti.task) + + _push_xcom_if_needed(result, ti) # TODO: Get things from _execute_task_with_callbacks # - Clearing XCom - # - Setting Current Context (set_current_context) - # - Render Templates # - Update RTIF # - Pre Execute # etc - - result = None - if ti.task.execution_timeout: - # TODO: handle timeout in case of deferral - from airflow.utils.timeout import timeout - - timeout_seconds = ti.task.execution_timeout.total_seconds() - try: - # It's possible we're already timed out, so fast-fail if true - if timeout_seconds <= 0: - raise AirflowTaskTimeout() - # Run task in timeout wrapper - with timeout(timeout_seconds): - result = ti.task.execute(context) # type: ignore[attr-defined] - except AirflowTaskTimeout: - # TODO: handle on kill callback here - raise - else: - result = ti.task.execute(context) # type: ignore[attr-defined] - - _push_xcom_if_needed(result, ti) msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc)) except TaskDeferred as defer: classpath, trigger_kwargs = defer.trigger.serialize() @@ -524,6 +513,30 @@ def run(ti: RuntimeTaskInstance, log: Logger): SUPERVISOR_COMMS.send_request(msg=msg, log=log) +def _execute_task(context: Mapping[str, Any], task: BaseOperator): + """Execute Task (optionally with a Timeout) and push Xcom results.""" + from airflow.exceptions import AirflowTaskTimeout + + if task.execution_timeout: + # TODO: handle timeout in case of deferral + from airflow.utils.timeout import timeout + + timeout_seconds = task.execution_timeout.total_seconds() + try: + # It's possible we're already timed out, so fast-fail if true + if timeout_seconds <= 0: + raise AirflowTaskTimeout() + # Run task in timeout wrapper + with timeout(timeout_seconds): + result = task.execute(context) # type: ignore[attr-defined] + except AirflowTaskTimeout: + # TODO: handle on kill callback here + raise + else: + result = task.execute(context) # type: ignore[attr-defined] + return result + + def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance): """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result.""" if ti.task.do_xcom_push: diff --git a/task_sdk/tests/defintions/test_contextmanager.py b/task_sdk/tests/defintions/test_contextmanager.py new file mode 100644 index 0000000000000..be624aff3d132 --- /dev/null +++ b/task_sdk/tests/defintions/test_contextmanager.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.sdk import get_current_context + + +class TestCurrentContext: + def test_current_context_no_context_raise(self): + with pytest.raises(RuntimeError): + get_current_context() + + def test_get_current_context_with_context(self, monkeypatch): + mock_context = {"ti": "task_instance", "key": "value"} + monkeypatch.setattr("airflow.sdk.definitions.contextmanager._CURRENT_CONTEXT", [mock_context]) + result = get_current_context() + assert result == mock_context + + def test_get_current_context_without_context(self, monkeypatch): + monkeypatch.setattr("airflow.sdk.definitions.contextmanager._CURRENT_CONTEXT", []) + with pytest.raises(RuntimeError, match="Current context was requested but no context was found!"): + get_current_context() diff --git a/task_sdk/tests/execution_time/conftest.py b/task_sdk/tests/execution_time/conftest.py index bf482e5ec7b03..032e67ae343ca 100644 --- a/task_sdk/tests/execution_time/conftest.py +++ b/task_sdk/tests/execution_time/conftest.py @@ -18,8 +18,17 @@ from __future__ import annotations import sys +from typing import TYPE_CHECKING from unittest import mock +if TYPE_CHECKING: + from collections.abc import Callable + + from airflow.sdk.api.datamodels._generated import TIRunContext + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.execution_time.comms import StartupDetails + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + import pytest @@ -40,3 +49,95 @@ def mock_supervisor_comms(): "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True ) as supervisor_comms: yield supervisor_comms + + +@pytest.fixture +def mocked_parse(spy_agency): + """ + Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you + want to isolate and test `parse` or `run` logic without having to define a DAG file. + + This fixture returns a helper function `set_dag` that: + 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task) + 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task. + 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`. + + After adding the fixture in your test function signature, you can use it like this :: + + mocked_parse( + StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), + file="", + requests_fd=0, + ), + "example_dag_id", + CustomOperator(task_id="hello"), + ) + """ + + def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse + from airflow.utils import timezone + + dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3)) + task.dag = dag + t = dag.task_dict[task.task_id] + ti = RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True), task=t) + spy_agency.spy_on(parse, call_fake=lambda _: ti) + return ti + + return set_dag + + +@pytest.fixture +def create_runtime_ti( + mocked_parse: Callable[[StartupDetails, str, BaseOperator], RuntimeTaskInstance], + make_ti_context: Callable[..., TIRunContext], +) -> Callable[[BaseOperator, TIRunContext | None, StartupDetails | None], RuntimeTaskInstance]: + """ + Fixture to create a Runtime TaskInstance for testing purposes without defining a dag file. + + This fixture sets up a `RuntimeTaskInstance` with default or custom `TIRunContext` and `StartupDetails`, + making it easy to simulate task execution scenarios in tests. + + Example usage: :: + + def test_custom_task_instance(create_runtime_ti): + class MyTaskOperator(BaseOperator): + def execute(self, context): + assert context["dag_run"].run_id == "test_run" + + task = MyTaskOperator(task_id="test_task") + ti = create_runtime_ti(task, context_from_server=make_ti_context(run_id="test_run")) + # Further test logic... + """ + from uuid6 import uuid7 + + from airflow.sdk.api.datamodels._generated import TaskInstance + from airflow.sdk.execution_time.comms import StartupDetails + + def _create_task_instance( + task, context_from_server: TIRunContext | None = None, startup_details: StartupDetails | None = None + ) -> RuntimeTaskInstance: + if context_from_server is None: + context_from_server = make_ti_context() + + if not startup_details: + startup_details = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id=task.task_id, + dag_id=context_from_server.dag_run.dag_id, + run_id=context_from_server.dag_run.run_id, + try_number=1, + ), + file="", + requests_fd=0, + ti_context=context_from_server, + ) + + ti = mocked_parse(startup_details, context_from_server.dag_run.dag_id, task) + return ti + + return _create_task_instance diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py index d3bf589c84e1d..21a79ae5c3eb9 100644 --- a/task_sdk/tests/execution_time/test_context.py +++ b/task_sdk/tests/execution_time/test_context.py @@ -19,7 +19,10 @@ from unittest.mock import MagicMock, patch +import pytest + from airflow.sdk.definitions.connection import Connection +from airflow.sdk.definitions.contextmanager import get_current_context from airflow.sdk.definitions.variable import Variable from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, VariableResult @@ -28,6 +31,7 @@ VariableAccessor, _convert_connection_result_conn, _convert_variable_result_to_variable, + set_current_context, ) @@ -206,3 +210,41 @@ def test_get_method_with_default(self, mock_supervisor_comms): var = accessor.get("nonexistent_var_key", default_var=default_var) assert var == default_var + + +class TestCurrentContext: + def test_current_context_roundtrip(self): + example_context = {"Hello": "World"} + + with set_current_context(example_context): + assert get_current_context() == example_context + + def test_context_removed_after_exit(self): + example_context = {"Hello": "World"} + + with set_current_context(example_context): + pass + with pytest.raises(RuntimeError): + get_current_context() + + def test_nested_context(self): + """ + Nested execution context should be supported in case the user uses multiple context managers. + Each time the execute method of an operator is called, we set a new 'current' context. + This test verifies that no matter how many contexts are entered - order is preserved + """ + max_stack_depth = 15 + ctx_list = [] + for i in range(max_stack_depth): + # Create all contexts in ascending order + new_context = {"ContextId": i} + # Like 15 nested with statements + ctx_obj = set_current_context(new_context) + ctx_obj.__enter__() + ctx_list.append(ctx_obj) + for i in reversed(range(max_stack_depth)): + # Iterate over contexts in reverse order - stack is LIFO + ctx = get_current_context() + assert ctx["ContextId"] == i + # End of with statement + ctx_list[i].__exit__(None, None, None) diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 2dc8c0ef5aad5..6a7b5743209f5 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -66,40 +66,6 @@ def get_inline_dag(dag_id: str, task: BaseOperator) -> DAG: return dag -@pytest.fixture -def mocked_parse(spy_agency): - """ - Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you - want to isolate and test `parse` or `run` logic without having to define a DAG file. - - This fixture returns a helper function `set_dag` that: - 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task) - 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task. - 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`. - - After adding the fixture in your test function signature, you can use it like this :: - - mocked_parse( - StartupDetails( - ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), - file="", - requests_fd=0, - ), - "example_dag_id", - CustomOperator(task_id="hello"), - ) - """ - - def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: - dag = get_inline_dag(dag_id, task) - t = dag.task_dict[task.task_id] - ti = RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True), task=t) - spy_agency.spy_on(parse, call_fake=lambda _: ti) - return ti - - return set_dag - - class CustomOperator(BaseOperator): def execute(self, context): task_id = context["task_instance"].task_id @@ -559,6 +525,32 @@ def test_startup_and_run_dag_with_templated_fields( assert ti.task.bash_command == rendered_command +def test_get_context_in_task(create_runtime_ti, time_machine, mock_supervisor_comms): + """Test that the `get_current_context` & `set_current_context` work correctly.""" + + class MyContextAssertOperator(BaseOperator): + def execute(self, context): + from airflow.sdk import get_current_context + + # Ensure the context returned by get_current_context is the same as the + # context passed to the operator + assert context == get_current_context() + + task = MyContextAssertOperator(task_id="assert_context") + + ti = create_runtime_ti(task=task) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + run(ti, log=mock.MagicMock()) + + # Ensure the task is Successful + mock_supervisor_comms.send_request.assert_called_once_with( + msg=TaskState(state=TerminalTIState.SUCCESS, end_date=instant), log=mock.ANY + ) + + @pytest.mark.parametrize( ["dag_id", "task_id", "fail_with_exception"], [