Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
surcyf123 committed Jan 11, 2024
1 parent 8193eeb commit 80e0148
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 44 deletions.
2 changes: 1 addition & 1 deletion state.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


# version must stay on line 22
__version__ = "3.1.1"
__version__ = "3.1.2"
version_split = __version__.split(".")
__spec_version__ = (
(1000 * int(version_split[0]))
Expand Down
80 changes: 40 additions & 40 deletions validators/image_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def start_query(self, available_uids, metagraph):
uid_to_question = {}

# Randomly choose the provider based on specified probabilities
providers = ["OpenAI"] * 7 + ["Stability"] * 3
providers = ["OpenAI"] * 6 + ["Stability"] * 4
self.provider = random.choice(providers)

if self.provider == "Stability":
Expand Down Expand Up @@ -84,7 +84,6 @@ async def b64_to_image(self, b64):
return await asyncio.to_thread(Image.open, BytesIO(image_data))

async def download_image(self, url, session):
bt.logging.debug(f"Starting download for URL: {url}")
try:
async with session.get(url) as response:
content = await response.read()
Expand All @@ -93,20 +92,20 @@ async def download_image(self, url, session):
bt.logging.error(f"Exception occurred while downloading image: {traceback.format_exc()}")
raise

async def process_download_result(self, uid, download_task):
try:
image = await download_task
self.wandb_data["images"][uid] = wandb.Image(image)
except Exception as e:
bt.logging.error(f"Error downloading image for UID {uid}: {traceback.format_exc()}")

async def process_score_result(self, uid, score_task, scores, uid_scores_dict):
try:
scored_response = await score_task
score = scored_response if scored_response is not None else 0
scores[uid] = uid_scores_dict[uid] = score
except Exception as e:
bt.logging.error(f"Error scoring image for UID {uid}: {traceback.format_exc()}")
# async def process_download_result(self, uid, download_task):
# try:
# image = await download_task
# self.wandb_data["images"][uid] = wandb.Image(image)
# except Exception as e:
# bt.logging.error(f"Error downloading image for UID {uid}: {traceback.format_exc()}")

# async def process_score_result(self, uid, score_task, scores, uid_scores_dict):
# try:
# scored_response = await score_task
# score = scored_response if scored_response is not None else 0
# scores[uid] = uid_scores_dict[uid] = score
# except Exception as e:
# bt.logging.error(f"Error scoring image for UID {uid}: {traceback.format_exc()}")

async def score_responses(self, query_responses, uid_to_question, metagraph):
scores = torch.zeros(len(metagraph.hotkeys))
Expand All @@ -115,6 +114,7 @@ async def score_responses(self, query_responses, uid_to_question, metagraph):
score_tasks = []
rand = random.random()
will_score_all = rand < 1/1

async with aiohttp.ClientSession() as session:
for uid, syn in query_responses:
syn = syn[0]
Expand All @@ -127,40 +127,40 @@ async def score_responses(self, query_responses, uid_to_question, metagraph):
if syn.provider == "OpenAI":
image_url = completion["url"]
bt.logging.info(f"UID {uid} response = {image_url}")
download_tasks.append((uid, asyncio.create_task(self.download_image(image_url, session))))
download_tasks.append(asyncio.create_task(self.download_image(image_url, session)))
else: # Stability
b64s = completion["b64s"]
bt.logging.info(f"UID {uid} responded with an image")
for b64 in b64s:
download_tasks.append((uid, asyncio.create_task(self.b64_to_image(b64))))
download_tasks.append(asyncio.create_task(self.b64_to_image(b64)))

if will_score_all:
if syn.provider == "OpenAI":
score_task = template.reward.dalle_score(uid, image_url, self.size, syn.messages, self.weight)
score_tasks.append((uid, asyncio.create_task(score_task)))
else:
continue
score_task = template.reward.deterministic_score(uid, syn, self.weight)

# score_tasks.append((uid, asyncio.create_task(score_task)))

await asyncio.gather(*(dt[1] for dt in download_tasks), *(st[1] for st in score_tasks))

bt.logging.info("Processing download results.")
download_results = [self.process_download_result(uid, dt) for uid, dt in download_tasks]
await asyncio.gather(*download_results)
bt.logging.info("Completed processing download results.")

bt.logging.info(f"random number = {rand}, will score all = {will_score_all}")

bt.logging.info("Processing score results.")
score_results = [self.process_score_result(uid, st, scores, uid_scores_dict) for uid, st in score_tasks]
await asyncio.gather(*score_results)
bt.logging.info("Completed processing score results.")

if uid_scores_dict != {}:
bt.logging.info(f"Final scores: {uid_scores_dict}")

score_tasks.append(asyncio.create_task(score_task))

# Wait for all tasks to complete
download_results = await asyncio.gather(*download_tasks)
score_results = await asyncio.gather(*score_tasks, return_exceptions=True)

# Process download results
for image, uid in zip(download_results, [uid for uid, _ in query_responses]):
try:
self.wandb_data["images"][uid] = wandb.Image(image)
except Exception as e:
bt.logging.error(f"Error processing image for UID {uid}: {traceback.format_exc()}")

# Process score results
for score, uid in zip(score_results, [uid for uid, _ in query_responses]):
try:
final_score = score if score is not None else 0
scores[uid] = uid_scores_dict[uid] = final_score
except Exception as e:
bt.logging.error(f"Error processing score for UID {uid}: {traceback.format_exc()}")

bt.logging.info(f"Final scores: {uid_scores_dict}")
bt.logging.info("score_responses process completed.")
return scores, uid_scores_dict, self.wandb_data

2 changes: 1 addition & 1 deletion validators/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# organic requests are scored, the tasks are stored in this queue
# for later being consumed by `query_synapse` cycle:
organic_scoring_tasks = set()
EXPECTED_ACCESS_KEY = env('EXPECTED_ACCESS_KEY')
EXPECTED_ACCESS_KEY = env.get('EXPECTED_ACCESS_KEY', "hello")


def get_config() -> bt.config:
Expand Down
2 changes: 1 addition & 1 deletion validators/weight_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from image_validator import ImageValidator
from embeddings_validator import EmbeddingsValidator

iterations_per_set_weights = 12
iterations_per_set_weights = 10
scoring_organic_timeout = 60


Expand Down

0 comments on commit 80e0148

Please sign in to comment.