Skip to content

Commit

Permalink
refresh Cohere (#141)
Browse files Browse the repository at this point in the history
* edit src/alpaca_eval/decoders/cohere.py, remove src/alpaca_eval/models_configs/cohere-chat/configs.yaml and 3 other changes

* edit src/alpaca_eval/decoders/cohere.py, edit src/alpaca_eval/leaderboards/evaluators/evaluators_leaderboard.csv

* updated results and leaderboard

* clean

* fix tests

* fix tests

* fix tests more

* r

* restore val

* restore val

* restore val

* remove chat mode
  • Loading branch information
sanderland authored Oct 1, 2023
1 parent aef168b commit 0ac9b14
Show file tree
Hide file tree
Showing 10 changed files with 823 additions and 5,674 deletions.
4,832 changes: 0 additions & 4,832 deletions results/cohere-chat/model_outputs.json

This file was deleted.

1,606 changes: 803 additions & 803 deletions results/cohere/model_outputs.json

Large diffs are not rendered by default.

35 changes: 14 additions & 21 deletions src/alpaca_eval/decoders/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import multiprocessing
import os
import random
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple

import cohere
import tqdm
Expand All @@ -19,7 +19,6 @@
def cohere_completions(
prompts: Sequence[str],
model_name="command",
mode="instruct",
num_procs: int = 5,
**decoding_kwargs,
) -> dict[str, list]:
Expand All @@ -46,29 +45,29 @@ def cohere_completions(
else:
logging.info(f"Using `cohere_completions` on {n_examples} prompts using {model_name}.")

kwargs = dict(model=model_name, mode=mode, **decoding_kwargs)
kwargs = dict(model=model_name, **decoding_kwargs)
logging.info(f"Kwargs to completion: {kwargs}")

with utils.Timer() as t:
if num_procs == 1:
completions = [_cohere_completion_helper(prompt, **kwargs) for prompt in tqdm.tqdm(prompts, desc="prompts")]
completions_and_token_counts = [_cohere_completion_helper(prompt, **kwargs) for prompt in tqdm.tqdm(prompts, desc="prompts")]
else:
with multiprocessing.Pool(num_procs) as p:
partial_completion_helper = functools.partial(_cohere_completion_helper, **kwargs)
completions = list(
completions_and_token_counts = list(
tqdm.tqdm(
p.imap(partial_completion_helper, prompts),
desc="prompts",
total=len(prompts),
)
)
logging.info(f"Completed {n_examples} examples in {t}.")

# cohere charges $2.5 for every 1000 call to API that is less than 1000 characters. Only counting prompts here
price = [2.5 / 1000 * math.ceil(len(prompt) / 1000) for prompt in prompts]
completions, num_tokens = zip(*completions_and_token_counts)
price_per_token = 0.000015 # cohere charges $0.000015 per token.
price_per_example = [price_per_token * n for n in num_tokens]
avg_time = [t.duration / n_examples] * len(completions)

return dict(completions=completions, price_per_example=price, time_per_example=avg_time)
return dict(completions=list(completions), price_per_example=price_per_example, time_per_example=avg_time)


def _cohere_completion_helper(
Expand All @@ -77,9 +76,8 @@ def _cohere_completion_helper(
max_tokens: Optional[int] = 1000,
temperature: Optional[float] = 0.7,
max_tries=5,
mode="instruct",
**kwargs,
) -> str:
) -> Tuple[str,int]:
cohere_api_key = random.choice(cohere_api_keys)
client = cohere.Client(cohere_api_key)

Expand All @@ -88,21 +86,16 @@ def _cohere_completion_helper(

for trynum in range(max_tries): # retry errors
try:
if mode == "instruct":
response = client.generate(prompt=prompt, **curr_kwargs)
text = response[0].text
elif mode == "chat":
response = client.chat(prompt, **curr_kwargs)
text = response.text
else:
raise ValueError(f"Invalid mode {mode} for cohere_completions")
response = client.generate(prompt=prompt, return_likelihoods="ALL", **curr_kwargs)
text = response[0].text
num_tokens = len(response[0].token_likelihoods)

if text == "":
raise CohereError("Empty string response")

return text
return text, num_tokens

except CohereError as e:
print(f"Try #{trynum+1}/{max_tries}: Error running prompt {repr(prompt)}: {e}")

return " " # placeholder response for errors, doesn't allow empty string
return " ", 0 # placeholder response for errors, doesn't allow empty string
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ openbuddy-llama2-70b-v10.1,87.67123287671232,1.1508417516577765,701,96,6,803,com
openchat-v2-w-13b,87.1268656716418,1.1769197439396015,699,102,3,804,community,1566.0
openbuddy-llama-65b-v8,86.53366583541147,1.2029182403474274,693,107,2,802,community,1162.0
wizardlm-13b-v1.1,86.31840796019901,1.2063217831272972,692,108,4,804,community,1525.0
cohere,85.0560398505604,1.2558329840021718,682,119,2,803,community,1715.0
openchat-v2-13b,84.96894409937889,1.2572979835605944,683,120,2,805,community,1564.0
humpback-llama-65b,83.70646766169155,1.3071034735987248,672,130,2,804,community,1269.0
ultralm-13b-v2.0,83.60248447204968,1.30578174546824,673,132,0,805,community,1399.0
Expand Down Expand Up @@ -55,8 +56,6 @@ falcon-40b-instruct,45.71428571428572,1.7524717060805597,366,435,4,805,minimal,6
alpaca-farm-ppo-sim-gpt4-20k,44.099378881987576,1.7399772578861137,350,445,10,805,verified,511.0
pythia-12b-mix-sft,41.86335403726708,1.737637146007538,336,467,2,805,verified,913.0
alpaca-farm-ppo-human,41.24223602484472,1.7271813123250834,328,469,8,805,minimal,803.0
cohere-chat,29.565217391304348,1.5949050483247118,232,561,12,805,community,779.0
cohere,28.385093167701864,1.5717547121761728,221,569,15,805,community,682.0
alpaca-7b,26.459627329192543,1.535711469748,205,584,16,805,minimal,396.0
oasst-sft-pythia-12b,25.962732919254663,1.5261079289535309,201,588,16,805,verified,726.0
falcon-7b-instruct,23.60248447204969,1.4898235369056625,187,612,6,805,verified,478.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ falcon-40b-instruct,46.70807453416149,1.7551420072945083,minimal,4,805,374,427,,
alpaca-farm-ppo-human,46.45962732919255,1.750131850347461,minimal,8,805,370,427,,803
pythia-12b-mix-sft,43.22981366459627,1.7449120766669366,verified,2,805,347,456,,913
oasst-sft-pythia-12b,32.79503105590062,1.6369108459870174,verified,16,805,256,533,,726
cohere-chat,32.79503105590062,1.6416235300873216,community,12,805,258,535,,779
cohere,32.608695652173914,1.635641080422956,community,15,805,255,535,,682
alpaca-7b,32.298136645962735,1.630307861230374,minimal,16,805,252,537,,396
falcon-7b-instruct,29.565217391304348,1.6021542242903124,verified,6,805,235,564,,478
text_davinci_001,21.490683229813666,1.421716368655911,minimal,20,805,163,622,,296
8 changes: 0 additions & 8 deletions src/alpaca_eval/models_configs/cohere-chat/configs.yaml

This file was deleted.

5 changes: 2 additions & 3 deletions src/alpaca_eval/models_configs/cohere/configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ cohere:
prompt_template: "cohere/prompt.txt"
fn_completions: "cohere_completions"
completions_kwargs:
model_name: "command"
mode: "instruct"
model_name: "command-nightly"
max_tokens: 2048
pretty_name: "Cohere"
pretty_name: "Cohere Command"
2 changes: 1 addition & 1 deletion src/alpaca_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def make_prompts(
if df.empty:
return [], df

text_to_format = re.findall("{([^ \s]+?)}", template)
text_to_format = re.findall(r"{([^ \s]+?)}", template)
n_occurrences = Counter(text_to_format)

if not all([n == batch_size for n in n_occurrences.values()]):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_decoders_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_anthropic_completions(mocker):
def test_cohere_completions(mocker):
mocker.patch(
"alpaca_eval.decoders.cohere._cohere_completion_helper",
return_value="Mocked completion text",
return_value=["Mocked completion text",42],
)
result = cohere_completions(["Prompt 1", "Prompt 2"], num_procs=1)
_run_all_asserts_completions(result)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pairwise_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def expected_annotations():
def single_annotator():
return SinglePairwiseAnnotator(
prompt_template="text_davinci_003/basic_prompt.txt",
completion_parser_kwargs=dict(outputs_to_match={1: "(?:^|\n) ?Output \(a\)", 2: "(?:^|\n) ?Output \(b\)"}),
completion_parser_kwargs=dict(outputs_to_match={1: r"(?:^|\n) ?Output \(a\)", 2: "(?:^|\n) ?Output \(b\)"}),
is_randomize_output_order=False,
is_shuffle=False,
)
Expand Down

0 comments on commit 0ac9b14

Please sign in to comment.