diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index ece3bc96009b..3706c7378126 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -23,6 +23,7 @@ import httpx import methodtools +import msgspec import structlog from pydantic import BaseModel from uuid6 import uuid7 @@ -48,7 +49,6 @@ __all__ = [ "Client", "ConnectionOperations", - "ErrorBody", "ServerResponseError", "TaskInstanceOperations", ] @@ -177,8 +177,9 @@ def connections(self) -> ConnectionOperations: return ConnectionOperations(self) -class ErrorBody(BaseModel): - detail: list[RemoteValidationError] | dict[str, Any] +# This is only used for parsing. ServerResponseError is raised instead +class _ErrorBody(BaseModel): + detail: list[RemoteValidationError] | str def __repr__(self): return repr(self.detail) @@ -188,7 +189,7 @@ class ServerResponseError(httpx.HTTPStatusError): def __init__(self, message: str, *, request: httpx.Request, response: httpx.Response): super().__init__(message, request=request, response=response) - detail: ErrorBody + detail: list[RemoteValidationError] | str | dict[str, Any] | None @classmethod def from_response(cls, response: httpx.Response) -> ServerResponseError | None: @@ -201,16 +202,23 @@ def from_response(cls, response: httpx.Response) -> ServerResponseError | None: if response.headers.get("content-type") != "application/json": return None + detail: list[RemoteValidationError] | dict[str, Any] | None = None try: - err = ErrorBody.model_validate_json(response.read()) - if isinstance(err.detail, list): + body = _ErrorBody.model_validate_json(response.read()) + + if isinstance(body.detail, list): + detail = body.detail msg = "Remote server returned validation error" else: - msg = err.detail.get("message", "") or "Un-parseable error" + msg = body.detail or "Un-parseable error" except Exception: - err = ErrorBody.model_validate_json(response.content) + try: + detail = msgspec.json.decode(response.content) + except Exception: + # Fallback to a normal httpx error + return None msg = "Server returned error" self = cls(msg, request=response.request, response=response) - self.detail = err + self.detail = detail return self diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index a32b321545dd..eb6611b0f57c 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -20,7 +20,7 @@ import httpx import pytest -from airflow.sdk.api.client import Client, ErrorBody, RemoteValidationError, ServerResponseError +from airflow.sdk.api.client import Client, RemoteValidationError, ServerResponseError class TestClient: @@ -40,8 +40,8 @@ def handle_request(request: httpx.Request) -> httpx.Response: client.get("http://error") assert isinstance(err.value, ServerResponseError) - assert isinstance(err.value.detail, ErrorBody) - assert err.value.detail.detail == [ + assert isinstance(err.value.detail, list) + assert err.value.detail == [ RemoteValidationError(loc=["#0"], msg="err", type="required"), ] @@ -60,3 +60,17 @@ def handle_request(request: httpx.Request) -> httpx.Response: with pytest.raises(httpx.HTTPStatusError) as err: client.get("http://error") assert not isinstance(err.value, ServerResponseError) + + def test_error_parsing_other_json(self): + def handle_request(request: httpx.Request) -> httpx.Response: + # Some other json than an error body. + return httpx.Response(404, json={"detail": "Not found"}) + + client = Client( + base_url=None, dry_run=True, token="", mounts={"'http://": httpx.MockTransport(handle_request)} + ) + + with pytest.raises(ServerResponseError) as err: + client.get("http://error") + assert err.value.args == ("Not found",) + assert err.value.detail is None