Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to HTTPX #3098

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,13 @@ async def async_response_hook(span, request, response):
NETWORK_PEER_ADDRESS,
NETWORK_PEER_PORT,
)
from opentelemetry.trace import SpanKind, TracerProvider, get_tracer
from opentelemetry.trace import SpanKind, Tracer, TracerProvider, get_tracer
from opentelemetry.trace.span import Span
from opentelemetry.trace.status import StatusCode
from opentelemetry.util.http import remove_url_credentials, sanitize_method

_logger = logging.getLogger(__name__)

URL = typing.Tuple[bytes, bytes, typing.Optional[int], bytes]
Headers = typing.List[typing.Tuple[bytes, bytes]]
RequestHook = typing.Callable[[Span, "RequestInfo"], None]
ResponseHook = typing.Callable[[Span, "RequestInfo", "ResponseInfo"], None]
AsyncRequestHook = typing.Callable[
Expand All @@ -253,17 +251,15 @@ class RequestInfo(typing.NamedTuple):
method: bytes
url: httpx.URL
headers: httpx.Headers | None
stream: typing.Optional[
typing.Union[httpx.SyncByteStream, httpx.AsyncByteStream]
]
extensions: typing.Optional[dict]
stream: httpx.SyncByteStream | httpx.AsyncByteStream | None
extensions: dict[str, typing.Any] | None


class ResponseInfo(typing.NamedTuple):
status_code: int
headers: httpx.Headers | None
stream: typing.Iterable[bytes]
extensions: typing.Optional[dict]
stream: httpx.SyncByteStream | httpx.AsyncByteStream
extensions: dict[str, typing.Any] | None


def _get_default_span_name(method: str) -> str:
Expand All @@ -274,11 +270,19 @@ def _get_default_span_name(method: str) -> str:
return method


def _prepare_headers(headers: typing.Optional[Headers]) -> httpx.Headers:
def _prepare_headers(headers: httpx.Headers | None) -> httpx.Headers:
return httpx.Headers(headers)


def _extract_parameters(args, kwargs):
def _extract_parameters(
args: tuple[typing.Any, ...], kwargs: dict[str, typing.Any]
) -> tuple[
bytes,
httpx.URL,
httpx.Headers | None,
httpx.SyncByteStream | httpx.AsyncByteStream | None,
dict[str, typing.Any],
]:
if isinstance(args[0], httpx.Request):
# In httpx >= 0.20.0, handle_request receives a Request object
request: httpx.Request = args[0]
Expand Down Expand Up @@ -311,10 +315,15 @@ def _inject_propagation_headers(headers, args, kwargs):


def _extract_response(
response: typing.Union[
httpx.Response, typing.Tuple[int, Headers, httpx.SyncByteStream, dict]
],
) -> typing.Tuple[int, Headers, httpx.SyncByteStream, dict, str]:
response: httpx.Response
| tuple[int, httpx.Headers, httpx.SyncByteStream, dict[str, typing.Any]],
) -> tuple[
int,
httpx.Headers,
httpx.SyncByteStream | httpx.AsyncByteStream,
dict[str, typing.Any],
str,
]:
if isinstance(response, httpx.Response):
status_code = response.status_code
headers = response.headers
Expand All @@ -331,8 +340,8 @@ def _extract_response(


def _apply_request_client_attributes_to_span(
span_attributes: dict,
url: typing.Union[str, URL, httpx.URL],
span_attributes: dict[str, typing.Any],
url: str | httpx.URL,
method_original: str,
semconv: _StabilityMode,
):
Expand Down Expand Up @@ -407,9 +416,9 @@ class SyncOpenTelemetryTransport(httpx.BaseTransport):
def __init__(
self,
transport: httpx.BaseTransport,
tracer_provider: typing.Optional[TracerProvider] = None,
request_hook: typing.Optional[RequestHook] = None,
response_hook: typing.Optional[ResponseHook] = None,
tracer_provider: TracerProvider | None = None,
request_hook: RequestHook | None = None,
response_hook: ResponseHook | None = None,
):
_OpenTelemetrySemanticConventionStability._initialize()
self._sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode(
Expand All @@ -426,27 +435,27 @@ def __init__(
self._request_hook = request_hook
self._response_hook = response_hook

def __enter__(self) -> "SyncOpenTelemetryTransport":
def __enter__(self) -> SyncOpenTelemetryTransport:
self._transport.__enter__()
return self

def __exit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
self._transport.__exit__(exc_type, exc_value, traceback)

# pylint: disable=R0914
def handle_request(
self,
*args,
**kwargs,
) -> typing.Union[
typing.Tuple[int, "Headers", httpx.SyncByteStream, dict],
httpx.Response,
]:
*args: typing.Any,
**kwargs: typing.Any,
) -> (
tuple[int, httpx.Headers, httpx.SyncByteStream, dict[str, typing.Any]]
| httpx.Response
):
"""Add request info to span."""
if not is_http_instrumentation_enabled():
return self._transport.handle_request(*args, **kwargs)
Expand Down Expand Up @@ -532,9 +541,9 @@ class AsyncOpenTelemetryTransport(httpx.AsyncBaseTransport):
def __init__(
self,
transport: httpx.AsyncBaseTransport,
tracer_provider: typing.Optional[TracerProvider] = None,
request_hook: typing.Optional[AsyncRequestHook] = None,
response_hook: typing.Optional[AsyncResponseHook] = None,
tracer_provider: TracerProvider | None = None,
request_hook: AsyncRequestHook | None = None,
response_hook: AsyncResponseHook | None = None,
):
_OpenTelemetrySemanticConventionStability._initialize()
self._sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode(
Expand All @@ -557,19 +566,19 @@ async def __aenter__(self) -> "AsyncOpenTelemetryTransport":

async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
exc_type: typing.Type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
await self._transport.__aexit__(exc_type, exc_value, traceback)

# pylint: disable=R0914
async def handle_async_request(
self, *args, **kwargs
) -> typing.Union[
typing.Tuple[int, "Headers", httpx.AsyncByteStream, dict],
httpx.Response,
]:
self, *args: typing.Any, **kwargs: typing.Any
) -> (
tuple[int, httpx.Headers, httpx.AsyncByteStream, dict[str, typing.Any]]
| httpx.Response
):
"""Add request info to span."""
if not is_http_instrumentation_enabled():
return await self._transport.handle_async_request(*args, **kwargs)
Expand Down Expand Up @@ -653,7 +662,7 @@ class HTTPXClientInstrumentor(BaseInstrumentor):
def instrumentation_dependencies(self) -> typing.Collection[str]:
return _instruments

def _instrument(self, **kwargs):
def _instrument(self, **kwargs: typing.Any):
"""Instruments httpx Client and AsyncClient

Args:
Expand Down Expand Up @@ -716,20 +725,20 @@ def _instrument(self, **kwargs):
),
)

def _uninstrument(self, **kwargs):
def _uninstrument(self, **kwargs: typing.Any):
unwrap(httpx.HTTPTransport, "handle_request")
unwrap(httpx.AsyncHTTPTransport, "handle_async_request")

@staticmethod
def _handle_request_wrapper( # pylint: disable=too-many-locals
wrapped,
instance,
args,
kwargs,
tracer,
sem_conv_opt_in_mode,
request_hook,
response_hook,
wrapped: typing.Callable[..., typing.Any],
instance: httpx.HTTPTransport,
args: tuple[typing.Any, ...],
kwargs: dict[str, typing.Any],
tracer: Tracer,
sem_conv_opt_in_mode: _StabilityMode,
request_hook: RequestHook,
response_hook: ResponseHook,
):
if not is_http_instrumentation_enabled():
return wrapped(*args, **kwargs)
Expand Down Expand Up @@ -796,14 +805,14 @@ def _handle_request_wrapper( # pylint: disable=too-many-locals

@staticmethod
async def _handle_async_request_wrapper( # pylint: disable=too-many-locals
wrapped,
instance,
args,
kwargs,
tracer,
sem_conv_opt_in_mode,
async_request_hook,
async_response_hook,
wrapped: typing.Callable[..., typing.Awaitable[typing.Any]],
instance: httpx.AsyncHTTPTransport,
args: tuple[typing.Any, ...],
kwargs: dict[str, typing.Any],
tracer: Tracer,
sem_conv_opt_in_mode: _StabilityMode,
async_request_hook: AsyncRequestHook,
async_response_hook: AsyncResponseHook,
):
if not is_http_instrumentation_enabled():
return await wrapped(*args, **kwargs)
Expand Down Expand Up @@ -872,14 +881,10 @@ async def _handle_async_request_wrapper( # pylint: disable=too-many-locals
@classmethod
def instrument_client(
cls,
client: typing.Union[httpx.Client, httpx.AsyncClient],
tracer_provider: TracerProvider = None,
request_hook: typing.Union[
typing.Optional[RequestHook], typing.Optional[AsyncRequestHook]
] = None,
response_hook: typing.Union[
typing.Optional[ResponseHook], typing.Optional[AsyncResponseHook]
] = None,
client: httpx.Client | httpx.AsyncClient,
tracer_provider: TracerProvider | None = None,
request_hook: RequestHook | AsyncRequestHook | None = None,
response_hook: ResponseHook | AsyncResponseHook | None = None,
) -> None:
"""Instrument httpx Client or AsyncClient

Expand Down Expand Up @@ -977,9 +982,7 @@ def instrument_client(
client._is_instrumented_by_opentelemetry = True

@staticmethod
def uninstrument_client(
client: typing.Union[httpx.Client, httpx.AsyncClient],
):
def uninstrument_client(client: httpx.Client | httpx.AsyncClient) -> None:
"""Disables instrumentation for the given client instance

Args:
Expand Down
Loading