Skip to content

Commit

Permalink
feat: upgrade to openai 1.0.0 (#157)
Browse files Browse the repository at this point in the history
* feat:  upgrade to openai 1.0.0

* make test pass for openai 1.0.0
  • Loading branch information
YannDubs authored Nov 7, 2023
1 parent 2799ef7 commit dd202c7
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 46 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
tqdm
datasets
openai
openai>=1.0.0
pandas
tiktoken>=0.3.2
fire
2 changes: 1 addition & 1 deletion src/alpaca_eval/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_fn_completions(name: Union[str, Callable]) -> Callable:
packages = ["vllm", "ray", "transformers"]
logging.exception(f"You need {packages} to use vllm_completions. Error:")
raise e

elif name == "bedrock_anthropic_completions":
try:
from .bedrock_anthropic import bedrock_anthropic_completions
Expand Down
38 changes: 15 additions & 23 deletions src/alpaca_eval/decoders/bedrock_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import copy
import functools
import json
import logging
import multiprocessing
import random
import time
from typing import Optional, Sequence, Union

import botocore.exceptions
import boto3
import botocore.exceptions
import numpy as np
import tqdm
import json

from .. import utils

__all__ = ["bedrock_anthropic_completions"]

DEFAULT_NUM_PROCS = 3


def bedrock_anthropic_completions(
prompts: Sequence[str],
max_tokens_to_sample: Union[int, Sequence[int]] = 2048,
Expand Down Expand Up @@ -61,7 +62,9 @@ def bedrock_anthropic_completions(
logging.info(f"Kwargs to completion: {kwargs_to_log}")
with utils.Timer() as t:
if num_procs == 1:
responses = [_bedrock_anthropic_completion_helper(inp, **kwargs) for inp in tqdm.tqdm(inputs, desc="prompts")]
responses = [
_bedrock_anthropic_completion_helper(inp, **kwargs) for inp in tqdm.tqdm(inputs, desc="prompts")
]
else:
with multiprocessing.Pool(num_procs) as p:
partial_completion_helper = functools.partial(_bedrock_anthropic_completion_helper, **kwargs)
Expand All @@ -87,7 +90,7 @@ def bedrock_anthropic_completions(
def _bedrock_anthropic_completion_helper(
args: tuple[str, int],
sleep_time: int = 2,
region: Optional[str] = 'us-west-2',
region: Optional[str] = "us-west-2",
model_name: str = "anthropic.claude-v1",
temperature: Optional[float] = 0.7,
**kwargs,
Expand All @@ -97,36 +100,25 @@ def _bedrock_anthropic_completion_helper(
if not utils.check_pkg_atleast_version("boto3", "1.28.58"):
raise ValueError("boto3 version must be at least 1.28.58 Use `pip install -U boto3`.")

bedrock = boto3.client(
service_name='bedrock-runtime',
region_name=region
)
accept = 'application/json'
contentType = 'application/json'
bedrock = boto3.client(service_name="bedrock-runtime", region_name=region)
accept = "application/json"
contentType = "application/json"

kwargs.update(dict(max_tokens_to_sample=max_tokens, temperature=temperature))
curr_kwargs = copy.deepcopy(kwargs)
while True:
try:
body = json.dumps(
{
**{
'prompt':prompt
},
**curr_kwargs}
)
response = bedrock.invoke_model(
body=body, modelId=model_name, accept=accept, contentType=contentType
)
response = json.loads(response.get('body').read()).get('completion')
body = json.dumps({**{"prompt": prompt}, **curr_kwargs})
response = bedrock.invoke_model(body=body, modelId=model_name, accept=accept, contentType=contentType)
response = json.loads(response.get("body").read()).get("completion")
break
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == 'ThrottlingException':
if e.response["Error"]["Code"] == "ThrottlingException":
logging.warning(f"Hit throttling error: {e}.")
logging.warning(f"Rate limit hit. Sleeping for {sleep_time} seconds.")
time.sleep(sleep_time)
except Exception as e:
logging.error(f'Hit unknown error : {e}')
logging.error(f"Hit unknown error : {e}")
raise e

return response
45 changes: 29 additions & 16 deletions src/alpaca_eval/decoders/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
import openai
import tiktoken
import tqdm
from openai import OpenAI

from .. import constants, utils

__all__ = ["openai_completions"]

DEFAULT_OPENAI_API_BASE = openai.api_base
DEFAULT_OPENAI_API_BASE = openai.base_url


def openai_completions(
Expand Down Expand Up @@ -156,7 +157,7 @@ def openai_completions(

# flatten the list and select only the text
completions_all = [completion for completion_batch in completions for completion in completion_batch]
completions_text = [completion.text for completion in completions_all]
completions_text = [completion["text"] for completion in completions_all]

price = [
completion["total_tokens"] * _get_price_per_token(model_name)
Expand Down Expand Up @@ -185,18 +186,21 @@ def _openai_completion_helper(
**kwargs,
):
prompt_batch, max_tokens = args
client_kwargs = dict()

# randomly select orgs
if openai_organization_ids is not None:
openai.organization = random.choice(openai_organization_ids)
client_kwargs["organization"] = random.choice(openai_organization_ids)

openai_api_keys = openai_api_keys or constants.OPENAI_API_KEYS

if openai_api_keys is not None:
openai.api_key = random.choice(openai_api_keys)
client_kwargs["api_key"] = random.choice(openai_api_keys)

# set api base
openai.api_base = openai_api_base if openai_api_base is not None else DEFAULT_OPENAI_API_BASE
client_kwargs["base_url"] = base_url = openai_api_base if openai_api_base is not None else DEFAULT_OPENAI_API_BASE

client = OpenAI(**client_kwargs)

# copy shared_kwargs to avoid modifying it
kwargs.update(dict(max_tokens=max_tokens, top_p=top_p, temperature=temperature))
Expand All @@ -205,28 +209,31 @@ def _openai_completion_helper(
while True:
try:
if is_chat:
completion_batch = openai.ChatCompletion.create(messages=prompt_batch[0], **curr_kwargs)
completion_batch = client.chat.completions.create(messages=prompt_batch[0], **curr_kwargs)

choices = completion_batch.choices
for choice in choices:
for i, choice in enumerate(choices):
# openai now returns pydantic objects => convert to dict to keep all code
# TODO should just rewrite code to use pydantic objects
choices[i] = choice.model_dump()
assert choice.message.role == "assistant"
if choice.message.content == "":
choice["text"] = " " # annoying doesn't allow empty string
choices[i]["text"] = " " # annoying doesn't allow empty string
else:
choice["text"] = choice.message.content
choices[i]["text"] = choice.message.content

if choice.message.get("function_call"):
if choice.message.function_call:
# currently we only use function calls to get a JSON object => return raw text of json
choice["text"] = choice.message.function_call.arguments
choices[i]["text"] = choice.message.function_call.arguments

else:
completion_batch = openai.Completion.create(prompt=prompt_batch, **curr_kwargs)
completion_batch = client.completions.create(prompt=prompt_batch, **curr_kwargs)
choices = completion_batch.choices

for choice in choices:
choice["total_tokens"] = completion_batch.usage.total_tokens / len(prompt_batch)
break
except openai.error.OpenAIError as e:
except openai.OpenAIError as e:
logging.warning(f"OpenAIError: {e}.")
if "Please reduce your prompt" in str(e):
kwargs["max_tokens"] = int(kwargs["max_tokens"] * 0.8)
Expand All @@ -240,12 +247,14 @@ def _openai_completion_helper(
else:
logging.warning(f"Unknown error {e}. \n It's likely a rate limit so we are retrying...")
if openai_organization_ids is not None and len(openai_organization_ids) > 1:
openai.organization = random.choice(
client_kwargs["organization"] = organization = random.choice(
[o for o in openai_organization_ids if o != openai.organization]
)
client = OpenAI(**client_kwargs)
logging.info(f"Switching OAI organization.")
if openai_api_keys is not None and len(openai_api_keys) > 1:
openai.api_key = random.choice([o for o in openai_api_keys if o != openai.api_key])
client_kwargs["api_key"] = random.choice([o for o in openai_api_keys if o != openai.api_key])
client = OpenAI(**client_kwargs)
logging.info(f"Switching OAI API key.")
logging.info(f"Sleeping {sleep_time} before retrying to call openai API...")
time.sleep(sleep_time) # Annoying rate limit on requests.
Expand Down Expand Up @@ -323,7 +332,11 @@ def _string_to_dict(to_convert):

def _get_price_per_token(model):
"""Returns the price per token for a given model"""
if "gpt-4" in model:
if "gpt-4-1106" in model:
return (
0.01 / 1000
) # that's not completely true because decoding is 0.03 but close enough given that most is context
elif "gpt-4" in model:
return (
0.03 / 1000
) # that's not completely true because decoding is 0.06 but close enough given that most is context
Expand Down
7 changes: 5 additions & 2 deletions tests/integration_tests/test_decoders_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

from alpaca_eval import constants, utils
from alpaca_eval.decoders.anthropic import anthropic_completions
from alpaca_eval.decoders.bedrock_anthropic import bedrock_anthropic_completions
from alpaca_eval.decoders.cohere import cohere_completions
from alpaca_eval.decoders.huggingface_api import huggingface_api_completions
from alpaca_eval.decoders.huggingface_local import huggingface_local_completions
from alpaca_eval.decoders.openai import openai_completions
from alpaca_eval.decoders.bedrock_anthropic import bedrock_anthropic_completions


def _get_formatted_prompts(model):
filename = list((constants.MODELS_CONFIG_DIR / model).glob("*.txt"))[0]
Expand All @@ -17,6 +18,7 @@ def _get_formatted_prompts(model):
prompts = [template.format(instruction=prompt) for prompt in prompts]
return prompts


@pytest.mark.slow
def test_openai_completions_integration():
prompts = _get_formatted_prompts("gpt4")
Expand Down Expand Up @@ -72,10 +74,11 @@ def test_vllm_local_completions_integration():
)
assert len(results["completions"]) == len(prompts)


@pytest.mark.slow
def test_bedrock_anthropic_completions_integration():
prompts = _get_formatted_prompts("claude")
results = bedrock_anthropic_completions(prompts)
assert len(results["completions"]) == len(prompts)
assert "2" in results["completions"][0]
assert "4" in results["completions"][1]
assert "4" in results["completions"][1]
4 changes: 1 addition & 3 deletions tests/test_decoders_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import math
from types import SimpleNamespace

import anthropic.resources
import pytest
from openai.openai_object import OpenAIObject

from alpaca_eval.decoders.anthropic import anthropic_completions
from alpaca_eval.decoders.cohere import cohere_completions
Expand All @@ -17,7 +15,7 @@
@pytest.fixture
def mock_openai_completion():
# Create a mock Completion object
completion_mock = OpenAIObject()
completion_mock = dict()
completion_mock["total_tokens"] = 3
completion_mock["text"] = MOCKED_COMPLETION
return completion_mock
Expand Down

0 comments on commit dd202c7

Please sign in to comment.