Skip to content

Commit

Permalink
AIP-72: Add support for get_current_context in Task SDK (apache#45486)
Browse files Browse the repository at this point in the history
closes apache#45234

I am putting the logic for `set_current_context` in `execution_time/context.py`. I didn't want to put `_CURRENT_CONTEXT` in `task_sdk/src/airflow/sdk/definitions/contextmanager.py` to avoid execution logic in a user-facing module but I couldn't think of another way to store it from execution & allow retrieving (via `get_current_context` in the Standard Provider) in their Task.

Upcoming PRs:
- Move most of the internal stuff in Task SDK to a separate module.
- Use `create_runtime_ti` fixture more widely in tests

---

Tested with the following DAG:

```py
import pendulum

from airflow.decorators import dag, task
from airflow.providers.standard.operators.python import get_current_context

@dag(
    schedule=None,
    start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
    catchup=False,
)
def x_get_context():
    @task
    def template_test(data_interval_end):
        context = get_current_context()

        # Will print `2024-10-10 00:00:00+00:00`.
        # Note how we didn't pass this value when calling the task. Instead
        # it was passed by the decorator from the context
        print(f"data_interval_end: {data_interval_end}")

        # Will print the full context dict
        print(f"context: {context}")

    template_test()

x_get_context()

```
<img width="1703" alt="image" src="https://github.com/user-attachments/assets/2763963a-d299-412f-bee3-3b20904ca7c8" />
  • Loading branch information
kaxil authored Jan 9, 2025
1 parent b703d53 commit 0480623
Show file tree
Hide file tree
Showing 11 changed files with 338 additions and 78 deletions.
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.templater import SandboxedEnvironment
from airflow.sdk.execution_time.context import _CURRENT_CONTEXT
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
from airflow.stats import Stats
Expand Down Expand Up @@ -142,7 +143,6 @@

TR = TaskReschedule

_CURRENT_CONTEXT: list[Context] = []
log = logging.getLogger(__name__)


Expand Down
25 changes: 17 additions & 8 deletions providers/src/airflow/providers/standard/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,10 @@
)
from airflow.models.baseoperator import BaseOperator
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import _CURRENT_CONTEXT
from airflow.models.variable import Variable
from airflow.operators.branch import BranchMixIn
from airflow.providers.standard.utils.python_virtualenv import prepare_virtualenv, write_python_script
from airflow.providers.standard.version_compat import (
AIRFLOW_V_2_10_PLUS,
AIRFLOW_V_3_0_PLUS,
)
from airflow.providers.standard.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_merge
from airflow.utils.file import get_unique_dag_module_name
Expand Down Expand Up @@ -1122,7 +1118,7 @@ def execute(self, context: Context) -> Any:
return self.do_branch(context, super().execute(context))


def get_current_context() -> Context:
def get_current_context() -> Mapping[str, Any]:
"""
Retrieve the execution context dictionary without altering user method's signature.
Expand All @@ -1149,9 +1145,22 @@ def my_task():
Current context will only have value if this method was called after an operator
was starting to execute.
"""
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import get_current_context

return get_current_context()
else:
return _get_current_context()


def _get_current_context() -> Mapping[str, Any]:
# Airflow 2.x
# TODO: To be removed when Airflow 2 support is dropped
from airflow.models.taskinstance import _CURRENT_CONTEXT

if not _CURRENT_CONTEXT:
raise AirflowException(
raise RuntimeError(
"Current context was requested but no context was found! "
"Are you running within an airflow task?"
"Are you running within an Airflow task?"
)
return _CURRENT_CONTEXT[-1]
6 changes: 3 additions & 3 deletions providers/tests/standard/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ def f():
with pytest.raises(
AirflowException,
match="Current context was requested but no context was found! "
"Are you running within an airflow task?",
"Are you running within an Airflow task?",
):
self.run_as_task(f, return_ti=True, use_airflow_context=False)

Expand Down Expand Up @@ -1890,7 +1890,7 @@ def default_kwargs(*, python_version=DEFAULT_PYTHON_VERSION, **kwargs):

class TestCurrentContext:
def test_current_context_no_context_raise(self):
with pytest.raises(AirflowException):
with pytest.raises(RuntimeError):
get_current_context()

def test_current_context_roundtrip(self):
Expand All @@ -1904,7 +1904,7 @@ def test_context_removed_after_exit(self):

with set_current_context(example_context):
pass
with pytest.raises(AirflowException):
with pytest.raises(RuntimeError):
get_current_context()

def test_nested_context(self):
Expand Down
3 changes: 3 additions & 0 deletions task_sdk/src/airflow/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"TaskGroup",
"dag",
"Connection",
"get_current_context",
"__version__",
]

Expand All @@ -34,6 +35,7 @@
if TYPE_CHECKING:
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.contextmanager import get_current_context
from airflow.sdk.definitions.dag import DAG, dag
from airflow.sdk.definitions.edges import EdgeModifier, Label
from airflow.sdk.definitions.taskgroup import TaskGroup
Expand All @@ -47,6 +49,7 @@
"Label": ".definitions.edges",
"Connection": ".definitions.connection",
"Variable": ".definitions.variable",
"get_current_context": ".definitions.contextmanager",
}


Expand Down
46 changes: 42 additions & 4 deletions task_sdk/src/airflow/sdk/definitions/contextmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import sys
from collections import deque
from collections.abc import Mapping
from types import ModuleType
from typing import Any, Generic, TypeVar

Expand All @@ -27,10 +28,47 @@

T = TypeVar("T")

__all__ = [
"DagContext",
"TaskGroupContext",
]
__all__ = ["DagContext", "TaskGroupContext", "get_current_context"]

# This is a global variable that stores the current Task context.
# It is used to push the Context dictionary when Task starts execution
# and it is used to retrieve the current context in PythonOperator or Taskflow API via
# the `get_current_context` function.
_CURRENT_CONTEXT: list[Mapping[str, Any]] = []


def get_current_context() -> Mapping[str, Any]:
"""
Retrieve the execution context dictionary without altering user method's signature.
This is the simplest method of retrieving the execution context dictionary.
**Old style:**
.. code:: python
def my_task(**context):
ti = context["ti"]
**New style:**
.. code:: python
from airflow.providers.standard.operators.python import get_current_context
def my_task():
context = get_current_context()
ti = context["ti"]
Current context will only have value if this method was called after an operator
was starting to execute.
"""
if not _CURRENT_CONTEXT:
raise RuntimeError(
"Current context was requested but no context was found! Are you running within an Airflow task?"
)
return _CURRENT_CONTEXT[-1]


# In order to add a `@classproperty`-like thing we need to define a property on a metaclass.
Expand Down
27 changes: 25 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
# under the License.
from __future__ import annotations

import contextlib
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any

import structlog

from airflow.sdk.definitions.contextmanager import _CURRENT_CONTEXT
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.types import NOTSET

Expand All @@ -28,6 +31,8 @@
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult

log = structlog.get_logger(logger_name="task")


def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection:
from airflow.sdk.definitions.connection import Connection
Expand Down Expand Up @@ -55,7 +60,6 @@ def _get_connection(conn_id: str) -> Connection:
from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id))
msg = SUPERVISOR_COMMS.get_message()
if isinstance(msg, ErrorResponse):
Expand All @@ -75,7 +79,6 @@ def _get_variable(key: str, deserialize_json: bool) -> Variable:
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key))
msg = SUPERVISOR_COMMS.get_message()
if isinstance(msg, ErrorResponse):
Expand Down Expand Up @@ -157,3 +160,23 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, MacrosAccessor):
return False
return True


@contextlib.contextmanager
def set_current_context(context: Mapping[str, Any]) -> Generator[Mapping[str, Any], None, None]:
"""
Set the current execution context to the provided context object.
This method should be called once per Task execution, before calling operator.execute.
"""
_CURRENT_CONTEXT.append(context)
try:
yield context
finally:
expected_state = _CURRENT_CONTEXT.pop()
if expected_state != context:
log.warning(
"Current context is not equal to the state at context stack.",
expected=context,
got=expected_state,
)
65 changes: 39 additions & 26 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,21 @@
ToTask,
XComResult,
)
from airflow.sdk.execution_time.context import ConnectionAccessor, MacrosAccessor, VariableAccessor
from airflow.sdk.execution_time.context import (
ConnectionAccessor,
MacrosAccessor,
VariableAccessor,
set_current_context,
)

if TYPE_CHECKING:
import jinja2
from structlog.typing import FilteringBoundLogger as Logger


# TODO: Move this entire class into a separate file:
# `airflow/sdk/execution_time/task_instance.py`
# or `airflow/sdk/execution_time/runtime_ti.py`
class RuntimeTaskInstance(TaskInstance):
model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down Expand Up @@ -426,37 +434,18 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: Get a real context object
ti.task = ti.task.prepare_for_execution()
context = ti.get_template_context()
jinja_env = ti.task.dag.get_template_env()
ti.task = ti.render_templates(context=context, jinja_env=jinja_env)
with set_current_context(context):
jinja_env = ti.task.dag.get_template_env()
ti.task = ti.render_templates(context=context, jinja_env=jinja_env)
result = _execute_task(context, ti.task)

_push_xcom_if_needed(result, ti)

# TODO: Get things from _execute_task_with_callbacks
# - Clearing XCom
# - Setting Current Context (set_current_context)
# - Render Templates
# - Update RTIF
# - Pre Execute
# etc

result = None
if ti.task.execution_timeout:
# TODO: handle timeout in case of deferral
from airflow.utils.timeout import timeout

timeout_seconds = ti.task.execution_timeout.total_seconds()
try:
# It's possible we're already timed out, so fast-fail if true
if timeout_seconds <= 0:
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = ti.task.execute(context) # type: ignore[attr-defined]
except AirflowTaskTimeout:
# TODO: handle on kill callback here
raise
else:
result = ti.task.execute(context) # type: ignore[attr-defined]

_push_xcom_if_needed(result, ti)
msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))
except TaskDeferred as defer:
classpath, trigger_kwargs = defer.trigger.serialize()
Expand Down Expand Up @@ -524,6 +513,30 @@ def run(ti: RuntimeTaskInstance, log: Logger):
SUPERVISOR_COMMS.send_request(msg=msg, log=log)


def _execute_task(context: Mapping[str, Any], task: BaseOperator):
"""Execute Task (optionally with a Timeout) and push Xcom results."""
from airflow.exceptions import AirflowTaskTimeout

if task.execution_timeout:
# TODO: handle timeout in case of deferral
from airflow.utils.timeout import timeout

timeout_seconds = task.execution_timeout.total_seconds()
try:
# It's possible we're already timed out, so fast-fail if true
if timeout_seconds <= 0:
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = task.execute(context) # type: ignore[attr-defined]
except AirflowTaskTimeout:
# TODO: handle on kill callback here
raise
else:
result = task.execute(context) # type: ignore[attr-defined]
return result


def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance):
"""Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result."""
if ti.task.do_xcom_push:
Expand Down
39 changes: 39 additions & 0 deletions task_sdk/tests/defintions/test_contextmanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import pytest

from airflow.sdk import get_current_context


class TestCurrentContext:
def test_current_context_no_context_raise(self):
with pytest.raises(RuntimeError):
get_current_context()

def test_get_current_context_with_context(self, monkeypatch):
mock_context = {"ti": "task_instance", "key": "value"}
monkeypatch.setattr("airflow.sdk.definitions.contextmanager._CURRENT_CONTEXT", [mock_context])
result = get_current_context()
assert result == mock_context

def test_get_current_context_without_context(self, monkeypatch):
monkeypatch.setattr("airflow.sdk.definitions.contextmanager._CURRENT_CONTEXT", [])
with pytest.raises(RuntimeError, match="Current context was requested but no context was found!"):
get_current_context()
Loading

0 comments on commit 0480623

Please sign in to comment.