Skip to content

Commit

Permalink
add converstation to openai
Browse files Browse the repository at this point in the history
  • Loading branch information
acer-king committed Oct 17, 2024
1 parent dcd3bc0 commit 7cc750d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
22 changes: 12 additions & 10 deletions miner/providers/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,13 @@
from cortext.protocol import StreamPrompting
from miner.error_handler import error_handler


class OpenAI(Provider):
def __init__(self, synapse):
super().__init__(synapse)
self.openai_client = AsyncOpenAI(timeout=config.ASYNC_TIME_OUT, api_key=config.OPENAI_API_KEY)


@error_handler
async def _prompt(self, synapse: StreamPrompting, send: Send):

message = self.messages[0]

def create_filtered_message(self, message):
filtered_message: ChatCompletionMessageParam = {
"role": message["role"],
"content": [],
Expand All @@ -42,12 +38,18 @@ async def _prompt(self, synapse: StreamPrompting, send: Send):
},
}
)

return filtered_message

@error_handler
async def _prompt(self, synapse: StreamPrompting, send: Send):
filtered_messages = [self.create_filtered_message(message) for message in synapse.messages]
try:
response = await self.openai_client.chat.completions.create(
model=self.model, messages=[filtered_message],
temperature=self.temperature, stream=True,
seed=self.seed,
max_tokens=self.max_tokens,
model=synapse.model, messages=filtered_messages,
temperature=synapse.temperature, stream=True,
seed=synapse.seed,
max_tokens=synapse.max_tokens,
)
except Exception as err:
bt.logging.exception(err)
Expand Down
4 changes: 3 additions & 1 deletion validators/services/validators/text_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cortext.utils import (call_anthropic_bedrock, call_bedrock, call_anthropic, call_gemini,
call_groq, call_openai, get_question)
from validators.utils import save_or_get_answer_from_cache, get_query_synapse_from_cache
from validators import utils
from typing import List, Dict


Expand Down Expand Up @@ -117,8 +118,9 @@ async def call_api(self, conversation: List[Dict[str, Optional[str]]], query_syn
self.model = query_syn.model

if provider == "OpenAI":
filtered_messages = [utils.create_filtered_message_open_ai(message) for message in conversation]
return await call_openai(
conversation,
filtered_messages,
self.temperature,
self.model,
self.seed,
Expand Down
27 changes: 27 additions & 0 deletions validators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,30 @@ def main(urls):
queries.append(query)

return queries


def create_filtered_message_open_ai(message):
filtered_message = {
"role": message["role"],
"content": [],
}

if message.get("content"):
filtered_message["content"].append(
{
"type": "text",
"text": message["content"],
}
)
if message.get("image"):
image_url = message.get("image")
filtered_message["content"].append(
{
"type": "image_url",
"image_url": {
"url": image_url,
},
}
)

return filtered_message

0 comments on commit 7cc750d

Please sign in to comment.