From 673bc3a3b199bfc60600fa35cf72fdde69a966a2 Mon Sep 17 00:00:00 2001 From: a-a-ronchen <53277813+a-a-ronchen@users.noreply.github.com> Date: Mon, 30 Dec 2024 11:44:31 -0600 Subject: [PATCH] Pass aws_region_name to get_aws_service_client() in SageMakerLLM (#12000) --- .../llama_index/llms/sagemaker_endpoint/base.py | 17 +++++++++++++++-- .../pyproject.toml | 2 +- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py b/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py index fef6cded1349f..82e3da8c26a6b 100644 --- a/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py +++ b/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py @@ -29,6 +29,8 @@ ) from llama_index.llms.sagemaker_endpoint.utils import BaseIOHandler, IOHandler +import warnings + DEFAULT_IO_HANDLER = IOHandler() LLAMA_MESSAGES_TO_PROMPT = messages_to_prompt LLAMA_COMPLETION_TO_PROMPT = completion_to_prompt @@ -130,7 +132,7 @@ def __init__( aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, - region_name: Optional[str] = None, + aws_region_name: Optional[str] = None, max_retries: Optional[int] = 3, timeout: Optional[float] = 60.0, temperature: Optional[float] = 0.5, @@ -176,10 +178,21 @@ def __init__( output_parser=output_parser, ) self._completion_to_prompt = completion_to_prompt + + region_name = kwargs.pop("region_name", None) + if region_name is not None: + warnings.warn( + "Kwarg `region_name` is deprecated and will be removed in a future version. " + "Please use `aws_region_name` instead.", + DeprecationWarning, + ) + if not aws_region_name: + aws_region_name = region_name + self._client = get_aws_service_client( service_name="sagemaker-runtime", profile_name=profile_name, - region_name=region_name, + region_name=aws_region_name, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, diff --git a/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/pyproject.toml index 58e42cd1af49f..5d7b0112113d0 100644 --- a/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-sagemaker-endpoint" readme = "README.md" -version = "0.3.0" +version = "0.3.1" [tool.poetry.dependencies] python = ">=3.9,<4.0"