Skip to content

Commit

Permalink
fix: Run Bedrock calls in executor for async
Browse files Browse the repository at this point in the history
  • Loading branch information
lou-k committed Jul 16, 2024
1 parent bae8d19 commit 00bbf63
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
18 changes: 15 additions & 3 deletions packages/phoenix-evals/src/phoenix/evals/models/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import functools
import json
import logging
from dataclasses import dataclass, field
Expand All @@ -21,8 +23,8 @@ class BedrockModel(BaseModel):
AWS API are dynamically throttled when encountering rate limit errors. Requires the `boto3`
package to be installed.
Supports Async:
`boto3` does not support async calls
Supports Async: 🟡
`boto3` does not support async calls, so it's wrapped in an executor.
Args:
model_id (str): The model name to use.
Expand Down Expand Up @@ -109,7 +111,17 @@ def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
return self._parse_output(response) or ""

async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
return self._generate(prompt, **kwargs)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
functools.partial(
self._generate,
**{
"prompt": prompt,
**kwargs,
},
),
)

def _rate_limited_completion(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
Expand Down
12 changes: 12 additions & 0 deletions packages/phoenix-evals/tests/phoenix/evals/models/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
import asyncio

import boto3
import pytest
from phoenix.evals import BedrockModel


def test_bedrock_model_can_be_instantiated():
session = boto3.Session(region_name="us-west-2")
model = BedrockModel(session=session)
assert model


def test_bedrock_async_propagates_errors():
with pytest.raises(AttributeError, match="'NoneType' object has no attribute 'invoke_model'"):
session = boto3.Session(region_name="us-west-2")
client = session.client("bedrock-runtime")
model = BedrockModel(session=session, client=client)
model.client = None
asyncio.run(model._async_generate("prompt"))

0 comments on commit 00bbf63

Please sign in to comment.