Skip to content

Commit

Permalink
add sqlite database for saving all questions and answers.
Browse files Browse the repository at this point in the history
  • Loading branch information
acer-king committed Oct 14, 2024
1 parent 9018f4e commit 7f84cbd
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 27 deletions.
40 changes: 23 additions & 17 deletions miner/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,44 @@
from cortext.protocol import StreamPrompting
from miner.error_handler import error_handler


class Anthropic(Provider):
def __init__(self, synapse):
super().__init__(synapse)
self.anthropic_client = AsyncAnthropic(timeout=config.ASYNC_TIME_OUT, api_key=config.ANTHROPIC_API_KEY)
try:
self.anthropic_client = AsyncAnthropic(timeout=config.ASYNC_TIME_OUT, api_key=config.ANTHROPIC_API_KEY)
except Exception as err:
bt.logging.error(f"api_key looks incorrect or expired. {err}")
self.anthropic_client = None

@error_handler
async def _prompt(self, synapse: StreamPrompting, send: Send):
filtered_messages, system_prompt = self.generate_messages_to_claude(self.messages)

stream_kwargs = {
"max_tokens": self.max_tokens,
"max_tokens": synapse.max_tokens,
"messages": filtered_messages,
"model": self.model,
"model": synapse.model,
}

if system_prompt:
stream_kwargs["system"] = system_prompt

completion = self.anthropic_client.messages.stream(**stream_kwargs)
async with completion as stream:
async for text in stream.text_stream:
await send(
{
"type": "http.response.body",
"body": text.encode("utf-8"),
"more_body": True,
}
)
bt.logging.info(f"Streamed text: {text}")

# Send final message to close the stream
await send({"type": "http.response.body", "body": b'', "more_body": False})
try:
async with completion as stream:
async for text in stream.text_stream:
await send(
{
"type": "http.response.body",
"body": text.encode("utf-8"),
"more_body": True,
}
)
bt.logging.info(f"Streamed text: {text}")
except Exception as err:
bt.logging.error(f"{err}")
finally:
await send({"type": "http.response.body", "body": b'', "more_body": False})

def image_service(self, synapse):
pass
Expand Down
2 changes: 1 addition & 1 deletion miner/providers/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def _prompt(self, synapse: StreamPrompting, send: Send):
"more_body": True,
}
)
bt.logging.info(f"Streamed tokens: {joined_buffer}")
bt.logging.trace(f"Streamed tokens: {joined_buffer}")
buffer = []
await send(
{
Expand Down
2 changes: 1 addition & 1 deletion miner/providers/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def _prompt(self, synapse: StreamPrompting, send: Send):
"more_body": True,
}
)
bt.logging.info(f"Streamed tokens: {joined_buffer}")
bt.logging.trace(f"Streamed tokens: {joined_buffer}")
buffer = []

if buffer:
Expand Down
5 changes: 3 additions & 2 deletions validators/services/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def generate_hash(input_string):
return hashlib.sha256(input_string.encode('utf-8')).hexdigest()

def set_cache(self, question, answer, provider, model, ttl=3600 * 24):
return
p_key = self.generate_hash(str(question) + str(provider) + str(model))
expires_at = time.time() + ttl
cursor = self.conn.cursor()
Expand All @@ -47,8 +48,8 @@ def set_cache_in_batch(self, syns: List[StreamPrompting], ttl=3600 * 24):
datas = []
expires_at = time.time() + ttl
for syn in syns:
p_key = self.generate_hash(str(syn.messages) + str(syn.provider) + str(syn.model))
datas.append((p_key, syn.messages, syn.completion, syn.provider, syn.model, expires_at))
p_key = self.generate_hash(str(expires_at) + str(syn.messages) + str(syn.provider) + str(syn.model))
datas.append((p_key, syn.json(), syn.completion, syn.provider, syn.model, expires_at))

# Insert multiple records
cursor = self.conn.cursor()
Expand Down
23 changes: 17 additions & 6 deletions validators/weight_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,22 @@ async def perform_synthetic_queries(self):
while batched_tasks:
start_time_batch = time.time()
await self.dendrite.aclose_session()
await asyncio.gather(*batched_tasks)
await asyncio.gather(*batched_tasks, return_exceptions=True)
bt.logging.debug(
f"batch size {len(batched_tasks)} has been processed and time elapsed: {time.time() - start_time_batch}")
batched_tasks, remain_tasks = self.pop_synthetic_tasks_max_100_per_miner(remain_tasks)
bt.logging.debug(f"remain tasks: {len(remain_tasks)}")

self.synthetic_task_done = True
batched_tasks, remain_tasks = self.pop_synthetic_tasks_max_100_per_miner(remain_tasks)

bt.logging.info(f"saving responses...")
bt.logging.info(
f"synthetic queries has been processed successfully."
f"total queries are {len(query_synapses)}: total {time.time() - start_time} elapsed")
bt.logging.info(f"saving responses...")
self.cache.set_cache_in_batch([item.get('synapse') for item in self.query_database])
self.synthetic_task_done = True

bt.logging.info(
f"synthetic queries and answers has been saved in cache successfully. total times {time.time() - start_time}")

def pop_synthetic_tasks_max_100_per_miner(self, synthetic_tasks):
batch_size = 50000
Expand Down Expand Up @@ -550,8 +555,14 @@ async def process_queries_from_database(self):
while True:
await asyncio.sleep(1) # Adjust the sleep time as needed
# accumulate all query results for 36 blocks
if not self.query_database or not self.is_epoch_end():
bt.logging.trace("no data in query_database. so continue...")
if not self.query_database:
bt.logging.debug("no data in query_database. so continue...")
continue
if not self.is_epoch_end():
bt.logging.debug("no end of epoch. so continue...")
continue
if not self.synthetic_task_done:
bt.logging.debug("wait for synthetic tasks to complete.")
continue

bt.logging.info(f"start scoring process...")
Expand Down

0 comments on commit 7f84cbd

Please sign in to comment.