Skip to content

Commit

Permalink
[Feature] Support default input index when collect package tools for …
Browse files Browse the repository at this point in the history
…custom-llm tool (#1502)

# Description
### Reason
Previously, customers complained that the order of custom-llm tool
inputs displayed in the UI was inconsistent with the order defined in
tool interface or yaml file. This was because the UI received the inputs
as a dict, which resulted in an alphabetical order display.
In this PR, we will automatically add "ui_hints": {"index": XX} to every
input based on the order in tool interface or yaml file for custom_llm
tool. When UI received tool inputs, they can display the inputs
according to this ui_hints["index"] field to make the order consistent.
Both portal and extension have completed the code change to support
ui_hints["index"].

### Cases which will not be corrected by this pr:
- Existing flow with flow.tools.json
- All examples in promptflow gallery which also have flow.tools.json
file

### Test 
For more test cases, please go to [this
link](https://microsoft.sharepoint.com/:o:/r/teams/STCASharedDataTeam/_layouts/15/Doc.aspx?sourcedoc=%7Bde8c0032-8386-492d-af0e-104af9de317a%7D&action=edit&wd=target(Pipelines%2FDCM%2FDevs%2Fjiazeng.one%7C0FA273ED-FDCC-4644-82B0-C96606E278CB%2FTest%20default%20input%20index-3%7CDC170A9F-4572-4931-BFAD-CEECD018616B%2F)&share=IgEyAIzehoMtSa8OEEr53jF6AYDaZzIchXLytsUQwUbc29U)
- custom-llm tool

![image](https://github.com/microsoft/promptflow/assets/95729303/942b03a2-ae12-44bb-ba03-ccad98df5c69)
    - in portal

![image](https://github.com/microsoft/promptflow/assets/95729303/2f5cd399-de70-4eb9-b26f-cc44fd9f45c5)
    - in extension

![image](https://github.com/microsoft/promptflow/assets/95729303/f22e2b6f-71df-4787-bab0-16645f9eb264)

- built in custom-llm tool
    - in portal:

![image](https://github.com/microsoft/promptflow/assets/95729303/98faab67-e59a-4e48-b843-58f373638cd5)
    - in extension:

![image](https://github.com/microsoft/promptflow/assets/95729303/f5311e36-beb8-4be7-ba29-abec5148f2c7)

- Test cases
The behavior of the following test cases are as expected. Only
custom_llm tool's inputs' order are corrected, other tool's inputs'
order keep same as before.
    - Create a new flow
    - Clone a flow
    - Upload a flow without .promptflow folder



  



# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [x] Pull request includes test coverage for the included changes.

---------

Co-authored-by: cs_lucky <si.chen@microsoft.com>
Co-authored-by: jiazeng <jiazeng@microsoft.com>
  • Loading branch information
3 people authored Jan 31, 2024
1 parent a2aec32 commit 500eee0
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 1 deletion.
7 changes: 7 additions & 0 deletions src/promptflow/promptflow/_core/tools_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
load_function_from_function_path,
validate_dynamic_list_func_response_type,
validate_tool_func_result,
assign_tool_input_index_for_ux_order_if_needed,
)
from promptflow._utils.yaml_utils import load_yaml
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType
Expand All @@ -53,6 +54,10 @@ def collect_tools_from_directory(base_dir) -> dict:
tools = {}
for f in Path(base_dir).glob("**/*.yaml"):
with open(f, "r") as f:
# The feature that automatically assigns indexes to inputs based on their order in the tool YAML,
# relying on the feature of ruamel.yaml that maintains key order when load YAML file.
# For more information on ruamel.yaml's feature, please
# visit https://yaml.readthedocs.io/en/latest/overview/#overview.
tools_in_file = load_yaml(f)
for identifier, tool in tools_in_file.items():
tools[identifier] = tool
Expand Down Expand Up @@ -96,6 +101,7 @@ def collect_package_tools(keys: Optional[List[str]] = None) -> dict:
importlib.import_module(m) # Import the module to make sure it is valid
tool["package"] = entry_point.dist.metadata["Name"]
tool["package_version"] = entry_point.dist.version
assign_tool_input_index_for_ux_order_if_needed(tool)
all_package_tools[identifier] = tool
except Exception as e:
msg = (
Expand Down Expand Up @@ -126,6 +132,7 @@ def collect_package_tools_and_connections(keys: Optional[List[str]] = None) -> d
module = importlib.import_module(m) # Import the module to make sure it is valid
tool["package"] = entry_point.dist.metadata["Name"]
tool["package_version"] = entry_point.dist.version
assign_tool_input_index_for_ux_order_if_needed(tool)
all_package_tools[identifier] = tool

# Get custom strong type connection definition
Expand Down
65 changes: 64 additions & 1 deletion src/promptflow/promptflow/_utils/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from promptflow._utils.utils import is_json_serializable
from promptflow.exceptions import ErrorTarget, UserErrorException

from ..contracts.tool import ConnectionType, InputDefinition, Tool, ToolFuncCallScenario, ValueType
from ..contracts.tool import ConnectionType, InputDefinition, Tool, ToolFuncCallScenario, ToolType, ValueType
from ..contracts.types import PromptTemplate

module_logger = logging.getLogger(__name__)

_DEPRECATED_TOOLS = "deprecated_tools"
UI_HINTS = "ui_hints"


def value_to_str(val):
Expand Down Expand Up @@ -305,6 +306,68 @@ def load_function_from_function_path(func_path: str):
)


def assign_tool_input_index_for_ux_order_if_needed(tool):
"""
Automatically adds an index to the inputs of a tool based on their order in the tool's YAML.
This function directly modifies the tool without returning any value.
Example:
- tool (dict): A dictionary representing a tool configuration. Inputs do not contain 'ui_hints':
{
'name': 'My Custom LLM Tool',
'type': 'custom_llm',
'inputs':
{
'input1': {'type': 'string'},
'input2': {'type': 'string'},
'input3': {'type': 'string'}
}
}
>>> assign_tool_input_index_for_ux_order_if_needed(tool)
- tool (dict): Tool inputs are modified to include 'ui_hints' with an 'index', indicating the order.
{
'name': 'My Custom LLM Tool',
'type': 'custom_llm',
'inputs':
{
'input1': {'type': 'string', 'ui_hints': {'index': 0}},
'input2': {'type': 'string', 'ui_hints': {'index': 1}},
'input3': {'type': 'string', 'ui_hints': {'index': 2}}
}
}
"""
tool_type = tool.get("type")
if should_preserve_tool_inputs_order(tool_type) and "inputs" in tool:
inputs_dict = tool["inputs"]
input_index = 0
# The keys can keep order because the tool YAML is loaded by ruamel.yaml and
# ruamel.yaml has the feature of preserving the order of keys.
# For more information on ruamel.yaml's feature, please
# visit https://yaml.readthedocs.io/en/latest/overview/#overview.
for input_name, settings in inputs_dict.items():
# 'uionly_hidden' indicates that the inputs are not the tool's inputs.
# They are not displayed on the main interface but appear in a popup window.
# These inputs are passed to UX as a list, maintaining the same order as generated by func parameters.
# Skip the 'uionly_hidden' input type because the 'ui_hints: index' is not needed.
if "input_type" in settings.keys() and settings["input_type"] == "uionly_hidden":
continue
settings.setdefault(UI_HINTS, {})
settings[UI_HINTS]["index"] = input_index
input_index += 1


def should_preserve_tool_inputs_order(tool_type):
"""
Currently, we only automatically add input indexes for the custom_llm tool,
following the order specified in the tool interface or YAML.
As of now, only the custom_llm tool requires the order of its inputs displayed on the UI
to be consistent with the order in the YAML, because its inputs are shown in parameter style.
To avoid extensive changes, other types of tools will remain as they are.
"""
return tool_type == ToolType.CUSTOM_LLM


# Handling backward compatibility and generating a mapping between the previous and new tool IDs.
def _find_deprecated_tools(package_tools) -> Dict[str, str]:
_deprecated_tools = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
custom_llm_tool.TestCustomLLMTool.call:
name: Test Custom LLM Tool
type: custom_llm
function: call
module: TestCustomLLMTool
inputs:
connection:
type:
- AzureOpenAIConnection
deployment_name:
type:
- string
api:
type:
- string
temperature:
type:
- double
top_p:
type:
- double
max_tokens:
type:
- int
stop:
default: ""
type:
- list
presence_penalty:
default: 0
type:
- double
frequency_penalty:
default: 0
type:
- double
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pathlib import Path
from ruamel.yaml import YAML


def collect_tools_from_directory(base_dir) -> dict:
tools = {}
yaml = YAML()
for f in Path(base_dir).glob("**/*.yaml"):
with open(f, "r") as f:
tools_in_file = yaml.load(f)
for identifier, tool in tools_in_file.items():
tools[identifier] = tool
return tools


def list_package_tools():
"""List package tools"""
yaml_dir = Path(__file__).parent
return collect_tools_from_directory(yaml_dir)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import textwrap
from pathlib import Path
from unittest.mock import patch
Expand All @@ -19,6 +20,16 @@
from promptflow.exceptions import UserErrorException


@pytest.fixture
def mock_entry_point():
from ...package_tools.custom_llm_tool_multi_inputs_without_index.list import list_package_tools
entry_point = MagicMock()
entry_point.load.return_value = list_package_tools
entry_point.dist.metadata.return_value = "TestCustomLLMTool"
entry_point.dist.version.return_value = "0.0.1"
return entry_point


@pytest.mark.unittest
class TestToolLoader:
def test_load_tool_for_node_with_invalid_node(self):
Expand Down Expand Up @@ -159,6 +170,30 @@ def test_collect_package_tools_if_node_source_tool_is_legacy(self):
package_tools = collect_package_tools(legacy_node_source_tools)
assert "promptflow.tools.azure_content_safety.analyze_text" in package_tools.keys()

def test_collect_package_tools_set_defaut_input_index(self, mocker, mock_entry_point):
entry_point = mock_entry_point
entry_points = (entry_point, )
mocker.patch("promptflow._core.tools_manager._get_entry_points_by_group", return_value=entry_points)
mocker.patch.object(importlib, 'import_module', return_value=MagicMock())
tool = "custom_llm_tool.TestCustomLLMTool.call"
package_tools = collect_package_tools([tool])
inputs_order = ["connection", "deployment_name", "api", "temperature", "top_p", "max_tokens",
"stop", "presence_penalty", "frequency_penalty"]
for index, input_name in enumerate(inputs_order):
assert package_tools[tool]['inputs'][input_name]['ui_hints']['index'] == index

def test_collect_package_tools_and_connections_set_defaut_input_index(self, mocker, mock_entry_point):
entry_point = mock_entry_point
entry_points = (entry_point, )
mocker.patch("promptflow._core.tools_manager._get_entry_points_by_group", return_value=entry_points)
mocker.patch.object(importlib, 'import_module', return_value=MagicMock())
tool = "custom_llm_tool.TestCustomLLMTool.call"
package_tools, _, _ = collect_package_tools_and_connections([tool])
inputs_order = ["connection", "deployment_name", "api", "temperature", "top_p", "max_tokens",
"stop", "presence_penalty", "frequency_penalty"]
for index, input_name in enumerate(inputs_order):
assert package_tools[tool]['inputs'][input_name]['ui_hints']['index'] == index

def test_collect_package_tools_and_connections(self, install_custom_tool_pkg):
keys = ["my_tool_package.tools.my_tool_2.MyTool.my_tool"]
tools, specs, templates = collect_package_tools_and_connections(keys)
Expand Down

0 comments on commit 500eee0

Please sign in to comment.