From 51f09818e68464d6395cf97e4d4a3e379ea16448 Mon Sep 17 00:00:00 2001 From: Giga Chkhikvadze Date: Tue, 13 Feb 2024 10:14:19 +0000 Subject: [PATCH] fix: set weight periodically --- validators/text_validator.py | 3 ++- validators/weight_setter.py | 49 ++++++++++++++++++++++++++++++------ 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/validators/text_validator.py b/validators/text_validator.py index 96ccc050..08755930 100644 --- a/validators/text_validator.py +++ b/validators/text_validator.py @@ -125,13 +125,14 @@ async def score_responses( query_responses: list[tuple[int, str]], # [(uid, response)] uid_to_question: dict[int, str], # uid -> prompt metagraph: bt.metagraph, + is_score_all : False ) -> tuple[torch.Tensor, dict[int, float], dict]: scores = torch.zeros(len(metagraph.hotkeys)) uid_scores_dict = {} response_tasks = [] # Decide to score all UIDs this round based on a chance - will_score_all = self.should_i_score() + will_score_all = True if is_score_all else self.should_i_score() for uid, response in query_responses: self.wandb_data["responses"][uid] = response diff --git a/validators/weight_setter.py b/validators/weight_setter.py index 220f7eab..5cb4332b 100644 --- a/validators/weight_setter.py +++ b/validators/weight_setter.py @@ -52,9 +52,21 @@ def __init__(self, loop: asyncio.AbstractEventLoop, dendrite, subtensor, config, self.loop.create_task(self.update_available_uids_periodically()) self.available_uids = {} self.loop.create_task(self.consume_organic_scoring()) - # self.loop.create_task(self.perform_synthetic_scoring_and_update_weights()) - + self.loop.create_task(self.perform_synthetic_scoring_and_update_weights()) + self.loop.create_task(self.update_weights_periodically()) + + async def update_weights_periodically(self): + while True: + if len(self.available_uids) == 0 or \ + torch.all(self.total_scores == 0): + await asyncio.sleep(10) + continue + + await self.update_weights(self.steps_passed) + await asyncio.sleep(600) # 600 seconds = 10 minutes + + async def update_available_uids_periodically(self): while True: self.metagraph = await self.run_sync_in_async(lambda: self.subtensor.metagraph(self.config.netuid)) @@ -99,16 +111,18 @@ async def consume_organic_scoring(self): async def perform_synthetic_scoring_and_update_weights(self): while True: - for steps_passed in itertools.count(): - # self.metagraph = await self.run_sync_in_async(lambda: self.subtensor.metagraph(self.config.netuid)) + if len(self.available_uids) == 0: + await asyncio.sleep(10) + continue + for steps_passed in itertools.count(): available_uids = self.available_uids selected_validator = self.select_validator(steps_passed) scores, _ = await self.process_modality(selected_validator, available_uids) self.total_scores += scores steps_since_last_update = steps_passed % iterations_per_set_weights - + if steps_since_last_update == iterations_per_set_weights - 1: await self.update_weights(steps_passed) else: @@ -200,22 +214,43 @@ async def set_weights(self, scores): ) bt.logging.success("Successfully set weights.") + def handle_task_result_organic_query(self, task): + try: + success, data = task.result() + if success: + scores, uid_scores_dict, wandb_data = data + if self.config.wandb_on: + wandb.log(wandb_data) + bt.logging.success("wandb_log successful") + self.total_scores += scores + bt.logging.success(f"Task completed successfully. Scores updated.") + else: + bt.logging.error("Task failed. No scores updated.") + except Exception as e: + # Handle exceptions raised during task execution + bt.logging.error(f"handle_task_result_organic_query An error occurred during task execution: {e}") + def register_text_validator_organic_query( self, text_vali, uid_to_response: dict[int, str], # [(uid, response)] messages_dict: dict[int, str], ): - self.organic_scoring_tasks.add(asyncio.create_task( + self.steps_passed += 1 + + task = asyncio.create_task( wait_for_coro_with_limit( text_vali.score_responses( query_responses=list(uid_to_response.items()), uid_to_question=messages_dict, metagraph=self.metagraph, + is_score_all=True ), scoring_organic_timeout ) - )) + ) + task.add_done_callback(self.handle_task_result_organic_query) # Attach the callback + self.organic_scoring_tasks.add(task) class TestWeightSetter(WeightSetter):