-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add workers and task manager and add stream with celery
- Loading branch information
acer-king
committed
Sep 23, 2024
1 parent
bc40be2
commit fa3163a
Showing
7 changed files
with
115 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import redis | ||
import bittensor as bt | ||
|
||
from cortext import ALL_SYNAPSE_TYPE | ||
from validators.workers import Worker | ||
|
||
|
||
class TaskMgr: | ||
def __init__(self, uid_to_capacities, config): | ||
# Initialize Redis client | ||
self.redis_client = redis.StrictRedis(host='redis', port=6379, db=0) | ||
self.workers = [] | ||
self.create_workers(uid_to_capacities) | ||
self.config = config | ||
|
||
def assign_task(self, task: ALL_SYNAPSE_TYPE): | ||
|
||
# Find the worker with the most available resources (simplified logic) | ||
selected_worker = max(self.workers, | ||
key=lambda w: self.workers[w]) # Example: Assign to worker with max remaining bandwidth | ||
if self.workers[selected_worker] <= 0: | ||
bt.logging.debug(f"no available resources to assign this task.") | ||
return None | ||
|
||
bt.logging.trace(f"Assigning task {task} to {selected_worker}") | ||
# decrease remaining capacity after sending request. | ||
self.workers[selected_worker] -= 1 | ||
# Push task to the selected worker's task queue | ||
self.redis_client.rpush(f"tasks:{selected_worker}", task.json()) | ||
|
||
def create_workers(self, uid_to_capacities): | ||
# create worker for each uid, provider, model | ||
workers = [] | ||
for uid, cap_info in uid_to_capacities.items(): | ||
for provider, model_to_cap in cap_info.items(): | ||
for model, cap in model_to_cap.items(): | ||
worker_id = f"{uid}_{provider}_{model}" | ||
worker = Worker(worker_id, cap, config=self.config) | ||
workers.append(worker) | ||
self.workers = workers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import asyncio | ||
|
||
import redis | ||
import json | ||
import bittensor as bt | ||
from cortext import ALL_SYNAPSE_TYPE, StreamPrompting, ImageResponse | ||
|
||
|
||
class Worker: | ||
# Initialize Redis client | ||
redis_client = redis.StrictRedis(host='localhost', port=6379, db=0) | ||
TASK_STREAM = 'task_stream' | ||
RESULT_STREAM = 'result_stream' | ||
|
||
def __init__(self, worker_id, bandwidth, config, axon): | ||
self.worker_id = worker_id | ||
self.bandwidth = bandwidth | ||
self.dendrite = bt.dendrite(config.wallet) | ||
self.axon = axon | ||
self.report_resources() | ||
|
||
def report_resources(self): | ||
# Store worker's resource info in Redis hash | ||
self.redis_client.hset("workers", self.worker_id, self.bandwidth) | ||
|
||
@staticmethod | ||
def covert_json_to_synapse(task_obj): | ||
if task_obj.get("streaming"): | ||
synapse = StreamPrompting.parse_obj(task_obj) | ||
else: | ||
synapse = ImageResponse.parse_obj(task_obj) | ||
return synapse | ||
|
||
async def pull_and_run_task(self): | ||
# Pull task from worker-specific queue | ||
while True: | ||
task = json.loads(self.redis_client.lpop(f"tasks:{self.worker_id}") or "{}") | ||
if task: | ||
synapse = self.covert_json_to_synapse(task) | ||
bt.logging.trace(f"Worker {self.worker_id} received task: {synapse}") | ||
task_id = synapse.task_id | ||
try: | ||
responses = await self.dendrite( | ||
axons=[self.axon], | ||
synapse=synapse, | ||
deserialize=synapse.deserialize, | ||
timeout=synapse.timeout, | ||
streaming=synapse.streaming, | ||
) | ||
except Exception as err: | ||
bt.logging.exception(err) | ||
else: | ||
async for chunk in responses[0]: | ||
if isinstance(chunk, str): | ||
await self.redis_client.xadd(Worker.RESULT_STREAM, {'task_id': task_id, 'chunk': chunk}) | ||
await asyncio.sleep(0.1) |