-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #106 from Datura-ai/hotfix-main-bittensor
Hotfix main bittensor
- Loading branch information
Showing
18 changed files
with
398 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Use official Python image | ||
FROM python:3.10 | ||
|
||
# Set working directory | ||
WORKDIR /app | ||
|
||
# Copy and install dependencies | ||
COPY requirements.txt . | ||
RUN pip install --no-cache-dir -r requirements.txt | ||
|
||
# Copy the app files to the container | ||
COPY . . | ||
|
||
# Expose the FastAPI port | ||
EXPOSE 8000 | ||
|
||
# Start FastAPI app | ||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] |
This file was deleted.
Oops, something went wrong.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import time | ||
from fastapi import HTTPException | ||
|
||
|
||
async def verify_api_key_rate_limit(config, api_key): | ||
# NOTE: abit dangerous but very useful | ||
if not config.prod: | ||
if api_key == "test": | ||
return True | ||
|
||
rate_limit_key = f"rate_limit:{api_key}" | ||
rate_limit = await config.redis_db.get(rate_limit_key) | ||
if rate_limit is None: | ||
async with await config.psql_db.connection() as connection: | ||
# rate_limit = await get_api_key_rate_limit(connection, api_key) | ||
if rate_limit is None: | ||
raise HTTPException(status_code=403, detail="Invalid API key") | ||
await config.redis_db.set(rate_limit_key, rate_limit, ex=30) | ||
else: | ||
rate_limit = int(rate_limit) | ||
|
||
minute = time.time() // 60 | ||
current_rate_limit_key = f"current_rate_limit:{api_key}:{minute}" | ||
current_rate_limit = await config.redis_db.get(current_rate_limit_key) | ||
if current_rate_limit is None: | ||
current_rate_limit = 0 | ||
await config.redis_db.expire(current_rate_limit_key, 60) | ||
else: | ||
current_rate_limit = int(current_rate_limit) | ||
|
||
await config.redis_db.incr(current_rate_limit_key) | ||
if current_rate_limit >= rate_limit: | ||
raise HTTPException(status_code=429, detail="Too many requests") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import psycopg2 | ||
import os | ||
|
||
DATABASE_URL = os.getenv("DATABASE_URL") | ||
TABEL_NAME = 'query_resp_data' | ||
# PostgreSQL connection parameters | ||
conn = psycopg2.connect(DATABASE_URL) | ||
|
||
# Create a cursor object to interact with the database | ||
cur = conn.cursor() | ||
|
||
|
||
async def create_table(app): | ||
global conn, cur, TABEL_NAME | ||
try: | ||
pass | ||
|
||
except Exception as e: | ||
print(f"Error creating table: {e}") | ||
|
||
create_table(None) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
import json | ||
from typing import Any, AsyncGenerator | ||
import uuid | ||
from fastapi import Depends, HTTPException | ||
from fastapi.responses import JSONResponse, StreamingResponse | ||
from redis.asyncio import Redis | ||
from fastapi.routing import APIRouter | ||
from cursor.app.models import RequestModel | ||
import asyncio | ||
from redis.asyncio.client import PubSub | ||
import time | ||
|
||
COUNTER_TEXT_GENERATION_ERROR = metrics.get_meter(__name__).create_counter("validator.entry_node.text.error") | ||
COUNTER_TEXT_GENERATION_SUCCESS = metrics.get_meter(__name__).create_counter("validator.entry_node.text.success") | ||
GAUGE_TOKENS_PER_SEC = metrics.get_meter(__name__).create_gauge( | ||
"validator.entry_node.text.tokens_per_sec", | ||
description="Average tokens per second metric for LLM streaming for an organic LLM query" | ||
) | ||
|
||
|
||
def _construct_organic_message(payload: dict, job_id: str, task: str) -> str: | ||
return json.dumps({"query_type": gcst.ORGANIC, "query_payload": payload, "task": task, "job_id": job_id}) | ||
|
||
|
||
async def _wait_for_acknowledgement(pubsub: PubSub, job_id: str) -> bool: | ||
async for message in pubsub.listen(): | ||
channel = message["channel"].decode() | ||
if channel == f"{gcst.ACKNLOWEDGED}:{job_id}" and message["type"] == "message": | ||
logger.info(f"Job {job_id} confirmed by worker") | ||
break | ||
await pubsub.unsubscribe(f"{gcst.ACKNLOWEDGED}:{job_id}") | ||
return True | ||
|
||
|
||
async def _stream_results(pubsub: PubSub, job_id: str, task: str, first_chunk: str, start_time: float) -> \ | ||
AsyncGenerator[str, str]: | ||
yield first_chunk | ||
num_tokens = 0 | ||
async for message in pubsub.listen(): | ||
channel = message["channel"].decode() | ||
|
||
if channel == f"{rcst.JOB_RESULTS}:{job_id}" and message["type"] == "message": | ||
result = json.loads(message["data"].decode()) | ||
if gcst.ACKNLOWEDGED in result: | ||
continue | ||
status_code = result[gcst.STATUS_CODE] | ||
if status_code >= 400: | ||
COUNTER_TEXT_GENERATION_ERROR.add(1, {"task": task, "kind": "nth_chunk_timeout", | ||
"status_code": status_code}) | ||
raise HTTPException(status_code=status_code, detail=result[gcst.ERROR_MESSAGE]) | ||
|
||
content = result[gcst.CONTENT] | ||
num_tokens += 1 | ||
yield content | ||
if "[DONE]" in content: | ||
break | ||
COUNTER_TEXT_GENERATION_SUCCESS.add(1, {"task": task, "status_code": 200}) | ||
completion_time = time.time() - start_time | ||
|
||
tps = num_tokens / completion_time | ||
GAUGE_TOKENS_PER_SEC.set(tps, {"task": task}) | ||
logger.info(f"Tokens per second for job_id: {job_id}, task: {task}: {tps}") | ||
|
||
await pubsub.unsubscribe(f"{rcst.JOB_RESULTS}:{job_id}") | ||
|
||
|
||
async def _get_first_chunk(pubsub: PubSub, job_id: str) -> str | None: | ||
async for message in pubsub.listen(): | ||
if message["type"] == "message" and message["channel"].decode() == f"{rcst.JOB_RESULTS}:{job_id}": | ||
result = json.loads(message["data"].decode()) | ||
if gcst.STATUS_CODE in result and result[gcst.STATUS_CODE] >= 400: | ||
raise HTTPException(status_code=result[gcst.STATUS_CODE], detail=result[gcst.ERROR_MESSAGE]) | ||
return result[gcst.CONTENT] | ||
return None | ||
|
||
|
||
async def make_stream_organic_query( | ||
redis_db: Redis, | ||
payload: dict[str, Any], | ||
task: str, | ||
) -> AsyncGenerator[str, str]: | ||
job_id = uuid.uuid4().hex | ||
organic_message = _construct_organic_message(payload=payload, job_id=job_id, task=task) | ||
|
||
pubsub = redis_db.pubsub() | ||
await pubsub.subscribe(f"{gcst.ACKNLOWEDGED}:{job_id}") | ||
await redis_db.lpush(rcst.QUERY_QUEUE_KEY, organic_message) # type: ignore | ||
|
||
first_chunk = None | ||
try: | ||
await asyncio.wait_for(_wait_for_acknowledgement(pubsub, job_id), timeout=1) | ||
except asyncio.TimeoutError: | ||
logger.error( | ||
f"Query node down? No confirmation received for job {job_id} within timeout period. Task: {task}, model: {payload['model']}" | ||
) | ||
COUNTER_TEXT_GENERATION_ERROR.add(1, | ||
{"task": task, "kind": "redis_acknowledgement_timeout", "status_code": 500}) | ||
raise HTTPException(status_code=500, detail="Unable to process request ; redis_acknowledgement_timeout") | ||
|
||
await pubsub.subscribe(f"{rcst.JOB_RESULTS}:{job_id}") | ||
logger.info("Here waiting for a message!") | ||
start_time = time.time() | ||
try: | ||
first_chunk = await asyncio.wait_for(_get_first_chunk(pubsub, job_id), timeout=2) | ||
except asyncio.TimeoutError: | ||
logger.error( | ||
f"Query node down? Timed out waiting for the first chunk of results for job {job_id}. Task: {task}, model: {payload['model']}" | ||
) | ||
COUNTER_TEXT_GENERATION_ERROR.add(1, {"task": task, "kind": "first_chunk_timeout", "status_code": 500}) | ||
raise HTTPException(status_code=500, detail="Unable to process request ; first_chunk_timeout") | ||
|
||
if first_chunk is None: | ||
COUNTER_TEXT_GENERATION_ERROR.add(1, {"task": task, "kind": "first_chunk_missing", "status_code": 500}) | ||
raise HTTPException(status_code=500, detail="Unable to process request ; first_chunk_missing") | ||
return _stream_results(pubsub, job_id, task, first_chunk, start_time) | ||
|
||
|
||
async def _handle_no_stream(text_generator: AsyncGenerator[str, str]) -> JSONResponse: | ||
all_content = "" | ||
async for chunk in text_generator: | ||
chunks = load_sse_jsons(chunk) | ||
if isinstance(chunks, list): | ||
for chunk in chunks: | ||
content = chunk["choices"][0]["delta"]["content"] | ||
all_content += content | ||
if content == "": | ||
break | ||
|
||
return JSONResponse({"choices": [{"delta": {"content": all_content}}]}) | ||
|
||
|
||
async def chat( | ||
chat_request: request_models.ChatRequest, | ||
config: Config = Depends(get_config), | ||
) -> StreamingResponse | JSONResponse: | ||
payload = request_models.chat_to_payload(chat_request) | ||
payload.temperature = 0.5 | ||
|
||
try: | ||
text_generator = await make_stream_organic_query( | ||
redis_db=config.redis_db, | ||
payload=payload.model_dump(), | ||
task=payload.model, | ||
) | ||
|
||
logger.info("Here returning a response!") | ||
|
||
if chat_request.stream: | ||
return StreamingResponse(text_generator, media_type="text/event-stream") | ||
else: | ||
return await _handle_no_stream(text_generator) | ||
|
||
except HTTPException as http_exc: | ||
COUNTER_TEXT_GENERATION_ERROR.add(1, | ||
{"task": payload.model, "kind": type(http_exc).__name__, "status_code": 500}) | ||
logger.info(f"HTTPException in chat endpoint: {str(http_exc)}") | ||
raise http_exc | ||
|
||
except Exception as e: | ||
COUNTER_TEXT_GENERATION_ERROR.add(1, {"task": payload.model, "kind": type(e).__name__, "status_code": 500}) | ||
logger.error(f"Unexpected error in chat endpoint: {str(e)}") | ||
raise HTTPException(status_code=500, detail="An unexpected error occurred") | ||
|
||
|
||
router = APIRouter() | ||
router.add_api_route( | ||
"/v1/chat/completions", | ||
chat, | ||
methods=["POST", "OPTIONS"], | ||
tags=["StreamPrompting"], | ||
response_model=None, | ||
dependencies=[Depends(verify_api_key_rate_limit)], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from contextlib import asynccontextmanager | ||
from fastapi import FastAPI, Depends, HTTPException | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from . import curd, models, schemas | ||
from .database import create_table, conn, cur | ||
from typing import List | ||
from .endpoints.text import router as chat_router | ||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): | ||
# Load the ML model | ||
await create_table(None) | ||
yield | ||
|
||
|
||
app = FastAPI(lifespan=lifespan) | ||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], # Allows all origins | ||
) | ||
app.include_router(chat_router) | ||
|
||
|
||
@app.on_event("shutdown") | ||
async def shutdown_event(): | ||
cur.close() | ||
conn.close() | ||
|
||
|
||
# Create an item | ||
@app.post("/items") | ||
def create_item(items: List[schemas.ItemCreate]): | ||
return curd.create_items(items=items) | ||
|
||
|
||
# Read all items | ||
@app.post("/items/search") | ||
def read_items(req_body: models.RequestBody): | ||
items = curd.get_items(req_body) | ||
return items | ||
|
||
|
||
# Read a single item by ID | ||
@app.get("/items/{p_key}", response_model=schemas.Item) | ||
def read_item(p_key: int): | ||
db_item = curd.get_item(p_key=p_key) | ||
if db_item is None: | ||
raise HTTPException(status_code=404, detail="Item not found") | ||
return db_item |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from cortext.protocol import StreamPrompting | ||
|
||
|
||
class RequestModel(StreamPrompting): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
version: "3.9" | ||
|
||
services: | ||
db: | ||
image: postgres:13 | ||
restart: always | ||
environment: | ||
POSTGRES_USER: ${POSTGRES_USER} | ||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} | ||
POSTGRES_DB: ${POSTGRES_DB} | ||
volumes: | ||
- postgres_data_score:/var/lib/postgresql/data | ||
ports: | ||
- "5432:5432" | ||
|
||
web: | ||
build: . | ||
restart: always | ||
ports: | ||
- "8000:8000" | ||
depends_on: | ||
- db | ||
environment: | ||
DATABASE_URL: postgres://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB} | ||
POSTGRES_DB: ${POSTGRES_DB} | ||
volumes: | ||
- .:/app | ||
|
||
volumes: | ||
postgres_data_score: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
fastapi | ||
uvicorn[standard] | ||
psycopg2-binary | ||
sqlalchemy | ||
pydantic |
Oops, something went wrong.