diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a7edef80..2cd196da 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,12 +1,10 @@ name: Continuous Integration on: -# push: -# branches: [main, develop] -# pull_request: -# branches: [main, develop] + push: + branches: [main, develop] pull_request: - branches: [ non-existent ] + branches: [main, develop] jobs: test: @@ -31,7 +29,7 @@ jobs: - name: Install dependencies run: nox -s install_test_requirements - name: Run tests - run: pytest tests -rP -vv + run: PYTHONPATH=$PWD pytest tests/weights -rP -vv # integration tests hang in CI for some reason env: RICH_TRACEBACK: 0 CORTEXT_MINER_ADDITIONAL_WHITELIST_VALIDATOR_KEYS: ${{ secrets.VALIDATOR_KEY }} diff --git a/miner/miner.py b/miner/miner.py index 8654eda3..4ff4b5da 100644 --- a/miner/miner.py +++ b/miner/miner.py @@ -34,8 +34,8 @@ netrc_path = pathlib.Path.home() / ".netrc" wandb_api_key = os.getenv("WANDB_API_KEY") -print("WANDB_API_KEY is set:", bool(wandb_api_key)) -print("~/.netrc exists:", netrc_path.exists()) +bt.logging.info("WANDB_API_KEY is set") +bt.logging.info("~/.netrc exists:", netrc_path.exists()) if not wandb_api_key and not netrc_path.exists(): raise ValueError("Please log in to wandb using `wandb login` or set the WANDB_API_KEY environment variable.") diff --git a/noxfile.py b/noxfile.py index a8fec46b..68e8d6dd 100644 --- a/noxfile.py +++ b/noxfile.py @@ -4,6 +4,7 @@ REQUIREMENTS_TEST = [ "pytest==7.*", + "pytest-aiohttp==1.*", ] THIS_DIR = str(pathlib.Path(__file__).parent) diff --git a/requirements.txt b/requirements.txt index b3ca12fb..954e302b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ aiohttp==3.* bittensor==6.* datasets==2.* +envparse==0.2.0 openai>=1.3.2, ==1.* Pillow==10.* requests==2.* diff --git a/template/protocol.py b/template/protocol.py index f077b6c6..26922ef2 100644 --- a/template/protocol.py +++ b/template/protocol.py @@ -1,9 +1,14 @@ -from typing import AsyncIterator, Dict, List, Optional +from enum import Enum +from typing import AsyncIterator, Dict, List, Literal, Optional import bittensor as bt import pydantic from starlette.responses import StreamingResponse +# from ..providers.image import DallE, Stability + +# from ..providers.text import Anthropic, GeminiPro, OpenAI + class IsAlive( bt.Synapse ): answer: Optional[str] = None @@ -29,6 +34,18 @@ class ImageResponse(bt.Synapse): description="Messages related to the image response." ) + class Provider(str, Enum): + """ A class to represent the provider options for the StreamPrompting object. """ + dalle = 'DallE' + stability = 'Stability' + + provider: Provider = pydantic.Field( + Provider.dalle, + title="provider", + description="The provider to use when calling for your response.", + ) + + model: str = pydantic.Field( ..., title="Model", @@ -84,7 +101,7 @@ class Embeddings( bt.Synapse): description="The resulting list of embeddings, each corresponding to an input text." ) -class StreamPrompting( bt.StreamingSynapse ): +class StreamPrompting(bt.StreamingSynapse): messages: List[Dict[str, str]] = pydantic.Field( ..., @@ -107,6 +124,13 @@ class StreamPrompting( bt.StreamingSynapse ): description="Seed for text generation. This attribute is immutable and cannot be updated.", ) + temperature: float = pydantic.Field( + 0.0, + title="Temperature", + description="Temperature for text generation. " + "This attribute is immutable and cannot be updated.", + ) + completion: str = pydantic.Field( "", title="Completion", @@ -114,10 +138,22 @@ class StreamPrompting( bt.StreamingSynapse ): "This attribute is mutable and can be updated.", ) + class Provider(str, Enum): + """ A class to represent the provider options for the StreamPrompting object. """ + openai = 'OpenAI' + anthropic = 'Anthropic' + gemini_pro = 'GeminiPro' + + provider: Provider = pydantic.Field( + Provider.openai, + title="provider", + description="The provider to use when calling for your response.", + ) + model: str = pydantic.Field( "", title="model", - description="The model that which to use when calling openai for your response.", + description="The model to use when calling provider for your response.", ) async def process_streaming_response(self, response: StreamingResponse) -> AsyncIterator[str]: diff --git a/template/utils.py b/template/utils.py index bf45d2f1..301f8919 100644 --- a/template/utils.py +++ b/template/utils.py @@ -316,7 +316,7 @@ def extract_python_list(text: str): return None -async def call_openai(messages, temperature, model, seed=1234): +async def call_openai(messages, temperature, model, seed=1234) -> str: for _ in range(2): bt.logging.debug(f"Calling Openai. Temperature = {temperature}, Model = {model}, Seed = {seed}, Messages = {messages}") try: diff --git a/tests/weights/test_weights.py b/tests/weights/test_weights.py new file mode 100644 index 00000000..43acfa17 --- /dev/null +++ b/tests/weights/test_weights.py @@ -0,0 +1,141 @@ +import asyncio +import os +import sys +from unittest import mock + +import bittensor +import pytest +import torch + +from validators.validator import main, validator_app + +hotkeys = os.environ.get('CORTEXT_MINER_ADDITIONAL_WHITELIST_VALIDATOR_KEYS', '').split(',') + +hotkeys += ['mock'] * (7 - len(hotkeys)) + +synthetic_question = "tell me why aint nothing but a heartbreak" + +synthetic_resp1 = """ +The phrase "ain't nothing but a heartbreak" is a line from the song "I Want It That Way" by the Backstreet Boys, which was released in 1999. The song is about the complexities of a relationship and the pain of being apart from the person you love. The line suggests that the situation they are singing about causes nothing but emotional pain and heartache. + +In a broader sense, the phrase can be used to describe any situation that causes significant emotional distress, particularly in the context of romantic relationships. It's a way of expressing that the primary outcome of a situation is heartbreak. +""" + +synthetic_resp2 = synthetic_resp1 + ' And that\'s why.' + +synthetic_resp3 = """ +The phrase "ain't nothing but a heartbreak" is a lyric from the song "I Want It That Way" by the Backstreet Boys, a popular boy band from the late 1990s and early 2000s. The song was released in 1999 as part of their album "Millennium" and quickly became one of their signature hits. + +The line is part of the chorus: + +"Tell me why +Ain't nothin' but a heartache +Tell me why +Ain't nothin' but a mistake +Tell me why +I never wanna hear you say +I want it that way" + +In the context of the song, the phrase expresses the pain and frustration of a romantic relationship that is causing heartache. The song's lyrics deal with themes of love, regret, and misunderstanding between partners. The phrase "ain't nothing but a heartbreak" suggests that the relationship is causing nothing but emotional pain, emphasizing the depth of the narrator's distress. +""" + +organic_question = "What is black thunder?" + +organic_question_1 = organic_question + ' 1' +organic_question_2 = organic_question + ' 2' + +organic_answer_1 = """ +Black Thunder could refer to different things depending on the context. Here are a few possibilities: + +Amusement Park: Black Thunder could refer to an amusement park. There's a famous water theme park in Tamil Nadu, India, called Black Thunder, known for its water rides and entertainment attractions. +Military Operations: Sometimes, military operations or exercises are given code names. "Black Thunder" might be the name of a specific military operation conducted by a particular country's armed forces. +Film or Media: There might be movies, books, or other media with the title "Black Thunder." It could be a novel, film, or series with a plot related to action, adventure, or a specific theme. +Nickname or Alias: It might also be a nickname or alias used by an individual or a group for various purposes. It could be in reference to someone's personality, actions, or a particular event. +Without additional context, it's challenging to pinpoint the exact reference to "Black Thunder." If you have more details or a specific context in mind, I could provide more accurate information. +""" + +organic_answer_2 = organic_answer_1 + " that would be it." + +organic_answer_3 = """ +"Yellow lightning" typically refers to a type of lightning that appears to have a yellowish or amber hue during a thunderstorm. Lightning usually appears as a bright flash or streak in the sky during a thunderstorm due to the discharge of electricity between clouds or between a cloud and the ground. + +The color of lightning can vary depending on various factors, such as atmospheric conditions, the presence of particles or gases in the air, or the distance between the observer and the lightning strike. Lightning often appears as white or bluish-white, but it can also exhibit different colors like yellow, orange, or even red. + +The yellowish or amber hue in lightning might be caused by the scattering of light through a greater distance due to atmospheric conditions or the presence of particles. However, the exact reason for the yellow coloration in lightning can vary and is still an area of study among meteorologists and atmospheric scientists. +""" + + +def feed_mock_data(text_validator): + text_validator.feed_mock_data( + { + synthetic_question + ' 1': [synthetic_resp1, synthetic_resp2], + synthetic_question + ' 2': [synthetic_resp1, synthetic_resp3], + synthetic_question + ' 3': [synthetic_resp2, synthetic_resp1], + synthetic_question + ' 4': [synthetic_resp2, synthetic_resp3], + synthetic_question + ' 5': [synthetic_resp3, synthetic_resp1], + synthetic_question + ' 6': [synthetic_resp3, synthetic_resp2], + organic_question_1: [organic_answer_1, organic_answer_2], + organic_question_2: [organic_answer_2, organic_answer_3], + }, + [synthetic_question + f' {i}' for i in range(1, 7)] + ) + + +async def assert_weights_update(set_weights_mock: mock.Mock, expected_weights: torch.tensor): + previous_calls = len(set_weights_mock.call_args_list) + for _ in range(400): + await asyncio.sleep(0.25) + if len(set_weights_mock.call_args_list) > previous_calls: + assert len(set_weights_mock.call_args_list) == previous_calls + 1 + assert all(set_weights_mock.call_args_list[-1].kwargs['weights'] == expected_weights) + break + else: + raise ValueError('set_weights_mock not called') + + +expected_scores_after_one_iteration = torch.tensor([1.0, 0.3333333432674408, 0.3333333432674408, 0.3333333432674408, + 0.3333333432674408, 0.3333333432674408, 1.0]) + + +@pytest.mark.asyncio +async def test_synthetic_and_organic(aiohttp_client): + with (mock.patch.object(bittensor.subtensor, 'set_weights') as set_weights_mock, + mock.patch.object(bittensor.metagraph, 'hotkeys', new=hotkeys)): + sys.argv = ['validator.py', '--netuid', '49', '--subtensor.network', 'test', '--wallet.name', 'validator', + '--wallet.hotkey', 'default'] + main(run_aio_app=False, test=True) + feed_mock_data(validator_app.weight_setter.text_vali) + + await assert_weights_update(set_weights_mock, expected_scores_after_one_iteration) + + validator_app.weight_setter.total_scores = torch.zeros(7) + validator_app.weight_setter.moving_average_scores = None + feed_mock_data(validator_app.weight_setter.text_vali) + + await assert_weights_update(set_weights_mock, expected_scores_after_one_iteration / 2) + + validator_app.weight_setter.total_scores = torch.zeros(7) + validator_app.weight_setter.moving_average_scores = None + feed_mock_data(validator_app.weight_setter.text_vali) + + client = await aiohttp_client(validator_app) + + resp = await client.post('/text-validator/', headers={'access-key': 'hello'}, json={'4': organic_question_1}) + resp_content = (await resp.content.read()).decode() + assert resp_content == organic_answer_1 + + resp = await client.post('/text-validator/', headers={'access-key': 'hello'}, json={'5': organic_question_2}) + resp_content = (await resp.content.read()).decode() + assert resp_content == organic_answer_2 + + await assert_weights_update( + set_weights_mock, + torch.tensor([0.3333333432674408, 0.111111119389534, 0.111111119389534, 0.111111119389534, + 0.1388888955116272, # this one was asked a question and answered correctly + 0.111111119389534, # this one was asked a question and answered incorrectly + 0.3333333432674408, + ]) + ) + + + diff --git a/validators/base_validator.py b/validators/base_validator.py index d2ddecfd..7abae535 100644 --- a/validators/base_validator.py +++ b/validators/base_validator.py @@ -12,9 +12,10 @@ def __init__(self, dendrite, config, subtensor, wallet, timeout): self.timeout = timeout self.streaming = False - async def query_miner(self, axon, uid, syn): + async def query_miner(self, metagraph, uid, syn): try: - responses = await self.dendrite([axon], syn, deserialize=False, timeout=self.timeout, streaming=self.streaming) + responses = await self.dendrite([metagraph.axons[uid]], syn, deserialize=False, timeout=self.timeout, + streaming=self.streaming) return await self.handle_response(uid, responses) except Exception as e: diff --git a/validators/embeddings_validator.py b/validators/embeddings_validator.py index b1d7e7af..35a4794e 100644 --- a/validators/embeddings_validator.py +++ b/validators/embeddings_validator.py @@ -9,7 +9,7 @@ from template import client from datasets import load_dataset from template.protocol import Embeddings -from base_validator import BaseValidator +from validators.base_validator import BaseValidator class EmbeddingsValidator(BaseValidator): def __init__(self, dendrite, config, subtensor, wallet): @@ -89,7 +89,7 @@ async def start_query(self, available_uids, metagraph) -> tuple[list, dict]: f"Sending {self.query_type} request to uid: {uid} " f"using {syn.model} with timeout {self.timeout}: {syn.texts[0]}" ) - task = self.query_miner(metagraph.axons[uid], uid, syn) + task = self.query_miner(metagraph, uid, syn) query_tasks.append(task) self.wandb_data["texts"][uid] = prompt diff --git a/validators/image_validator.py b/validators/image_validator.py index 45634f23..61ca195b 100644 --- a/validators/image_validator.py +++ b/validators/image_validator.py @@ -9,7 +9,7 @@ from PIL import Image from io import BytesIO from template.utils import get_question -from base_validator import BaseValidator +from validators.base_validator import BaseValidator from template.protocol import ImageResponse @@ -45,7 +45,7 @@ async def start_query(self, available_uids, metagraph): f"Sending a {self.size} {self.quality} {self.style} {self.query_type} request " f"to uid: {uid} using {syn.model} with timeout {self.timeout}: {syn.messages}" ) - task = self.query_miner(metagraph.axons[uid], uid, syn) + task = self.query_miner(metagraph, uid, syn) query_tasks.append(task) self.wandb_data["prompts"][uid] = messages diff --git a/validators/text_validator.py b/validators/text_validator.py index 388638bb..1a40a1c4 100644 --- a/validators/text_validator.py +++ b/validators/text_validator.py @@ -4,7 +4,7 @@ import bittensor as bt import torch -from base_validator import BaseValidator +from validators.base_validator import BaseValidator import template.reward from template.protocol import StreamPrompting @@ -12,7 +12,7 @@ class TextValidator(BaseValidator): - def __init__(self, dendrite, config, subtensor, wallet): + def __init__(self, dendrite, config, subtensor, wallet: bt.wallet): super().__init__(dendrite, config, subtensor, wallet, timeout=75) self.streaming = True self.query_type = "text" @@ -28,7 +28,7 @@ def __init__(self, dendrite, config, subtensor, wallet): "timestamps": {}, } - async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]): + async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]) -> AsyncIterator[tuple[int, str]]: for uid, messages in query.items(): syn = StreamPrompting(messages=messages, model=self.model, seed=self.seed) bt.logging.info( @@ -44,12 +44,10 @@ async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]): streaming=self.streaming, ) - async for response in self.return_tokens(uid, responses): - yield response + async for resp in responses: + if not isinstance(resp, str): + continue - async def return_tokens(self, uid: str, responses: AsyncIterator) -> AsyncIterator[Tuple[str, str]]: - async for resp in responses: - if isinstance(resp, str): bt.logging.trace(resp) yield uid, resp @@ -64,11 +62,14 @@ async def handle_response(self, uid: str, responses) -> tuple[str, str]: break return uid, full_response + async def get_question(self, qty): + return await get_question("text", qty) + async def start_query(self, available_uids, metagraph) -> tuple[list, dict]: query_tasks = [] uid_to_question = {} for uid in available_uids: - prompt = await get_question("text", len(available_uids)) + prompt = await self.get_question(len(available_uids)) uid_to_question[uid] = prompt messages = [{'role': 'user', 'content': prompt}] syn = StreamPrompting(messages=messages, model=self.model, seed=self.seed) @@ -76,30 +77,40 @@ async def start_query(self, available_uids, metagraph) -> tuple[list, dict]: f"Sending {syn.model} {self.query_type} request to uid: {uid}, " f"timeout {self.timeout}: {syn.messages[0]['content']}" ) - task = self.query_miner(metagraph.axons[uid], uid, syn) + task = self.query_miner(metagraph, uid, syn) query_tasks.append(task) self.wandb_data["prompts"][uid] = prompt query_responses = await asyncio.gather(*query_tasks) return query_responses, uid_to_question - async def score_responses(self, query_responses, uid_to_question, metagraph): + def should_i_score(self): + random_number = random.random() + will_score_all = random_number < 1 / 12 + bt.logging.info(f"Random Number: {random_number}, Will score text responses: {will_score_all}") + return will_score_all + + async def call_openai(self, prompt: str) -> str: + return await call_openai([{'role': 'user', 'content': prompt}], 0, self.model, self.seed) + + async def score_responses( + self, + query_responses: list[tuple[int, str]], # [(uid, response)] + uid_to_question: dict[int, str], # uid -> prompt + metagraph: bt.metagraph, + ) -> tuple[torch.Tensor, dict[int, float], dict]: scores = torch.zeros(len(metagraph.hotkeys)) uid_scores_dict = {} openai_response_tasks = [] # Decide to score all UIDs this round based on a chance - random_number = random.random() - will_score_all = random_number < 1/12 - bt.logging.info(f"Random Number: {random_number}, Will score text responses: {will_score_all}") + will_score_all = self.should_i_score() for uid, response in query_responses: self.wandb_data["responses"][uid] = response if will_score_all and response: prompt = uid_to_question[uid] - messages = [{'role': 'user', 'content': prompt}] - task = call_openai(messages, 0, self.model, self.seed) - openai_response_tasks.append((uid, task)) + openai_response_tasks.append((uid, self.call_openai(prompt))) openai_responses = await asyncio.gather(*[task for _, task in openai_response_tasks]) @@ -124,3 +135,45 @@ async def score_responses(self, query_responses, uid_to_question, metagraph): if uid_scores_dict != {}: bt.logging.info(f"text_scores is {uid_scores_dict}") return scores, uid_scores_dict, self.wandb_data + + +class TestTextValidator(TextValidator): + def __init__( + self, + dendrite, + config, + subtensor, + wallet: bt.wallet, + ): + super().__init__(dendrite, config, subtensor, wallet) + self.openai_prompt_to_contents: dict[str, list[str]] = {} + self.questions: list[str] = [] + self._questions_retrieved = -1 + self._openai_prompts_used: dict[str, int] = {} + + def feed_mock_data(self, openai_prompt_to_contents, questions): + self.questions = questions + self.openai_prompt_to_contents = openai_prompt_to_contents + self._openai_prompts_used = dict.fromkeys(self.openai_prompt_to_contents, -1) + self._questions_retrieved = -1 + + def should_i_score(self): + return True + + async def call_openai(self, prompt: str) -> str: + self._openai_prompts_used[prompt] += 1 + used = self._openai_prompts_used[prompt] + contents = self.openai_prompt_to_contents[prompt] + return contents[used % len(contents)] + + async def get_question(self, qty): + self._questions_retrieved += 1 + return self.questions[self._questions_retrieved % len(self.questions)] + + async def query_miner(self, metagraph, uid, syn: StreamPrompting): + return uid, await self.call_openai(syn.messages[0]['content']) + + async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]) -> AsyncIterator[tuple[int, str]]: + for uid, messages in query.items(): + for msg in messages: + yield uid, await self.call_openai(msg['content']) diff --git a/validators/validator.py b/validators/validator.py index 0c2b34c9..4c747ebd 100644 --- a/validators/validator.py +++ b/validators/validator.py @@ -1,3 +1,7 @@ +import logging +import time +from typing import Tuple + import base # noqa import argparse @@ -11,22 +15,26 @@ import wandb from aiohttp import web from aiohttp.web_response import Response -from image_validator import ImageValidator -from text_validator import TextValidator +from bittensor.btlogging import logger +from validators.image_validator import ImageValidator +from validators.text_validator import TextValidator, TestTextValidator +from envparse import env import template from template import utils -from template.protocol import IsAlive import sys +from validators.weight_setter import WeightSetter, TestWeightSetter -moving_average_scores = None text_vali = None image_vali = None embed_vali = None metagraph = None wandb_runs = {} -EXPECTED_ACCESS_KEY = "hello" +# organic requests are scored, the tasks are stored in this queue +# for later being consumed by `query_synapse` cycle: +organic_scoring_tasks = set() +EXPECTED_ACCESS_KEY = env('EXPECTED_ACCESS_KEY', default='hello') def get_config() -> bt.config: @@ -96,117 +104,15 @@ def initialize_components(config: bt.config): return wallet, subtensor, dendrite, my_uid -def initialize_validators(vali_config): +def initialize_validators(vali_config, test=False): global text_vali, image_vali, embed_vali - text_vali = TextValidator(**vali_config) + text_vali = (TextValidator if not test else TestTextValidator)(**vali_config) image_vali = ImageValidator(**vali_config) # embed_vali = EmbeddingsValidator(**vali_config) bt.logging.info("initialized_validators") -async def check_uid(dendrite, axon, uid): - """Asynchronously check if a UID is available.""" - try: - response = await dendrite(axon, IsAlive(), deserialize=False, timeout=4) - if response.is_success: - bt.logging.trace(f"UID {uid} is active") - return axon # Return the axon info instead of the UID - - bt.logging.trace(f"UID {uid} is not active") - return None - - except Exception as e: - bt.logging.error(f"Error checking UID {uid}: {e}\n{traceback.format_exc()}") - return None - -async def get_available_uids(dendrite, metagraph): - """Get a dictionary of available UIDs and their axons asynchronously.""" - tasks = {uid.item(): check_uid(dendrite, metagraph.axons[uid.item()], uid.item()) for uid in metagraph.uids} - results = await asyncio.gather(*tasks.values()) - - # Create a dictionary of UID to axon info for active UIDs - available_uids = {uid: axon_info for uid, axon_info in zip(tasks.keys(), results) if axon_info is not None} - - return available_uids - - -def set_weights(scores, config, subtensor, wallet, metagraph): - global moving_average_scores - # alpha of .3 means that each new score replaces 30% of the weight of the previous weights - alpha = .3 - if moving_average_scores is None: - moving_average_scores = scores.clone() - - # Update the moving average scores - moving_average_scores = alpha * scores + (1 - alpha) * moving_average_scores - bt.logging.info(f"Updated moving average of weights: {moving_average_scores}") - subtensor.set_weights(netuid=config.netuid, wallet=wallet, uids=metagraph.uids, weights=moving_average_scores, wait_for_inclusion=False) - bt.logging.success("Successfully set weights.") - - -def update_weights(total_scores, steps_passed, config, subtensor, wallet, metagraph): - """ Update weights based on total scores, using min-max normalization for display. """ - avg_scores = total_scores / (steps_passed + 1) - - # Normalize avg_scores to a range of 0 to 1 - min_score = torch.min(avg_scores) - max_score = torch.max(avg_scores) - - if max_score - min_score != 0: - normalized_scores = (avg_scores - min_score) / (max_score - min_score) - else: - normalized_scores = torch.zeros_like(avg_scores) - - bt.logging.info(f"normalized_scores = {normalized_scores}") - # We can't set weights with normalized scores because that disrupts the weighting assigned to each validator class - # Weights get normalized anyways in weight_utils - set_weights(avg_scores, config, subtensor, wallet, metagraph) - - -async def process_modality(config, selected_validator, available_uids, metagraph): - uid_list = list(available_uids.keys()) - random.shuffle(uid_list) - bt.logging.info(f"starting {selected_validator.__class__.__name__} get_and_score for {uid_list}") - scores, uid_scores_dict, wandb_data = await selected_validator.get_and_score(uid_list, metagraph) - if config.wandb_on: - wandb.log(wandb_data) - bt.logging.success("wandb_log successful") - return scores, uid_scores_dict - - -async def query_synapse(dendrite, subtensor, config, wallet): - global metagraph - steps_passed = 0 - total_scores = torch.zeros(len(metagraph.hotkeys)) - while True: - try: - metagraph = subtensor.metagraph(config.netuid) - available_uids = await get_available_uids(dendrite, metagraph) - - if steps_passed % 5 in (0, 1, 2): - selected_validator = text_vali - else: - selected_validator = image_vali - - scores, _uid_scores_dict = await process_modality(config, selected_validator, available_uids, metagraph) - total_scores += scores - - iterations_per_set_weights = 12 - iterations_until_update = iterations_per_set_weights - ((steps_passed + 1) % iterations_per_set_weights) - bt.logging.info(f"Updating weights in {iterations_until_update} iterations.") - - if iterations_until_update == 1: - update_weights(total_scores, steps_passed, config, subtensor, wallet, metagraph) - - steps_passed += 1 - await asyncio.sleep(0.5) - - except Exception as e: - bt.logging.error(f"General exception: {e}\n{traceback.format_exc()}") - await asyncio.sleep(100) - - async def process_text_validator(request: web.Request): # Check access key access_key = request.headers.get("access-key") @@ -221,20 +127,32 @@ async def process_text_validator(request: web.Request): response = web.StreamResponse() await response.prepare(request) + uid_to_response = dict.fromkeys(messages_dict, "") try: - async for uid, content in text_vali.organic(metagraph, messages_dict): + async for uid, content in text_vali.organic(validator_app.weight_setter.metagraph, messages_dict): + uid_to_response[uid] += content await response.write(content.encode()) - except Exception as e: - bt.logging.error(f"error in response_stream {traceback.format_exc()}") - return web.StreamResponse(status=500, reason='internal error') + validator_app.weight_setter.register_text_validator_organic_query( + uid_to_response, {k: v[0]['content'] for k, v in messages_dict.items()} + ) + except Exception: + logger.error(f'Encountered in {process_text_validator.__name__}:\n{traceback.format_exc()}') + await response.write(b'<>') return response -aio_app = web.Application() -aio_app.add_routes([web.post('/text-validator/', process_text_validator)]) + +class ValidatorApplication(web.Application): + def __init__(self, *a, **kw): + super().__init__(*a, **kw) + self.weight_setter: WeightSetter | None = None + + +validator_app = ValidatorApplication() +validator_app.add_routes([web.post('/text-validator/', process_text_validator)]) -def main() -> None: +def main(run_aio_app=True, test=False) -> None: config = get_config() wallet, subtensor, dendrite, my_uid = initialize_components(config) validator_config = { @@ -243,19 +161,24 @@ def main() -> None: "subtensor": subtensor, "wallet": wallet } - initialize_validators(validator_config) + initialize_validators(validator_config, test) init_wandb(config, my_uid, wallet) loop = asyncio.get_event_loop() - try: - loop.create_task(query_synapse(dendrite, subtensor, config, wallet)) - web.run_app(aio_app, port=config.http_port, loop=loop) - except KeyboardInterrupt: - bt.logging.info("Keyboard interrupt detected. Exiting validator.") - finally: - state = utils.get_state() - utils.save_state_to_file(state) - if config.wandb_on: - wandb.finish() + + weight_setter = (WeightSetter if not test else TestWeightSetter)( + loop, dendrite, subtensor, config, wallet, text_vali, image_vali) + validator_app.weight_setter = weight_setter + + if run_aio_app: + try: + web.run_app(validator_app, port=config.http_port, loop=loop) + except KeyboardInterrupt: + bt.logging.info("Keyboard interrupt detected. Exiting validator.") + finally: + state = utils.get_state() + utils.save_state_to_file(state) + if config.wandb_on: + wandb.finish() if __name__ == "__main__": diff --git a/validators/weight_setter.py b/validators/weight_setter.py new file mode 100644 index 00000000..4c7f7731 --- /dev/null +++ b/validators/weight_setter.py @@ -0,0 +1,201 @@ +import asyncio +import concurrent +import itertools +import traceback +import random +from typing import Tuple + +import bittensor as bt +import torch +import wandb +from bittensor.btlogging import logger + +from template.protocol import IsAlive +from validators.text_validator import TextValidator + +iterations_per_set_weights = 12 +scoring_organic_timeout = 60 + + +async def wait_for_coro_with_limit(coro, timeout: int) -> Tuple[bool, object]: + try: + result = await asyncio.wait_for(coro, timeout) + except asyncio.TimeoutError: + logger.error('scoring task timed out') + return False, None + return True, result + + +class WeightSetter: + def __init__(self, loop: asyncio.AbstractEventLoop, dendrite, subtensor, config, wallet, text_vali, image_vali): + self.loop = loop + self.dendrite = dendrite + self.subtensor = subtensor + self.config = config + self.wallet = wallet + self.text_vali = text_vali + self.image_vali = image_vali + + self.moving_average_scores = None + self.metagraph = subtensor.metagraph(config.netuid) + self.total_scores = torch.zeros(len(self.metagraph.hotkeys)) + self.organic_scoring_tasks = set() + + self.thread_executor = concurrent.futures.ThreadPoolExecutor(thread_name_prefix='asyncio') + self.loop.create_task(self.consume_organic_scoring()) + self.loop.create_task(self.perform_synthetic_scoring_and_update_weights()) + + async def run_sync_in_async(self, fn): + return await self.loop.run_in_executor(self.thread_executor, fn) + + async def consume_organic_scoring(self): + while True: + try: + if self.organic_scoring_tasks: + completed, _ = await asyncio.wait(self.organic_scoring_tasks, timeout=1, + return_when=asyncio.FIRST_COMPLETED) + for task in completed: + if task.exception(): + logger.error( + f'Encountered in {TextValidator.score_responses.__name__} task:\n' + f'{"".join(traceback.format_exception(task.exception()))}' + ) + else: + success, data = task.result() + if not success: + continue + self.total_scores += data[0] + self.organic_scoring_tasks.difference_update(completed) + else: + await asyncio.sleep(1) + except Exception as e: + logger.error(f'Encountered in {self.consume_organic_scoring.__name__} loop:\n{traceback.format_exc()}') + await asyncio.sleep(10) + + async def perform_synthetic_scoring_and_update_weights(self): + while True: + for steps_passed in itertools.count(): + self.metagraph = await self.run_sync_in_async(lambda: self.subtensor.metagraph(self.config.netuid)) + + available_uids = await self.get_available_uids() + selected_validator = self.select_validator(steps_passed) + scores, _ = await self.process_modality(selected_validator, available_uids) + self.total_scores += scores + + steps_since_last_update = steps_passed % iterations_per_set_weights + + if steps_since_last_update == iterations_per_set_weights - 1: + await self.update_weights(steps_passed) + else: + bt.logging.info( + f"Updating weights in {iterations_per_set_weights - steps_since_last_update - 1} iterations." + ) + + await asyncio.sleep(0.5) + + def select_validator(self, steps_passed): + return self.text_vali if steps_passed % 5 in (0, 1, 2) else self.image_vali + + async def get_available_uids(self): + """Get a dictionary of available UIDs and their axons asynchronously.""" + tasks = {uid.item(): self.check_uid(self.metagraph.axons[uid.item()], uid.item()) for uid in self.metagraph.uids} + results = await asyncio.gather(*tasks.values()) + + # Create a dictionary of UID to axon info for active UIDs + available_uids = {uid: axon_info for uid, axon_info in zip(tasks.keys(), results) if axon_info is not None} + + return available_uids + + async def check_uid(self, axon, uid): + """Asynchronously check if a UID is available.""" + try: + response = await self.dendrite(axon, IsAlive(), deserialize=False, timeout=4) + if response.is_success: + bt.logging.trace(f"UID {uid} is active") + return axon # Return the axon info instead of the UID + + bt.logging.trace(f"UID {uid} is not active") + return None + + except Exception as e: + bt.logging.error(f"Error checking UID {uid}: {e}\n{traceback.format_exc()}") + return None + + def shuffled(self, list_: list) -> list: + list_ = list_.copy() + random.shuffle(list_) + return list_ + + async def process_modality(self, selected_validator, available_uids): + uid_list = self.shuffled(list(available_uids.keys())) + bt.logging.info(f"starting {selected_validator.__class__.__name__} get_and_score for {uid_list}") + scores, uid_scores_dict, wandb_data = await selected_validator.get_and_score(uid_list, self.metagraph) + if self.config.wandb_on: + wandb.log(wandb_data) + bt.logging.success("wandb_log successful") + return scores, uid_scores_dict + + async def update_weights(self, steps_passed): + """ Update weights based on total scores, using min-max normalization for display. """ + avg_scores = self.total_scores / (steps_passed + 1) + + # Normalize avg_scores to a range of 0 to 1 + min_score = torch.min(avg_scores) + max_score = torch.max(avg_scores) + + if max_score - min_score != 0: + normalized_scores = (avg_scores - min_score) / (max_score - min_score) + else: + normalized_scores = torch.zeros_like(avg_scores) + + bt.logging.info(f"normalized_scores = {normalized_scores}") + # We can't set weights with normalized scores because that disrupts the weighting assigned to each validator class + # Weights get normalized anyways in weight_utils + await self.set_weights(avg_scores) + + async def set_weights(self, scores): + # alpha of .3 means that each new score replaces 30% of the weight of the previous weights + alpha = .3 + if self.moving_average_scores is None: + self.moving_average_scores = scores.clone() + + # Update the moving average scores + self.moving_average_scores = alpha * scores + (1 - alpha) * self.moving_average_scores + bt.logging.info(f"Updated moving average of weights: {self.moving_average_scores}") + await self.run_sync_in_async( + lambda: self.subtensor.set_weights( + netuid=self.config.netuid, + wallet=self.wallet, + uids=self.metagraph.uids, + weights=self.moving_average_scores, + wait_for_inclusion=False, + ) + ) + bt.logging.success("Successfully set weights.") + + def register_text_validator_organic_query( + self, + uid_to_response: dict[int, str], # [(uid, response)] + messages_dict: dict[int, str], + ): + self.organic_scoring_tasks.add(asyncio.create_task( + wait_for_coro_with_limit( + self.text_vali.score_responses( + query_responses=list(uid_to_response.items()), + uid_to_question=messages_dict, + metagraph=self.metagraph, + ), + scoring_organic_timeout + ) + )) + + +class TestWeightSetter(WeightSetter): + def select_validator(self, steps_passed): + return self.text_vali + + async def get_available_uids(self): + return {i: None for i in range(len(self.metagraph.hotkeys))} + + def shuffled(self, list_: list) -> list: + return list_