From 7f84cbd3de08f49eebb5eb2949a8e3dfcbb477a7 Mon Sep 17 00:00:00 2001 From: acer-king Date: Mon, 14 Oct 2024 05:05:02 -0700 Subject: [PATCH] add sqlite database for saving all questions and answers. --- miner/providers/anthropic.py | 40 +++++++++++++++++++++--------------- miner/providers/groq.py | 2 +- miner/providers/open_ai.py | 2 +- validators/services/cache.py | 5 +++-- validators/weight_setter.py | 23 +++++++++++++++------ 5 files changed, 45 insertions(+), 27 deletions(-) diff --git a/miner/providers/anthropic.py b/miner/providers/anthropic.py index e15f41ca..b9b52640 100644 --- a/miner/providers/anthropic.py +++ b/miner/providers/anthropic.py @@ -7,38 +7,44 @@ from cortext.protocol import StreamPrompting from miner.error_handler import error_handler + class Anthropic(Provider): def __init__(self, synapse): super().__init__(synapse) - self.anthropic_client = AsyncAnthropic(timeout=config.ASYNC_TIME_OUT, api_key=config.ANTHROPIC_API_KEY) + try: + self.anthropic_client = AsyncAnthropic(timeout=config.ASYNC_TIME_OUT, api_key=config.ANTHROPIC_API_KEY) + except Exception as err: + bt.logging.error(f"api_key looks incorrect or expired. {err}") + self.anthropic_client = None @error_handler async def _prompt(self, synapse: StreamPrompting, send: Send): filtered_messages, system_prompt = self.generate_messages_to_claude(self.messages) stream_kwargs = { - "max_tokens": self.max_tokens, + "max_tokens": synapse.max_tokens, "messages": filtered_messages, - "model": self.model, + "model": synapse.model, } if system_prompt: stream_kwargs["system"] = system_prompt - completion = self.anthropic_client.messages.stream(**stream_kwargs) - async with completion as stream: - async for text in stream.text_stream: - await send( - { - "type": "http.response.body", - "body": text.encode("utf-8"), - "more_body": True, - } - ) - bt.logging.info(f"Streamed text: {text}") - - # Send final message to close the stream - await send({"type": "http.response.body", "body": b'', "more_body": False}) + try: + async with completion as stream: + async for text in stream.text_stream: + await send( + { + "type": "http.response.body", + "body": text.encode("utf-8"), + "more_body": True, + } + ) + bt.logging.info(f"Streamed text: {text}") + except Exception as err: + bt.logging.error(f"{err}") + finally: + await send({"type": "http.response.body", "body": b'', "more_body": False}) def image_service(self, synapse): pass diff --git a/miner/providers/groq.py b/miner/providers/groq.py index 6947be93..6fc0d7a9 100644 --- a/miner/providers/groq.py +++ b/miner/providers/groq.py @@ -43,7 +43,7 @@ async def _prompt(self, synapse: StreamPrompting, send: Send): "more_body": True, } ) - bt.logging.info(f"Streamed tokens: {joined_buffer}") + bt.logging.trace(f"Streamed tokens: {joined_buffer}") buffer = [] await send( { diff --git a/miner/providers/open_ai.py b/miner/providers/open_ai.py index f734a5d3..4cb815cc 100644 --- a/miner/providers/open_ai.py +++ b/miner/providers/open_ai.py @@ -66,7 +66,7 @@ async def _prompt(self, synapse: StreamPrompting, send: Send): "more_body": True, } ) - bt.logging.info(f"Streamed tokens: {joined_buffer}") + bt.logging.trace(f"Streamed tokens: {joined_buffer}") buffer = [] if buffer: diff --git a/validators/services/cache.py b/validators/services/cache.py index 429c57f5..77d84a76 100644 --- a/validators/services/cache.py +++ b/validators/services/cache.py @@ -34,6 +34,7 @@ def generate_hash(input_string): return hashlib.sha256(input_string.encode('utf-8')).hexdigest() def set_cache(self, question, answer, provider, model, ttl=3600 * 24): + return p_key = self.generate_hash(str(question) + str(provider) + str(model)) expires_at = time.time() + ttl cursor = self.conn.cursor() @@ -47,8 +48,8 @@ def set_cache_in_batch(self, syns: List[StreamPrompting], ttl=3600 * 24): datas = [] expires_at = time.time() + ttl for syn in syns: - p_key = self.generate_hash(str(syn.messages) + str(syn.provider) + str(syn.model)) - datas.append((p_key, syn.messages, syn.completion, syn.provider, syn.model, expires_at)) + p_key = self.generate_hash(str(expires_at) + str(syn.messages) + str(syn.provider) + str(syn.model)) + datas.append((p_key, syn.json(), syn.completion, syn.provider, syn.model, expires_at)) # Insert multiple records cursor = self.conn.cursor() diff --git a/validators/weight_setter.py b/validators/weight_setter.py index 873af948..efc5d878 100644 --- a/validators/weight_setter.py +++ b/validators/weight_setter.py @@ -242,17 +242,22 @@ async def perform_synthetic_queries(self): while batched_tasks: start_time_batch = time.time() await self.dendrite.aclose_session() - await asyncio.gather(*batched_tasks) + await asyncio.gather(*batched_tasks, return_exceptions=True) bt.logging.debug( f"batch size {len(batched_tasks)} has been processed and time elapsed: {time.time() - start_time_batch}") - batched_tasks, remain_tasks = self.pop_synthetic_tasks_max_100_per_miner(remain_tasks) + bt.logging.debug(f"remain tasks: {len(remain_tasks)}") - self.synthetic_task_done = True + batched_tasks, remain_tasks = self.pop_synthetic_tasks_max_100_per_miner(remain_tasks) - bt.logging.info(f"saving responses...") bt.logging.info( f"synthetic queries has been processed successfully." f"total queries are {len(query_synapses)}: total {time.time() - start_time} elapsed") + bt.logging.info(f"saving responses...") + self.cache.set_cache_in_batch([item.get('synapse') for item in self.query_database]) + self.synthetic_task_done = True + + bt.logging.info( + f"synthetic queries and answers has been saved in cache successfully. total times {time.time() - start_time}") def pop_synthetic_tasks_max_100_per_miner(self, synthetic_tasks): batch_size = 50000 @@ -550,8 +555,14 @@ async def process_queries_from_database(self): while True: await asyncio.sleep(1) # Adjust the sleep time as needed # accumulate all query results for 36 blocks - if not self.query_database or not self.is_epoch_end(): - bt.logging.trace("no data in query_database. so continue...") + if not self.query_database: + bt.logging.debug("no data in query_database. so continue...") + continue + if not self.is_epoch_end(): + bt.logging.debug("no end of epoch. so continue...") + continue + if not self.synthetic_task_done: + bt.logging.debug("wait for synthetic tasks to complete.") continue bt.logging.info(f"start scoring process...")