Skip to content

Commit

Permalink
rebased... and add more types
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Dec 19, 2024
1 parent c84c942 commit 39d3670
Showing 1 changed file with 28 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,6 @@ async def async_response_hook(span, request, response):

_logger = logging.getLogger(__name__)

URL = typing.Tuple[bytes, bytes, typing.Optional[int], bytes]
Headers = typing.List[typing.Tuple[bytes, bytes]]
RequestHook = typing.Callable[[Span, "RequestInfo"], None]
ResponseHook = typing.Callable[[Span, "RequestInfo", "ResponseInfo"], None]
AsyncRequestHook = typing.Callable[
Expand All @@ -260,7 +258,7 @@ class RequestInfo(typing.NamedTuple):
class ResponseInfo(typing.NamedTuple):
status_code: int
headers: httpx.Headers | None
stream: typing.Iterable[bytes]
stream: httpx.SyncByteStream | httpx.AsyncByteStream
extensions: dict[str, typing.Any] | None


Expand All @@ -272,13 +270,19 @@ def _get_default_span_name(method: str) -> str:
return method


def _prepare_headers(headers: Headers | None) -> httpx.Headers:
def _prepare_headers(headers: httpx.Headers | None) -> httpx.Headers:
return httpx.Headers(headers)


def _extract_parameters(
args: tuple[typing.Any, ...], kwargs: dict[str, typing.Any]
):
) -> tuple[
bytes,
httpx.URL,
httpx.Headers | None,
httpx.SyncByteStream | httpx.AsyncByteStream | None,
dict[str, typing.Any],
]:
if isinstance(args[0], httpx.Request):
# In httpx >= 0.20.0, handle_request receives a Request object
request: httpx.Request = args[0]
Expand Down Expand Up @@ -312,8 +316,14 @@ def _inject_propagation_headers(headers, args, kwargs):

def _extract_response(
response: httpx.Response
| tuple[int, Headers, httpx.SyncByteStream, dict[str, typing.Any]],
) -> tuple[int, Headers, httpx.SyncByteStream, dict[str, typing.Any], str]:
| tuple[int, httpx.Headers, httpx.SyncByteStream, dict[str, typing.Any]],
) -> tuple[
int,
httpx.Headers,
httpx.SyncByteStream | httpx.AsyncByteStream,
dict[str, typing.Any],
str,
]:
if isinstance(response, httpx.Response):
status_code = response.status_code
headers = response.headers
Expand All @@ -331,7 +341,7 @@ def _extract_response(

def _apply_request_client_attributes_to_span(
span_attributes: dict[str, typing.Any],
url: typing.Union[str, URL, httpx.URL],
url: str | httpx.URL,
method_original: str,
semconv: _StabilityMode,
):
Expand Down Expand Up @@ -443,7 +453,7 @@ def handle_request(
*args: typing.Any,
**kwargs: typing.Any,
) -> (
tuple[int, Headers, httpx.SyncByteStream, dict[str, typing.Any]]
tuple[int, httpx.Headers, httpx.SyncByteStream, dict[str, typing.Any]]
| httpx.Response
):
"""Add request info to span."""
Expand Down Expand Up @@ -531,9 +541,9 @@ class AsyncOpenTelemetryTransport(httpx.AsyncBaseTransport):
def __init__(
self,
transport: httpx.AsyncBaseTransport,
tracer_provider: typing.Optional[TracerProvider] = None,
request_hook: typing.Optional[AsyncRequestHook] = None,
response_hook: typing.Optional[AsyncResponseHook] = None,
tracer_provider: TracerProvider | None = None,
request_hook: AsyncRequestHook | None = None,
response_hook: AsyncResponseHook | None = None,
):
_OpenTelemetrySemanticConventionStability._initialize()
self._sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode(
Expand All @@ -556,19 +566,17 @@ async def __aenter__(self) -> "AsyncOpenTelemetryTransport":

async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
exc_type: typing.Type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
await self._transport.__aexit__(exc_type, exc_value, traceback)

# pylint: disable=R0914
async def handle_async_request(
self, *args: typing.Any, **kwargs: typing.Any
) -> (
typing.Tuple[
int, Headers, httpx.AsyncByteStream, dict[str, typing.Any]
]
tuple[int, httpx.Headers, httpx.AsyncByteStream, dict[str, typing.Any]]
| httpx.Response
):
"""Add request info to span."""
Expand Down Expand Up @@ -728,7 +736,7 @@ def _handle_request_wrapper( # pylint: disable=too-many-locals
args: tuple[typing.Any, ...],
kwargs: dict[str, typing.Any],
tracer: Tracer,
sem_conv_opt_in_mode: _HTTPStabilityMode,
sem_conv_opt_in_mode: _StabilityMode,
request_hook: RequestHook,
response_hook: ResponseHook,
):
Expand Down Expand Up @@ -802,7 +810,7 @@ async def _handle_async_request_wrapper( # pylint: disable=too-many-locals
args: tuple[typing.Any, ...],
kwargs: dict[str, typing.Any],
tracer: Tracer,
sem_conv_opt_in_mode: _HTTPStabilityMode,
sem_conv_opt_in_mode: _StabilityMode,
async_request_hook: AsyncRequestHook,
async_response_hook: AsyncResponseHook,
):
Expand Down

0 comments on commit 39d3670

Please sign in to comment.