From 038eb0ddd99059d97434cccad46581c734d0523a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Nowacki?= Date: Thu, 4 Jan 2024 15:02:25 +0100 Subject: [PATCH] v2/text-validator endpoint --- tests/integration/test_integration.py | 15 +++ tests/weights/test_weights.py | 12 ++- validators/base_validator.py | 6 +- validators/image_validator.py | 17 ++- validators/text_validator.py | 147 +++++++++++++++----------- validators/validator.py | 88 ++++++++++++++- validators/weight_setter.py | 17 ++- 7 files changed, 228 insertions(+), 74 deletions(-) diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 5cb8d44c..fd8d2049 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -47,3 +47,18 @@ def test_text_validator(self): resp.raise_for_status() assert "cucumber" in resp.text print(resp.text) + + def test_text_validator_v2(self): + resp = requests.post( + f'http://localhost:{VALIDATOR_PORT}/v2/text-validator/', + headers={'Authorization': 'token hello'}, + json={ + 'content': 'please write a sentence using the word "cucumber"', + 'provider': 'openai', + 'miner_uid': 1, + }, + timeout=15, + ) + resp.raise_for_status() + assert "cucumber" in resp.text + print(resp.text) \ No newline at end of file diff --git a/tests/weights/test_weights.py b/tests/weights/test_weights.py index 43acfa17..f9986291 100644 --- a/tests/weights/test_weights.py +++ b/tests/weights/test_weights.py @@ -120,7 +120,17 @@ async def test_synthetic_and_organic(aiohttp_client): client = await aiohttp_client(validator_app) - resp = await client.post('/text-validator/', headers={'access-key': 'hello'}, json={'4': organic_question_1}) + resp = await client.post( + '/v2/text-validator/', + headers={ + 'Authorization': 'token hello', + }, + json={ + 'content': organic_question_1, + 'miner_uid': 4, + 'provider': 'openai', + }, + ) resp_content = (await resp.content.read()).decode() assert resp_content == organic_answer_1 diff --git a/validators/base_validator.py b/validators/base_validator.py index 7abae535..5aea70af 100644 --- a/validators/base_validator.py +++ b/validators/base_validator.py @@ -33,6 +33,6 @@ async def start_query(self, available_uids) -> tuple[list, dict]: async def score_responses(self, responses): ... - async def get_and_score(self, available_uids, metagraph): - query_responses, uid_to_question = await self.start_query(available_uids, metagraph) - return await self.score_responses(query_responses, uid_to_question, metagraph) + async def get_and_score(self, available_uids, metagraph, provider): + query_responses, uid_to_question = await self.start_query(available_uids, metagraph, provider) + return await self.score_responses(query_responses, uid_to_question, metagraph, provider) diff --git a/validators/image_validator.py b/validators/image_validator.py index 61ca195b..93ce85e0 100644 --- a/validators/image_validator.py +++ b/validators/image_validator.py @@ -1,3 +1,5 @@ +import enum + import torch import wandb import random @@ -13,6 +15,11 @@ from template.protocol import ImageResponse +class Provider(enum.Enum): + openai = 'openai' + anthropic = 'stability' + + class ImageValidator(BaseValidator): def __init__(self, dendrite, config, subtensor, wallet): super().__init__(dendrite, config, subtensor, wallet, timeout=60) @@ -33,7 +40,7 @@ def __init__(self, dendrite, config, subtensor, wallet): "timestamps": {}, } - async def start_query(self, available_uids, metagraph): + async def start_query(self, available_uids, metagraph, provider): # Query all images concurrently query_tasks = [] uid_to_messages = {} @@ -58,7 +65,13 @@ async def download_image(self, url): content = await response.read() return Image.open(BytesIO(content)) - async def score_responses(self, query_responses, uid_to_messages, metagraph): + async def score_responses( + self, + query_responses, + uid_to_messages, + metagraph, + provider: Provider, + ): scores = torch.zeros(len(metagraph.hotkeys)) uid_scores_dict = {} download_tasks = [] diff --git a/validators/text_validator.py b/validators/text_validator.py index 1a40a1c4..80f65ab7 100644 --- a/validators/text_validator.py +++ b/validators/text_validator.py @@ -1,4 +1,5 @@ import asyncio +import enum import random from typing import AsyncIterator, Tuple @@ -11,6 +12,11 @@ from template.utils import call_openai, get_question +class Provider(enum.Enum): + openai = 'openai' + anthropic = 'anthropic' + + class TextValidator(BaseValidator): def __init__(self, dendrite, config, subtensor, wallet: bt.wallet): super().__init__(dendrite, config, subtensor, wallet, timeout=75) @@ -28,28 +34,38 @@ def __init__(self, dendrite, config, subtensor, wallet: bt.wallet): "timestamps": {}, } - async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]) -> AsyncIterator[tuple[int, str]]: - for uid, messages in query.items(): - syn = StreamPrompting(messages=messages, model=self.model, seed=self.seed) - bt.logging.info( - f"Sending {syn.model} {self.query_type} request to uid: {uid}, " - f"timeout {self.timeout}: {syn.messages[0]['content']}" - ) - self.wandb_data["prompts"][uid] = messages - responses = await self.dendrite( - metagraph.axons[uid], - syn, - deserialize=False, - timeout=self.timeout, - streaming=self.streaming, - ) - - async for resp in responses: - if not isinstance(resp, str): - continue - - bt.logging.trace(resp) - yield uid, resp + async def organic( + self, + metagraph, + query: dict[str, list[dict[str, str]]], + provider: Provider, + ) -> AsyncIterator[tuple[int, str]]: + if provider == Provider.openai: + for uid, messages in query.items(): + syn = StreamPrompting(messages=messages, model=self.model, seed=self.seed) + bt.logging.info( + f"Sending {syn.model} {self.query_type} request to uid: {uid}, " + f"timeout {self.timeout}: {syn.messages[0]['content']}" + ) + self.wandb_data["prompts"][uid] = messages + responses = await self.dendrite( + metagraph.axons[uid], + syn, + deserialize=False, + timeout=self.timeout, + streaming=self.streaming, + ) + + async for resp in responses: + if not isinstance(resp, str): + continue + + bt.logging.trace(resp) + yield uid, resp + elif provider == Provider.anthropic: + raise NotImplementedError(f'{provider=} is not supported') + else: + raise NotImplementedError(f'{provider=} is not supported') async def handle_response(self, uid: str, responses) -> tuple[str, str]: full_response = "" @@ -65,7 +81,7 @@ async def handle_response(self, uid: str, responses) -> tuple[str, str]: async def get_question(self, qty): return await get_question("text", qty) - async def start_query(self, available_uids, metagraph) -> tuple[list, dict]: + async def start_query(self, available_uids, metagraph, provider) -> tuple[list, dict]: query_tasks = [] uid_to_question = {} for uid in available_uids: @@ -98,43 +114,49 @@ async def score_responses( query_responses: list[tuple[int, str]], # [(uid, response)] uid_to_question: dict[int, str], # uid -> prompt metagraph: bt.metagraph, + provider: Provider, ) -> tuple[torch.Tensor, dict[int, float], dict]: - scores = torch.zeros(len(metagraph.hotkeys)) - uid_scores_dict = {} - openai_response_tasks = [] - - # Decide to score all UIDs this round based on a chance - will_score_all = self.should_i_score() - - for uid, response in query_responses: - self.wandb_data["responses"][uid] = response - if will_score_all and response: - prompt = uid_to_question[uid] - openai_response_tasks.append((uid, self.call_openai(prompt))) - - openai_responses = await asyncio.gather(*[task for _, task in openai_response_tasks]) - - scoring_tasks = [] - for (uid, _), openai_answer in zip(openai_response_tasks, openai_responses): - if openai_answer: - response = next(res for u, res in query_responses if u == uid) # Find the matching response - task = template.reward.openai_score(openai_answer, response, self.weight) - scoring_tasks.append((uid, task)) - - scored_responses = await asyncio.gather(*[task for _, task in scoring_tasks]) - - for (uid, _), scored_response in zip(scoring_tasks, scored_responses): - if scored_response is not None: - scores[uid] = scored_response - uid_scores_dict[uid] = scored_response - else: - scores[uid] = 0 - uid_scores_dict[uid] = 0 - # self.wandb_data["scores"][uid] = score - - if uid_scores_dict != {}: - bt.logging.info(f"text_scores is {uid_scores_dict}") - return scores, uid_scores_dict, self.wandb_data + if provider == Provider.openai: + scores = torch.zeros(len(metagraph.hotkeys)) + uid_scores_dict = {} + openai_response_tasks = [] + + # Decide to score all UIDs this round based on a chance + will_score_all = self.should_i_score() + + for uid, response in query_responses: + self.wandb_data["responses"][uid] = response + if will_score_all and response: + prompt = uid_to_question[uid] + openai_response_tasks.append((uid, self.call_openai(prompt))) + + openai_responses = await asyncio.gather(*[task for _, task in openai_response_tasks]) + + scoring_tasks = [] + for (uid, _), openai_answer in zip(openai_response_tasks, openai_responses): + if openai_answer: + response = next(res for u, res in query_responses if u == uid) # Find the matching response + task = template.reward.openai_score(openai_answer, response, self.weight) + scoring_tasks.append((uid, task)) + + scored_responses = await asyncio.gather(*[task for _, task in scoring_tasks]) + + for (uid, _), scored_response in zip(scoring_tasks, scored_responses): + if scored_response is not None: + scores[uid] = scored_response + uid_scores_dict[uid] = scored_response + else: + scores[uid] = 0 + uid_scores_dict[uid] = 0 + # self.wandb_data["scores"][uid] = score + + if uid_scores_dict != {}: + bt.logging.info(f"text_scores is {uid_scores_dict}") + return scores, uid_scores_dict, self.wandb_data + elif provider == Provider.anthropic: + raise NotImplementedError(f'{provider=} is not supported') + else: + raise NotImplementedError(f'{provider=} is not supported') class TestTextValidator(TextValidator): @@ -173,7 +195,12 @@ async def get_question(self, qty): async def query_miner(self, metagraph, uid, syn: StreamPrompting): return uid, await self.call_openai(syn.messages[0]['content']) - async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]) -> AsyncIterator[tuple[int, str]]: + async def organic( + self, + metagraph, + query: dict[str, list[dict[str, str]]], + provider: Provider, + ) -> AsyncIterator[tuple[int, str]]: for uid, messages in query.items(): for msg in messages: yield uid, await self.call_openai(msg['content']) diff --git a/validators/validator.py b/validators/validator.py index 4c747ebd..c96c9322 100644 --- a/validators/validator.py +++ b/validators/validator.py @@ -1,7 +1,11 @@ +import json import logging +import re import time from typing import Tuple +import pydantic + import base # noqa import argparse @@ -16,6 +20,7 @@ from aiohttp import web from aiohttp.web_response import Response from bittensor.btlogging import logger +from validators import text_validator, image_validator from validators.image_validator import ImageValidator from validators.text_validator import TextValidator, TestTextValidator from envparse import env @@ -34,7 +39,7 @@ # organic requests are scored, the tasks are stored in this queue # for later being consumed by `query_synapse` cycle: organic_scoring_tasks = set() -EXPECTED_ACCESS_KEY = env('EXPECTED_ACCESS_KEY', default='hello') +EXPECTED_ACCESS_KEYS = env('EXPECTED_ACCESS_KEY', default='hello').split(',') def get_config() -> bt.config: @@ -114,9 +119,11 @@ def initialize_validators(vali_config, test=False): async def process_text_validator(request: web.Request): + # TODO: this is deprecated in favor process_text_validator_v2 + # Check access key access_key = request.headers.get("access-key") - if access_key != EXPECTED_ACCESS_KEY: + if access_key not in EXPECTED_ACCESS_KEYS: return Response(status=401, reason="Invalid access key") try: @@ -129,11 +136,17 @@ async def process_text_validator(request: web.Request): uid_to_response = dict.fromkeys(messages_dict, "") try: - async for uid, content in text_vali.organic(validator_app.weight_setter.metagraph, messages_dict): + async for uid, content in text_vali.organic( + validator_app.weight_setter.metagraph, + messages_dict, + text_validator.Provider.openai + ): uid_to_response[uid] += content await response.write(content.encode()) validator_app.weight_setter.register_text_validator_organic_query( - uid_to_response, {k: v[0]['content'] for k, v in messages_dict.items()} + uid_to_response, + {k: v[0]['content'] for k, v in messages_dict.items()}, + text_validator.Provider.openai ) except Exception: logger.error(f'Encountered in {process_text_validator.__name__}:\n{traceback.format_exc()}') @@ -142,6 +155,72 @@ async def process_text_validator(request: web.Request): return response +auth_regex = re.compile('token (?P.+)') + + +def is_auhtorized(request: web.Request) -> bool: + if not (authorization := request.headers.get("Authorization")): + return False + + if not (match := auth_regex.match(authorization)): + return False + + if match.group('key') not in EXPECTED_ACCESS_KEYS: + return False + + return True + + +class TextValidatorRequestPayload(pydantic.BaseModel): + provider: text_validator.Provider + content: str + miner_uid: int + + +async def write_error_message(response: web.StreamResponse, msg: str): + await response.write(f'\n--ERROR-- {msg}'.encode()) + + +async def process_text_validator_v2(request: web.Request): + # Check access key + if not is_auhtorized(request): + return Response(status=401, reason="Invalid access key") + + try: + payload: TextValidatorRequestPayload = TextValidatorRequestPayload.parse_raw(await request.text()) + except pydantic.ValidationError as e: + return Response(status=400, reason=json.dumps(e.json())) + + messages_dict = {payload.miner_uid: [{'role': 'user', 'content': payload.content}]} + + text_response = "" + + response = web.StreamResponse() + await response.prepare(request) + + try: + async for uid, content in text_vali.organic( + validator_app.weight_setter.metagraph, + messages_dict, + payload.provider, + ): + text_response += content + await response.write(content.encode()) + if text_response: + validator_app.weight_setter.register_text_validator_organic_query( + {payload.miner_uid: text_response}, + {k: v[0]['content'] for k, v in messages_dict.items()}, + payload.provider, + ) + except Exception: + logger.error(f'Encountered in {process_text_validator.__name__}:\n{traceback.format_exc()}') + await write_error_message(response, 'INTERNAL') + if not text_response: + await write_error_message(response, 'MINER OFFLINE') + + return response + + class ValidatorApplication(web.Application): def __init__(self, *a, **kw): super().__init__(*a, **kw) @@ -150,6 +229,7 @@ def __init__(self, *a, **kw): validator_app = ValidatorApplication() validator_app.add_routes([web.post('/text-validator/', process_text_validator)]) +validator_app.add_routes([web.post('/v2/text-validator/', process_text_validator_v2)]) def main(run_aio_app=True, test=False) -> None: diff --git a/validators/weight_setter.py b/validators/weight_setter.py index 4c7f7731..24c50b58 100644 --- a/validators/weight_setter.py +++ b/validators/weight_setter.py @@ -3,7 +3,7 @@ import itertools import traceback import random -from typing import Tuple +from typing import Tuple, Union import bittensor as bt import torch @@ -11,6 +11,7 @@ from bittensor.btlogging import logger from template.protocol import IsAlive +from validators import text_validator, image_validator from validators.text_validator import TextValidator iterations_per_set_weights = 12 @@ -79,7 +80,8 @@ async def perform_synthetic_scoring_and_update_weights(self): available_uids = await self.get_available_uids() selected_validator = self.select_validator(steps_passed) - scores, _ = await self.process_modality(selected_validator, available_uids) + provider = self.select_provider(selected_validator, steps_passed) + scores, _ = await self.process_modality(selected_validator, available_uids, provider) self.total_scores += scores steps_since_last_update = steps_passed % iterations_per_set_weights @@ -93,6 +95,10 @@ async def perform_synthetic_scoring_and_update_weights(self): await asyncio.sleep(0.5) + def select_provider(self, validator, steps_passed) -> Union[text_validator.Provider, image_validator.Provider]: + # TODO: implement + return text_validator.Provider.openai + def select_validator(self, steps_passed): return self.text_vali if steps_passed % 5 in (0, 1, 2) else self.image_vali @@ -126,10 +132,11 @@ def shuffled(self, list_: list) -> list: random.shuffle(list_) return list_ - async def process_modality(self, selected_validator, available_uids): + async def process_modality(self, selected_validator, available_uids, + provider: Union[text_validator.Provider, image_validator.Provider]): uid_list = self.shuffled(list(available_uids.keys())) bt.logging.info(f"starting {selected_validator.__class__.__name__} get_and_score for {uid_list}") - scores, uid_scores_dict, wandb_data = await selected_validator.get_and_score(uid_list, self.metagraph) + scores, uid_scores_dict, wandb_data = await selected_validator.get_and_score(uid_list, self.metagraph, provider) if self.config.wandb_on: wandb.log(wandb_data) bt.logging.success("wandb_log successful") @@ -177,6 +184,7 @@ def register_text_validator_organic_query( self, uid_to_response: dict[int, str], # [(uid, response)] messages_dict: dict[int, str], + provider: text_validator.Provider, ): self.organic_scoring_tasks.add(asyncio.create_task( wait_for_coro_with_limit( @@ -184,6 +192,7 @@ def register_text_validator_organic_query( query_responses=list(uid_to_response.items()), uid_to_question=messages_dict, metagraph=self.metagraph, + provider=provider, ), scoring_organic_timeout )