Skip to content

Commit

Permalink
test mock entry point
Browse files Browse the repository at this point in the history
update
  • Loading branch information
jiazengcindy committed Jan 30, 2024
1 parent 97af29b commit 168b6cb
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 38 deletions.
10 changes: 10 additions & 0 deletions src/promptflow/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,13 @@ def enable_logger_propagate():
logger.propagate = True
yield
logger.propagate = original_value


@pytest.fixture
def mock_entry_point():
from executor.package_tools.custom_llm_tool_multi_inputs_without_index.list import list_package_tools
entry_point = MagicMock()
entry_point.load.return_value = list_package_tools
entry_point.dist.metadata.return_value = "TestCustomLLMTool"
entry_point.dist.version.return_value = "0.0.1"
return entry_point
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ custom_llm_tool.TestCustomLLMTool.call:
name: Test Custom LLM Tool
type: custom_llm
function: call
module: TestCustomLLMTool
inputs:
connection:
type:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pathlib import Path
from ruamel.yaml import YAML


def collect_tools_from_directory(base_dir) -> dict:
tools = {}
yaml = YAML()
for f in Path(base_dir).glob("**/*.yaml"):
with open(f, "r") as f:
tools_in_file = yaml.load(f)
for identifier, tool in tools_in_file.items():
tools[identifier] = tool
return tools


def list_package_tools():
"""List package tools"""
yaml_dir = Path(__file__).parent
return collect_tools_from_directory(yaml_dir)
62 changes: 24 additions & 38 deletions src/promptflow/tests/executor/unittests/_core/test_tools_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import textwrap
from pathlib import Path
from unittest.mock import patch
Expand All @@ -10,18 +11,14 @@
from promptflow._core.tools_manager import (
BuiltinsManager,
ToolLoader,
assign_tool_input_index_for_ux_order_if_needed,
collect_package_tools,
collect_package_tools_and_connections,
collect_tools_from_directory,
)
from promptflow._utils.yaml_utils import load_yaml_string
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSource, ToolSourceType
from promptflow.contracts.tool import Tool, ToolType
from promptflow.exceptions import UserErrorException

PACKAGE_TOOL_BASE = Path(__file__).parent.parent.parent / "package_tools"


@pytest.mark.unittest
class TestToolLoader:
Expand Down Expand Up @@ -163,40 +160,29 @@ def test_collect_package_tools_if_node_source_tool_is_legacy(self):
package_tools = collect_package_tools(legacy_node_source_tools)
assert "promptflow.tools.azure_content_safety.analyze_text" in package_tools.keys()

def test_assign_tool_input_index_for_ux_order_if_needed(self):
tool = {
'name': 'My Custom LLM Tool',
'type': 'custom_llm',
'inputs': {
'input2': {'type': 'string'},
'input1': {'type': 'string'},
'input3': {'type': 'string'}
}
}
assign_tool_input_index_for_ux_order_if_needed(tool)
assert tool == {
'name': 'My Custom LLM Tool',
'type': 'custom_llm',
'inputs': {
'input2': {'type': 'string', 'ui_hints': {'index': 0}},
'input1': {'type': 'string', 'ui_hints': {'index': 1}},
'input3': {'type': 'string', 'ui_hints': {'index': 2}}
}
}

def test_collect_tools_from_directory_keeps_keys_order(self):
"""
Test that it can keep the order of keys when loading tools from a directory.
This is important because the feature automatically assigns indexes to inputs based on their order
in the tool's YAML, relying on ruamel.yaml's ability to maintain key order when loading a YAML file.
If ruamel.yaml were to break this feature, such a breaking change could be detected by this test.
"""
tool_yaml_folder = PACKAGE_TOOL_BASE / "custom_llm_tool_multi_inputs_without_index"
collected_tools = collect_tools_from_directory(tool_yaml_folder)
tool = collected_tools["custom_llm_tool.TestCustomLLMTool.call"]
expected_keys_order = ["connection", "deployment_name", "api", "temperature", "top_p", "max_tokens",
"stop", "presence_penalty", "frequency_penalty"]
assert list(tool["inputs"]) == expected_keys_order
def test_collect_package_tools_set_defaut_input_index(self, mocker, mock_entry_point):
entry_point = mock_entry_point
entry_points = (entry_point, )
mocker.patch("promptflow._core.tools_manager._get_entry_points_by_group", return_value=entry_points)
mocker.patch.object(importlib, 'import_module', return_value=MagicMock())
tool = "custom_llm_tool.TestCustomLLMTool.call"
package_tools = collect_package_tools([tool])
inputs_order = ["connection", "deployment_name", "api", "temperature", "top_p", "max_tokens",
"stop", "presence_penalty", "frequency_penalty"]
for index, input_name in enumerate(inputs_order):
assert package_tools[tool]['inputs'][input_name]['ui_hints']['index'] == index

def test_collect_package_tools_and_connections_set_defaut_input_index(self, mocker, mock_entry_point):
entry_point = mock_entry_point
entry_points = (entry_point, )
mocker.patch("promptflow._core.tools_manager._get_entry_points_by_group", return_value=entry_points)
mocker.patch.object(importlib, 'import_module', return_value=MagicMock())
tool = "custom_llm_tool.TestCustomLLMTool.call"
package_tools, _, _ = collect_package_tools_and_connections([tool])
inputs_order = ["connection", "deployment_name", "api", "temperature", "top_p", "max_tokens",
"stop", "presence_penalty", "frequency_penalty"]
for index, input_name in enumerate(inputs_order):
assert package_tools[tool]['inputs'][input_name]['ui_hints']['index'] == index

def test_collect_package_tools_and_connections(self, install_custom_tool_pkg):
keys = ["my_tool_package.tools.my_tool_2.MyTool.my_tool"]
Expand Down

0 comments on commit 168b6cb

Please sign in to comment.