Skip to content

Commit

Permalink
Add * before keyword args for ChatCompletionClient (#4822)
Browse files Browse the repository at this point in the history
add * before keyword args

Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 27, 2024
1 parent edad1b6 commit 9a2dbb4
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class ChatCompletionClient(ABC, ComponentLoader):
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
Expand All @@ -41,6 +42,7 @@ async def create(
def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
Expand All @@ -56,10 +58,10 @@ def actual_usage(self) -> RequestUsage: ...
def total_usage(self) -> RequestUsage: ...

@abstractmethod
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...

@abstractmethod
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...

@property
@abstractmethod
Expand Down
6 changes: 4 additions & 2 deletions python/packages/autogen-core/tests/test_tool_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class MockChatCompletionClient(ChatCompletionClient):
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
Expand All @@ -116,6 +117,7 @@ async def create(
def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
Expand All @@ -129,10 +131,10 @@ def actual_usage(self) -> RequestUsage:
def total_usage(self) -> RequestUsage:
return RequestUsage(prompt_tokens=0, completion_tokens=0)

def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return 0

def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return 0

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
def __init__(
self,
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
*,
create_args: Dict[str, Any],
model_capabilities: Optional[ModelCapabilities] = None,
):
Expand Down Expand Up @@ -389,6 +390,7 @@ def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient:
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
Expand Down Expand Up @@ -581,11 +583,11 @@ async def create(
async def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
*,
max_consecutive_empty_chunk_tolerance: int = 0,
) -> AsyncGenerator[Union[str, CreateResult], None]:
"""
Expand Down Expand Up @@ -800,7 +802,7 @@ def actual_usage(self) -> RequestUsage:
def total_usage(self) -> RequestUsage:
return self._total_usage

def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
model = self._create_args["model"]
try:
encoding = tiktoken.encoding_for_model(model)
Expand Down Expand Up @@ -889,9 +891,9 @@ def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | To
num_tokens += 12
return num_tokens

def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
token_limit = _model_info.get_token_limit(self._create_args["model"])
return token_limit - self.count_tokens(messages, tools)
return token_limit - self.count_tokens(messages, tools=tools)

@property
def capabilities(self) -> ModelCapabilities:
Expand Down Expand Up @@ -974,7 +976,7 @@ def __init__(self, **kwargs: Unpack[OpenAIClientConfiguration]):
client = _openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args)
self._raw_config: Dict[str, Any] = copied_args
super().__init__(client, create_args, model_capabilities)
super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities)

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
Expand Down Expand Up @@ -1059,7 +1061,7 @@ def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]):
client = _azure_openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args)
self._raw_config: Dict[str, Any] = copied_args
super().__init__(client, create_args, model_capabilities)
super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities)

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
Expand Down Expand Up @@ -155,6 +156,7 @@ async def create(
async def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
Expand Down Expand Up @@ -191,11 +193,11 @@ def actual_usage(self) -> RequestUsage:
def total_usage(self) -> RequestUsage:
return self._total_usage

def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
_, token_count = self._tokenize(messages)
return token_count

def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return max(
0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens
)
Expand Down

0 comments on commit 9a2dbb4

Please sign in to comment.