Skip to content

Commit

Permalink
fix: scoring for text validator
Browse files Browse the repository at this point in the history
  • Loading branch information
Chkhikvadze committed Feb 12, 2024
1 parent c28b3fd commit b56cbf6
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 76 deletions.
54 changes: 1 addition & 53 deletions validators/text_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,6 @@ def __init__(self, dendrite, config, subtensor, wallet: bt.wallet):
"scores": {},
"timestamps": {},
}

# async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]) -> AsyncIterator[tuple[int, str]]:
# for uid, messages in query.items():
# syn = StreamPrompting(messages=messages, model=self.model, seed=self.seed, max_tokens=self.max_tokens, temperature=self.temperature, provider=self.provider, top_p=self.top_p, top_k=self.top_k)
# bt.logging.info(
# f"Sending {syn.model} {self.query_type} request to uid: {uid}, "
# f"timeout {self.timeout}: {syn.messages[0]['content']}"
# )

# self.wandb_data["prompts"][uid] = messages
# responses = await self.dendrite(
# metagraph.axons[uid],
# syn,
# deserialize=False,
# timeout=self.timeout,
# streaming=self.streaming,
# )

# async for resp in responses:
# if not isinstance(resp, str):
# continue

# bt.logging.trace(resp)
# yield uid, resp

async def organic(self, metagraph, available_uids, messages: dict[str, list[dict[str, str]]]) -> AsyncIterator[tuple[int, str]]:
uid_to_question = {}
Expand All @@ -68,7 +44,7 @@ async def organic(self, metagraph, available_uids, messages: dict[str, list[dict
prompt = message_list[-1]['content']
uid_to_question[uid] = prompt
message = message_list
syn = StreamPrompting(messages=message_list, model='gpt-3.5-turbo-16k', seed=self.seed, max_tokens=8096, temperature=self.temperature, provider=self.provider, top_p=self.top_p, top_k=self.top_k)
syn = StreamPrompting(messages=message_list, model=self.model, seed=self.seed, max_tokens=self.max_tokens, temperature=self.temperature, provider=self.provider, top_p=self.top_p, top_k=self.top_k)
bt.logging.info(f"Sending {syn.model} {self.query_type} request to uid: {uid}, timeout {self.timeout}: {message[0]['content']}")
self.wandb_data["prompts"][uid] = messages
responses = await self.dendrite(
Expand All @@ -85,34 +61,6 @@ async def organic(self, metagraph, available_uids, messages: dict[str, list[dict

bt.logging.trace(resp)
yield uid, key, resp

# async def organic_scoring(self, available_uids, metagraph, messages):
# query_tasks = []
# uid_to_question = {}
# if len(messages) <= len(available_uids):
# random_uids = random.sample(list(available_uids.keys()), len(messages))
# else:
# random_uids = [random.choice(list(available_uids.keys())) for _ in range(len(messages))]
# for message_dict, uid in zip(messages, random_uids): # Iterate over each dictionary in the list and random_uids
# (key, message_list), = message_dict.items()
# prompt = message_list[-1]['content']
# uid_to_question[uid] = prompt
# message = message_list
# syn = StreamPrompting(messages=message_list, model='gpt-3.5-turbo-16k', seed=self.seed, max_tokens=8096, temperature=self.temperature, provider=self.provider, top_p=self.top_p, top_k=self.top_k)
# bt.logging.info(f"Sending {syn.model} {self.query_type} request to uid: {uid}, timeout {self.timeout}: {message[0]['content']}")
# task = self.query_miner(metagraph, uid, syn)
# query_tasks.append(task)
# self.wandb_data["prompts"][uid] = prompt

# query_responses = await asyncio.gather(*query_tasks)
# scores, uid_scores_dict, wandb_data = await self.score_responses(query_responses, uid_to_question, metagraph)

# result = {}
# for (_, value), message_dict in zip(query_responses, messages):
# (key, message_list), = message_dict.items()
# result[key] = value

# return result, scores, uid_scores_dict, wandb_data

async def handle_response(self, uid: str, responses) -> tuple[str, str]:
full_response = ""
Expand Down
29 changes: 8 additions & 21 deletions validators/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,12 @@ def initialize_components(config: bt.config):


def initialize_validators(vali_config, test=False):
global text_vali, image_vali, embed_vali
global text_vali, text_vali_organic, image_vali, embed_vali

text_vali = (TextValidator if not test else TestTextValidator)(**vali_config)
text_vali_organic = (TextValidator if not test else TestTextValidator)(**vali_config)
text_vali_organic.model = 'gpt-3.5-turbo-16k'
text_vali_organic.max_tokens = 8096
image_vali = ImageValidator(**vali_config)
embed_vali = EmbeddingsValidator(**vali_config)
bt.logging.info("initialized_validators")
Expand Down Expand Up @@ -137,7 +140,7 @@ async def process_text_validator(request: web.Request):
key_to_response = {}
uid_to_response = {}
try:
async for uid, key, content in text_vali.organic(metagraph=validator_app.weight_setter.metagraph,
async for uid, key, content in text_vali_organic.organic(metagraph=validator_app.weight_setter.metagraph,
available_uids=validator_app.weight_setter.available_uids,
messages=messages):
uid_to_response[uid] = uid_to_response.get(uid, '') + content
Expand All @@ -151,10 +154,12 @@ async def process_text_validator(request: web.Request):


validator_app.weight_setter.register_text_validator_organic_query(
text_vali=text_vali_organic,
uid_to_response=uid_to_response,
messages_dict=prompts
)
await response.write(json.dumps(key_to_response).encode())
await response.write(json.dumps(key_to_response).encode())

except Exception as e:
bt.logging.error(f'Encountered in {process_text_validator.__name__}:\n{traceback.format_exc()}, ERROR: {e}')
await response.write(b'<<internal error>>')
Expand All @@ -166,26 +171,8 @@ def __init__(self, *a, **kw):
super().__init__(*a, **kw)
self.weight_setter: WeightSetter | None = None

# async def organic_scoring(request: web.Request):
# try:
# # Check access key
# access_key = request.headers.get("access-key")
# if access_key != EXPECTED_ACCESS_KEY:
# raise web.Response(status_code=401, detail="Invalid access key")
# body = await request.json()
# messages = body['messages']

# responses = await validator_app.weight_setter.perform_api_scoring_and_update_weights(messages)

# return web.json_response(responses)
# except Exception as e:
# bt.logging.error(f'Organic scoring error: ${e}')
# await web.Response(status_code=400, detail="{e}")

validator_app = ValidatorApplication()
validator_app.add_routes([web.post('/text-validator/', process_text_validator)])
# validator_app.add_routes([web.post('/scoring/', organic_scoring)])


def main(run_aio_app=True, test=False) -> None:
config = get_config()
Expand Down
5 changes: 3 additions & 2 deletions validators/weight_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from embeddings_validator import EmbeddingsValidator

iterations_per_set_weights = 5
scoring_organic_timeout = 60
scoring_organic_timeout = 120


async def wait_for_coro_with_limit(coro, timeout: int) -> Tuple[bool, object]:
Expand Down Expand Up @@ -202,12 +202,13 @@ async def set_weights(self, scores):

def register_text_validator_organic_query(
self,
text_vali,
uid_to_response: dict[int, str], # [(uid, response)]
messages_dict: dict[int, str],
):
self.organic_scoring_tasks.add(asyncio.create_task(
wait_for_coro_with_limit(
self.text_vali.score_responses(
text_vali.score_responses(
query_responses=list(uid_to_response.items()),
uid_to_question=messages_dict,
metagraph=self.metagraph,
Expand Down

0 comments on commit b56cbf6

Please sign in to comment.