From b56cbf6ac2ead36cf51b0b608ebdd03fafc386d2 Mon Sep 17 00:00:00 2001 From: Giga Chkhikvadze Date: Mon, 12 Feb 2024 15:02:12 +0000 Subject: [PATCH] fix: scoring for text validator --- validators/text_validator.py | 54 +----------------------------------- validators/validator.py | 29 ++++++------------- validators/weight_setter.py | 5 ++-- 3 files changed, 12 insertions(+), 76 deletions(-) diff --git a/validators/text_validator.py b/validators/text_validator.py index cec8d6c4..96ccc050 100644 --- a/validators/text_validator.py +++ b/validators/text_validator.py @@ -32,30 +32,6 @@ def __init__(self, dendrite, config, subtensor, wallet: bt.wallet): "scores": {}, "timestamps": {}, } - - # async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]) -> AsyncIterator[tuple[int, str]]: - # for uid, messages in query.items(): - # syn = StreamPrompting(messages=messages, model=self.model, seed=self.seed, max_tokens=self.max_tokens, temperature=self.temperature, provider=self.provider, top_p=self.top_p, top_k=self.top_k) - # bt.logging.info( - # f"Sending {syn.model} {self.query_type} request to uid: {uid}, " - # f"timeout {self.timeout}: {syn.messages[0]['content']}" - # ) - - # self.wandb_data["prompts"][uid] = messages - # responses = await self.dendrite( - # metagraph.axons[uid], - # syn, - # deserialize=False, - # timeout=self.timeout, - # streaming=self.streaming, - # ) - - # async for resp in responses: - # if not isinstance(resp, str): - # continue - - # bt.logging.trace(resp) - # yield uid, resp async def organic(self, metagraph, available_uids, messages: dict[str, list[dict[str, str]]]) -> AsyncIterator[tuple[int, str]]: uid_to_question = {} @@ -68,7 +44,7 @@ async def organic(self, metagraph, available_uids, messages: dict[str, list[dict prompt = message_list[-1]['content'] uid_to_question[uid] = prompt message = message_list - syn = StreamPrompting(messages=message_list, model='gpt-3.5-turbo-16k', seed=self.seed, max_tokens=8096, temperature=self.temperature, provider=self.provider, top_p=self.top_p, top_k=self.top_k) + syn = StreamPrompting(messages=message_list, model=self.model, seed=self.seed, max_tokens=self.max_tokens, temperature=self.temperature, provider=self.provider, top_p=self.top_p, top_k=self.top_k) bt.logging.info(f"Sending {syn.model} {self.query_type} request to uid: {uid}, timeout {self.timeout}: {message[0]['content']}") self.wandb_data["prompts"][uid] = messages responses = await self.dendrite( @@ -85,34 +61,6 @@ async def organic(self, metagraph, available_uids, messages: dict[str, list[dict bt.logging.trace(resp) yield uid, key, resp - - # async def organic_scoring(self, available_uids, metagraph, messages): - # query_tasks = [] - # uid_to_question = {} - # if len(messages) <= len(available_uids): - # random_uids = random.sample(list(available_uids.keys()), len(messages)) - # else: - # random_uids = [random.choice(list(available_uids.keys())) for _ in range(len(messages))] - # for message_dict, uid in zip(messages, random_uids): # Iterate over each dictionary in the list and random_uids - # (key, message_list), = message_dict.items() - # prompt = message_list[-1]['content'] - # uid_to_question[uid] = prompt - # message = message_list - # syn = StreamPrompting(messages=message_list, model='gpt-3.5-turbo-16k', seed=self.seed, max_tokens=8096, temperature=self.temperature, provider=self.provider, top_p=self.top_p, top_k=self.top_k) - # bt.logging.info(f"Sending {syn.model} {self.query_type} request to uid: {uid}, timeout {self.timeout}: {message[0]['content']}") - # task = self.query_miner(metagraph, uid, syn) - # query_tasks.append(task) - # self.wandb_data["prompts"][uid] = prompt - - # query_responses = await asyncio.gather(*query_tasks) - # scores, uid_scores_dict, wandb_data = await self.score_responses(query_responses, uid_to_question, metagraph) - - # result = {} - # for (_, value), message_dict in zip(query_responses, messages): - # (key, message_list), = message_dict.items() - # result[key] = value - - # return result, scores, uid_scores_dict, wandb_data async def handle_response(self, uid: str, responses) -> tuple[str, str]: full_response = "" diff --git a/validators/validator.py b/validators/validator.py index c8d7ecd4..f64c4b92 100644 --- a/validators/validator.py +++ b/validators/validator.py @@ -107,9 +107,12 @@ def initialize_components(config: bt.config): def initialize_validators(vali_config, test=False): - global text_vali, image_vali, embed_vali + global text_vali, text_vali_organic, image_vali, embed_vali text_vali = (TextValidator if not test else TestTextValidator)(**vali_config) + text_vali_organic = (TextValidator if not test else TestTextValidator)(**vali_config) + text_vali_organic.model = 'gpt-3.5-turbo-16k' + text_vali_organic.max_tokens = 8096 image_vali = ImageValidator(**vali_config) embed_vali = EmbeddingsValidator(**vali_config) bt.logging.info("initialized_validators") @@ -137,7 +140,7 @@ async def process_text_validator(request: web.Request): key_to_response = {} uid_to_response = {} try: - async for uid, key, content in text_vali.organic(metagraph=validator_app.weight_setter.metagraph, + async for uid, key, content in text_vali_organic.organic(metagraph=validator_app.weight_setter.metagraph, available_uids=validator_app.weight_setter.available_uids, messages=messages): uid_to_response[uid] = uid_to_response.get(uid, '') + content @@ -151,10 +154,12 @@ async def process_text_validator(request: web.Request): validator_app.weight_setter.register_text_validator_organic_query( + text_vali=text_vali_organic, uid_to_response=uid_to_response, messages_dict=prompts ) - await response.write(json.dumps(key_to_response).encode()) + await response.write(json.dumps(key_to_response).encode()) + except Exception as e: bt.logging.error(f'Encountered in {process_text_validator.__name__}:\n{traceback.format_exc()}, ERROR: {e}') await response.write(b'<>') @@ -166,26 +171,8 @@ def __init__(self, *a, **kw): super().__init__(*a, **kw) self.weight_setter: WeightSetter | None = None -# async def organic_scoring(request: web.Request): -# try: -# # Check access key -# access_key = request.headers.get("access-key") -# if access_key != EXPECTED_ACCESS_KEY: -# raise web.Response(status_code=401, detail="Invalid access key") -# body = await request.json() -# messages = body['messages'] - -# responses = await validator_app.weight_setter.perform_api_scoring_and_update_weights(messages) - -# return web.json_response(responses) -# except Exception as e: -# bt.logging.error(f'Organic scoring error: ${e}') -# await web.Response(status_code=400, detail="{e}") - validator_app = ValidatorApplication() validator_app.add_routes([web.post('/text-validator/', process_text_validator)]) -# validator_app.add_routes([web.post('/scoring/', organic_scoring)]) - def main(run_aio_app=True, test=False) -> None: config = get_config() diff --git a/validators/weight_setter.py b/validators/weight_setter.py index b23aa1a8..220f7eab 100644 --- a/validators/weight_setter.py +++ b/validators/weight_setter.py @@ -18,7 +18,7 @@ from embeddings_validator import EmbeddingsValidator iterations_per_set_weights = 5 -scoring_organic_timeout = 60 +scoring_organic_timeout = 120 async def wait_for_coro_with_limit(coro, timeout: int) -> Tuple[bool, object]: @@ -202,12 +202,13 @@ async def set_weights(self, scores): def register_text_validator_organic_query( self, + text_vali, uid_to_response: dict[int, str], # [(uid, response)] messages_dict: dict[int, str], ): self.organic_scoring_tasks.add(asyncio.create_task( wait_for_coro_with_limit( - self.text_vali.score_responses( + text_vali.score_responses( query_responses=list(uid_to_response.items()), uid_to_question=messages_dict, metagraph=self.metagraph,