Skip to content

Commit

Permalink
v2/text-validator endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
mpnowacki-reef committed Jan 4, 2024
1 parent b271265 commit 038eb0d
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 74 deletions.
15 changes: 15 additions & 0 deletions tests/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,18 @@ def test_text_validator(self):
resp.raise_for_status()
assert "cucumber" in resp.text
print(resp.text)

def test_text_validator_v2(self):
resp = requests.post(
f'http://localhost:{VALIDATOR_PORT}/v2/text-validator/',
headers={'Authorization': 'token hello'},
json={
'content': 'please write a sentence using the word "cucumber"',
'provider': 'openai',
'miner_uid': 1,
},
timeout=15,
)
resp.raise_for_status()
assert "cucumber" in resp.text
print(resp.text)
12 changes: 11 additions & 1 deletion tests/weights/test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,17 @@ async def test_synthetic_and_organic(aiohttp_client):

client = await aiohttp_client(validator_app)

resp = await client.post('/text-validator/', headers={'access-key': 'hello'}, json={'4': organic_question_1})
resp = await client.post(
'/v2/text-validator/',
headers={
'Authorization': 'token hello',
},
json={
'content': organic_question_1,
'miner_uid': 4,
'provider': 'openai',
},
)
resp_content = (await resp.content.read()).decode()
assert resp_content == organic_answer_1

Expand Down
6 changes: 3 additions & 3 deletions validators/base_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ async def start_query(self, available_uids) -> tuple[list, dict]:
async def score_responses(self, responses):
...

async def get_and_score(self, available_uids, metagraph):
query_responses, uid_to_question = await self.start_query(available_uids, metagraph)
return await self.score_responses(query_responses, uid_to_question, metagraph)
async def get_and_score(self, available_uids, metagraph, provider):
query_responses, uid_to_question = await self.start_query(available_uids, metagraph, provider)
return await self.score_responses(query_responses, uid_to_question, metagraph, provider)
17 changes: 15 additions & 2 deletions validators/image_validator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import enum

import torch
import wandb
import random
Expand All @@ -13,6 +15,11 @@
from template.protocol import ImageResponse


class Provider(enum.Enum):
openai = 'openai'
anthropic = 'stability'


class ImageValidator(BaseValidator):
def __init__(self, dendrite, config, subtensor, wallet):
super().__init__(dendrite, config, subtensor, wallet, timeout=60)
Expand All @@ -33,7 +40,7 @@ def __init__(self, dendrite, config, subtensor, wallet):
"timestamps": {},
}

async def start_query(self, available_uids, metagraph):
async def start_query(self, available_uids, metagraph, provider):
# Query all images concurrently
query_tasks = []
uid_to_messages = {}
Expand All @@ -58,7 +65,13 @@ async def download_image(self, url):
content = await response.read()
return Image.open(BytesIO(content))

async def score_responses(self, query_responses, uid_to_messages, metagraph):
async def score_responses(
self,
query_responses,
uid_to_messages,
metagraph,
provider: Provider,
):
scores = torch.zeros(len(metagraph.hotkeys))
uid_scores_dict = {}
download_tasks = []
Expand Down
147 changes: 87 additions & 60 deletions validators/text_validator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import enum
import random
from typing import AsyncIterator, Tuple

Expand All @@ -11,6 +12,11 @@
from template.utils import call_openai, get_question


class Provider(enum.Enum):
openai = 'openai'
anthropic = 'anthropic'


class TextValidator(BaseValidator):
def __init__(self, dendrite, config, subtensor, wallet: bt.wallet):
super().__init__(dendrite, config, subtensor, wallet, timeout=75)
Expand All @@ -28,28 +34,38 @@ def __init__(self, dendrite, config, subtensor, wallet: bt.wallet):
"timestamps": {},
}

async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]) -> AsyncIterator[tuple[int, str]]:
for uid, messages in query.items():
syn = StreamPrompting(messages=messages, model=self.model, seed=self.seed)
bt.logging.info(
f"Sending {syn.model} {self.query_type} request to uid: {uid}, "
f"timeout {self.timeout}: {syn.messages[0]['content']}"
)
self.wandb_data["prompts"][uid] = messages
responses = await self.dendrite(
metagraph.axons[uid],
syn,
deserialize=False,
timeout=self.timeout,
streaming=self.streaming,
)

async for resp in responses:
if not isinstance(resp, str):
continue

bt.logging.trace(resp)
yield uid, resp
async def organic(
self,
metagraph,
query: dict[str, list[dict[str, str]]],
provider: Provider,
) -> AsyncIterator[tuple[int, str]]:
if provider == Provider.openai:
for uid, messages in query.items():
syn = StreamPrompting(messages=messages, model=self.model, seed=self.seed)
bt.logging.info(
f"Sending {syn.model} {self.query_type} request to uid: {uid}, "
f"timeout {self.timeout}: {syn.messages[0]['content']}"
)
self.wandb_data["prompts"][uid] = messages
responses = await self.dendrite(
metagraph.axons[uid],
syn,
deserialize=False,
timeout=self.timeout,
streaming=self.streaming,
)

async for resp in responses:
if not isinstance(resp, str):
continue

bt.logging.trace(resp)
yield uid, resp
elif provider == Provider.anthropic:
raise NotImplementedError(f'{provider=} is not supported')
else:
raise NotImplementedError(f'{provider=} is not supported')

async def handle_response(self, uid: str, responses) -> tuple[str, str]:
full_response = ""
Expand All @@ -65,7 +81,7 @@ async def handle_response(self, uid: str, responses) -> tuple[str, str]:
async def get_question(self, qty):
return await get_question("text", qty)

async def start_query(self, available_uids, metagraph) -> tuple[list, dict]:
async def start_query(self, available_uids, metagraph, provider) -> tuple[list, dict]:
query_tasks = []
uid_to_question = {}
for uid in available_uids:
Expand Down Expand Up @@ -98,43 +114,49 @@ async def score_responses(
query_responses: list[tuple[int, str]], # [(uid, response)]
uid_to_question: dict[int, str], # uid -> prompt
metagraph: bt.metagraph,
provider: Provider,
) -> tuple[torch.Tensor, dict[int, float], dict]:
scores = torch.zeros(len(metagraph.hotkeys))
uid_scores_dict = {}
openai_response_tasks = []

# Decide to score all UIDs this round based on a chance
will_score_all = self.should_i_score()

for uid, response in query_responses:
self.wandb_data["responses"][uid] = response
if will_score_all and response:
prompt = uid_to_question[uid]
openai_response_tasks.append((uid, self.call_openai(prompt)))

openai_responses = await asyncio.gather(*[task for _, task in openai_response_tasks])

scoring_tasks = []
for (uid, _), openai_answer in zip(openai_response_tasks, openai_responses):
if openai_answer:
response = next(res for u, res in query_responses if u == uid) # Find the matching response
task = template.reward.openai_score(openai_answer, response, self.weight)
scoring_tasks.append((uid, task))

scored_responses = await asyncio.gather(*[task for _, task in scoring_tasks])

for (uid, _), scored_response in zip(scoring_tasks, scored_responses):
if scored_response is not None:
scores[uid] = scored_response
uid_scores_dict[uid] = scored_response
else:
scores[uid] = 0
uid_scores_dict[uid] = 0
# self.wandb_data["scores"][uid] = score

if uid_scores_dict != {}:
bt.logging.info(f"text_scores is {uid_scores_dict}")
return scores, uid_scores_dict, self.wandb_data
if provider == Provider.openai:
scores = torch.zeros(len(metagraph.hotkeys))
uid_scores_dict = {}
openai_response_tasks = []

# Decide to score all UIDs this round based on a chance
will_score_all = self.should_i_score()

for uid, response in query_responses:
self.wandb_data["responses"][uid] = response
if will_score_all and response:
prompt = uid_to_question[uid]
openai_response_tasks.append((uid, self.call_openai(prompt)))

openai_responses = await asyncio.gather(*[task for _, task in openai_response_tasks])

scoring_tasks = []
for (uid, _), openai_answer in zip(openai_response_tasks, openai_responses):
if openai_answer:
response = next(res for u, res in query_responses if u == uid) # Find the matching response
task = template.reward.openai_score(openai_answer, response, self.weight)
scoring_tasks.append((uid, task))

scored_responses = await asyncio.gather(*[task for _, task in scoring_tasks])

for (uid, _), scored_response in zip(scoring_tasks, scored_responses):
if scored_response is not None:
scores[uid] = scored_response
uid_scores_dict[uid] = scored_response
else:
scores[uid] = 0
uid_scores_dict[uid] = 0
# self.wandb_data["scores"][uid] = score

if uid_scores_dict != {}:
bt.logging.info(f"text_scores is {uid_scores_dict}")
return scores, uid_scores_dict, self.wandb_data
elif provider == Provider.anthropic:
raise NotImplementedError(f'{provider=} is not supported')
else:
raise NotImplementedError(f'{provider=} is not supported')


class TestTextValidator(TextValidator):
Expand Down Expand Up @@ -173,7 +195,12 @@ async def get_question(self, qty):
async def query_miner(self, metagraph, uid, syn: StreamPrompting):
return uid, await self.call_openai(syn.messages[0]['content'])

async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]) -> AsyncIterator[tuple[int, str]]:
async def organic(
self,
metagraph,
query: dict[str, list[dict[str, str]]],
provider: Provider,
) -> AsyncIterator[tuple[int, str]]:
for uid, messages in query.items():
for msg in messages:
yield uid, await self.call_openai(msg['content'])
Loading

0 comments on commit 038eb0d

Please sign in to comment.