diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index fd4dd6c7e6cf5..4c6a02efe9451 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -219,7 +219,25 @@ def get( params = {} if map_index is not None: params.update({"map_index": map_index}) - resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) + try: + resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) + except ServerResponseError as e: + if e.response.status_code == HTTPStatus.NOT_FOUND: + log.error( + "XCom not found", + dag_id=dag_id, + run_id=run_id, + task_id=task_id, + key=key, + map_index=map_index, + detail=e.detail, + status_code=e.response.status_code, + ) + # Airflow 2.x just ignores the absence of an XCom and moves on with a return value of None + # Hence returning with key as `key` and value as `None`, so that the message is sent back to task runner + # and the default value of None in xcom_pull is used. + return XComResponse(key=key, value=None) + raise return XComResponse.model_validate_json(resp.read()) def set( diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index c52feb9676670..31580892dc27f 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -460,6 +460,30 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert result.key == "test_key" assert result.value == "test_value" + @mock.patch("time.sleep", return_value=None) + def test_xcom_get_500_error(self, mock_sleep): + # Simulate a successful response from the server returning a 500 error + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/xcoms/dag_id/run_id/task_id/key": + return httpx.Response( + status_code=500, + headers=[("content-Type", "application/json")], + json={ + "reason": "invalid_format", + "message": "XCom value is not a valid JSON", + }, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + with pytest.raises(ServerResponseError): + client.xcoms.get( + dag_id="dag_id", + run_id="run_id", + task_id="task_id", + key="key", + ) + @pytest.mark.parametrize( "values", [ diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 9cfe456962bb9..3ced432b2eea3 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -831,6 +831,14 @@ def watched_subprocess(self, mocker): XComResult(key="test_key", value="test_value"), id="get_xcom_map_index", ), + pytest.param( + GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), + b'{"key":"test_key","value":null,"type":"XComResult"}\n', + "xcoms.get", + ("test_dag", "test_run", "test_task", "test_key", None), + XComResult(key="test_key", value=None, type="XComResult"), + id="get_xcom_not_found", + ), pytest.param( SetXCom( dag_id="test_dag",