From dcd3bc0fb826975eb690541e787219f11fe96d9a Mon Sep 17 00:00:00 2001 From: acer-king Date: Wed, 16 Oct 2024 11:01:53 -0700 Subject: [PATCH] add conversation instead of one message. --- .../services/validators/text_validator.py | 45 +++++++++++++------ 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/validators/services/validators/text_validator.py b/validators/services/validators/text_validator.py index 608a96b5..cadf617f 100644 --- a/validators/services/validators/text_validator.py +++ b/validators/services/validators/text_validator.py @@ -12,6 +12,7 @@ from cortext.utils import (call_anthropic_bedrock, call_bedrock, call_anthropic, call_gemini, call_groq, call_openai, get_question) from validators.utils import save_or_get_answer_from_cache, get_query_synapse_from_cache +from typing import List, Dict class TextValidator(BaseValidator): @@ -111,54 +112,72 @@ async def build_wandb_data(self, uid_to_score, responses): self.wandb_data["responses"][uid] = response return self.wandb_data - async def call_api(self, prompt: str, image_url: Optional[str], query_syn: StreamPrompting) -> str: + async def call_api(self, conversation: List[Dict[str, Optional[str]]], query_syn: StreamPrompting) -> str: provider = query_syn.provider self.model = query_syn.model + if provider == "OpenAI": return await call_openai( - [{"role": "user", "content": prompt, "image": image_url}], self.temperature, self.model, self.seed, + conversation, + self.temperature, + self.model, + self.seed, self.max_tokens ) elif provider == "AnthropicBedrock": - return await call_anthropic_bedrock(prompt, self.temperature, self.model, self.max_tokens, self.top_p, - self.top_k) + prompt = " ".join([m['content'] for m in conversation if m['content']]) + return await call_anthropic_bedrock( + prompt, + self.temperature, + self.model, + self.max_tokens, + self.top_p, + self.top_k + ) elif provider == "Gemini": - return await call_gemini(prompt, self.temperature, self.model, self.max_tokens, self.top_p, self.top_k) + prompt = " ".join([m['content'] for m in conversation if m['content']]) + return await call_gemini( + prompt, + self.temperature, + self.model, + self.max_tokens, + self.top_p, + self.top_k + ) elif provider == "Anthropic": return await call_anthropic( - [{"role": "user", "content": prompt, "image": image_url}], + conversation, self.temperature, self.model, self.max_tokens, self.top_p, - self.top_k, + self.top_k ) elif provider == "Groq": + prompt = " ".join([m['content'] for m in conversation if m['content']]) return await call_groq( [{"role": "user", "content": prompt}], self.temperature, self.model, self.max_tokens, self.top_p, - self.seed, + self.seed ) elif provider == "Bedrock": return await call_bedrock( - [{"role": "user", "content": prompt, "image": image_url}], + conversation, self.temperature, self.model, self.max_tokens, self.top_p, - self.seed, + self.seed ) else: bt.logging.error(f"provider {provider} not found") @save_or_get_answer_from_cache async def get_answer_task(self, uid: int, query_syn: StreamPrompting, response): - prompt = query_syn.messages[0].get("content") - image_url = query_syn.messages[0].get("image") - answer = await self.call_api(prompt, image_url, query_syn) + answer = await self.call_api(query_syn.messages, query_syn) return answer async def get_scoring_task(self, uid, answer, response):