From 7cc750dfcfc276538d8763f441b725e54801b0a2 Mon Sep 17 00:00:00 2001 From: acer-king Date: Thu, 17 Oct 2024 08:21:45 -0700 Subject: [PATCH] add converstation to openai --- miner/providers/open_ai.py | 22 ++++++++------- .../services/validators/text_validator.py | 4 ++- validators/utils.py | 27 +++++++++++++++++++ 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/miner/providers/open_ai.py b/miner/providers/open_ai.py index 4cb815cc..221fb76e 100644 --- a/miner/providers/open_ai.py +++ b/miner/providers/open_ai.py @@ -9,17 +9,13 @@ from cortext.protocol import StreamPrompting from miner.error_handler import error_handler + class OpenAI(Provider): def __init__(self, synapse): super().__init__(synapse) self.openai_client = AsyncOpenAI(timeout=config.ASYNC_TIME_OUT, api_key=config.OPENAI_API_KEY) - - @error_handler - async def _prompt(self, synapse: StreamPrompting, send: Send): - - message = self.messages[0] - + def create_filtered_message(self, message): filtered_message: ChatCompletionMessageParam = { "role": message["role"], "content": [], @@ -42,12 +38,18 @@ async def _prompt(self, synapse: StreamPrompting, send: Send): }, } ) + + return filtered_message + + @error_handler + async def _prompt(self, synapse: StreamPrompting, send: Send): + filtered_messages = [self.create_filtered_message(message) for message in synapse.messages] try: response = await self.openai_client.chat.completions.create( - model=self.model, messages=[filtered_message], - temperature=self.temperature, stream=True, - seed=self.seed, - max_tokens=self.max_tokens, + model=synapse.model, messages=filtered_messages, + temperature=synapse.temperature, stream=True, + seed=synapse.seed, + max_tokens=synapse.max_tokens, ) except Exception as err: bt.logging.exception(err) diff --git a/validators/services/validators/text_validator.py b/validators/services/validators/text_validator.py index cadf617f..6c12983f 100644 --- a/validators/services/validators/text_validator.py +++ b/validators/services/validators/text_validator.py @@ -12,6 +12,7 @@ from cortext.utils import (call_anthropic_bedrock, call_bedrock, call_anthropic, call_gemini, call_groq, call_openai, get_question) from validators.utils import save_or_get_answer_from_cache, get_query_synapse_from_cache +from validators import utils from typing import List, Dict @@ -117,8 +118,9 @@ async def call_api(self, conversation: List[Dict[str, Optional[str]]], query_syn self.model = query_syn.model if provider == "OpenAI": + filtered_messages = [utils.create_filtered_message_open_ai(message) for message in conversation] return await call_openai( - conversation, + filtered_messages, self.temperature, self.model, self.seed, diff --git a/validators/utils.py b/validators/utils.py index 461a5b79..0506e487 100644 --- a/validators/utils.py +++ b/validators/utils.py @@ -222,3 +222,30 @@ def main(urls): queries.append(query) return queries + + +def create_filtered_message_open_ai(message): + filtered_message = { + "role": message["role"], + "content": [], + } + + if message.get("content"): + filtered_message["content"].append( + { + "type": "text", + "text": message["content"], + } + ) + if message.get("image"): + image_url = message.get("image") + filtered_message["content"].append( + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + } + ) + + return filtered_message