diff --git a/README.md b/README.md index 36aa8de0..dd17332b 100644 --- a/README.md +++ b/README.md @@ -236,7 +236,7 @@ wand login You can launch your validator using following command ```python -pm2 start start_validator.py --interpreter python3 -- --wallet.name default --wallet.hotkey defualt --subtensor.chain_endpoint --autoupdate --wandb_on +pm2 start start_validator.py --interpreter python3 -- --wallet_name "default" --wallet_hotkey "default" --subtensor.chain_endpoint --autoupdate --wandb_on ``` --- diff --git a/diff.diff b/diff.diff new file mode 100644 index 00000000..e9d0cec2 --- /dev/null +++ b/diff.diff @@ -0,0 +1,1774 @@ +diff --git a/cortext/__init__.py b/cortext/__init__.py +index d52e2d5..348a394 100644 +--- a/cortext/__init__.py ++++ b/cortext/__init__.py +@@ -19,7 +19,7 @@ + + + # version must stay on line 22 +-__version__ = "4.0.6" ++__version__ = "4.0.4" + version_split = __version__.split(".") + __spec_version__ = ( + (1000 * int(version_split[0])) +@@ -27,7 +27,7 @@ __spec_version__ = ( + + (1 * int(version_split[2])) + ) + +-u64_max = 2 ** 64 - 9 ++u64_max = 2 ** 64 - 10 + __weights_version__ = u64_max + + import os +@@ -36,7 +36,7 @@ from typing import Union + + from openai import AsyncOpenAI + +-from cortext.protocol import StreamPrompting, Embeddings, ImageResponse, IsAlive ++from cortext.protocol import StreamPrompting, TextPrompting, Embeddings, ImageResponse, IsAlive + + load_dotenv() + try: +@@ -3768,4 +3768,4 @@ IMAGE_THEMES = [ + ] + + +-ALL_SYNAPSE_TYPE = Union[StreamPrompting, Embeddings, ImageResponse, IsAlive] ++ALL_SYNAPSE_TYPE = Union[StreamPrompting, TextPrompting, Embeddings, ImageResponse, IsAlive] +diff --git a/cortext/protocol.py b/cortext/protocol.py +index 4fec337..bea6ba7 100644 +--- a/cortext/protocol.py ++++ b/cortext/protocol.py +@@ -1,3 +1,4 @@ ++from enum import Enum + from typing import AsyncIterator, Dict, List, Optional, Union + import bittensor as bt + import pydantic +@@ -14,9 +15,6 @@ class IsAlive(bt.Synapse): + ) + + +-class Bandwidth(bt.Synapse): +- bandwidth_rpm: Optional[Dict[str, int]] = None +- + class ImageResponse(bt.Synapse): + """ A class to represent the response for an image-related request. """ + # https://platform.stability.ai/docs/api-reference#tag/v1generation/operation/textToImage +@@ -322,14 +320,95 @@ class StreamPrompting(bt.StreamingSynapse): + "axon": extract_info("bt_header_axon"), + "messages": self.messages, + "completion": self.completion, +- "provider": self.provider, +- "model": self.model, +- "seed": self.seed, +- "max_tokens": self.max_tokens, +- "temperature": self.temperature, +- "top_p": self.top_p, +- "top_k": self.top_k, +- "timeout": self.timeout, +- "streaming": self.streaming, +- "uid": self.uid, +- } +\ No newline at end of file ++ } ++ ++ ++class TextPrompting(bt.Synapse): ++ messages: List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]] = pydantic.Field( ++ ..., ++ title="Messages", ++ description="A list of messages in the StreamPrompting scenario, " ++ "each containing a role and content. Immutable.", ++ allow_mutation=False, ++ ) ++ ++ required_hash_fields: List[str] = pydantic.Field( ++ ["messages"], ++ title="Required Hash Fields", ++ description="A list of required fields for the hash.", ++ allow_mutation=False, ++ ) ++ ++ seed: int = pydantic.Field( ++ default="1234", ++ title="Seed", ++ description="Seed for text generation. This attribute is immutable and cannot be updated.", ++ ) ++ ++ temperature: float = pydantic.Field( ++ default=0.0001, ++ title="Temperature", ++ description="Temperature for text generation. " ++ "This attribute is immutable and cannot be updated.", ++ ) ++ ++ max_tokens: int = pydantic.Field( ++ default=2048, ++ title="Max Tokens", ++ description="Max tokens for text generation. " ++ "This attribute is immutable and cannot be updated.", ++ ) ++ ++ top_p: float = pydantic.Field( ++ default=0.001, ++ title="Top_p", ++ description="Top_p for text generation. The sampler will pick one of " ++ "the top p percent tokens in the logit distirbution. " ++ "This attribute is immutable and cannot be updated.", ++ ) ++ ++ top_k: int = pydantic.Field( ++ default=1, ++ title="Top_k", ++ description="Top_k for text generation. Sampler will pick one of " ++ "the k most probablistic tokens in the logit distribtion. " ++ "This attribute is immutable and cannot be updated.", ++ ) ++ ++ completion: str = pydantic.Field( ++ None, ++ title="Completion", ++ description="Completion status of the current StreamPrompting object. " ++ "This attribute is mutable and can be updated.", ++ ) ++ ++ provider: str = pydantic.Field( ++ default="OpenAI", ++ title="Provider", ++ description="The provider to use when calling for your response. " ++ "Options: OpenAI, Anthropic, Gemini, Groq, Bedrock", ++ ) ++ ++ model: str = pydantic.Field( ++ default="gpt-3.5-turbo", ++ title="model", ++ description="The model to use when calling provider for your response.", ++ ) ++ ++ uid: int = pydantic.Field( ++ default=3, ++ title="uid", ++ description="The UID to send the streaming synapse to", ++ ) ++ ++ timeout: int = pydantic.Field( ++ default=60, ++ title="timeout", ++ description="The timeout for the dendrite of the streaming synapse", ++ ) ++ ++ streaming: bool = pydantic.Field( ++ default=True, ++ title="streaming", ++ description="whether to stream the output", ++ ) +diff --git a/cortext/reward.py b/cortext/reward.py +index 24bd52b..f5c4a78 100644 +--- a/cortext/reward.py ++++ b/cortext/reward.py +@@ -24,8 +24,13 @@ hf_logging.set_verbosity_error() + import re + import io + import torch ++import openai ++import typing ++import difflib + import asyncio ++import logging + import aiohttp ++import requests + import traceback + import numpy as np + from numpy.linalg import norm +@@ -66,8 +71,8 @@ async def api_score(api_answer: str, response: str, weight: float, temperature: + words_in_response = len(response.split()) + words_in_api = len(api_answer.split()) + +- word_count_over_threshold = words_in_api * 1.4 +- word_count_under_threshold = words_in_api * 0.50 ++ word_count_over_threshold = words_in_api * 1.20 ++ word_count_under_threshold = words_in_api * 0.60 + + # Check if the word count of the response is within the thresholds + if words_in_response <= word_count_over_threshold and words_in_response >= word_count_under_threshold: +@@ -153,7 +158,7 @@ def calculate_image_similarity(image, description, max_length: int = 77): + # Calculate cosine similarity + return torch.cosine_similarity(image_embedding, text_embedding, dim=1).item() + +-async def dalle_score(uid, url, desired_size, description, weight, similarity_threshold=0.21) -> float: ++async def dalle_score(uid, url, desired_size, description, weight, similarity_threshold=0.23) -> float: + """Calculate the image score based on similarity and size asynchronously.""" + + if not re.match(url_regex, url): +diff --git a/cortext/utils.py b/cortext/utils.py +index 0ded531..5f745f9 100644 +--- a/cortext/utils.py ++++ b/cortext/utils.py +@@ -163,7 +163,7 @@ def fetch_random_image_urls(num_images): + images = response.json().get('hits', []) + return [image['webformatURL'] for image in images] + except Exception as e: +- bt.logging.error(f"Error fetching random images: {e}") ++ print(f"Error fetching random images: {e}") + return [] + + +@@ -315,9 +315,9 @@ async def update_counters_and_get_new_list(category, item_type, num_questions_ne + item = await get_item_from_list(items, vision) + + if not item: +- bt.logging.trace(f"Item not founded in items: {items}. Calling get_items!") ++ bt.logging.info(f"Item not founded in items: {items}. Calling get_items!") + items = await get_items(category, item_type, theme) +- bt.logging.trace(f"Items generated: {items}") ++ bt.logging.info(f"Items generated: {items}") + state[category][item_type] = items + bt.logging.debug(f"Fetched new list for {list_type}, containing {len(items)} items") + +@@ -503,7 +503,7 @@ async def call_openai(messages, temperature, model, seed=1234, max_tokens=2048, + + + async def call_gemini(messages, temperature, model, max_tokens, top_p, top_k): +- bt.logging.debug(f"Calling Gemini. Temperature = {temperature}, Model = {model}, Messages = {messages}") ++ print(f"Calling Gemini. Temperature = {temperature}, Model = {model}, Messages = {messages}") + try: + model = genai.GenerativeModel(model) + response = model.generate_content( +@@ -520,10 +520,10 @@ async def call_gemini(messages, temperature, model, max_tokens, top_p, top_k): + ), + ) + +- bt.logging.trace(f"validator response is {response.text}") ++ print(f"validator response is {response.text}") + return response.text + except: +- bt.logging.error(f"error in call_gemini {traceback.format_exc()}") ++ print(f"error in call_gemini {traceback.format_exc()}") + + + # anthropic = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) +@@ -625,7 +625,7 @@ async def call_anthropic(messages, temperature, model, max_tokens, top_p, top_k) + kwargs["system"] = system_prompt + + message = await anthropic_client.messages.create(**kwargs) +- bt.logging.trace(f"validator response is {message.content[0].text}") ++ bt.logging.debug(f"validator response is {message.content[0].text}") + return message.content[0].text + except: + bt.logging.error(f"error in call_anthropic {traceback.format_exc()}") +@@ -647,7 +647,7 @@ async def call_groq(messages, temperature, model, max_tokens, top_p, seed): + } + + message = await groq_client.chat.completions.create(**kwargs) +- bt.logging.trace(f"validator response is {message.choices[0].message.content}") ++ bt.logging.debug(f"validator response is {message.choices[0].message.content}") + return message.choices[0].message.content + except: + bt.logging.error(f"error in call_groq {traceback.format_exc()}") +@@ -738,7 +738,7 @@ async def call_bedrock(messages, temperature, model, max_tokens, top_p, seed): + message = await response['body'].read() + message = await extract_message(message) + +- bt.logging.trace(f"validator response is {message}") ++ bt.logging.debug(f"validator response is {message}") + return message + except: + bt.logging.error(f"error in call_bedrock {traceback.format_exc()}") +diff --git a/env.example b/env.example +index c9d20f6..67c5b46 100644 +--- a/env.example ++++ b/env.example +@@ -1,12 +1,12 @@ +-# for validators ++ENV=test + WANDB_API_KEY= +-OPENAI_API_KEY= +-PIXABAY_API_KEY= + +-# For validators and miners ++# used both by validator and miner: ++OPENAI_API_KEY= + GOOGLE_API_KEY= + ANTHROPIC_API_KEY= + GROQ_API_KEY=test + AWS_ACCESS_KEY= + AWS_SECRET_KEY= ++PIXABAY_API_KEY= + +diff --git a/miner/config.py b/miner/config.py +index 7cdcb40..3ad29e9 100644 +--- a/miner/config.py ++++ b/miner/config.py +@@ -43,7 +43,6 @@ class Config: + self.NO_SET_WEIGHTS = os.getenv('NO_SET_WEIGHTS', False) + self.NO_SERVE = os.getenv('NO_SERVE', False) + +- + def __repr__(self): + return ( + f"Config(BT_SUBTENSOR_NETWORK={self.BT_SUBTENSOR_NETWORK}, WALLET_NAME={self.WALLET_NAME}, HOT_KEY={self.HOT_KEY}" +diff --git a/miner/constants.py b/miner/constants.py +deleted file mode 100644 +index ef7371a..0000000 +--- a/miner/constants.py ++++ /dev/null +@@ -1,28 +0,0 @@ +-from cortext import ImageResponse, StreamPrompting +-from miner.providers import OpenAI, Anthropic, AnthropicBedrock, Groq, Gemini, Bedrock +- +-task_image = ImageResponse.__name__ +-task_stream = StreamPrompting.__name__ +- +-openai_provider = OpenAI.__name__ +-anthropic_provider = Anthropic.__name__ +-anthropic_bedrock_provider = AnthropicBedrock.__name__ +-groq_provider = Groq.__name__ +-gemini_provider = Gemini.__name__ +-bedrock_provider = Bedrock.__name__ +- +-capacity_to_task_and_provider = { +- f"{task_image}_{openai_provider}": 1, +- f"{task_image}_{anthropic_provider}": 1, +- f"{task_image}_{anthropic_bedrock_provider}": 1, +- f"{task_image}_{groq_provider}": 1, +- f"{task_image}_{gemini_provider}": 1, +- f"{task_image}_{bedrock_provider}": 1, +- +- f"{task_stream}_{openai_provider}": 1, +- f"{task_stream}_{anthropic_provider}": 1, +- f"{task_stream}_{anthropic_bedrock_provider}": 1, +- f"{task_stream}_{groq_provider}": 1, +- f"{task_stream}_{gemini_provider}": 1, +- f"{task_stream}_{bedrock_provider}": 1, +-} +diff --git a/miner/providers/base.py b/miner/providers/base.py +index 9123f09..18a0734 100644 +--- a/miner/providers/base.py ++++ b/miner/providers/base.py +@@ -5,7 +5,7 @@ import httpx + from starlette.types import Send + from abc import abstractmethod + +-from cortext.protocol import StreamPrompting, Embeddings, ImageResponse, IsAlive ++from cortext.protocol import StreamPrompting, TextPrompting, Embeddings, ImageResponse, IsAlive + from cortext import ALL_SYNAPSE_TYPE + from cortext.metaclasses import ProviderRegistryMeta + from miner.error_handler import error_handler +@@ -15,7 +15,7 @@ class Provider(metaclass=ProviderRegistryMeta): + self.model = synapse.model + self.uid = synapse.uid + self.timeout = synapse.timeout +- if type(synapse) in [StreamPrompting]: ++ if type(synapse) in [StreamPrompting, TextPrompting]: + self.messages = synapse.messages + self.required_hash_fields = synapse.required_hash_fields + self.seed = synapse.seed +diff --git a/miner/services/__init__.py b/miner/services/__init__.py +index 0823ea8..8309158 100644 +--- a/miner/services/__init__.py ++++ b/miner/services/__init__.py +@@ -4,7 +4,6 @@ from .image import ImageService + from .embedding import EmbeddingService + from .text import TextService + from .check_status import IsAliveService +-from .capacity import CapacityService + +-ALL_SERVICE_TYPE = Union[PromptService, ImageService, EmbeddingService, TextService, IsAliveService, CapacityService] +-__all__ = [PromptService, ImageService, EmbeddingService, CapacityService, ALL_SERVICE_TYPE] ++ALL_SERVICE_TYPE = Union[PromptService, ImageService, EmbeddingService, TextService, IsAliveService] ++__all__ = [PromptService, ImageService, EmbeddingService, ALL_SERVICE_TYPE] +diff --git a/miner/services/capacity.py b/miner/services/capacity.py +deleted file mode 100644 +index eab539b..0000000 +--- a/miner/services/capacity.py ++++ /dev/null +@@ -1,24 +0,0 @@ +-import bittensor as bt +- +-from cortext.protocol import Bandwidth +-from typing import Tuple +- +-from .base import BaseService +-from cortext import ISALIVE_BLACKLIST_STAKE +-from miner.constants import capacity_to_task_and_provider +- +- +-class CapacityService(BaseService): +- def __init__(self, metagraph, blacklist_amt=ISALIVE_BLACKLIST_STAKE): +- super().__init__(metagraph, blacklist_amt) +- +- async def forward_fn(self, synapse: Bandwidth): +- bt.logging.debug("capacity request is being processed") +- synapse.bandwidth_rpm = capacity_to_task_and_provider +- bt.logging.info("check status is executed.") +- return synapse +- +- def blacklist_fn(self, synapse: Bandwidth) -> Tuple[bool, str]: +- blacklist = self.base_blacklist(synapse) +- bt.logging.info(blacklist[1]) +- return blacklist +diff --git a/miner/services/text.py b/miner/services/text.py +new file mode 100644 +index 0000000..3f0c853 +--- /dev/null ++++ b/miner/services/text.py +@@ -0,0 +1,19 @@ ++import bittensor as bt ++from cortext.protocol import TextPrompting ++from typing import Tuple ++ ++from .base import BaseService ++from miner.config import config ++ ++ ++class TextService(BaseService): ++ def __init__(self, metagraph, blacklist_amt=config.BLACKLIST_AMT): ++ super().__init__(metagraph, blacklist_amt) ++ ++ async def forward_fn(self, synapse: TextPrompting): ++ synapse.completion = "completed by miner" ++ bt.logging.info("text service is executed.") ++ return synapse ++ ++ def blacklist_fn(self, synapse: TextPrompting) -> Tuple[bool, str]: ++ return False, "" +diff --git a/start_validator.py b/start_validator.py +index ad013e1..4b54570 100644 +--- a/start_validator.py ++++ b/start_validator.py +@@ -4,18 +4,19 @@ import subprocess + import cortext + from cortext.utils import get_version, send_discord_alert + ++default_address = "wss://bittensor-finney.api.onfinality.io/public-ws" + webhook_url = "" + current_version = cortext.__version__ + + + def update_and_restart(pm2_name, netuid, wallet_name, wallet_hotkey, address, autoupdate, logging, wandb_on): + global current_version +- wandb = "--wandb_on" if wandb_on else "" ++ wandb = "" if wandb_on else "--wandb_off" + subprocess.run(["pm2", "start", "--name", pm2_name, f"python3 -m validators.validator --wallet.name {wallet_name}" + f" --wallet.hotkey {wallet_hotkey} " + f" --netuid {netuid} " + f"--subtensor.chain_endpoint {address} " +- f"--logging.level {logging} {wandb}"]) ++ f"--logging.{logging} {wandb}"]) + while True: + latest_version = get_version() + print(f"Current version: {current_version}") +@@ -24,7 +25,7 @@ def update_and_restart(pm2_name, netuid, wallet_name, wallet_hotkey, address, au + if current_version != latest_version and latest_version != None: + if not autoupdate: + send_discord_alert( +- f"Your validator not running the latest code ({current_version}). You will quickly lose vtrust if you don't update to version {latest_version}", ++ f"Your validator not running the latest code ({current_version}). You will quickly lose vturst if you don't update to version {latest_version}", + webhook_url) + print("Updating to the latest version...") + subprocess.run(["pm2", "delete", pm2_name]) +@@ -36,7 +37,7 @@ def update_and_restart(pm2_name, netuid, wallet_name, wallet_hotkey, address, au + f" --wallet.hotkey {wallet_hotkey} " + f" --netuid {netuid} " + f"--subtensor.chain_endpoint {address} " +- f"--logging.level {logging} {wandb}"]) ++ f"--logging.{logging} {wandb}"]) + current_version = latest_version + + print("All up to date!") +@@ -48,15 +49,16 @@ if __name__ == "__main__": + description="Automatically update and restart the validator process when a new version is released." + ) + +- parser.add_argument("--pm2_name", required=False, default="autoupdater", help="Name of the PM2 process.") +- parser.add_argument("--wallet_name", required=False, default="default", help="Name of the wallet.") +- parser.add_argument("--wallet_hotkey", required=False, default="default", help="Hotkey for the wallet.") +- parser.add_argument("--netuid", required=False, default=18, help="netuid for validator") +- parser.add_argument("--subtensor.chain_endpoint", required=False, default="wss://entrypoint-finney.opentensor.ai:443", dest="address") +- parser.add_argument("--autoupdate", action='store_true', dest="autoupdate") +- parser.add_argument("--logging", required=False, default="info") +- parser.add_argument("--wandb_on", action='store_true', required=False, dest="wandb_on") +- parser.add_argument("--max_miners_cnt", type=int, default=30) ++ parser.add_argument("--pm2_name", required=True, help="Name of the PM2 process.") ++ parser.add_argument("--wallet_name", required=True, help="Name of the wallet.") ++ parser.add_argument("--wallet_hotkey", required=True, help="Hotkey for the wallet.") ++ parser.add_argument("--netuid", required=True, help="netuid for validator") ++ parser.add_argument("--subtensor.chain_endpoint", default=default_address, dest='address', ++ help="Subtensor chain_endpoint, defaults to 'wss://bittensor-finney.api.onfinality.io/public-ws' if not provided.") ++ parser.add_argument("--autoupdate", action='store_true', dest='autoupdate', ++ help="Disable automatic update. Only send a Discord alert. Add your webhook at the top of the script.") ++ parser.add_argument("--logging", required=False, default="debug") ++ parser.add_argument("--wandb_on", action='store_true', required=False, dest='wandb_on') + + args = parser.parse_args() + +@@ -64,4 +66,4 @@ if __name__ == "__main__": + update_and_restart(args.pm2_name, args.netuid, args.wallet_name, args.wallet_hotkey, args.address, + args.autoupdate, args.logging, args.wandb_on) + except Exception as e: +- parser.error(f"An error occurred: {e}") +\ No newline at end of file ++ parser.error(f"An error occurred: {e}") +diff --git a/validators/config.py b/validators/config.py +new file mode 100644 +index 0000000..2280fd0 +--- /dev/null ++++ b/validators/config.py +@@ -0,0 +1,76 @@ ++import bittensor as bt ++ ++from dotenv import load_dotenv ++import argparse ++import os ++from pathlib import Path ++ ++load_dotenv() # Load environment variables from .env file ++ ++ ++class Config: ++ def __init__(self): ++ super().__init__() ++ ++ self.ENV = os.getenv('ENV') ++ self.ASYNC_TIME_OUT = int(os.getenv('ASYNC_TIME_OUT', 60)) ++ self.BT_SUBTENSOR_NETWORK = 'test' if self.ENV == 'test' else 'finney' ++ self.SLEEP_PER_ITERATION = 1 ++ self.IMAGE_VALIDATOR_CHOOSE_PROBABILITY = 0.03 ++ ++ @staticmethod ++ def check_required_env_vars(): ++ AWS_ACCESS_KEY = os.getenv('AWS_ACCESS_KEY') ++ AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY') ++ if all([AWS_SECRET_KEY, AWS_ACCESS_KEY]): ++ pass ++ else: ++ bt.logging.info("AWS_KEY is not provided correctly. so exit system") ++ exit(0) ++ ++ ++def get_config() -> bt.config: ++ Config.check_required_env_vars() ++ parser = argparse.ArgumentParser() ++ ++ parser.add_argument("--subtensor.chain_endpoint", type=str) ++ parser.add_argument("--wallet.name", type=str) ++ parser.add_argument("--wallet.hotkey", type=str) ++ parser.add_argument("--netuid", type=int) ++ parser.add_argument("--wandb_off", action="store_true", dest="wandb_off") ++ parser.add_argument("--max_miners_cnt", type=int, default=30) ++ parser.add_argument("--axon.port", type=int, default=8000) ++ parser.add_argument('--logging.info', action='store_true') ++ parser.add_argument('--logging.debug', action='store_true') ++ parser.add_argument('--logging.trace', action='store_true') ++ ++ # Activating the parser to read any command-line inputs. ++ # To print help message, run python3 template/miner.py --help ++ bt_config_ = bt.config(parser) ++ bt.configs.append(bt_config_) ++ bt_config_ = bt.config.merge_all(bt.configs) ++ ++ # Logging captures events for diagnosis or understanding miner's behavior. ++ full_path = Path(f"{bt_config_.logging.logging_dir}/{bt_config_.wallet.name}/{bt_config_.wallet.hotkey}" ++ f"/netuid{bt_config_.netuid}/miner").expanduser() ++ bt_config_.full_path = str(full_path) ++ # Ensure the directory for logging exists, else create one. ++ full_path.mkdir(parents=True, exist_ok=True) ++ ++ bt.axon.check_config(bt_config_) ++ bt.logging.check_config(bt_config_) ++ ++ local_host_str = ['local', '127.0.0.1', '0.0.0.0'] ++ if 'test' in bt_config_.subtensor.chain_endpoint: ++ bt_config_.subtensor.network = 'test' ++ elif any(word in bt_config_.subtensor.chain_endpoint for word in local_host_str): ++ bt_config_.subtensor.network = 'local' ++ else: ++ bt_config_.subtensor.network = 'finney' ++ ++ bt.logging.info(bt_config_) ++ return bt_config_ ++ ++ ++app_config = Config() ++bt_config = get_config() +diff --git a/validators/services/__init__.py b/validators/services/__init__.py +index 0fc0233..9fe505e 100644 +--- a/validators/services/__init__.py ++++ b/validators/services/__init__.py +@@ -1,2 +1,2 @@ + from validators.services.validators import * +-from .capacity import CapacityService +\ No newline at end of file ++from .bittensor import bt_validator +\ No newline at end of file +diff --git a/validators/services/bittensor.py b/validators/services/bittensor.py +new file mode 100644 +index 0000000..8dc1ef2 +--- /dev/null ++++ b/validators/services/bittensor.py +@@ -0,0 +1,34 @@ ++from validators.config import bt_config ++import bittensor as bt ++import sys ++ ++ ++class BittensorValidator: ++ def __init__(self): ++ self.config = bt_config ++ bt.logging(config=self.config, logging_dir=self.config.full_path) ++ self.logging = bt.logging ++ self.logging.info( ++ f"Running validator for subnet: {self.config.netuid} on network: {self.config.subtensor.network}") ++ self.wallet = bt.wallet(config=self.config) ++ self.subtensor = bt.subtensor(config=self.config, network=self.config.subtensor.network) ++ self.metagraph = self.subtensor.metagraph(netuid=self.config.netuid) ++ self.axon = bt.axon(wallet=self.wallet, port=self.config.axon.port) ++ self.dendrite = bt.dendrite(wallet=self.wallet) ++ self.my_uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address) ++ self.check_wallet_registered_in_network() ++ ++ ++ def check_wallet_registered_in_network(self): ++ if self.wallet.hotkey.ss58_address not in self.metagraph.hotkeys: ++ bt.logging.error( ++ f"Your validator: {self.wallet} is not registered to chain connection: " ++ f"{self.subtensor}. Run btcli register --netuid 18 and try again." ++ ) ++ sys.exit() ++ ++ def refresh_network(self): ++ pass ++ ++ ++bt_validator = BittensorValidator() +diff --git a/validators/services/capacity.py b/validators/services/capacity.py +deleted file mode 100644 +index 51a26f4..0000000 +--- a/validators/services/capacity.py ++++ /dev/null +@@ -1,32 +0,0 @@ +-import asyncio +- +-from cortext.protocol import Bandwidth +-import bittensor as bt +- +- +-class CapacityService: +- def __init__(self, metagraph, dendrite): +- self.metagraph = metagraph +- self.dendrite: bt.dendrite = dendrite +- self.timeout = 4 +- +- async def query_capacity_to_miners(self, available_uids): +- capacity_query_tasks = [] +- +- # Query all images concurrently +- for uid in available_uids: +- syn = Bandwidth() +- bt.logging.info(f"querying capacity to uid = {uid}") +- task = self.dendrite.call(self.metagraph.axons[uid], syn, +- timeout=self.timeout) +- capacity_query_tasks.append(task) +- +- # Query responses is (uid. syn) +- query_responses = await asyncio.gather(*capacity_query_tasks, return_exceptions=True) +- uid_to_capacity = {} +- for uid, resp in zip(available_uids, query_responses): +- if isinstance(resp, Exception): +- bt.logging.error(f"exception happens while querying capacity to miner {uid}, {resp}") +- else: +- uid_to_capacity[uid] = resp +- return uid_to_capacity +diff --git a/validators/services/validators/__init__.py b/validators/services/validators/__init__.py +index d461d47..27e4483 100644 +--- a/validators/services/validators/__init__.py ++++ b/validators/services/validators/__init__.py +@@ -1,3 +1,4 @@ ++from .text_validator import TextValidator + from .image_validator import ImageValidator + from .embeddings_validator import EmbeddingsValidator + from .base_validator import BaseValidator +\ No newline at end of file +diff --git a/validators/services/validators/base_validator.py b/validators/services/validators/base_validator.py +index 50e7d7c..5451ec8 100644 +--- a/validators/services/validators/base_validator.py ++++ b/validators/services/validators/base_validator.py +@@ -4,25 +4,27 @@ from datasets import load_dataset + import random + from typing import List, Tuple + +-import bittensor as bt ++import bittensor + + from cortext.metaclasses import ValidatorRegistryMeta + from cortext import utils ++from validators.services.bittensor import bt_validator as bt ++from validators.config import app_config + + dataset = None + + + class BaseValidator(metaclass=ValidatorRegistryMeta): +- def __init__(self, config, metagraph): +- self.config = config +- self.metagraph = metagraph +- self.dendrite = config.dendrite +- self.wallet = config.wallet +- self.timeout = config.async_time_out ++ def __init__(self): ++ self.dendrite = bt.dendrite ++ self.config = bt.config ++ self.subtensor = bt.subtensor ++ self.wallet = bt.wallet ++ self.metagraph = bt.metagraph ++ self.timeout = app_config.ASYNC_TIME_OUT + self.streaming = False + self.provider = None + self.model = None +- self.seed = random.randint(1111, 9999) + self.uid_to_questions = dict() + self.available_uids = [] + self.num_samples = 100 +@@ -41,7 +43,7 @@ class BaseValidator(metaclass=ValidatorRegistryMeta): + for index, uid in enumerate(available_uids): + + if item_type == "images": +- content = await utils.get_question("images", len(available_uids)) ++ content = await utils.get_question("images", len(available_uids)) + self.uid_to_questions[uid] = content # Store messages for each UID + elif item_type == "text": + question = await utils.get_question("text", len(available_uids), vision) +@@ -69,13 +71,13 @@ class BaseValidator(metaclass=ValidatorRegistryMeta): + bt.logging.error(f"Exception during query for uid {uid}: {e}") + return uid, None + +- async def handle_response(self, uid, response) -> Tuple[int, bt.Synapse]: ++ async def handle_response(self, uid, response) -> Tuple[int, bittensor.Synapse]: + if type(response) == list and response: + response = response[0] + return uid, response + + @abstractmethod +- async def start_query(self, available_uids: List[int]) -> bt.Synapse: ++ async def start_query(self, available_uids: List[int]) -> bittensor.Synapse: + pass + + @abstractmethod +@@ -99,6 +101,7 @@ class BaseValidator(metaclass=ValidatorRegistryMeta): + uid_scores_dict = {} + scored_response = [] + ++ + for uid, syn in responses: + task = self.get_answer_task(uid, syn) + answering_tasks.append((uid, task)) +@@ -111,8 +114,7 @@ class BaseValidator(metaclass=ValidatorRegistryMeta): + + # Await all scoring tasks + scored_responses = await asyncio.gather(*[task for _, task in scoring_tasks]) +- average_score = sum(0 if score is None else score for score in scored_responses) / len( +- scored_responses) if scored_responses else 0 ++ average_score = sum(scored_responses) / len(scored_responses) if scored_responses else 0 + bt.logging.debug(f"scored responses = {scored_responses}, average score = {average_score}") + + for (uid, _), scored_response in zip(scoring_tasks, scored_responses): +@@ -123,12 +125,12 @@ class BaseValidator(metaclass=ValidatorRegistryMeta): + + if uid_scores_dict != {}: + bt.logging.info(f"text_scores is {uid_scores_dict}") +- bt.logging.trace("score_responses process completed.") ++ bt.logging.info("score_responses process completed.") + + return uid_scores_dict, scored_response, responses + + async def get_and_score(self, available_uids: List[int]): +- bt.logging.trace("starting query") ++ bt.logging.info("starting query") + query_responses = await self.start_query(available_uids) +- bt.logging.trace("scoring query with query responses") ++ bt.logging.info("scoring query with query responses") + return await self.score_responses(query_responses) +diff --git a/validators/services/validators/constants.py b/validators/services/validators/constants.py +index fb75163..5ad2e2b 100644 +--- a/validators/services/validators/constants.py ++++ b/validators/services/validators/constants.py +@@ -3,19 +3,23 @@ TEXT_PROVIDER = "OpenAI" + TEXT_MAX_TOKENS = 4096 + TEXT_TEMPERATURE = 0.001 + TEXT_WEIGHT = 1 ++TEXT_SEED = 1234 + TEXT_TOP_P = 0.01 + TEXT_TOP_K = 1 + VISION_MODELS = ["gpt-4o", "claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-5-sonnet-20240620"] ++DEFAULT_NUM_UID_PICK = 30 ++DEFAULT_NUM_UID_PICK_ANTHROPIC = 1 + TEXT_VALI_MODELS_WEIGHTS = { + "AnthropicBedrock": { + "anthropic.claude-v2:1": 1 + }, + "OpenAI": { + "gpt-4o": 1, +- "gpt-3.5-turbo": 1000, +- "o1-preview": 1, +- "o1-mini": 1, ++ "gpt-4-1106-preview": 1, ++ "gpt-3.5-turbo": 1, ++ "gpt-3.5-turbo-16k": 1, ++ "gpt-3.5-turbo-0125": 1, + }, + "Gemini": { + "gemini-pro": 1, +@@ -26,23 +30,23 @@ TEXT_VALI_MODELS_WEIGHTS = { + "claude-3-5-sonnet-20240620": 1, + "claude-3-opus-20240229": 1, + "claude-3-sonnet-20240229": 1, +- "claude-3-haiku-20240307": 1000, ++ "claude-3-haiku-20240307": 1 + }, + "Groq": { +- "gemma-7b-it": 500, ++ "gemma-7b-it": 1, + "llama3-70b-8192": 1, +- "llama3-8b-8192": 500, ++ "llama3-8b-8192": 1, + "mixtral-8x7b-32768": 1, + }, + "Bedrock": { +- # "anthropic.claude-3-sonnet-20240229-v1:0": 1, ++ "anthropic.claude-3-sonnet-20240229-v1:0": 1, + "cohere.command-r-v1:0": 1, +- # "meta.llama2-70b-chat-v1": 1, +- # "amazon.titan-text-express-v1": 1, ++ "meta.llama2-70b-chat-v1": 1, ++ "amazon.titan-text-express-v1": 1, + "mistral.mistral-7b-instruct-v0:2": 1, + "ai21.j2-mid-v1": 1, +- # "anthropic.claude-3-5-sonnet-20240620-v1:0": 1, +- # "anthropic.claude-3-opus-20240229-v1:0": 1, +- # "anthropic.claude-3-haiku-20240307-v1:0": 1 ++ "anthropic.claude-3-5-sonnet-20240620-v1:0": 1, ++ "anthropic.claude-3-opus-20240229-v1:0": 1, ++ "anthropic.claude-3-haiku-20240307-v1:0": 1 + } + } +diff --git a/validators/services/validators/embeddings_validator.py b/validators/services/validators/embeddings_validator.py +index 53d8926..99494c2 100644 +--- a/validators/services/validators/embeddings_validator.py ++++ b/validators/services/validators/embeddings_validator.py +@@ -1,7 +1,8 @@ + from __future__ import annotations ++from bittensor import Synapse + import random + import asyncio +-import bittensor as bt ++from validators.services.bittensor import bt_validator as bt + import cortext.reward + from cortext import client + from cortext.protocol import Embeddings +@@ -9,10 +10,9 @@ from validators.services.validators.base_validator import BaseValidator + + + class EmbeddingsValidator(BaseValidator): +- def __init__(self, config): +- super().__init__(config) ++ def __init__(self): ++ super().__init__() + self.streaming = False +- self.config = config + self.query_type = "embeddings" + self.model = "text-embedding-ada-002" + self.weight = 1 +@@ -63,7 +63,7 @@ class EmbeddingsValidator(BaseValidator): + # bt.logging.error(f"Error in processing batch: {e}") + return all_embeddings + +- async def start_query(self, available_uids) -> tuple[(int, bt.Synapse)] | None: ++ async def start_query(self, available_uids) -> tuple[(int, Synapse)] | None: + if not available_uids: + return None + +diff --git a/validators/services/validators/image_validator.py b/validators/services/validators/image_validator.py +index 0bcd243..46e41f3 100644 +--- a/validators/services/validators/image_validator.py ++++ b/validators/services/validators/image_validator.py +@@ -5,15 +5,15 @@ import wandb + + import cortext.reward + from cortext.protocol import ImageResponse ++from validators.services.bittensor import bt_validator as bt + from validators.services.validators.base_validator import BaseValidator + from validators import utils + from validators.utils import error_handler +-import bittensor as bt + + + class ImageValidator(BaseValidator): +- def __init__(self, config, metagraph=None): +- super().__init__(config, metagraph) ++ def __init__(self): ++ super().__init__() + self.num_uids_to_pick = 30 + self.streaming = False + self.query_type = "images" +@@ -26,6 +26,7 @@ class ImageValidator(BaseValidator): + self.quality = "standard" + self.style = "vivid" + self.steps = 30 ++ self.seed = 123456 + self.wandb_data = { + "modality": "images", + "prompts": {}, +diff --git a/validators/services/validators/text_validator.py b/validators/services/validators/text_validator.py +index e739e3f..8adf3af 100644 +--- a/validators/services/validators/text_validator.py ++++ b/validators/services/validators/text_validator.py +@@ -1,9 +1,9 @@ + import asyncio + import random +-import bittensor as bt + from typing import AsyncIterator + + from cortext.reward import model ++from validators.services.bittensor import bt_validator as bt + from . import constants + import cortext.reward + from validators.services.validators.base_validator import BaseValidator +@@ -12,23 +12,22 @@ from typing import Optional + from cortext.protocol import StreamPrompting + from cortext.utils import (call_anthropic_bedrock, call_bedrock, call_anthropic, call_gemini, + call_groq, call_openai, get_question) +-from validators.utils import get_should_i_score_arr_for_text + + + class TextValidator(BaseValidator): +- gen_should_i_score = get_should_i_score_arr_for_text() +- def __init__(self, config, provider: str = None, model: str = None, metagraph=None): +- super().__init__(config, metagraph) ++ def __init__(self, provider: str = None, model: str = None): ++ super().__init__() + self.streaming = True + self.query_type = "text" +- self.metagraph = metagraph + self.model = model or constants.TEXT_MODEL + self.max_tokens = constants.TEXT_MAX_TOKENS + self.temperature = constants.TEXT_TEMPERATURE + self.weight = constants.TEXT_WEIGHT ++ self.seed = constants.TEXT_SEED + self.top_p = constants.TEXT_TOP_P + self.top_k = constants.TEXT_TOP_K + self.provider = provider or constants.TEXT_PROVIDER ++ self.num_uids_to_pick = constants.DEFAULT_NUM_UID_PICK + + self.wandb_data = { + "modality": "text", +@@ -78,7 +77,7 @@ class TextValidator(BaseValidator): + if isinstance(chunk, str): + bt.logging.trace(chunk) + full_response += chunk +- bt.logging.trace(f"full_response for uid {uid}: {full_response}") ++ bt.logging.debug(f"full_response for uid {uid}: {full_response}") + break + return uid, full_response + +@@ -89,7 +88,7 @@ class TextValidator(BaseValidator): + await self.load_questions(available_uids, "text", is_vision_model) + + query_tasks = [] +- bt.logging.trace(f"provider = {self.provider} model = {self.model}") ++ bt.logging.info(f"provider = {self.provider}\nmodel = {self.model}") + for uid, question in self.uid_to_questions.items(): + prompt = question.get("prompt") + image = question.get("image") +@@ -118,17 +117,22 @@ class TextValidator(BaseValidator): + + def select_random_provider_and_model(self): + # AnthropicBedrock should only be used if a validators' anthropic account doesn't work +- providers = ["OpenAI"] * 55 + ["AnthropicBedrock"] * 0 + ["Gemini"] * 1 + ["Anthropic"] * 20 + [ ++ providers = ["OpenAI"] * 55 + ["AnthropicBedrock"] * 0 + ["Gemini"] * 2 + ["Anthropic"] * 20 + [ + "Groq"] * 30 + ["Bedrock"] * 0 + self.provider = random.choice(providers) ++ self.num_uids_to_pick = constants.DEFAULT_NUM_UID_PICK + + model_to_weights = constants.TEXT_VALI_MODELS_WEIGHTS[self.provider] + self.model = random.choices(list(model_to_weights.keys()), + weights=list(model_to_weights.values()), k=1)[0] + +- @classmethod +- def should_i_score(cls): +- return next(cls.gen_should_i_score) ++ return self.num_uids_to_pick ++ ++ def should_i_score(self): ++ random_number = random.random() ++ will_score_all = random_number < 1 / 5 ++ bt.logging.info(f"Random Number: {random_number}, Will score text responses: {will_score_all}") ++ return will_score_all + + @error_handler + async def build_wandb_data(self, uid_to_score, responses): +@@ -185,6 +189,7 @@ class TextValidator(BaseValidator): + question = self.uid_to_questions[uid] + prompt = question.get("prompt") + image_url = question.get("image") ++ bt.logging.info(f"processing image url {image_url}") + return await self.call_api(prompt, image_url, self.provider) + + async def get_scoring_task(self, uid, answer, response): +diff --git a/validators/utils.py b/validators/utils.py +index b8ea92a..95b4ecf 100644 +--- a/validators/utils.py ++++ b/validators/utils.py +@@ -1,7 +1,6 @@ + import aiohttp + import asyncio + import base64 +-import itertools + import bittensor as bt + + from PIL import Image +@@ -37,13 +36,3 @@ def error_handler(func): + return result + + return wrapper +- +- +-def get_should_i_score_arr_for_text(): +- for i in itertools.count(): +- yield (i % 5) != 0 +- +- +-def get_should_i_score_arr_for_image(): +- for i in itertools.count(): +- yield (i % 1) != 0 +diff --git a/validators/validator.py b/validators/validator.py +index 9ec4507..15477bf 100644 +--- a/validators/validator.py ++++ b/validators/validator.py +@@ -1,130 +1,45 @@ +-import os +-import random +-import time +-import argparse + import asyncio +-from pathlib import Path +-from dotenv import load_dotenv +-import bittensor as bt +-import wandb ++import os ++ ++import base # noqa + import cortext ++import wandb + from cortext import utils + from validators.weight_setter import WeightSetter ++from validators.config import bt_config ++from validators.services import bt_validator as bt + +-# Load environment variables from .env file +-load_dotenv() +-random.seed(time.time()) +- +-class NestedNamespace(argparse.Namespace): +- def __setattr__(self, name, value): +- if '.' in name: +- group, name = name.split('.', 1) +- ns = getattr(self, group, NestedNamespace()) +- setattr(ns, name, value) +- self.__dict__[group] = ns +- else: +- self.__dict__[name] = value +- +- def get(self, key, default=None): +- if '.' in key: +- group, key = key.split('.', 1) +- return getattr(self, group, NestedNamespace()).get(key, default) +- return self.__dict__.get(key, default) +- +- +-class Config: +- def __init__(self, args): +- +- # Add command-line arguments to the Config object +- for key, value in vars(args).items(): +- setattr(self, key, value) +- +- @staticmethod +- def check_required_env_vars(): +- required_vars = ['AWS_ACCESS_KEY', 'AWS_SECRET_KEY'] +- missing_vars = [var for var in required_vars if not os.getenv(var)] +- if missing_vars: +- bt.logging.error(f"Missing required environment variables: {', '.join(missing_vars)}") +- exit(1) + +- def get(self, key, default=None): +- return getattr(self, key, default) +- +- +-def parse_arguments(): +- parser = argparse.ArgumentParser(description="Validator Configuration") +- parser.add_argument("--subtensor.chain_endpoint", type=str, default="wss://entrypoint-finney.opentensor.ai:443") +- parser.add_argument("--wallet.name", type=str, default="default") +- parser.add_argument("--wallet.hotkey", type=str, default="default") +- parser.add_argument("--netuid", type=int, default=18) +- parser.add_argument("--wandb_on", action="store_true") +- parser.add_argument("--max_miners_cnt", type=int, default=30) +- parser.add_argument("--axon.port", type=int, default=8000) +- parser.add_argument("--logging.level", choices=['info', 'debug', 'trace'], default='info') +- parser.add_argument("--autoupdate", action="store_true", help="Enable auto-updates") +- parser.add_argument("--image_validator_probability", type=float, default=0.001) +- parser.add_argument("--async_time_out", type=int, default=60) +- return parser.parse_args(namespace=NestedNamespace()) +- +- +-def setup_logging(config): +- if config.logging.level == 'trace': +- bt.logging.set_trace() +- elif config.logging.level == 'debug': +- bt.logging.set_debug() +- else: +- # set to info by default +- pass +- bt.logging.info(f"Set logging level to {config.logging.level}") +- +- full_path = Path( +- f"~/.bittensor/validators/{config.wallet.name}/{config.wallet.hotkey}/netuid{config.netuid}/validator").expanduser() +- full_path.mkdir(parents=True, exist_ok=True) +- config.full_path = str(full_path) +- +- bt.logging.info(f"Arguments: {vars(config)}") +- +- +-def init_wandb(config): +- if not config.wandb_on: ++def init_wandb(): ++ if not bt_config.wandb_on: + return + +- wallet = bt.wallet(name=config.wallet.name, hotkey=config.wallet.hotkey) +- run_name = f"validator-{wallet.hotkey.ss58_address}-{cortext.__version__}" +- config.run_name = run_name +- config.version = cortext.__version__ +- config.type = "validator" ++ run_name = f"validator-{bt.my_uid}-{cortext.__version__}" ++ bt_config.uid = bt.my_uid ++ bt_config.hotkey = bt.wallet.hotkey.ss58_address ++ bt_config.run_name = run_name ++ bt_config.version = cortext.__version__ ++ bt_config.type = "validator" + ++ # Initialize the wandb run for the single project + run = wandb.init( +- name=run_name, +- project=cortext.PROJECT_NAME, +- entity="cortex-t", +- config=config.__dict__, +- dir=config.full_path, ++ name=run_name, project=cortext.PROJECT_NAME, entity="cortex-t", config=bt_config, dir=bt_config.full_path, + reinit=True + ) + +- signature = wallet.hotkey.sign(run.id.encode()).hex() +- config.signature = signature +- wandb.config.update(config.__dict__, allow_val_change=True) ++ # Sign the run to ensure it's from the correct hotkey ++ signature = bt.wallet.hotkey.sign(run.id.encode()).hex() ++ bt_config.signature = signature ++ wandb.config.update(bt_config, allow_val_change=True) + + bt.logging.success(f"Started wandb run for project '{cortext.PROJECT_NAME}'") + + +-def main(): +- Config.check_required_env_vars() +- args = parse_arguments() +- config = Config(args) +- config.wallet = bt.wallet(name=config.wallet.name, hotkey=config.wallet.hotkey) +- config.dendrite = bt.dendrite(wallet=config.wallet) +- setup_logging(config) +- +- bt.logging.info(f"Config: {vars(config)}") +- +- init_wandb(config) ++def main(test=False) -> None: ++ init_wandb() + loop = asyncio.get_event_loop() +- weight_setter = WeightSetter(config=config) +- state_path = os.path.join(config.full_path, "state.json") ++ weight_setter = WeightSetter(loop) ++ state_path = os.path.join(bt_config.full_path, "state.json") + utils.get_state(state_path) + try: + loop.run_forever() +@@ -136,7 +51,7 @@ def main(): + bt.logging.info("updating status before exiting validator") + state = utils.get_state(state_path) + utils.save_state_to_file(state, state_path) +- if config.wandb_on: ++ if bt_config.wandb_on: + wandb.finish() + + +diff --git a/validators/weight_setter.py b/validators/weight_setter.py +index 8c1d52a..08e4b09 100644 +--- a/validators/weight_setter.py ++++ b/validators/weight_setter.py +@@ -1,65 +1,45 @@ + import asyncio + import concurrent ++import itertools + import random + import torch + import traceback +-from substrateinterface import SubstrateInterface + from functools import partial + from typing import Tuple + import wandb +-import bittensor as bt + + import cortext ++from cortext.protocol import TextPrompting + + from starlette.types import Send + + from cortext.protocol import IsAlive, StreamPrompting, ImageResponse, Embeddings + from cortext.metaclasses import ValidatorRegistryMeta +-from validators.services import BaseValidator, TextValidator, CapacityService ++from validators.services import BaseValidator, TextValidator ++from validators.config import bt_config, app_config ++from validators.services.bittensor import bt_validator as bt + ++iterations_per_set_weights = 10 + scoring_organic_timeout = 60 + + + class WeightSetter: +- def __init__(self, config): +- self.uid_to_capacity = {} +- self.available_uids = None +- self.NUM_QUERIES_PER_UID = 10 +- self.remaining_queries = [] +- bt.logging.info("Initializing WeightSetter") +- self.config = config +- self.wallet = config.wallet +- self.subtensor = bt.subtensor(config=config) +- self.node = SubstrateInterface(url=config.subtensor.chain_endpoint) +- self.netuid = self.config.netuid +- self.metagraph = bt.metagraph(netuid=self.netuid, network=config.subtensor.chain_endpoint) +- self.my_uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address) +- bt.logging.info(f"Running validator on subnet: {self.netuid} with uid: {self.my_uid}") +- +- # Initialize scores +- self.total_scores = {} +- self.score_counts = {} +- self.moving_average_scores = None +- +- # Set up axon and dendrite +- self.axon = bt.axon(wallet=self.wallet, config=self.config) +- bt.logging.info(f"Axon server started on port {self.config.axon.port}") +- self.dendrite = config.dendrite +- +- # Set up async-related attributes +- self.lock = asyncio.Lock() +- self.loop = asyncio.get_event_loop() ++ def __init__(self, loop: asyncio.AbstractEventLoop): ++ bt.logging.info("starting weight setter") ++ self.config = bt_config ++ bt.logging.info(f"config:\n{self.config}") ++ self.prompt_cache: dict[str, Tuple[str, int]] = {} + self.request_timestamps = {} ++ self.loop = loop ++ self.dendrite = bt.dendrite ++ self.subtensor = bt.subtensor ++ self.wallet = bt.wallet ++ self.moving_average_scores = None ++ self.axon = bt.axon ++ self.metagraph = bt.metagraph ++ self.my_uid = bt.my_uid ++ self.total_scores = torch.zeros(len(self.metagraph.hotkeys)) + self.organic_scoring_tasks = set() +- +- # Initialize prompt cache +- self.prompt_cache = {} +- +- # Get network tempo +- self.tempo = self.subtensor.tempo(self.netuid) +- self.weights_rate_limit = self.get_weights_rate_limit() +- +- # Set up async tasks + self.thread_executor = concurrent.futures.ThreadPoolExecutor(thread_name_prefix='asyncio') + self.loop.create_task(self.consume_organic_scoring()) + self.loop.create_task(self.perform_synthetic_scoring_and_update_weights()) +@@ -67,226 +47,24 @@ class WeightSetter: + async def run_sync_in_async(self, fn): + return await self.loop.run_in_executor(self.thread_executor, fn) + +- def get_current_block(self): +- return self.node.query("System", "Number", []).value +- +- def get_weights_rate_limit(self): +- return self.node.query("SubtensorModule", "WeightsSetRateLimit", [self.netuid]).value +- +- def get_last_update(self, block): +- try: +- last_update_blocks = block - self.node.query("SubtensorModule", "LastUpdate", [self.netuid]).value[ +- self.my_uid] +- except Exception as err: +- bt.logging.error(f"Error getting last update: {traceback.format_exc()}") +- bt.logging.exception(err) +- # means that the validator is not registered yet. The validator should break if this is the case anyways +- last_update_blocks = 1000 +- +- bt.logging.trace(f"last set weights successfully {last_update_blocks} blocks ago") +- return last_update_blocks +- +- def get_blocks_til_epoch(self, block): +- return self.tempo - (block + 19) % (self.tempo + 1) +- +- async def refresh_metagraph(self): +- await self.run_sync_in_async(lambda: self.metagraph.sync()) +- +- async def initialize_uids_and_capacities(self): +- self.available_uids = await self.get_available_uids() +- bt.logging.info(f"Available UIDs: {list(self.available_uids.keys())}") +- # self.uid_to_capacity = await self.get_capacities_for_uids(self.available_uids) +- # bt.logging.info(f"Capacities for miners: {self.uid_to_capacity}") +- self.total_scores = {uid: 0.0 for uid in self.available_uids.keys()} +- self.score_counts = {uid: 0 for uid in self.available_uids.keys()} +- self.remaining_queries = self.shuffled(list(self.available_uids.keys()) * self.NUM_QUERIES_PER_UID) +- +- async def update_and_refresh(self, last_update): +- bt.logging.info(f"setting weights, last update {last_update} blocks ago") +- await self.update_weights() +- +- bt.logging.info("Refreshing metagraph...") +- await self.refresh_metagraph() +- +- bt.logging.info("Refreshing available UIDs...") +- self.available_uids = await self.get_available_uids() +- bt.logging.info(f"Available UIDs: {list(self.available_uids.keys())}") +- +- # bt.logging.info("Refreshing capacities...") +- # self.uid_to_capacity = await self.get_capacities_for_uids(self.available_uids) +- +- self.total_scores = {uid: 0 for uid in self.available_uids.keys()} +- self.score_counts = {uid: 0 for uid in self.available_uids.keys()} +- self.remaining_queries = self.shuffled(list(self.available_uids.keys()) * self.NUM_QUERIES_PER_UID) +- +- +- async def perform_synthetic_scoring_and_update_weights(self): +- while True: +- if self.available_uids is None: +- await self.initialize_uids_and_capacities() +- +- current_block = self.get_current_block() +- last_update = self.get_last_update(current_block) +- +- if last_update >= self.tempo * 2 or ( +- self.get_blocks_til_epoch(current_block) < 10 and last_update >= self.weights_rate_limit): +- +- await self.update_and_refresh(last_update) +- +- if not self.remaining_queries: +- bt.logging.info("No more queries to perform until next epoch.") +- continue +- +- bt.logging.debug(f"not setting weights, last update {last_update} blocks ago, " +- f"{self.get_blocks_til_epoch(current_block)} blocks til epoch") +- +- selected_validator = self.select_validator() +- num_uids_to_query = min(self.config.max_miners_cnt, len(self.remaining_queries)) +- +- # Pop UIDs to query from the remaining_queries list +- uids_to_query = [self.remaining_queries.pop() for _ in range(num_uids_to_query)] +- uid_to_scores = await self.process_modality(selected_validator, uids_to_query) +- +- bt.logging.info(f"Remaining queries: {len(self.remaining_queries)}") +- +- if uid_to_scores is None: +- bt.logging.trace("uid_to_scores is None.") +- continue +- +- for uid, score in uid_to_scores.items(): +- async with self.lock: +- self.total_scores[uid] += score +- self.score_counts[uid] += 1 +- +- # Slow down the validator steps if necessary +- await asyncio.sleep(1) +- +- def select_validator(self): +- rand = random.random() +- text_validator = ValidatorRegistryMeta.get_class('TextValidator')(config=self.config, metagraph=self.metagraph) +- image_validator = ValidatorRegistryMeta.get_class('ImageValidator')(config=self.config, +- metagraph=self.metagraph) +- if rand > self.config.image_validator_probability: +- return text_validator +- else: +- return image_validator +- +- async def get_capacities_for_uids(self, uids): +- capacity_service = CapacityService(metagraph=self.metagraph, dendrite=self.dendrite) +- uid_to_capacity = await capacity_service.query_capacity_to_miners(uids) +- return uid_to_capacity +- +- async def get_available_uids(self): +- """Get a dictionary of available UIDs and their axons asynchronously.""" +- await self.dendrite.aclose_session() +- tasks = {uid.item(): self.check_uid(self.metagraph.axons[uid.item()], uid.item()) for uid in +- self.metagraph.uids} +- results = await asyncio.gather(*tasks.values()) +- +- # Create a dictionary of UID to axon info for active UIDs +- available_uids = {uid: axon_info for uid, axon_info in zip(tasks.keys(), results) if axon_info is not None} +- +- bt.logging.info(f"Available UIDs: {list(available_uids.keys())}") +- +- return available_uids +- +- async def check_uid(self, axon, uid): +- """Asynchronously check if a UID is available.""" +- try: +- response = await self.dendrite(axon, IsAlive(), timeout=4) +- if response.completion == 'True': +- bt.logging.trace(f"UID {uid} is active") +- return axon # Return the axon info instead of the UID +- +- bt.logging.error(f"UID {uid} is not active") +- return None +- +- except Exception as err: +- bt.logging.error(f"Error checking UID {uid}: {err}") +- return None +- +- @staticmethod +- def shuffled(list_: list) -> list: +- list_ = list_.copy() +- random.shuffle(list_) +- return list_ +- +- async def process_modality(self, selected_validator: BaseValidator, available_uids): +- if not available_uids: +- bt.logging.info("No available uids.") +- return None +- bt.logging.info(f"starting query {selected_validator.__class__.__name__} for miners {available_uids}") +- query_responses = await selected_validator.start_query(available_uids) +- +- if not selected_validator.should_i_score(): +- bt.logging.info("we don't score this time.") +- return None +- +- bt.logging.debug(f"scoring query with query responses for " +- f"these uids: {available_uids}") +- uid_scores_dict, scored_responses, responses = await selected_validator.score_responses(query_responses) +- wandb_data = await selected_validator.build_wandb_data(uid_scores_dict, responses) +- if self.config.wandb_on and not wandb_data: +- wandb.log(wandb_data) +- bt.logging.success("wandb_log successful") +- return uid_scores_dict +- +- async def update_weights(self): +- """Update weights based on average scores, using min-max normalization.""" +- bt.logging.info("Updating weights...") +- avg_scores = {} +- +- # Compute average scores per UID +- for uid in self.total_scores: +- count = self.score_counts[uid] +- if count > 0: +- avg_scores[uid] = self.total_scores[uid] / count +- else: +- avg_scores[uid] = 0.0 +- +- bt.logging.info(f"Average scores = {avg_scores}") +- +- # Convert avg_scores to a tensor aligned with metagraph UIDs +- weights = torch.zeros(len(self.metagraph.uids)) +- for uid, score in avg_scores.items(): +- weights[uid] = score +- +- await self.set_weights(weights) +- +- async def set_weights(self, scores): +- # alpha of .3 means that each new score replaces 30% of the weight of the previous weights +- alpha = .3 +- if self.moving_average_scores is None: +- self.moving_average_scores = scores.clone() +- +- # Update the moving average scores +- self.moving_average_scores = alpha * scores + (1 - alpha) * self.moving_average_scores +- bt.logging.info(f"Updated moving average of weights: {self.moving_average_scores}") +- await self.run_sync_in_async( +- lambda: self.subtensor.set_weights( +- netuid=self.config.netuid, +- wallet=self.wallet, +- uids=self.metagraph.uids, +- weights=self.moving_average_scores, +- wait_for_inclusion=True, +- version_key=cortext.__weights_version__, +- ) +- ) +- bt.logging.success("Successfully included weights in block.") +- + def blacklist_prompt(self, synapse: StreamPrompting) -> Tuple[bool, str]: + blacklist = self.base_blacklist(synapse, cortext.PROMPT_BLACKLIST_STAKE) ++ bt.logging.info(blacklist[1]) ++ return blacklist ++ ++ def blacklist_is_alive(self, synapse: IsAlive) -> Tuple[bool, str]: ++ blacklist = self.base_blacklist(synapse, cortext.ISALIVE_BLACKLIST_STAKE) + bt.logging.debug(blacklist[1]) + return blacklist + + def blacklist_images(self, synapse: ImageResponse) -> Tuple[bool, str]: + blacklist = self.base_blacklist(synapse, cortext.IMAGE_BLACKLIST_STAKE) +- bt.logging.debug(blacklist[1]) ++ bt.logging.info(blacklist[1]) + return blacklist + + def blacklist_embeddings(self, synapse: Embeddings) -> Tuple[bool, str]: + blacklist = self.base_blacklist(synapse, cortext.EMBEDDING_BLACKLIST_STAKE) +- bt.logging.debug(blacklist[1]) ++ bt.logging.info(blacklist[1]) + return blacklist + + def base_blacklist(self, synapse, blacklist_amt=20000) -> Tuple[bool, str]: +@@ -308,7 +86,7 @@ class WeightSetter: + async def images(self, synapse: ImageResponse) -> ImageResponse: + bt.logging.info(f"received {synapse}") + +- synapse = await self.dendrite(self.metagraph.axons[synapse.uid], synapse, deserialize=False, ++ synapse = self.dendrite.query(self.metagraph.axons[synapse.uid], synapse, deserialize=False, + timeout=synapse.timeout) + + bt.logging.info(f"new synapse = {synapse}") +@@ -326,24 +104,27 @@ class WeightSetter: + async def prompt(self, synapse: StreamPrompting) -> StreamPrompting: + bt.logging.info(f"received {synapse}") + +- # Return the streaming response as before + async def _prompt(synapse, send: Send): +- bt.logging.info(f"Sending {synapse} request to uid: {synapse.uid}") ++ bt.logging.info( ++ f"Sending {synapse} request to uid: {synapse.uid}, " ++ ) + + async def handle_response(responses): + for resp in responses: + async for chunk in resp: + if isinstance(chunk, str): +- await send({ +- "type": "http.response.body", +- "body": chunk.encode("utf-8"), +- "more_body": True, +- }) ++ await send( ++ { ++ "type": "http.response.body", ++ "body": chunk.encode("utf-8"), ++ "more_body": True, ++ } ++ ) + bt.logging.info(f"Streamed text: {chunk}") +- await send({"type": "http.response.body", "body": b'', "more_body": False}) ++ await send({"type": "http.response.body", "body": b'', "more_body": False}) + + axon = self.metagraph.axons[synapse.uid] +- responses = await self.dendrite( ++ responses = self.dendrite.query( + axons=[axon], + synapse=synapse, + deserialize=False, +@@ -355,6 +136,16 @@ class WeightSetter: + token_streamer = partial(_prompt, synapse) + return synapse.create_streaming_response(token_streamer) + ++ def text(self, synapse: TextPrompting) -> TextPrompting: ++ synapse.completion = "completed" ++ bt.logging.info("completed") ++ ++ synapse = self.dendrite.query(self.metagraph.axons[synapse.uid], synapse, deserialize=False, ++ timeout=synapse.timeout) ++ ++ bt.logging.info(f"synapse = {synapse}") ++ return synapse ++ + async def consume_organic_scoring(self): + bt.logging.info("Attaching forward function to axon.") + self.axon.attach( +@@ -366,14 +157,170 @@ class WeightSetter: + ).attach( + forward_fn=self.embeddings, + blacklist_fn=self.blacklist_embeddings, ++ ).attach( ++ forward_fn=self.text, + ) +- self.axon.serve(netuid=self.netuid) ++ self.axon.serve(netuid=self.config.netuid, subtensor=self.subtensor) + self.axon.start() + bt.logging.info(f"Running validator on uid: {self.my_uid}") + while True: + try: +- # Check for organic scoring tasks here +- await asyncio.sleep(60) ++ if self.organic_scoring_tasks: ++ completed, _ = await asyncio.wait(self.organic_scoring_tasks, timeout=1, ++ return_when=asyncio.FIRST_COMPLETED) ++ for task in completed: ++ if task.exception(): ++ bt.logging.error( ++ f'Encountered in {TextValidator.score_responses.__name__} task:\n' ++ f'{"".join(traceback.format_exception(task.exception()))}' ++ ) ++ else: ++ success, data = task.result() ++ if not success: ++ continue ++ self.total_scores += data[0] ++ self.organic_scoring_tasks.difference_update(completed) ++ else: ++ await asyncio.sleep(60) + except Exception as err: + bt.logging.exception(err) +- await asyncio.sleep(10) +\ No newline at end of file ++ await asyncio.sleep(10) ++ ++ async def refresh_metagraph(self): ++ await self.run_sync_in_async(lambda: self.metagraph.sync()) ++ ++ async def perform_synthetic_scoring_and_update_weights(self): ++ cur_block = self.subtensor.block ++ while True: ++ bt.logging.info("start validating process.") ++ for steps_passed in itertools.count(): ++ ++ selected_validator = self.select_validator() ++ if not selected_validator.should_i_score(): ++ bt.logging.info("We don't score this time.") ++ await asyncio.sleep(app_config.SLEEP_PER_ITERATION) ++ continue ++ ++ available_uids = await self.get_available_uids() ++ if not len(available_uids): ++ bt.logging.info("no available uids. so referesh network and continue.") ++ await asyncio.sleep(app_config.SLEEP_PER_ITERATION) ++ continue ++ bt.logging.info(f"available uids: {available_uids.keys()}") ++ if bt_config.max_miners_cnt < len(available_uids): ++ available_uids = random.sample(list(available_uids.keys()), bt_config.max_miners_cnt) ++ ++ uid_to_scores = await self.process_modality(selected_validator, available_uids) ++ ++ if uid_to_scores is None: ++ bt.logging.info("uid_to_scores is None.") ++ continue ++ ++ for uid, score in uid_to_scores.items(): ++ self.total_scores[uid] += score ++ ++ # if we want to slow down the speed of the validator steps ++ await asyncio.sleep(app_config.SLEEP_PER_ITERATION) ++ ++ if (self.subtensor.block - cur_block) >= 360: ++ bt.logging.info("refreshing metagraph...") ++ cur_block = self.subtensor.block ++ await self.refresh_metagraph() ++ bt.logging.info("updating weights...") ++ await self.update_weights(steps_passed) ++ ++ @staticmethod ++ def select_validator(): ++ rand = random.random() ++ text_validator = ValidatorRegistryMeta.get_class('TextValidator')() ++ image_validator = ValidatorRegistryMeta.get_class('ImageValidator')() ++ if rand > app_config.IMAGE_VALIDATOR_CHOOSE_PROBABILITY: ++ bt.logging.info("text_validator is selected.") ++ return text_validator ++ else: ++ bt.logging.info("image_validator is selected.") ++ return image_validator ++ ++ async def get_available_uids(self): ++ """Get a dictionary of available UIDs and their axons asynchronously.""" ++ await self.dendrite.aclose_session() ++ tasks = {uid.item(): self.check_uid(self.metagraph.axons[uid.item()], uid.item()) for uid in ++ self.metagraph.uids} ++ results = await asyncio.gather(*tasks.values()) ++ ++ # Create a dictionary of UID to axon info for active UIDs ++ available_uids = {uid: axon_info for uid, axon_info in zip(tasks.keys(), results) if axon_info is not None} ++ ++ return available_uids ++ ++ async def check_uid(self, axon, uid): ++ """Asynchronously check if a UID is available.""" ++ try: ++ response = await self.dendrite(axon, IsAlive(), timeout=4) ++ if response.completion == 'True': ++ bt.logging.trace(f"UID {uid} is active") ++ return axon # Return the axon info instead of the UID ++ ++ bt.logging.error(f"UID {uid} is not active") ++ return None ++ ++ except Exception as err: ++ bt.logging.error(f"Error checking UID {uid}: {err}") ++ return None ++ ++ @staticmethod ++ def shuffled(list_: list) -> list: ++ list_ = list_.copy() ++ random.shuffle(list_) ++ return list_ ++ ++ async def process_modality(self, selected_validator: BaseValidator, available_uids): ++ uid_list = self.shuffled(available_uids) ++ bt.logging.info(f"starting {selected_validator.__class__.__name__} get_and_score for {uid_list}") ++ uid_scores_dict, scored_responses, responses = \ ++ await selected_validator.get_and_score(uid_list) ++ wandb_data = await selected_validator.build_wandb_data(uid_scores_dict, responses) ++ if self.config.wandb_on and not wandb_data: ++ wandb.log(wandb_data) ++ bt.logging.success("wandb_log successful") ++ return uid_scores_dict ++ ++ async def update_weights(self, steps_passed): ++ """ Update weights based on total scores, using min-max normalization for display. """ ++ bt.logging.info("updated weights") ++ avg_scores = self.total_scores / (steps_passed + 1) ++ ++ # Normalize avg_scores to a range of 0 to 1 ++ min_score = torch.min(avg_scores) ++ max_score = torch.max(avg_scores) ++ ++ if max_score - min_score != 0: ++ normalized_scores = (avg_scores - min_score) / (max_score - min_score) ++ else: ++ normalized_scores = torch.zeros_like(avg_scores) ++ ++ bt.logging.info(f"normalized_scores = {normalized_scores}") ++ # We can't set weights with normalized scores because that disrupts the weighting assigned to each validator class ++ # Weights get normalized anyways in weight_utils ++ await self.set_weights(avg_scores) ++ ++ async def set_weights(self, scores): ++ # alpha of .3 means that each new score replaces 30% of the weight of the previous weights ++ alpha = .3 ++ if self.moving_average_scores is None: ++ self.moving_average_scores = scores.clone() ++ ++ # Update the moving average scores ++ self.moving_average_scores = alpha * scores + (1 - alpha) * self.moving_average_scores ++ bt.logging.info(f"Updated moving average of weights: {self.moving_average_scores}") ++ await self.run_sync_in_async( ++ lambda: self.subtensor.set_weights( ++ netuid=self.config.netuid, ++ wallet=self.wallet, ++ uids=self.metagraph.uids, ++ weights=self.moving_average_scores, ++ wait_for_inclusion=True, ++ version_key=cortext.__weights_version__, ++ ) ++ ) ++ bt.logging.success("Successfully set weights.") diff --git a/start_validator.py b/start_validator.py index ad013e1c..92a20c7a 100644 --- a/start_validator.py +++ b/start_validator.py @@ -48,7 +48,7 @@ def update_and_restart(pm2_name, netuid, wallet_name, wallet_hotkey, address, au description="Automatically update and restart the validator process when a new version is released." ) - parser.add_argument("--pm2_name", required=False, default="autoupdater", help="Name of the PM2 process.") + parser.add_argument("--pm2_name", required=False, default="main-process", help="Name of the PM2 process.") parser.add_argument("--wallet_name", required=False, default="default", help="Name of the wallet.") parser.add_argument("--wallet_hotkey", required=False, default="default", help="Hotkey for the wallet.") parser.add_argument("--netuid", required=False, default=18, help="netuid for validator")