From 684ecb5cf6e77e759a3bad6eb2157894727e0e00 Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Thu, 12 Oct 2023 23:57:15 +0100 Subject: [PATCH] chore: add case insensitive get header function (#3) --- src/anthropic_bedrock/_utils/__init__.py | 1 + src/anthropic_bedrock/_utils/_utils.py | 22 +++++++++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/anthropic_bedrock/_utils/__init__.py b/src/anthropic_bedrock/_utils/__init__.py index 6d13a36..26dc560 100644 --- a/src/anthropic_bedrock/_utils/__init__.py +++ b/src/anthropic_bedrock/_utils/__init__.py @@ -22,6 +22,7 @@ from ._utils import is_required_type as is_required_type from ._utils import is_annotated_type as is_annotated_type from ._utils import maybe_coerce_float as maybe_coerce_float +from ._utils import get_required_header as get_required_header from ._utils import maybe_coerce_boolean as maybe_coerce_boolean from ._utils import maybe_coerce_integer as maybe_coerce_integer from ._utils import strip_annotated_type as strip_annotated_type diff --git a/src/anthropic_bedrock/_utils/_utils.py b/src/anthropic_bedrock/_utils/_utils.py index e43ef6f..d4eafd4 100644 --- a/src/anthropic_bedrock/_utils/_utils.py +++ b/src/anthropic_bedrock/_utils/_utils.py @@ -1,13 +1,14 @@ from __future__ import annotations import os +import re import inspect import functools from typing import Any, Mapping, TypeVar, Callable, Iterable, Sequence, cast, overload from pathlib import Path from typing_extensions import Required, Annotated, TypeGuard, get_args, get_origin -from .._types import NotGiven, FileTypes, NotGivenOr +from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike from .._compat import is_union as _is_union from .._compat import parse_date as parse_date from .._compat import parse_datetime as parse_datetime @@ -351,3 +352,22 @@ def file_from_path(path: str) -> FileTypes: contents = Path(path).read_bytes() file_name = os.path.basename(path) return (file_name, contents) + + +def get_required_header(headers: HeadersLike, header: str) -> str: + lower_header = header.lower() + if isinstance(headers, Mapping): + headers = cast(Headers, headers) + for k, v in headers.items(): + if k.lower() == lower_header and isinstance(v, str): + return v + + """ to deal with the case where the header looks like Finch-Event-Id """ + intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize()) + + for normalized_header in [header, lower_header, header.upper(), intercaps_header]: + value = headers.get(normalized_header) + if value: + return value + + raise ValueError(f"Could not find {header} header")