Skip to content

Commit

Permalink
[Core] Introduce asyncio within Ray Actors handling LLMClient (#8)
Browse files Browse the repository at this point in the history
* initial async io implementation

* timeout

* single client, multiple concurrent requests

* asyncio with ray actors

* Simplifying logic

* Remove print statements

* Use num-ray-clients and requests-per-client

* Minor bug fixes and cleanup

* Removed send_llm_request_ in between

* Initial fixes with print logs

* Removed print and test asyncio.sleep

* make format

* Error handling with streaming OpenAI client

* make format

* Removing timeout for large requests

* Update stop logic to use completed requests

* make format

* Update pbar once more

* Instantiate client only once and close after run.

* Fixing requests and concurrency (prefill_profiler)

* make format

* make format
  • Loading branch information
anmolagarwalcp810 authored Jul 17, 2024
1 parent 852be91 commit c009611
Show file tree
Hide file tree
Showing 19 changed files with 396 additions and 162 deletions.
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ python -m metron.run_benchmark \
--model "meta-llama/Meta-Llama-3-8B-Instruct" \
--max-num-completed-requests 150 \
--timeout 600 \
--num-concurrent-requests 10 \
--num-ray-clients 2 \
--num-concurrent-requests-per-client 5 \
--output-dir "result_outputs" \
--request-interval-generator-provider "poisson" \
--poisson-request-interval-generator-qps 0.5 \
Expand Down Expand Up @@ -114,9 +115,7 @@ Launch any open source system and setup API keys and URL as shown for [vLLM](#ru
```bash
python -m metron.prefill_profiler \
--model "meta-llama/Meta-Llama-3-8B-Instruct" \
--max-num-completed-requests 1 \
--timeout 600 \
--num-concurrent-requests 1 \
--fixed-request-generator-decode-tokens 16 \
--output-dir "prefill_experiments/prefill_profiler_vllm_llama-3-8b" \
--should-use-given-dir true
Expand All @@ -126,10 +125,7 @@ To modify range of prompt tokens for which prefill times get profiled, use the f
```bash
python -m metron.prefill_profiler \
--model "meta-llama/Meta-Llama-3-8B-Instruct" \
--max-num-completed-requests 1 \
--timeout 600 \
--num-concurrent-requests 1 \
--fixed-request-generator-decode-tokens 16 \
--output-dir "prefill_experiments/prefill_profiler_vllm_llama-3-8b" \
--should-use-given-dir true \
--prefill-lengths 256 512 1024 2048 4096 8192 16384 32768 65536
Expand Down
4 changes: 0 additions & 4 deletions docs/tutorials/prefill_profiler.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ And, then run the following command:
python -m metron.prefill_profiler \
--model "meta-llama/Meta-Llama-3-8B-Instruct" \
--max-num-completed-requests 1 \
--timeout 600 \
--num-concurrent-requests 1 \
--fixed-request-generator-decode-tokens 16 \
--output-dir "prefill_experiments/prefill_profiler_vllm_llama-3-8b"
Expand All @@ -38,9 +36,7 @@ To profile a custom range of prompt lengths, use the flag ``--prefill-lengths``
python -m metron.prefill_profiler \
--model "meta-llama/Meta-Llama-3-8B-Instruct" \
--max-num-completed-requests 1 \
--timeout 600 \
--num-concurrent-requests 1 \
--fixed-request-generator-decode-tokens 16 \
--output-dir "prefill_experiments/prefill_profiler_vllm_llama-3-8b" \
--prefill-lengths 256 512 1024 2048 4096 8192 16384 32768 65536
Expand Down
10 changes: 6 additions & 4 deletions metron/capacity_search/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def to_args(self):

@dataclass
class RequestConfig:
num_concurrent_requests: Optional[int] = None
num_ray_clients: Optional[int] = None
num_concurrent_requests_per_client: Optional[int] = None
timeout: Optional[int] = None
max_num_completed_requests: Optional[int] = None
additional_sampling_params: Optional[Dict[str, Any]] = None
Expand All @@ -188,7 +189,8 @@ class RequestConfig:

def to_config_dict(self):
return {
"num-concurrent-requests": self.num_concurrent_requests,
"num-ray-clients": self.num_ray_clients,
"num-concurrent-requests-per-client": self.num_concurrent_requests_per_client,
"timeout": self.timeout,
"max-num-completed-requests": self.max_num_completed_requests,
"additional-sampling-params": self.additional_sampling_params,
Expand All @@ -208,10 +210,10 @@ def to_args(self):
return " ".join(args)

def get_key(self):
return f"{self.num_concurrent_requests}_{self.timeout}_{self.max_num_completed_requests}_{self.llm_api}"
return f"{self.num_ray_clients}_{self.timeout}_{self.max_num_completed_requests}_{self.llm_api}"

def to_human_readable_name(self):
return f"Num concurrent requests: {self.num_concurrent_requests}, Timeout: {self.timeout}, Max num completed requests: {self.max_num_completed_requests}, LLM API: {self.llm_api}"
return f"Num ray clients: {self.num_ray_clients}, Num concurrent requests per client: {self.num_concurrent_requests_per_client}, Timeout: {self.timeout}, Max num completed requests: {self.max_num_completed_requests}, LLM API: {self.llm_api}"


@dataclass
Expand Down
6 changes: 4 additions & 2 deletions metron/capacity_search/config/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ request_generator_configs:
trace_file_name: "sharegpt"

request_configs:
- num_concurrent_requests: 100
- num_ray_clients: 10
num_concurrent_requests_per_client: 10
timeout: 1200
max_num_completed_requests: 1000
additional_sampling_params: {}
llm_api: "openai"
request_generator_max_tokens: 8192
- num_concurrent_requests: 100
- num_ray_clients: 10
num_concurrent_requests_per_client: 10
timeout: 1200
max_num_completed_requests: 1000
additional_sampling_params: {}
Expand Down
3 changes: 2 additions & 1 deletion metron/capacity_search/config/llama_70b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ request_generator_configs:
trace_file_name: "arxiv"

request_configs:
- num_concurrent_requests: 100
- num_ray_clients: 10
num_concurrent_requests_per_client: 10
timeout: 1200
max_num_completed_requests: 1000
additional_sampling_params: {}
Expand Down
6 changes: 4 additions & 2 deletions metron/capacity_search/config/llama_8b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ request_generator_configs:
trace_file_name: "arxiv"

request_configs:
- num_concurrent_requests: 100
- num_ray_clients: 10
num_concurrent_requests_per_client: 10
timeout: 1200
max_num_completed_requests: 1000
additional_sampling_params: {}
llm_api: "openai"
request_generator_max_tokens: 8192
- num_concurrent_requests: 100
- num_ray_clients: 10
num_concurrent_requests_per_client: 10
timeout: 1200
max_num_completed_requests: 1000
additional_sampling_params: {}
Expand Down
3 changes: 2 additions & 1 deletion metron/capacity_search/config/mixtral.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ request_generator_configs:
trace_file_name: "arxiv"

request_configs:
- num_concurrent_requests: 100
- num_ray_clients: 10
num_concurrent_requests_per_client: 10
timeout: 1200
max_num_completed_requests: 1000
additional_sampling_params: {}
Expand Down
2 changes: 1 addition & 1 deletion metron/core/llm_clients/base_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_token_length(self, text: str) -> int:
return len(self.tokenizer.encode(text))

@abc.abstractmethod
def send_llm_request(
async def send_llm_request(
self, request_config: RequestConfig
) -> Tuple[RequestMetrics, str]:
"""Make a single completion request to a LLM API
Expand Down
42 changes: 0 additions & 42 deletions metron/core/llm_clients/common.py

This file was deleted.

3 changes: 1 addition & 2 deletions metron/core/llm_clients/litellm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
logger = init_logger(__name__)


@ray.remote
class LiteLLMClient(BaseLLMClient):
"""Client for LiteLLM Completions API."""

def send_llm_request(
async def send_llm_request(
self, request_config: RequestConfig
) -> Tuple[RequestMetrics, str]:
# litellm package isn't serializable, so we import it within the function
Expand Down
40 changes: 25 additions & 15 deletions metron/core/llm_clients/openai_chat_completions_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import time
from typing import List, Tuple

import ray
import requests
import httpx

from metron.core.llm_clients.base_llm_client import BaseLLMClient
from metron.core.request_config import RequestConfig
Expand All @@ -17,10 +16,13 @@
MAX_RESPONSES_ALLOWED_TO_STORE = 5


@ray.remote
class OpenAIChatCompletionsClient(BaseLLMClient):
"""Client for OpenAI Chat Completions API."""

def __init__(self, model_name: str) -> None:
super().__init__(model_name)
self.client = httpx.AsyncClient()

def total_tokens(self, response_list: List[str]) -> int:
merged_content = "".join(response_list)
return self.get_token_length(merged_content)
Expand All @@ -40,7 +42,11 @@ def get_current_tokens_received(
previous_token_count = self.total_tokens(previous_responses)
return current_tokens_received, previous_token_count

def send_llm_request(
async def close_client(self):
# Close the client
await self.client.aclose()

async def send_llm_request(
self, request_config: RequestConfig
) -> Tuple[RequestMetrics, str]:
prompt = request_config.prompt
Expand Down Expand Up @@ -82,29 +88,33 @@ def send_llm_request(
most_recent_received_token_time = time.monotonic()

try:
with requests.post(
address,
json=body,
stream=True,
timeout=180,
headers=headers,
async with self.client.stream(
"POST", address, json=body, timeout=None, headers=headers
) as response:
if response.status_code != 200:
error_msg = response.text
error_response_code = response.status_code
logger.error(f"Request Error: {response.content}")
error_content = []
async for error_line in response.aiter_lines():
error_content.append(error_line)
error_msg = "".join(error_content)
logger.error(f"Request Error: {error_msg}")
response.raise_for_status()

for chunk in response.iter_lines(chunk_size=None):
async for chunk in response.aiter_lines():
chunk = chunk.strip()

if not chunk:
continue
stem = "data: "
chunk = chunk[len(stem) :]
if chunk == b"[DONE]":
if chunk in [b"[DONE]", "[DONE]"]:
continue
data = json.loads(chunk)

try:
data = json.loads(chunk)
except json.JSONDecodeError:
logger.error(f"JSON decode error with chunk: {chunk}")
continue # Skip malformed JSON

if "error" in data:
error_msg = data["error"]["message"]
Expand Down
3 changes: 1 addition & 2 deletions metron/core/llm_clients/sagemaker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
logger = init_logger(__name__)


@ray.remote
class SageMakerClient(BaseLLMClient):
"""Client for OpenAI Chat Completions API."""

def send_llm_request(
async def send_llm_request(
self, request_config: RequestConfig
) -> Tuple[RequestMetrics, str]:
if not os.environ.get("AWS_ACCESS_KEY_ID"):
Expand Down
3 changes: 1 addition & 2 deletions metron/core/llm_clients/vertexai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
logger = init_logger(__name__)


@ray.remote
class VertexAIClient(BaseLLMClient):
"""Client for VertexAI API."""

def send_llm_request(
async def send_llm_request(
self, request_config: RequestConfig
) -> Tuple[RequestMetrics, str]:
project_id = os.environ.get("GCLOUD_PROJECT_ID")
Expand Down
1 change: 1 addition & 0 deletions metron/core/request_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ class RequestConfig(BaseModel):
llm_api: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
address_append_value: Optional[str] = None
id: Optional[int] = None
Loading

0 comments on commit c009611

Please sign in to comment.