Skip to content

Commit

Permalink
Adjust tool limit per request
Browse files Browse the repository at this point in the history
  • Loading branch information
TaoChenOSU committed Dec 5, 2024
1 parent 30b67ec commit 14a9627
Show file tree
Hide file tree
Showing 13 changed files with 635 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class AnthropicChatPromptExecutionSettings(AnthropicPromptExecutionSettings):
tools: Annotated[
list[dict[str, Any]] | None,
Field(
max_length=64,
description=(
"Do not set this manually. It is set by the service based on the function choice configuration."
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from semantic_kernel.connectors.ai.azure_ai_inference.azure_ai_inference_prompt_execution_settings import (
AzureAIInferenceChatPromptExecutionSettings,
AzureAIInferenceEmbeddingPromptExecutionSettings,
AzureAIInferencePromptExecutionSettings,
)
from semantic_kernel.connectors.ai.azure_ai_inference.azure_ai_inference_settings import AzureAIInferenceSettings
from semantic_kernel.connectors.ai.azure_ai_inference.services.azure_ai_inference_chat_completion import (
Expand All @@ -16,6 +17,7 @@
"AzureAIInferenceChatCompletion",
"AzureAIInferenceChatPromptExecutionSettings",
"AzureAIInferenceEmbeddingPromptExecutionSettings",
"AzureAIInferencePromptExecutionSettings",
"AzureAIInferenceSettings",
"AzureAIInferenceTextEmbedding",
]
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BedrockChatPromptExecutionSettings(BedrockPromptExecutionSettings):
tools: Annotated[
list[dict[str, Any]] | None,
Field(
max_length=64,
min_length=1,
description="Do not set this manually. It is set by the service based "
"on the function choice configuration.",
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ class GoogleAIChatPromptExecutionSettings(GoogleAIPromptExecutionSettings):
tools: Annotated[
list[dict[str, Any]] | None,
Field(
max_length=64,
# There is no official documentation on the maximum length of the tools list.
# Using the limit stated on the Vertex AI documentation:
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#function-declarations
max_length=128,
description="Do not set this manually. It is set by the service based "
"on the function choice configuration.",
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import override # pragma: no cover
else:
from typing_extensions import override # pragma: no cover

from pydantic import Field
from vertexai.generative_models import Tool, ToolConfig

Expand Down Expand Up @@ -38,7 +39,7 @@ class VertexAIChatPromptExecutionSettings(VertexAIPromptExecutionSettings):
tools: Annotated[
list[Tool] | None,
Field(
max_length=64,
max_length=128, # https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#function-declarations
description="Do not set this manually. It is set by the service based "
"on the function choice configuration.",
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class MistralAIChatPromptExecutionSettings(MistralAIPromptExecutionSettings):
tools: Annotated[
list[dict[str, Any]] | None,
Field(
max_length=64,
description="Do not set this manually. It is set by the service based "
"on the function choice configuration.",
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ class OllamaPromptExecutionSettings(PromptExecutionSettings):
format: Literal["json"] | None = None
options: dict[str, Any] | None = None

# TODO(@taochen): Add individual properties for execution settings and
# convert them to the appropriate types in the options dictionary.


class OllamaTextPromptExecutionSettings(OllamaPromptExecutionSettings):
"""Settings for Ollama text prompt execution."""
Expand All @@ -32,7 +29,6 @@ class OllamaChatPromptExecutionSettings(OllamaPromptExecutionSettings):
tools: Annotated[
list[dict[str, Any]] | None,
Field(
max_length=64,
description="Do not set this manually. It is set by the service based "
"on the function choice configuration.",
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class OpenAIChatPromptExecutionSettings(OpenAIPromptExecutionSettings):
tools: Annotated[
list[dict[str, Any]] | None,
Field(
max_length=64,
max_length=128, # https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools
description="Do not set this manually. It is set by the service based "
"on the function choice configuration.",
),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) Microsoft. All rights reserved.

from semantic_kernel.connectors.ai.azure_ai_inference import (
AzureAIInferenceChatPromptExecutionSettings,
AzureAIInferenceEmbeddingPromptExecutionSettings,
AzureAIInferencePromptExecutionSettings,
)
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings


def test_default_azure_ai_inference_prompt_execution_settings():
settings = AzureAIInferencePromptExecutionSettings()

assert settings.frequency_penalty is None
assert settings.max_tokens is None
assert settings.presence_penalty is None
assert settings.seed is None
assert settings.stop is None
assert settings.temperature is None
assert settings.top_p is None
assert settings.extra_parameters is None


def test_custom_azure_ai_inference_prompt_execution_settings():
settings = AzureAIInferencePromptExecutionSettings(
frequency_penalty=0.5,
max_tokens=128,
presence_penalty=0.5,
seed=1,
stop="world",
temperature=0.5,
top_p=0.5,
extra_parameters={"key": "value"},
)

assert settings.frequency_penalty == 0.5
assert settings.max_tokens == 128
assert settings.presence_penalty == 0.5
assert settings.seed == 1
assert settings.stop == "world"
assert settings.temperature == 0.5
assert settings.top_p == 0.5
assert settings.extra_parameters == {"key": "value"}


def test_azure_ai_inference_prompt_execution_settings_from_default_completion_config():
settings = PromptExecutionSettings(service_id="test_service")
chat_settings = AzureAIInferenceChatPromptExecutionSettings.from_prompt_execution_settings(settings)

assert chat_settings.service_id == "test_service"
assert chat_settings.frequency_penalty is None
assert chat_settings.max_tokens is None
assert chat_settings.presence_penalty is None
assert chat_settings.seed is None
assert chat_settings.stop is None
assert chat_settings.temperature is None
assert chat_settings.top_p is None
assert chat_settings.extra_parameters is None


def test_azure_ai_inference_prompt_execution_settings_from_openai_prompt_execution_settings():
chat_settings = AzureAIInferenceChatPromptExecutionSettings(service_id="test_service", temperature=1.0)
new_settings = AzureAIInferencePromptExecutionSettings(service_id="test_2", temperature=0.0)
chat_settings.update_from_prompt_execution_settings(new_settings)

assert chat_settings.service_id == "test_2"
assert chat_settings.temperature == 0.0


def test_azure_ai_inference_prompt_execution_settings_from_custom_completion_config():
settings = PromptExecutionSettings(
service_id="test_service",
extension_data={
"frequency_penalty": 0.5,
"max_tokens": 128,
"presence_penalty": 0.5,
"seed": 1,
"stop": "world",
"temperature": 0.5,
"top_p": 0.5,
"extra_parameters": {"key": "value"},
},
)
chat_settings = AzureAIInferenceChatPromptExecutionSettings.from_prompt_execution_settings(settings)

assert chat_settings.service_id == "test_service"
assert chat_settings.frequency_penalty == 0.5
assert chat_settings.max_tokens == 128
assert chat_settings.presence_penalty == 0.5
assert chat_settings.seed == 1
assert chat_settings.stop == "world"
assert chat_settings.temperature == 0.5
assert chat_settings.top_p == 0.5
assert chat_settings.extra_parameters == {"key": "value"}


def test_azure_ai_inference_chat_prompt_execution_settings_from_custom_completion_config_with_functions():
settings = PromptExecutionSettings(
service_id="test_service",
extension_data={
"tools": [{"function": {}}],
},
)
chat_settings = AzureAIInferenceChatPromptExecutionSettings.from_prompt_execution_settings(settings)

assert chat_settings.tools == [{"function": {}}]


def test_create_options():
settings = AzureAIInferenceChatPromptExecutionSettings(
service_id="test_service",
extension_data={
"frequency_penalty": 0.5,
"max_tokens": 128,
"presence_penalty": 0.5,
"seed": 1,
"stop": "world",
"temperature": 0.5,
"top_p": 0.5,
"extra_parameters": {"key": "value"},
},
)
options = settings.prepare_settings_dict()

assert options["frequency_penalty"] == 0.5
assert options["max_tokens"] == 128
assert options["presence_penalty"] == 0.5
assert options["seed"] == 1
assert options["stop"] == "world"
assert options["temperature"] == 0.5
assert options["top_p"] == 0.5
assert options["extra_parameters"] == {"key": "value"}
assert "tools" not in options
assert "tool_config" not in options


def test_default_azure_ai_inference_embedding_prompt_execution_settings():
settings = AzureAIInferenceEmbeddingPromptExecutionSettings()

assert settings.dimensions is None
assert settings.encoding_format is None
assert settings.input_type is None
assert settings.extra_parameters is None
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) Microsoft. All rights reserved.

import pytest
from pydantic import ValidationError

from semantic_kernel.connectors.ai.bedrock import BedrockChatPromptExecutionSettings, BedrockPromptExecutionSettings
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings


def test_default_bedrock_prompt_execution_settings():
settings = BedrockPromptExecutionSettings()

assert settings.temperature is None
assert settings.top_p is None
assert settings.top_k is None
assert settings.max_tokens is None
assert settings.stop == []


def test_custom_bedrock_prompt_execution_settings():
settings = BedrockPromptExecutionSettings(
temperature=0.5,
top_p=0.5,
top_k=10,
max_tokens=128,
stop=["world"],
)

assert settings.temperature == 0.5
assert settings.top_p == 0.5
assert settings.top_k == 10
assert settings.max_tokens == 128
assert settings.stop == ["world"]


def test_bedrock_prompt_execution_settings_from_default_completion_config():
settings = PromptExecutionSettings(service_id="test_service")
chat_settings = BedrockChatPromptExecutionSettings.from_prompt_execution_settings(settings)

assert chat_settings.service_id == "test_service"
assert chat_settings.temperature is None
assert chat_settings.top_p is None
assert chat_settings.top_k is None
assert chat_settings.max_tokens is None
assert chat_settings.stop == []


def test_bedrock_prompt_execution_settings_from_openai_prompt_execution_settings():
chat_settings = BedrockChatPromptExecutionSettings(service_id="test_service", temperature=1.0)
new_settings = BedrockPromptExecutionSettings(service_id="test_2", temperature=0.0)
chat_settings.update_from_prompt_execution_settings(new_settings)

assert chat_settings.service_id == "test_2"
assert chat_settings.temperature == 0.0


def test_bedrock_prompt_execution_settings_from_custom_completion_config():
settings = PromptExecutionSettings(
service_id="test_service",
extension_data={
"temperature": 0.5,
"top_p": 0.5,
"top_k": 10,
"max_tokens": 128,
"stop": ["world"],
},
)
chat_settings = BedrockChatPromptExecutionSettings.from_prompt_execution_settings(settings)

assert chat_settings.temperature == 0.5
assert chat_settings.top_p == 0.5
assert chat_settings.top_k == 10
assert chat_settings.max_tokens == 128
assert chat_settings.stop == ["world"]


def test_bedrock_chat_prompt_execution_settings_from_custom_completion_config_with_functions():
settings = PromptExecutionSettings(
service_id="test_service",
extension_data={
"tools": [{"function": {}}],
},
)
chat_settings = BedrockChatPromptExecutionSettings.from_prompt_execution_settings(settings)

assert chat_settings.tools == [{"function": {}}]


def test_bedrock_chat_prompt_execution_settings_with_functions_exception():
settings = PromptExecutionSettings(
service_id="test_service",
extension_data={
"tools": [],
},
)

with pytest.raises(ValidationError, match="List should have at least 1 item after validation"):
BedrockChatPromptExecutionSettings.from_prompt_execution_settings(settings)


def test_create_options():
settings = BedrockPromptExecutionSettings(
service_id="test_service",
extension_data={
"temperature": 0.5,
"top_p": 0.5,
"top_k": 10,
"max_tokens": 128,
"stop": ["world"],
},
)
options = settings.prepare_settings_dict()

assert options["temperature"] == 0.5
assert options["top_p"] == 0.5
assert options["top_k"] == 10
assert options["max_tokens"] == 128
assert options["stop"] == ["world"]
Loading

0 comments on commit 14a9627

Please sign in to comment.