Skip to content

Commit

Permalink
Merge pull request #98 from Datura-ai/hotfix-main-bittensor
Browse files Browse the repository at this point in the history
Hotfix main bittensor
  • Loading branch information
surcyf123 authored Oct 29, 2024
2 parents 52f806a + d7cf6b1 commit 63aa4d7
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 64 deletions.
18 changes: 11 additions & 7 deletions cortext/dendrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import aiohttp
import bittensor as bt
from aiohttp import ServerTimeoutError
from aiohttp import ServerTimeoutError, ClientConnectorError
from bittensor import dendrite
import traceback
import time
Expand Down Expand Up @@ -47,14 +47,11 @@ async def call_stream(
# Preprocess synapse for making a request
synapse: StreamPrompting = self.preprocess_synapse_for_request(target_axon, synapse, timeout) # type: ignore
max_try = 0
session = CortexDendrite.miner_to_session.get(endpoint)
timeout = aiohttp.ClientTimeout(total=300, connect=timeout, sock_connect=timeout, sock_read=timeout)
connector = aiohttp.TCPConnector(limit=200)
session = aiohttp.ClientSession(timeout=timeout, connector=connector)
try:
while max_try < 3:
if not session:
timeout = aiohttp.ClientTimeout(total=300, connect=timeout, sock_connect=timeout, sock_read=timeout)
connector = aiohttp.TCPConnector(limit=200)
session = aiohttp.ClientSession(timeout=timeout, connector=connector)
CortexDendrite.miner_to_session[endpoint] = session
async with session.post(
url,
headers=synapse.to_headers(),
Expand All @@ -70,6 +67,12 @@ async def call_stream(
bt.logging.error(f"timeout error happens. max_try is {max_try}")
max_try += 1
continue
except ConnectionRefusedError as err:
bt.logging.error(f"can not connect to miner for now. connection failed")
break
except ClientConnectorError as err:
bt.logging.error(f"can not connect to miner for now. connection failed")
break
except ServerTimeoutError as err:
bt.logging.error(f"timeout error happens. max_try is {max_try}")
max_try += 1
Expand All @@ -84,6 +87,7 @@ async def call_stream(
bt.logging.error(f"{e} {traceback.format_exc()}")
finally:
synapse.dendrite.process_time = str(time.time() - start_time)
await session.close()

async def call_stream_in_batch(
self,
Expand Down
2 changes: 1 addition & 1 deletion organic.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ async def query_and_log(synapse):

responses = await asyncio.gather(*[query_and_log(synapse) for synapse in synapses])

cache_service.set_cache_in_batch(synapses)
print(responses[0], len(responses))

print("Responses saved to cache database")

Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ pyOpenSSL==24.*
google-generativeai
groq==0.5.0
aioboto3
tabulate
uvloop
tabulate
17 changes: 14 additions & 3 deletions server/app/curd.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,45 @@
import psycopg2
import os
from typing import List
from . import models, schemas
from .database import cur, TABEL_NAME, conn
from .database import cur, TABEL_NAME, conn, DATABASE_URL
from fastapi import HTTPException


def create_item(item: schemas.ItemCreate):
global cur, conn
query = f"INSERT INTO {TABEL_NAME} (p_key, question, answer, provider, model, timestamp) VALUES (%s, %s, %s, %s, %s, %s)"
cur.execute(query, item.p_key, item.question, item.answer, item.provider, item.model, item.timestamp)
conn.commit() # Save changes to the database
return item


def create_items(items: List[schemas.ItemCreate]):
conn = psycopg2.connect(DATABASE_URL)
# Create a cursor object to interact with the database
cur = conn.cursor()
query = f"INSERT INTO {TABEL_NAME} (p_key, question, answer, provider, model, timestamp) VALUES (%s, %s, %s, %s, %s, %s)"
datas = []
for item in items:
datas.append((item.p_key, item.question, item.answer, item.provider, item.model, item.timestamp))
try:
if conn.closed:
print("connection is closed already")
cur.executemany(query, datas)
conn.commit() # Save changes to the database
print("successfully saved in database")
except Exception as err:
raise HTTPException(status_code=500, detail=f"Internal Server Error {err}")


def get_items(skip: int = 0, limit: int = 10):
query = f"SELECT * FROM {TABEL_NAME} LIMIT {limit} OFFSET {skip};"
conn = psycopg2.connect(DATABASE_URL)
# Create a cursor object to interact with the database
cur = conn.cursor()
query = f"SELECT * FROM {TABEL_NAME} offset {skip} limit {limit};"
cur.execute(query)
items = cur.fetchall() # Fetch all results
return [dict(item) for item in items]
return [item for item in items]


def get_item(p_key: int):
Expand Down
8 changes: 0 additions & 8 deletions server/app/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,4 @@ async def create_table(app):
except Exception as e:
print(f"Error creating table: {e}")

finally:
# Close the cursor and connection
if cur:
cur.close()
if conn:
conn.close()


create_table(None)
10 changes: 8 additions & 2 deletions server/app/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException
from . import curd, models, schemas
from .database import create_table
from .database import create_table, conn, cur
from typing import List


Expand All @@ -15,14 +15,20 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)


@app.on_event("shutdown")
async def shutdown_event():
cur.close()
conn.close()


# Create an item
@app.post("/items")
def create_item(items: List[schemas.ItemCreate]):
return curd.create_items(items=items)


# Read all items
@app.get("/items", response_model=list)
@app.get("/items")
def read_items(skip: int = 0, limit: int = 10):
items = curd.get_items(skip=skip, limit=limit)
return items
Expand Down
1 change: 1 addition & 0 deletions start_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def update_and_restart(pm2_name, netuid, wallet_name, wallet_hotkey, address, au
subprocess.run(["git", "reset", "--hard"])
subprocess.run(["git", "pull"])
subprocess.run(["pip", "install", "-e", "."])
subprocess.run(["pip", "uninstall", "uvloop"])
subprocess.run(
["pm2", "start", "--name", pm2_name, f"python3 -m validators.validator --wallet.name {wallet_name}"
f" --wallet.hotkey {wallet_hotkey} "
Expand Down
51 changes: 41 additions & 10 deletions validators/services/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import time
import hashlib
from typing import List

import json
import requests
import bittensor as bt
from cortext import StreamPrompting


Expand Down Expand Up @@ -51,7 +53,28 @@ def set_cache(self, question, answer, provider, model, ttl=3600 * 24):
''', (p_key, question, answer, provider, model, expires_at))
self.conn.commit()

def set_cache_in_batch(self, syns: List[StreamPrompting], ttl=3600 * 24, block_num=0, cycle_num=0, epoch_num=0):
def send_to_central_server(self, url, datas):
start_time = time.time()
if not url:
return
bt.logging.info("sending datas to central server.")
headers = {
'Content-Type': 'application/json' # Specify that we're sending JSON
}
response = requests.post(url, data=json.dumps(datas), headers=headers)
# Check the response
if response.status_code == 200:
bt.logging.info(
f"Successfully sent data to central server. {time.time() - start_time} sec total elapsed for sending to central server.")
return True
else:
bt.logging.info(
f"Failed to send data. Status code: {response.status_code} {time.time() - start_time} sec total elapsed for sending to central server.")
bt.logging.info(f"Response:{response.text}")
return False

def set_cache_in_batch(self, central_server_url, syns: List[StreamPrompting],
ttl=3600 * 24, block_num=0, cycle_num=0, epoch_num=0):
datas = []
last_update_time = time.time()
for syn in syns:
Expand All @@ -62,21 +85,29 @@ def set_cache_in_batch(self, syns: List[StreamPrompting], ttl=3600 * 24, block_n
syn.block_num = block_num
syn.epoch_num = epoch_num
syn.cycle_num = cycle_num
datas.append((p_key, syn.json(
exclude={"dendrite", "completion", "total_size", "header_size", "axon", "uid", "provider", "model",
"required_hash_fields", "computed_body_hash", "streaming", "deserialize_flag", "task_id", }),
syn.completion, syn.provider, syn.model,
last_update_time))

# Insert multiple records
datas.append({"p_key": p_key,
"question": json.dumps(syn.json(
exclude={"dendrite", "completion", "total_size", "header_size", "axon", "uid", "provider",
"model",
"required_hash_fields", "computed_body_hash", "streaming", "deserialize_flag",
"task_id", })),
"answer": syn.completion,
"provider": syn.provider,
"model": syn.model,
"timestamp": last_update_time})

if self.send_to_central_server(central_server_url, datas):
return
# if not saved in central server successfully, then just save local cache.db file
cursor = self.conn.cursor()
cursor.executemany('''
INSERT OR IGNORE INTO cache (p_key, question, answer, provider, model, timestamp)
VALUES (?, ?, ?, ?, ?, ?)
''', datas)
''', [list(item.values()) for item in datas])

# Commit the transaction
self.conn.commit()
return datas

def get_answer(self, question, provider, model):
p_key = self.generate_hash(str(question) + str(provider) + str(model))
Expand Down
6 changes: 3 additions & 3 deletions validators/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def main():
bt.logging.info("Keyboard interrupt detected. Exiting validator.")
finally:
bt.logging.info("stopping axon server.")
bt.logging.info(
f"closing all sessins. total connections is {len(dendrite.CortexDendrite.miner_to_session.keys())}")
asyncio.run(close_all_connections())
# bt.logging.info(
# f"closing all sessins. total connections is {len(dendrite.CortexDendrite.miner_to_session.keys())}")
# asyncio.run(close_all_connections())
weight_setter.axon.stop()
bt.logging.info("updating status before exiting validator")
state = utils.get_state(state_path)
Expand Down
49 changes: 21 additions & 28 deletions validators/weight_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import concurrent
import random
import threading
import traceback

import torch
import time
Expand Down Expand Up @@ -87,16 +88,26 @@ def __init__(self, config, cache: QueryResponseCache, loop=None):
bt.logging.info(f"total loaded questions are {len(self.queries)}")
self.set_up_next_block_to_wait()
# Set up async tasks
self.thread_executor = concurrent.futures.ThreadPoolExecutor(thread_name_prefix='asyncio')
self.loop.create_task(self.consume_organic_queries())
self.loop.create_task(self.perform_synthetic_queries())
self.loop.create_task(self.process_queries_from_database())

self.saving_datas = []
self.url = None
self.url = "http://ec2-3-239-8-190.compute-1.amazonaws.com:8000/items"
daemon_thread = threading.Thread(target=self.saving_resp_answers_from_miners)
daemon_thread.start()

synthetic_thread = threading.Thread(target=self.process_synthetic_tasks)
synthetic_thread.start()

organic_thread = threading.Thread(target=self.start_axon_server)
organic_thread.start()

def start_axon_server(self):
asyncio.run(self.consume_organic_queries())

def process_synthetic_tasks(self):
bt.logging.info("starting synthetic tasks.")
asyncio.run(self.perform_synthetic_queries())

def saving_resp_answers_from_miners(self):
self.cache = QueryResponseCache()
self.cache.set_vali_info(vali_uid=self.my_uid, vali_hotkey=self.wallet.hotkey.ss58_address)
Expand All @@ -106,27 +117,12 @@ def saving_resp_answers_from_miners(self):
else:
bt.logging.info(f"saving responses...")
start_time = time.time()
self.cache.set_cache_in_batch([item.get('synapse') for item in self.saving_datas],
block_num=self.current_block,
cycle_num=self.current_block // 36, epoch_num=self.current_block // 360)
self.cache.set_cache_in_batch(self.url, [item.get('synapse') for item in self.saving_datas],
block_num=self.current_block or 0,
cycle_num=(self.current_block or 0) // 36,
epoch_num=(self.current_block or 0) // 360)
bt.logging.info(f"total saved responses is {len(self.saving_datas)}")
self.saving_datas.clear()
if not self.url:
return
bt.logging.info("sending datas to central server.")
json_data = [item.get('synapse').dict() for item in self.saving_datas]
headers = {
'Content-Type': 'application/json' # Specify that we're sending JSON
}
response = requests.post(self.url, data=json.dumps(json_data), headers=headers)
# Check the response
if response.status_code == 200:
bt.logging.info(
f"Successfully sent data to central server. {time.time() - start_time} sec total elapsed for sending to central server.")
else:
bt.logging.info(
f"Failed to send data. Status code: {response.status_code} {time.time() - start_time} sec total elapsed for sending to central server.")
bt.logging.info(f"Response:{response.text}")

async def run_sync_in_async(self, fn):
return await self.loop.run_in_executor(None, fn)
Expand Down Expand Up @@ -485,13 +481,9 @@ async def embeddings(self, synapse: Embeddings) -> Embeddings:
async def prompt(self, synapse: StreamPrompting) -> StreamingSynapse.BTStreamingResponse:
bt.logging.info(f"Received {synapse}")

# Return the streaming response
async def _prompt(query_synapse: StreamPrompting, send: Send):
bt.logging.info(f"Sending {synapse} request to uid: {synapse.uid}")

query_synapse.deserialize_flag = False
query_synapse.streaming = True
query_synapse.validator_uid = self.my_uid or 0
query_synapse.block_num = self.current_block or 0
uid = self.task_mgr.assign_task(query_synapse)
query_synapse.uid = uid
Expand Down Expand Up @@ -529,14 +521,15 @@ async def handle_response(resp):
await send({"type": "http.response.body", "body": b'', "more_body": False})

axon = self.metagraph.axons[uid]
bt.logging.trace(f"Sending {query_synapse} request to uid: {query_synapse.uid}")
responses = self.dendrite.call_stream(
target_axon=axon,
synapse=synapse,
timeout=synapse.timeout,
)
return await handle_response(responses)

token_streamer = partial(_prompt, synapse)

return synapse.create_streaming_response(token_streamer)

async def consume_organic_queries(self):
Expand Down

0 comments on commit 63aa4d7

Please sign in to comment.