Skip to content

Commit

Permalink
init task_manager and use aioredis
Browse files Browse the repository at this point in the history
  • Loading branch information
acer-king committed Sep 24, 2024
1 parent 2c13979 commit 9509794
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 65 deletions.
2 changes: 1 addition & 1 deletion cortext/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ class StreamPrompting(bt.StreamingSynapse):
title="streaming",
description="whether to stream the output",
)
deserialize: bool = pydantic.Field(
deserialize_flag: bool = pydantic.Field(
default=True
)
task_id: int = pydantic.Field(
Expand Down
12 changes: 6 additions & 6 deletions validators/services/redis.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import redis
import aioredis
import asyncio
import bittensor as bt
from cortext import REDIS_RESULT_STREAM


class Redis:
redis_client = redis.StrictRedis(host='localhost', port=6379, db=0)

def __init__(self):
pass

async def get_stream_result(self, task_id):
@staticmethod
async def get_stream_result(redis_client, task_id):
last_id = '0' # Start reading from the beginning of the stream
bt.logging.trace(f"Waiting for results of task {task_id}...")
stream_name = REDIS_RESULT_STREAM + f"{task_id}"

while True:
# Read from the Redis stream
result_entries = Redis.redis_client.xread({stream_name: last_id}, block=5000)
result_entries = redis_client.xread({stream_name: last_id}, block=5000)
result_entries = result_entries or []

for entry in result_entries:
Expand All @@ -31,7 +31,7 @@ async def get_stream_result(self, task_id):
bt.logging.trace("No new results, waiting...")
break
bt.logging.trace(f"stream exit. delete old messages from queue.")
await self.redis_client.xtrim(stream_name, maxlen=0, approximate=False)
await redis_client.xtrim(stream_name, maxlen=0, approximate=False)

def get_result(self, task_id):
pass
14 changes: 7 additions & 7 deletions validators/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from cortext import ALL_SYNAPSE_TYPE
from validators.workers import Worker
from validators import utils
from validators.services.redis import Redis


class TaskMgr:
def __init__(self, uid_to_capacities, dendrite, metagraph):
def __init__(self, uid_to_capacities, dendrite, metagraph, redis_client):
# Initialize Redis client
self.redis_client = Redis.redis_client
self.redis_client = redis_client
self.resources = {}
self.init_resources(uid_to_capacities)
self.dendrite = dendrite
Expand All @@ -26,14 +25,15 @@ def assign_task(self, synapse: ALL_SYNAPSE_TYPE):
return None

task_id = utils.create_hash_value((synapse.json()))
synapse.task_id = task_id
bt.logging.trace(f"Assigning task {task_id} to {resource_key}")

# decrease remaining capacity after sending request.
self.resources[resource_key] -= 1
# Push task to the selected worker's task queue
worker = Worker(task_id=task_id, dendrite=self.dendrite, axon=self.get_axon_from_resource_key(resource_key))
self.redis_client.rpush(f"tasks:{task_id}", synapse.json())
asyncio.create_task(worker.pull_and_run_task())
worker = Worker(synapse=synapse, dendrite=self.dendrite, axon=self.get_axon_from_resource_key(resource_key),
redis_client=self.redis_client)
asyncio.create_task(worker.run_task())

def get_axon_from_resource_key(self, resource_key):
uid = resource_key.split("_")[0]
Expand All @@ -45,4 +45,4 @@ def init_resources(self, uid_to_capacities):
for provider, model_to_cap in cap_info.items():
for model, cap in model_to_cap.items():
resource_key = f"{uid}_{provider}_{model}"
self.resources[resource_key] = cap
self.resources[resource_key] = cap
16 changes: 9 additions & 7 deletions validators/weight_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
import torch
import traceback
import time

import aioredis

from black.trans import defaultdict
from substrateinterface import SubstrateInterface
from functools import partial
from typing import Tuple, List
import bittensor as bt
from bittensor import StreamingSynapse
import redis

import cortext

from starlette.types import Send
Expand All @@ -32,10 +30,11 @@ class WeightSetter:
def __init__(self, config, cache: QueryResponseCache):

# Cache object using sqlite3.
self.task_mgr = None
self.in_cache_processing = False
self.batch_size = config.max_miners_cnt
self.cache = cache
self.redis_client = redis.Redis.redis_client
self.redis_client = aioredis.from_url("redis://localhost", encoding="utf-8", decode_responses=True)

self.uid_to_capacity = {}
self.available_uid_to_axons = {}
Expand Down Expand Up @@ -79,12 +78,12 @@ def __init__(self, config, cache: QueryResponseCache):
self.tempo = self.subtensor.tempo(self.netuid)
self.weights_rate_limit = self.get_weights_rate_limit()

# initialize uid and capacities.
asyncio.run(self.initialize_uids_and_capacities())
self.task_mgr = TaskMgr(uid_to_capacities=self.uid_to_capacity, config=config)
# Set up async tasks
self.thread_executor = concurrent.futures.ThreadPoolExecutor(thread_name_prefix='asyncio')
self.loop.create_task(self.consume_organic_queries())
# self.loop.create_task(self.perform_synthetic_queries())
self.loop.create_task(self.perform_synthetic_queries())
self.loop.create_task(self.process_queries_from_database())

async def run_sync_in_async(self, fn):
Expand Down Expand Up @@ -125,6 +124,9 @@ async def initialize_uids_and_capacities(self):
self.total_scores = {uid: 0.0 for uid in self.available_uid_to_axons.keys()}
self.score_counts = {uid: 0 for uid in self.available_uid_to_axons.keys()}

self.task_mgr = TaskMgr(uid_to_capacities=self.uid_to_capacity, dendrite=self.dendrite,
metagraph=self.metagraph, redis_client=self.redis_client)

async def update_and_refresh(self, last_update):
bt.logging.info(f"Setting weights, last update {last_update} blocks ago")
await self.update_weights()
Expand Down Expand Up @@ -435,7 +437,7 @@ async def _prompt(query_synapse: StreamPrompting, send: Send):
bt.logging.info(f"Sending {synapse} request to uid: {synapse.uid}")
start_time = time.time()

synapse.deserialize = False
synapse.deserialize_flag = False
synapse.streaming = True

task_id = self.task_mgr.assign_task(query_synapse)
Expand Down
67 changes: 23 additions & 44 deletions validators/workers.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,33 @@
import asyncio

import redis
import json
import bittensor as bt
from cortext import StreamPrompting, ImageResponse, REDIS_RESULT_STREAM, REDIS_RESULT
from cortext import REDIS_RESULT_STREAM, REDIS_RESULT


class Worker:
# Initialize Redis client
redis_client = redis.StrictRedis(host='localhost', port=6379, db=0)

def __init__(self, task_id, dendrite, axon):
self.worker_id = task_id
def __init__(self, synapse, dendrite, axon, redis_client):
self.redis_client = redis_client
self.synapse = synapse
self.dendrite = dendrite
self.axon = axon


@staticmethod
def covert_json_to_synapse(task_obj):
if task_obj.get("streaming"):
synapse = StreamPrompting.parse_obj(task_obj)
else:
synapse = ImageResponse.parse_obj(task_obj)
return synapse

async def pull_and_run_task(self):
async def run_task(self):
# Pull task from worker-specific queue
while True:
task = json.loads(self.redis_client.lpop(f"tasks:{self.worker_id}") or "{}")
if task:
synapse = self.covert_json_to_synapse(task)
bt.logging.trace(f"Worker {self.worker_id} received task: {synapse}")
try:
responses = self.dendrite(
axons=[self.axon],
synapse=synapse,
deserialize=synapse.deserialize,
timeout=synapse.timeout,
streaming=synapse.streaming,
)
except Exception as err:
bt.logging.exception(err)
else:
if synapse.streaming:
async for chunk in responses[0]:
if isinstance(chunk, str):
await self.redis_client.xadd(REDIS_RESULT_STREAM, {"chunk": chunk})
else:
await self.redis_client.rpush(REDIS_RESULT, responses[0])
task_id = self.synapse.task_id
bt.logging.trace(f"Worker {task_id} received task: {self.synapse}")
try:
responses = await self.dendrite(
axons=[self.axon],
synapse=self.synapse,
deserialize=self.synapse.deserialize_flag,
timeout=self.synapse.timeout,
streaming=self.synapse.streaming,
)
except Exception as err:
bt.logging.exception(err)
else:
if self.synapse.streaming:
async for chunk in responses[0]:
if isinstance(chunk, str):
await self.redis_client.xadd(REDIS_RESULT_STREAM + f"{task_id}", {"chunk": chunk})
else:
# if there is no task then await 1sec.
bt.logging.info(f"no new task to consume")
break
await self.redis_client.rpush(REDIS_RESULT, responses[0])

0 comments on commit 9509794

Please sign in to comment.