Skip to content

Commit

Permalink
AIP-72: Gracefully handle 'not-found' XCOMs in task sdk API client (a…
Browse files Browse the repository at this point in the history
…pache#45344)

* AIP-72: Gracefully handle not-found XCOMs in task sdk API client

* re raising exception for non 404
  • Loading branch information
amoghrajesh authored Jan 3, 2025
1 parent 55a14ef commit aae1a57
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
20 changes: 19 additions & 1 deletion task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,25 @@ def get(
params = {}
if map_index is not None:
params.update({"map_index": map_index})
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
try:
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.NOT_FOUND:
log.error(
"XCom not found",
dag_id=dag_id,
run_id=run_id,
task_id=task_id,
key=key,
map_index=map_index,
detail=e.detail,
status_code=e.response.status_code,
)
# Airflow 2.x just ignores the absence of an XCom and moves on with a return value of None
# Hence returning with key as `key` and value as `None`, so that the message is sent back to task runner
# and the default value of None in xcom_pull is used.
return XComResponse(key=key, value=None)
raise
return XComResponse.model_validate_json(resp.read())

def set(
Expand Down
24 changes: 24 additions & 0 deletions task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,30 @@ def handle_request(request: httpx.Request) -> httpx.Response:
assert result.key == "test_key"
assert result.value == "test_value"

@mock.patch("time.sleep", return_value=None)
def test_xcom_get_500_error(self, mock_sleep):
# Simulate a successful response from the server returning a 500 error
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/xcoms/dag_id/run_id/task_id/key":
return httpx.Response(
status_code=500,
headers=[("content-Type", "application/json")],
json={
"reason": "invalid_format",
"message": "XCom value is not a valid JSON",
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
with pytest.raises(ServerResponseError):
client.xcoms.get(
dag_id="dag_id",
run_id="run_id",
task_id="task_id",
key="key",
)

@pytest.mark.parametrize(
"values",
[
Expand Down
8 changes: 8 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,14 @@ def watched_subprocess(self, mocker):
XComResult(key="test_key", value="test_value"),
id="get_xcom_map_index",
),
pytest.param(
GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"),
b'{"key":"test_key","value":null,"type":"XComResult"}\n',
"xcoms.get",
("test_dag", "test_run", "test_task", "test_key", None),
XComResult(key="test_key", value=None, type="XComResult"),
id="get_xcom_not_found",
),
pytest.param(
SetXCom(
dag_id="test_dag",
Expand Down

0 comments on commit aae1a57

Please sign in to comment.