diff --git a/cortext/dendrite.py b/cortext/dendrite.py index 3c9527cf..bfa12e8e 100644 --- a/cortext/dendrite.py +++ b/cortext/dendrite.py @@ -2,7 +2,7 @@ import aiohttp import bittensor as bt -from aiohttp import ServerTimeoutError +from aiohttp import ServerTimeoutError, ClientConnectorError from bittensor import dendrite import traceback import time @@ -47,14 +47,11 @@ async def call_stream( # Preprocess synapse for making a request synapse: StreamPrompting = self.preprocess_synapse_for_request(target_axon, synapse, timeout) # type: ignore max_try = 0 - session = CortexDendrite.miner_to_session.get(endpoint) + timeout = aiohttp.ClientTimeout(total=300, connect=timeout, sock_connect=timeout, sock_read=timeout) + connector = aiohttp.TCPConnector(limit=200) + session = aiohttp.ClientSession(timeout=timeout, connector=connector) try: while max_try < 3: - if not session: - timeout = aiohttp.ClientTimeout(total=300, connect=timeout, sock_connect=timeout, sock_read=timeout) - connector = aiohttp.TCPConnector(limit=200) - session = aiohttp.ClientSession(timeout=timeout, connector=connector) - CortexDendrite.miner_to_session[endpoint] = session async with session.post( url, headers=synapse.to_headers(), @@ -70,6 +67,12 @@ async def call_stream( bt.logging.error(f"timeout error happens. max_try is {max_try}") max_try += 1 continue + except ConnectionRefusedError as err: + bt.logging.error(f"can not connect to miner for now. connection failed") + break + except ClientConnectorError as err: + bt.logging.error(f"can not connect to miner for now. connection failed") + break except ServerTimeoutError as err: bt.logging.error(f"timeout error happens. max_try is {max_try}") max_try += 1 @@ -84,6 +87,7 @@ async def call_stream( bt.logging.error(f"{e} {traceback.format_exc()}") finally: synapse.dendrite.process_time = str(time.time() - start_time) + await session.close() async def call_stream_in_batch( self, diff --git a/organic.py b/organic.py index 93040d74..a5f21e5c 100644 --- a/organic.py +++ b/organic.py @@ -162,7 +162,7 @@ async def query_and_log(synapse): responses = await asyncio.gather(*[query_and_log(synapse) for synapse in synapses]) - cache_service.set_cache_in_batch(synapses) + print(responses[0], len(responses)) print("Responses saved to cache database") diff --git a/requirements.txt b/requirements.txt index b70fd048..29b72eed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,5 +16,4 @@ pyOpenSSL==24.* google-generativeai groq==0.5.0 aioboto3 -tabulate -uvloop \ No newline at end of file +tabulate \ No newline at end of file diff --git a/server/app/curd.py b/server/app/curd.py index c40b4742..01c15cdb 100644 --- a/server/app/curd.py +++ b/server/app/curd.py @@ -1,11 +1,13 @@ +import psycopg2 import os from typing import List from . import models, schemas -from .database import cur, TABEL_NAME, conn +from .database import cur, TABEL_NAME, conn, DATABASE_URL from fastapi import HTTPException def create_item(item: schemas.ItemCreate): + global cur, conn query = f"INSERT INTO {TABEL_NAME} (p_key, question, answer, provider, model, timestamp) VALUES (%s, %s, %s, %s, %s, %s)" cur.execute(query, item.p_key, item.question, item.answer, item.provider, item.model, item.timestamp) conn.commit() # Save changes to the database @@ -13,22 +15,31 @@ def create_item(item: schemas.ItemCreate): def create_items(items: List[schemas.ItemCreate]): + conn = psycopg2.connect(DATABASE_URL) + # Create a cursor object to interact with the database + cur = conn.cursor() query = f"INSERT INTO {TABEL_NAME} (p_key, question, answer, provider, model, timestamp) VALUES (%s, %s, %s, %s, %s, %s)" datas = [] for item in items: datas.append((item.p_key, item.question, item.answer, item.provider, item.model, item.timestamp)) try: + if conn.closed: + print("connection is closed already") cur.executemany(query, datas) conn.commit() # Save changes to the database + print("successfully saved in database") except Exception as err: raise HTTPException(status_code=500, detail=f"Internal Server Error {err}") def get_items(skip: int = 0, limit: int = 10): - query = f"SELECT * FROM {TABEL_NAME} LIMIT {limit} OFFSET {skip};" + conn = psycopg2.connect(DATABASE_URL) + # Create a cursor object to interact with the database + cur = conn.cursor() + query = f"SELECT * FROM {TABEL_NAME} offset {skip} limit {limit};" cur.execute(query) items = cur.fetchall() # Fetch all results - return [dict(item) for item in items] + return [item for item in items] def get_item(p_key: int): diff --git a/server/app/database.py b/server/app/database.py index 4b2a5449..efb73d69 100644 --- a/server/app/database.py +++ b/server/app/database.py @@ -43,12 +43,4 @@ async def create_table(app): except Exception as e: print(f"Error creating table: {e}") - finally: - # Close the cursor and connection - if cur: - cur.close() - if conn: - conn.close() - - create_table(None) diff --git a/server/app/main.py b/server/app/main.py index b6feedd7..f1f423cb 100644 --- a/server/app/main.py +++ b/server/app/main.py @@ -1,7 +1,7 @@ from contextlib import asynccontextmanager from fastapi import FastAPI, Depends, HTTPException from . import curd, models, schemas -from .database import create_table +from .database import create_table, conn, cur from typing import List @@ -15,6 +15,12 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) +@app.on_event("shutdown") +async def shutdown_event(): + cur.close() + conn.close() + + # Create an item @app.post("/items") def create_item(items: List[schemas.ItemCreate]): @@ -22,7 +28,7 @@ def create_item(items: List[schemas.ItemCreate]): # Read all items -@app.get("/items", response_model=list) +@app.get("/items") def read_items(skip: int = 0, limit: int = 10): items = curd.get_items(skip=skip, limit=limit) return items diff --git a/start_validator.py b/start_validator.py index 92a20c7a..bcf2e160 100644 --- a/start_validator.py +++ b/start_validator.py @@ -31,6 +31,7 @@ def update_and_restart(pm2_name, netuid, wallet_name, wallet_hotkey, address, au subprocess.run(["git", "reset", "--hard"]) subprocess.run(["git", "pull"]) subprocess.run(["pip", "install", "-e", "."]) + subprocess.run(["pip", "uninstall", "uvloop"]) subprocess.run( ["pm2", "start", "--name", pm2_name, f"python3 -m validators.validator --wallet.name {wallet_name}" f" --wallet.hotkey {wallet_hotkey} " diff --git a/validators/services/cache.py b/validators/services/cache.py index 6af0f9f6..a0064848 100644 --- a/validators/services/cache.py +++ b/validators/services/cache.py @@ -3,7 +3,9 @@ import time import hashlib from typing import List - +import json +import requests +import bittensor as bt from cortext import StreamPrompting @@ -51,7 +53,28 @@ def set_cache(self, question, answer, provider, model, ttl=3600 * 24): ''', (p_key, question, answer, provider, model, expires_at)) self.conn.commit() - def set_cache_in_batch(self, syns: List[StreamPrompting], ttl=3600 * 24, block_num=0, cycle_num=0, epoch_num=0): + def send_to_central_server(self, url, datas): + start_time = time.time() + if not url: + return + bt.logging.info("sending datas to central server.") + headers = { + 'Content-Type': 'application/json' # Specify that we're sending JSON + } + response = requests.post(url, data=json.dumps(datas), headers=headers) + # Check the response + if response.status_code == 200: + bt.logging.info( + f"Successfully sent data to central server. {time.time() - start_time} sec total elapsed for sending to central server.") + return True + else: + bt.logging.info( + f"Failed to send data. Status code: {response.status_code} {time.time() - start_time} sec total elapsed for sending to central server.") + bt.logging.info(f"Response:{response.text}") + return False + + def set_cache_in_batch(self, central_server_url, syns: List[StreamPrompting], + ttl=3600 * 24, block_num=0, cycle_num=0, epoch_num=0): datas = [] last_update_time = time.time() for syn in syns: @@ -62,21 +85,29 @@ def set_cache_in_batch(self, syns: List[StreamPrompting], ttl=3600 * 24, block_n syn.block_num = block_num syn.epoch_num = epoch_num syn.cycle_num = cycle_num - datas.append((p_key, syn.json( - exclude={"dendrite", "completion", "total_size", "header_size", "axon", "uid", "provider", "model", - "required_hash_fields", "computed_body_hash", "streaming", "deserialize_flag", "task_id", }), - syn.completion, syn.provider, syn.model, - last_update_time)) - - # Insert multiple records + datas.append({"p_key": p_key, + "question": json.dumps(syn.json( + exclude={"dendrite", "completion", "total_size", "header_size", "axon", "uid", "provider", + "model", + "required_hash_fields", "computed_body_hash", "streaming", "deserialize_flag", + "task_id", })), + "answer": syn.completion, + "provider": syn.provider, + "model": syn.model, + "timestamp": last_update_time}) + + if self.send_to_central_server(central_server_url, datas): + return + # if not saved in central server successfully, then just save local cache.db file cursor = self.conn.cursor() cursor.executemany(''' INSERT OR IGNORE INTO cache (p_key, question, answer, provider, model, timestamp) VALUES (?, ?, ?, ?, ?, ?) - ''', datas) + ''', [list(item.values()) for item in datas]) # Commit the transaction self.conn.commit() + return datas def get_answer(self, question, provider, model): p_key = self.generate_hash(str(question) + str(provider) + str(model)) diff --git a/validators/validator.py b/validators/validator.py index e4588975..14e280bc 100644 --- a/validators/validator.py +++ b/validators/validator.py @@ -145,9 +145,9 @@ def main(): bt.logging.info("Keyboard interrupt detected. Exiting validator.") finally: bt.logging.info("stopping axon server.") - bt.logging.info( - f"closing all sessins. total connections is {len(dendrite.CortexDendrite.miner_to_session.keys())}") - asyncio.run(close_all_connections()) + # bt.logging.info( + # f"closing all sessins. total connections is {len(dendrite.CortexDendrite.miner_to_session.keys())}") + # asyncio.run(close_all_connections()) weight_setter.axon.stop() bt.logging.info("updating status before exiting validator") state = utils.get_state(state_path) diff --git a/validators/weight_setter.py b/validators/weight_setter.py index 8ce5f4c4..05ae0917 100644 --- a/validators/weight_setter.py +++ b/validators/weight_setter.py @@ -2,6 +2,7 @@ import concurrent import random import threading +import traceback import torch import time @@ -87,16 +88,26 @@ def __init__(self, config, cache: QueryResponseCache, loop=None): bt.logging.info(f"total loaded questions are {len(self.queries)}") self.set_up_next_block_to_wait() # Set up async tasks - self.thread_executor = concurrent.futures.ThreadPoolExecutor(thread_name_prefix='asyncio') - self.loop.create_task(self.consume_organic_queries()) - self.loop.create_task(self.perform_synthetic_queries()) self.loop.create_task(self.process_queries_from_database()) self.saving_datas = [] - self.url = None + self.url = "http://ec2-3-239-8-190.compute-1.amazonaws.com:8000/items" daemon_thread = threading.Thread(target=self.saving_resp_answers_from_miners) daemon_thread.start() + synthetic_thread = threading.Thread(target=self.process_synthetic_tasks) + synthetic_thread.start() + + organic_thread = threading.Thread(target=self.start_axon_server) + organic_thread.start() + + def start_axon_server(self): + asyncio.run(self.consume_organic_queries()) + + def process_synthetic_tasks(self): + bt.logging.info("starting synthetic tasks.") + asyncio.run(self.perform_synthetic_queries()) + def saving_resp_answers_from_miners(self): self.cache = QueryResponseCache() self.cache.set_vali_info(vali_uid=self.my_uid, vali_hotkey=self.wallet.hotkey.ss58_address) @@ -106,27 +117,12 @@ def saving_resp_answers_from_miners(self): else: bt.logging.info(f"saving responses...") start_time = time.time() - self.cache.set_cache_in_batch([item.get('synapse') for item in self.saving_datas], - block_num=self.current_block, - cycle_num=self.current_block // 36, epoch_num=self.current_block // 360) + self.cache.set_cache_in_batch(self.url, [item.get('synapse') for item in self.saving_datas], + block_num=self.current_block or 0, + cycle_num=(self.current_block or 0) // 36, + epoch_num=(self.current_block or 0) // 360) bt.logging.info(f"total saved responses is {len(self.saving_datas)}") self.saving_datas.clear() - if not self.url: - return - bt.logging.info("sending datas to central server.") - json_data = [item.get('synapse').dict() for item in self.saving_datas] - headers = { - 'Content-Type': 'application/json' # Specify that we're sending JSON - } - response = requests.post(self.url, data=json.dumps(json_data), headers=headers) - # Check the response - if response.status_code == 200: - bt.logging.info( - f"Successfully sent data to central server. {time.time() - start_time} sec total elapsed for sending to central server.") - else: - bt.logging.info( - f"Failed to send data. Status code: {response.status_code} {time.time() - start_time} sec total elapsed for sending to central server.") - bt.logging.info(f"Response:{response.text}") async def run_sync_in_async(self, fn): return await self.loop.run_in_executor(None, fn) @@ -485,13 +481,9 @@ async def embeddings(self, synapse: Embeddings) -> Embeddings: async def prompt(self, synapse: StreamPrompting) -> StreamingSynapse.BTStreamingResponse: bt.logging.info(f"Received {synapse}") - # Return the streaming response async def _prompt(query_synapse: StreamPrompting, send: Send): - bt.logging.info(f"Sending {synapse} request to uid: {synapse.uid}") - query_synapse.deserialize_flag = False query_synapse.streaming = True - query_synapse.validator_uid = self.my_uid or 0 query_synapse.block_num = self.current_block or 0 uid = self.task_mgr.assign_task(query_synapse) query_synapse.uid = uid @@ -529,14 +521,15 @@ async def handle_response(resp): await send({"type": "http.response.body", "body": b'', "more_body": False}) axon = self.metagraph.axons[uid] + bt.logging.trace(f"Sending {query_synapse} request to uid: {query_synapse.uid}") responses = self.dendrite.call_stream( target_axon=axon, synapse=synapse, timeout=synapse.timeout, ) return await handle_response(responses) - token_streamer = partial(_prompt, synapse) + return synapse.create_streaming_response(token_streamer) async def consume_organic_queries(self):