Skip to content

Commit

Permalink
add workers and task manager and add stream with celery
Browse files Browse the repository at this point in the history
  • Loading branch information
acer-king committed Sep 23, 2024
1 parent bc40be2 commit fa3163a
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 84 deletions.
11 changes: 7 additions & 4 deletions cortext/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ class StreamPrompting(bt.StreamingSynapse):
title="streaming",
description="whether to stream the output",
)
deserialize: bool = pydantic.Field(
default=True
)
task_id: int = pydantic.Field(
default=0
)

async def process_streaming_response(self, response: StreamingResponse) -> AsyncIterator[str]:
if self.completion is None:
Expand All @@ -303,9 +309,6 @@ async def process_streaming_response(self, response: StreamingResponse) -> Async
self.completion += token
yield tokens

def deserialize(self) -> str:
return self.completion

def extract_response_json(self, response: StreamingResponse) -> dict:
headers = {
k.decode("utf-8"): v.decode("utf-8")
Expand Down Expand Up @@ -338,4 +341,4 @@ def extract_info(prefix: str) -> dict[str, str]:
"timeout": self.timeout,
"streaming": self.streaming,
"uid": self.uid,
}
}
7 changes: 1 addition & 6 deletions validators/dendrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,4 @@ async def call(
timeout: float = 12.0,
deserialize: bool = True,
) -> bittensor.Synapse:
uid, remain_cap = Dendrite.get_remaining_capacity(target_axon, synapse)
if remain_cap > 0:
return await super().call(target_axon, synapse, timeout, deserialize)
else:
bt.logging.debug(f"remain_cap is {remain_cap} for this uid {uid}. so can't send request.")
return synapse
pass
34 changes: 0 additions & 34 deletions validators/services/worker_manager.py

This file was deleted.

32 changes: 0 additions & 32 deletions validators/services/workers.py

This file was deleted.

40 changes: 40 additions & 0 deletions validators/task_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import redis
import bittensor as bt

from cortext import ALL_SYNAPSE_TYPE
from validators.workers import Worker


class TaskMgr:
def __init__(self, uid_to_capacities, config):
# Initialize Redis client
self.redis_client = redis.StrictRedis(host='redis', port=6379, db=0)
self.workers = []
self.create_workers(uid_to_capacities)
self.config = config

def assign_task(self, task: ALL_SYNAPSE_TYPE):

# Find the worker with the most available resources (simplified logic)
selected_worker = max(self.workers,
key=lambda w: self.workers[w]) # Example: Assign to worker with max remaining bandwidth
if self.workers[selected_worker] <= 0:
bt.logging.debug(f"no available resources to assign this task.")
return None

bt.logging.trace(f"Assigning task {task} to {selected_worker}")
# decrease remaining capacity after sending request.
self.workers[selected_worker] -= 1
# Push task to the selected worker's task queue
self.redis_client.rpush(f"tasks:{selected_worker}", task.json())

def create_workers(self, uid_to_capacities):
# create worker for each uid, provider, model
workers = []
for uid, cap_info in uid_to_capacities.items():
for provider, model_to_cap in cap_info.items():
for model, cap in model_to_cap.items():
worker_id = f"{uid}_{provider}_{model}"
worker = Worker(worker_id, cap, config=self.config)
workers.append(worker)
self.workers = workers
19 changes: 11 additions & 8 deletions validators/weight_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from validators.services import CapacityService, BaseValidator
from validators.services.cache import QueryResponseCache
from validators.utils import handle_response, error_handler
from validators.task_manager import TaskMgr

scoring_organic_timeout = 60

Expand Down Expand Up @@ -75,10 +76,12 @@ def __init__(self, config, cache: QueryResponseCache):
self.tempo = self.subtensor.tempo(self.netuid)
self.weights_rate_limit = self.get_weights_rate_limit()

asyncio.run(self.initialize_uids_and_capacities())
self.task_mgr = TaskMgr(uid_to_capacities=self.uid_to_capacity, config=config)
# Set up async tasks
self.thread_executor = concurrent.futures.ThreadPoolExecutor(thread_name_prefix='asyncio')
self.loop.create_task(self.consume_organic_queries())
self.loop.create_task(self.perform_synthetic_queries())
# self.loop.create_task(self.perform_synthetic_queries())
self.loop.create_task(self.process_queries_from_database())

async def run_sync_in_async(self, fn):
Expand Down Expand Up @@ -181,7 +184,6 @@ async def perform_synthetic_queries(self):
# remove processing uids
self.uids_to_query = self.uids_to_query[self.batch_size:]


for selected_validator in self.get_validators():
# Perform synthetic queries
bt.logging.info("start querying to miners")
Expand Down Expand Up @@ -381,8 +383,8 @@ async def images(self, synapse: ImageResponse) -> ImageResponse:

axon = self.metagraph.axons[synapse.uid]
start_time = time.time()
synapse_response:ImageResponse = await self.dendrite(axon, synapse, deserialize=False,
timeout=synapse.timeout)
synapse_response: ImageResponse = await self.dendrite(axon, synapse, deserialize=False,
timeout=synapse.timeout)
synapse_response.process_time = time.time() - start_time

bt.logging.info(f"New synapse = {synapse_response}")
Expand Down Expand Up @@ -426,13 +428,14 @@ async def prompt(self, synapse: StreamPrompting) -> StreamingSynapse.BTStreaming
bt.logging.info(f"Received {synapse}")

# Return the streaming response
async def _prompt(synapse, send: Send):
async def _prompt(synapse: StreamPrompting, send: Send):
bt.logging.info(f"Sending {synapse} request to uid: {synapse.uid}")

axon = self.metagraph.axons[synapse.uid]
start_time = time.time()

await self.dendrite.aclose_session()
synapse.deserialize = False
synapse.streaming = True

self.task_mgr.assign_task(synapse)
responses = await self.dendrite(
axons=[axon],
synapse=synapse,
Expand Down
56 changes: 56 additions & 0 deletions validators/workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import asyncio

import redis
import json
import bittensor as bt
from cortext import ALL_SYNAPSE_TYPE, StreamPrompting, ImageResponse


class Worker:
# Initialize Redis client
redis_client = redis.StrictRedis(host='localhost', port=6379, db=0)
TASK_STREAM = 'task_stream'
RESULT_STREAM = 'result_stream'

def __init__(self, worker_id, bandwidth, config, axon):
self.worker_id = worker_id
self.bandwidth = bandwidth
self.dendrite = bt.dendrite(config.wallet)
self.axon = axon
self.report_resources()

def report_resources(self):
# Store worker's resource info in Redis hash
self.redis_client.hset("workers", self.worker_id, self.bandwidth)

@staticmethod
def covert_json_to_synapse(task_obj):
if task_obj.get("streaming"):
synapse = StreamPrompting.parse_obj(task_obj)
else:
synapse = ImageResponse.parse_obj(task_obj)
return synapse

async def pull_and_run_task(self):
# Pull task from worker-specific queue
while True:
task = json.loads(self.redis_client.lpop(f"tasks:{self.worker_id}") or "{}")
if task:
synapse = self.covert_json_to_synapse(task)
bt.logging.trace(f"Worker {self.worker_id} received task: {synapse}")
task_id = synapse.task_id
try:
responses = await self.dendrite(
axons=[self.axon],
synapse=synapse,
deserialize=synapse.deserialize,
timeout=synapse.timeout,
streaming=synapse.streaming,
)
except Exception as err:
bt.logging.exception(err)
else:
async for chunk in responses[0]:
if isinstance(chunk, str):
await self.redis_client.xadd(Worker.RESULT_STREAM, {'task_id': task_id, 'chunk': chunk})
await asyncio.sleep(0.1)

0 comments on commit fa3163a

Please sign in to comment.