Skip to content

Commit

Permalink
add conversation instead of one message.
Browse files Browse the repository at this point in the history
  • Loading branch information
acer-king committed Oct 16, 2024
1 parent 73120d4 commit dcd3bc0
Showing 1 changed file with 32 additions and 13 deletions.
45 changes: 32 additions & 13 deletions validators/services/validators/text_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cortext.utils import (call_anthropic_bedrock, call_bedrock, call_anthropic, call_gemini,
call_groq, call_openai, get_question)
from validators.utils import save_or_get_answer_from_cache, get_query_synapse_from_cache
from typing import List, Dict


class TextValidator(BaseValidator):
Expand Down Expand Up @@ -111,54 +112,72 @@ async def build_wandb_data(self, uid_to_score, responses):
self.wandb_data["responses"][uid] = response
return self.wandb_data

async def call_api(self, prompt: str, image_url: Optional[str], query_syn: StreamPrompting) -> str:
async def call_api(self, conversation: List[Dict[str, Optional[str]]], query_syn: StreamPrompting) -> str:
provider = query_syn.provider
self.model = query_syn.model

if provider == "OpenAI":
return await call_openai(
[{"role": "user", "content": prompt, "image": image_url}], self.temperature, self.model, self.seed,
conversation,
self.temperature,
self.model,
self.seed,
self.max_tokens
)
elif provider == "AnthropicBedrock":
return await call_anthropic_bedrock(prompt, self.temperature, self.model, self.max_tokens, self.top_p,
self.top_k)
prompt = " ".join([m['content'] for m in conversation if m['content']])
return await call_anthropic_bedrock(
prompt,
self.temperature,
self.model,
self.max_tokens,
self.top_p,
self.top_k
)
elif provider == "Gemini":
return await call_gemini(prompt, self.temperature, self.model, self.max_tokens, self.top_p, self.top_k)
prompt = " ".join([m['content'] for m in conversation if m['content']])
return await call_gemini(
prompt,
self.temperature,
self.model,
self.max_tokens,
self.top_p,
self.top_k
)
elif provider == "Anthropic":
return await call_anthropic(
[{"role": "user", "content": prompt, "image": image_url}],
conversation,
self.temperature,
self.model,
self.max_tokens,
self.top_p,
self.top_k,
self.top_k
)
elif provider == "Groq":
prompt = " ".join([m['content'] for m in conversation if m['content']])
return await call_groq(
[{"role": "user", "content": prompt}],
self.temperature,
self.model,
self.max_tokens,
self.top_p,
self.seed,
self.seed
)
elif provider == "Bedrock":
return await call_bedrock(
[{"role": "user", "content": prompt, "image": image_url}],
conversation,
self.temperature,
self.model,
self.max_tokens,
self.top_p,
self.seed,
self.seed
)
else:
bt.logging.error(f"provider {provider} not found")

@save_or_get_answer_from_cache
async def get_answer_task(self, uid: int, query_syn: StreamPrompting, response):
prompt = query_syn.messages[0].get("content")
image_url = query_syn.messages[0].get("image")
answer = await self.call_api(prompt, image_url, query_syn)
answer = await self.call_api(query_syn.messages, query_syn)
return answer

async def get_scoring_task(self, uid, answer, response):
Expand Down

0 comments on commit dcd3bc0

Please sign in to comment.