Skip to content

Commit

Permalink
Merge pull request #106 from Datura-ai/hotfix-main-bittensor
Browse files Browse the repository at this point in the history
Hotfix main bittensor
  • Loading branch information
surcyf123 authored Nov 12, 2024
2 parents 91cf064 + 5f88168 commit 150b54d
Show file tree
Hide file tree
Showing 18 changed files with 398 additions and 62 deletions.
1 change: 1 addition & 0 deletions cortext/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ async def process_streaming_response(self, response: StreamingResponse, organic=
self.completion += tokens
yield tokens
except asyncio.TimeoutError as err:
self.completion += remain_chunk
yield remain_chunk


Expand Down
18 changes: 18 additions & 0 deletions cursor/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Use official Python image
FROM python:3.10

# Set working directory
WORKDIR /app

# Copy and install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy the app files to the container
COPY . .

# Expose the FastAPI port
EXPOSE 8000

# Start FastAPI app
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
30 changes: 0 additions & 30 deletions cursor/app.py

This file was deleted.

Empty file added cursor/app/__init__.py
Empty file.
Empty file added cursor/app/core/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions cursor/app/core/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import time
from fastapi import HTTPException


async def verify_api_key_rate_limit(config, api_key):
# NOTE: abit dangerous but very useful
if not config.prod:
if api_key == "test":
return True

rate_limit_key = f"rate_limit:{api_key}"
rate_limit = await config.redis_db.get(rate_limit_key)
if rate_limit is None:
async with await config.psql_db.connection() as connection:
# rate_limit = await get_api_key_rate_limit(connection, api_key)
if rate_limit is None:
raise HTTPException(status_code=403, detail="Invalid API key")
await config.redis_db.set(rate_limit_key, rate_limit, ex=30)
else:
rate_limit = int(rate_limit)

minute = time.time() // 60
current_rate_limit_key = f"current_rate_limit:{api_key}:{minute}"
current_rate_limit = await config.redis_db.get(current_rate_limit_key)
if current_rate_limit is None:
current_rate_limit = 0
await config.redis_db.expire(current_rate_limit_key, 60)
else:
current_rate_limit = int(current_rate_limit)

await config.redis_db.incr(current_rate_limit_key)
if current_rate_limit >= rate_limit:
raise HTTPException(status_code=429, detail="Too many requests")
21 changes: 21 additions & 0 deletions cursor/app/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import psycopg2
import os

DATABASE_URL = os.getenv("DATABASE_URL")
TABEL_NAME = 'query_resp_data'
# PostgreSQL connection parameters
conn = psycopg2.connect(DATABASE_URL)

# Create a cursor object to interact with the database
cur = conn.cursor()


async def create_table(app):
global conn, cur, TABEL_NAME
try:
pass

except Exception as e:
print(f"Error creating table: {e}")

create_table(None)
Empty file.
173 changes: 173 additions & 0 deletions cursor/app/endpoints/text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import json
from typing import Any, AsyncGenerator
import uuid
from fastapi import Depends, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from redis.asyncio import Redis
from fastapi.routing import APIRouter
from cursor.app.models import RequestModel
import asyncio
from redis.asyncio.client import PubSub
import time

COUNTER_TEXT_GENERATION_ERROR = metrics.get_meter(__name__).create_counter("validator.entry_node.text.error")
COUNTER_TEXT_GENERATION_SUCCESS = metrics.get_meter(__name__).create_counter("validator.entry_node.text.success")
GAUGE_TOKENS_PER_SEC = metrics.get_meter(__name__).create_gauge(
"validator.entry_node.text.tokens_per_sec",
description="Average tokens per second metric for LLM streaming for an organic LLM query"
)


def _construct_organic_message(payload: dict, job_id: str, task: str) -> str:
return json.dumps({"query_type": gcst.ORGANIC, "query_payload": payload, "task": task, "job_id": job_id})


async def _wait_for_acknowledgement(pubsub: PubSub, job_id: str) -> bool:
async for message in pubsub.listen():
channel = message["channel"].decode()
if channel == f"{gcst.ACKNLOWEDGED}:{job_id}" and message["type"] == "message":
logger.info(f"Job {job_id} confirmed by worker")
break
await pubsub.unsubscribe(f"{gcst.ACKNLOWEDGED}:{job_id}")
return True


async def _stream_results(pubsub: PubSub, job_id: str, task: str, first_chunk: str, start_time: float) -> \
AsyncGenerator[str, str]:
yield first_chunk
num_tokens = 0
async for message in pubsub.listen():
channel = message["channel"].decode()

if channel == f"{rcst.JOB_RESULTS}:{job_id}" and message["type"] == "message":
result = json.loads(message["data"].decode())
if gcst.ACKNLOWEDGED in result:
continue
status_code = result[gcst.STATUS_CODE]
if status_code >= 400:
COUNTER_TEXT_GENERATION_ERROR.add(1, {"task": task, "kind": "nth_chunk_timeout",
"status_code": status_code})
raise HTTPException(status_code=status_code, detail=result[gcst.ERROR_MESSAGE])

content = result[gcst.CONTENT]
num_tokens += 1
yield content
if "[DONE]" in content:
break
COUNTER_TEXT_GENERATION_SUCCESS.add(1, {"task": task, "status_code": 200})
completion_time = time.time() - start_time

tps = num_tokens / completion_time
GAUGE_TOKENS_PER_SEC.set(tps, {"task": task})
logger.info(f"Tokens per second for job_id: {job_id}, task: {task}: {tps}")

await pubsub.unsubscribe(f"{rcst.JOB_RESULTS}:{job_id}")


async def _get_first_chunk(pubsub: PubSub, job_id: str) -> str | None:
async for message in pubsub.listen():
if message["type"] == "message" and message["channel"].decode() == f"{rcst.JOB_RESULTS}:{job_id}":
result = json.loads(message["data"].decode())
if gcst.STATUS_CODE in result and result[gcst.STATUS_CODE] >= 400:
raise HTTPException(status_code=result[gcst.STATUS_CODE], detail=result[gcst.ERROR_MESSAGE])
return result[gcst.CONTENT]
return None


async def make_stream_organic_query(
redis_db: Redis,
payload: dict[str, Any],
task: str,
) -> AsyncGenerator[str, str]:
job_id = uuid.uuid4().hex
organic_message = _construct_organic_message(payload=payload, job_id=job_id, task=task)

pubsub = redis_db.pubsub()
await pubsub.subscribe(f"{gcst.ACKNLOWEDGED}:{job_id}")
await redis_db.lpush(rcst.QUERY_QUEUE_KEY, organic_message) # type: ignore

first_chunk = None
try:
await asyncio.wait_for(_wait_for_acknowledgement(pubsub, job_id), timeout=1)
except asyncio.TimeoutError:
logger.error(
f"Query node down? No confirmation received for job {job_id} within timeout period. Task: {task}, model: {payload['model']}"
)
COUNTER_TEXT_GENERATION_ERROR.add(1,
{"task": task, "kind": "redis_acknowledgement_timeout", "status_code": 500})
raise HTTPException(status_code=500, detail="Unable to process request ; redis_acknowledgement_timeout")

await pubsub.subscribe(f"{rcst.JOB_RESULTS}:{job_id}")
logger.info("Here waiting for a message!")
start_time = time.time()
try:
first_chunk = await asyncio.wait_for(_get_first_chunk(pubsub, job_id), timeout=2)
except asyncio.TimeoutError:
logger.error(
f"Query node down? Timed out waiting for the first chunk of results for job {job_id}. Task: {task}, model: {payload['model']}"
)
COUNTER_TEXT_GENERATION_ERROR.add(1, {"task": task, "kind": "first_chunk_timeout", "status_code": 500})
raise HTTPException(status_code=500, detail="Unable to process request ; first_chunk_timeout")

if first_chunk is None:
COUNTER_TEXT_GENERATION_ERROR.add(1, {"task": task, "kind": "first_chunk_missing", "status_code": 500})
raise HTTPException(status_code=500, detail="Unable to process request ; first_chunk_missing")
return _stream_results(pubsub, job_id, task, first_chunk, start_time)


async def _handle_no_stream(text_generator: AsyncGenerator[str, str]) -> JSONResponse:
all_content = ""
async for chunk in text_generator:
chunks = load_sse_jsons(chunk)
if isinstance(chunks, list):
for chunk in chunks:
content = chunk["choices"][0]["delta"]["content"]
all_content += content
if content == "":
break

return JSONResponse({"choices": [{"delta": {"content": all_content}}]})


async def chat(
chat_request: request_models.ChatRequest,
config: Config = Depends(get_config),
) -> StreamingResponse | JSONResponse:
payload = request_models.chat_to_payload(chat_request)
payload.temperature = 0.5

try:
text_generator = await make_stream_organic_query(
redis_db=config.redis_db,
payload=payload.model_dump(),
task=payload.model,
)

logger.info("Here returning a response!")

if chat_request.stream:
return StreamingResponse(text_generator, media_type="text/event-stream")
else:
return await _handle_no_stream(text_generator)

except HTTPException as http_exc:
COUNTER_TEXT_GENERATION_ERROR.add(1,
{"task": payload.model, "kind": type(http_exc).__name__, "status_code": 500})
logger.info(f"HTTPException in chat endpoint: {str(http_exc)}")
raise http_exc

except Exception as e:
COUNTER_TEXT_GENERATION_ERROR.add(1, {"task": payload.model, "kind": type(e).__name__, "status_code": 500})
logger.error(f"Unexpected error in chat endpoint: {str(e)}")
raise HTTPException(status_code=500, detail="An unexpected error occurred")


router = APIRouter()
router.add_api_route(
"/v1/chat/completions",
chat,
methods=["POST", "OPTIONS"],
tags=["StreamPrompting"],
response_model=None,
dependencies=[Depends(verify_api_key_rate_limit)],
)
49 changes: 49 additions & 0 deletions cursor/app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from . import curd, models, schemas
from .database import create_table, conn, cur
from typing import List
from .endpoints.text import router as chat_router

@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the ML model
await create_table(None)
yield


app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
)
app.include_router(chat_router)


@app.on_event("shutdown")
async def shutdown_event():
cur.close()
conn.close()


# Create an item
@app.post("/items")
def create_item(items: List[schemas.ItemCreate]):
return curd.create_items(items=items)


# Read all items
@app.post("/items/search")
def read_items(req_body: models.RequestBody):
items = curd.get_items(req_body)
return items


# Read a single item by ID
@app.get("/items/{p_key}", response_model=schemas.Item)
def read_item(p_key: int):
db_item = curd.get_item(p_key=p_key)
if db_item is None:
raise HTTPException(status_code=404, detail="Item not found")
return db_item
5 changes: 5 additions & 0 deletions cursor/app/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from cortext.protocol import StreamPrompting


class RequestModel(StreamPrompting):
pass
30 changes: 30 additions & 0 deletions cursor/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
version: "3.9"

services:
db:
image: postgres:13
restart: always
environment:
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_DB: ${POSTGRES_DB}
volumes:
- postgres_data_score:/var/lib/postgresql/data
ports:
- "5432:5432"

web:
build: .
restart: always
ports:
- "8000:8000"
depends_on:
- db
environment:
DATABASE_URL: postgres://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB}
POSTGRES_DB: ${POSTGRES_DB}
volumes:
- .:/app

volumes:
postgres_data_score:
5 changes: 5 additions & 0 deletions cursor/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fastapi
uvicorn[standard]
psycopg2-binary
sqlalchemy
pydantic
Loading

0 comments on commit 150b54d

Please sign in to comment.