Skip to content

Commit

Permalink
fix: set weight periodically
Browse files Browse the repository at this point in the history
  • Loading branch information
Chkhikvadze committed Feb 13, 2024
1 parent b80ef37 commit 51f0981
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
3 changes: 2 additions & 1 deletion validators/text_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,14 @@ async def score_responses(
query_responses: list[tuple[int, str]], # [(uid, response)]
uid_to_question: dict[int, str], # uid -> prompt
metagraph: bt.metagraph,
is_score_all : False
) -> tuple[torch.Tensor, dict[int, float], dict]:
scores = torch.zeros(len(metagraph.hotkeys))
uid_scores_dict = {}
response_tasks = []

# Decide to score all UIDs this round based on a chance
will_score_all = self.should_i_score()
will_score_all = True if is_score_all else self.should_i_score()

for uid, response in query_responses:
self.wandb_data["responses"][uid] = response
Expand Down
49 changes: 42 additions & 7 deletions validators/weight_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,21 @@ def __init__(self, loop: asyncio.AbstractEventLoop, dendrite, subtensor, config,
self.loop.create_task(self.update_available_uids_periodically())
self.available_uids = {}
self.loop.create_task(self.consume_organic_scoring())
# self.loop.create_task(self.perform_synthetic_scoring_and_update_weights())


self.loop.create_task(self.perform_synthetic_scoring_and_update_weights())
self.loop.create_task(self.update_weights_periodically())

async def update_weights_periodically(self):
while True:
if len(self.available_uids) == 0 or \
torch.all(self.total_scores == 0):
await asyncio.sleep(10)
continue

await self.update_weights(self.steps_passed)
await asyncio.sleep(600) # 600 seconds = 10 minutes


async def update_available_uids_periodically(self):
while True:
self.metagraph = await self.run_sync_in_async(lambda: self.subtensor.metagraph(self.config.netuid))
Expand Down Expand Up @@ -99,16 +111,18 @@ async def consume_organic_scoring(self):

async def perform_synthetic_scoring_and_update_weights(self):
while True:
for steps_passed in itertools.count():
# self.metagraph = await self.run_sync_in_async(lambda: self.subtensor.metagraph(self.config.netuid))
if len(self.available_uids) == 0:
await asyncio.sleep(10)
continue

for steps_passed in itertools.count():
available_uids = self.available_uids
selected_validator = self.select_validator(steps_passed)
scores, _ = await self.process_modality(selected_validator, available_uids)
self.total_scores += scores

steps_since_last_update = steps_passed % iterations_per_set_weights

if steps_since_last_update == iterations_per_set_weights - 1:
await self.update_weights(steps_passed)
else:
Expand Down Expand Up @@ -200,22 +214,43 @@ async def set_weights(self, scores):
)
bt.logging.success("Successfully set weights.")

def handle_task_result_organic_query(self, task):
try:
success, data = task.result()
if success:
scores, uid_scores_dict, wandb_data = data
if self.config.wandb_on:
wandb.log(wandb_data)
bt.logging.success("wandb_log successful")
self.total_scores += scores
bt.logging.success(f"Task completed successfully. Scores updated.")
else:
bt.logging.error("Task failed. No scores updated.")
except Exception as e:
# Handle exceptions raised during task execution
bt.logging.error(f"handle_task_result_organic_query An error occurred during task execution: {e}")

def register_text_validator_organic_query(
self,
text_vali,
uid_to_response: dict[int, str], # [(uid, response)]
messages_dict: dict[int, str],
):
self.organic_scoring_tasks.add(asyncio.create_task(
self.steps_passed += 1

task = asyncio.create_task(
wait_for_coro_with_limit(
text_vali.score_responses(
query_responses=list(uid_to_response.items()),
uid_to_question=messages_dict,
metagraph=self.metagraph,
is_score_all=True
),
scoring_organic_timeout
)
))
)
task.add_done_callback(self.handle_task_result_organic_query) # Attach the callback
self.organic_scoring_tasks.add(task)


class TestWeightSetter(WeightSetter):
Expand Down

0 comments on commit 51f0981

Please sign in to comment.