Skip to content

Commit

Permalink
fix: Add validation of which tools were requested by get_tools (#1107)
Browse files Browse the repository at this point in the history
  • Loading branch information
tushar-composio authored Dec 31, 2024
1 parent f69fe36 commit f496f7f
Show file tree
Hide file tree
Showing 15 changed files with 61 additions and 2 deletions.
23 changes: 23 additions & 0 deletions python/composio/tools/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ def _limit_file_search_response(response: t.Dict) -> t.Dict:
)
self.max_retries = max_retries

# To be populated by get_tools(), from within subclasses like
# composio_openai's Toolset.
self._requested_actions: t.Optional[t.List[str]] = None

def _validating_connection_ids(
self,
connected_account_ids: t.Dict[AppType, str],
Expand Down Expand Up @@ -797,6 +801,16 @@ def execute_action(
:return: Output object from the function call
"""
action = Action(action)
if (
self._requested_actions is not None
and action.slug not in self._requested_actions
):
raise ComposioSDKError(
f"Action {action.slug} is being called, but was never requested by the toolset. "
"Make sure that the actions you are trying to execute are requested in your "
"`get_tools()` call."
)

params = self._serialize_execute_params(param=params)
if processors is not None:
self._merge_processors(processors)
Expand Down Expand Up @@ -932,6 +946,7 @@ def validate_tools(
# NOTE: This an experimental, can convert to decorator for more convinience
if not apps and not actions and not tags:
return

self.workspace.check_for_missing_dependencies(
apps=apps,
actions=actions,
Expand All @@ -945,6 +960,7 @@ def get_action_schemas(
tags: t.Optional[t.Sequence[TagType]] = None,
*,
check_connected_accounts: bool = True,
_populate_requested: bool = False,
) -> t.List[ActionModel]:
runtime_actions = t.cast(
t.List[t.Type[LocalAction]],
Expand Down Expand Up @@ -1010,6 +1026,13 @@ def get_action_schemas(
if item.name == Action.ANTHROPIC_TEXT_EDITOR.slug:
item.name = "str_replace_editor"

if _populate_requested:
action_names = [item.name for item in items]
if self._requested_actions is None:
self._requested_actions = []

self._requested_actions += action_names

return items

def _process_schema(self, action_item: ActionModel) -> ActionModel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain_openai import ChatOpenAI


@action(toolname="math", requires=["smtplib"])
def multiply(a: int, b: int, c: int) -> int:
"""
Expand Down
2 changes: 1 addition & 1 deletion python/plugins/autogen/composio_autogen/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def register_tools(
self,
caller: ConversableAgent,
executor: ConversableAgent,
apps: t.Optional[t.Sequence[AppType]] = None,
actions: t.Optional[t.Sequence[ActionType]] = None,
apps: t.Optional[t.Sequence[AppType]] = None,
tags: t.Optional[t.List[TagType]] = None,
entity_id: t.Optional[str] = None,
) -> None:
Expand Down
1 change: 1 addition & 0 deletions python/plugins/camel/composio_camel/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,5 +187,6 @@ def get_tools(
apps=apps,
tags=tags,
check_connected_accounts=check_connected_accounts,
_populate_requested=True,
)
]
1 change: 1 addition & 0 deletions python/plugins/claude/composio_claude/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def get_tools(
apps=apps,
tags=tags,
check_connected_accounts=check_connected_accounts,
_populate_requested=True,
)
]

Expand Down
1 change: 1 addition & 0 deletions python/plugins/crew_ai/composio_crewai/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def get_tools(
apps=apps,
tags=tags,
check_connected_accounts=check_connected_accounts,
_populate_requested=True,
)
]

Expand Down
1 change: 1 addition & 0 deletions python/plugins/griptape/composio_griptape/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,6 @@ def get_tools(
apps=apps,
tags=tags,
check_connected_accounts=check_connected_accounts,
_populate_requested=True,
)
]
1 change: 1 addition & 0 deletions python/plugins/langchain/composio_langchain/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,6 @@ def get_tools(
apps=apps,
tags=tags,
check_connected_accounts=check_connected_accounts,
_populate_requested=True,
)
]
1 change: 1 addition & 0 deletions python/plugins/llamaindex/composio_llamaindex/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,6 @@ def get_tools(
apps=apps,
tags=tags,
check_connected_accounts=check_connected_accounts,
_populate_requested=True,
)
]
1 change: 1 addition & 0 deletions python/plugins/lyzr/composio_lyzr/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,6 @@ def get_tools(
apps=apps,
tags=tags,
check_connected_accounts=check_connected_accounts,
_populate_requested=True,
)
]
1 change: 1 addition & 0 deletions python/plugins/openai/composio_openai/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def get_tools(
apps=apps,
tags=tags,
check_connected_accounts=check_connected_accounts,
_populate_requested=True,
)
]

Expand Down
1 change: 1 addition & 0 deletions python/plugins/phidata/composio_phidata/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,6 @@ def get_tools(
apps=apps,
tags=tags,
check_connected_accounts=check_connected_accounts,
_populate_requested=True,
)
]
1 change: 1 addition & 0 deletions python/plugins/praisonai/composio_praisonai/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,5 +216,6 @@ def get_tools(
apps=apps,
tags=tags,
check_connected_accounts=check_connected_accounts,
_populate_requested=True,
)
]
2 changes: 1 addition & 1 deletion python/tests/test_cli/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ def test_list_one(self) -> None:
assert (
"Id : 6f4f4191-7fe9-4b5c-b491-4b7ec56ebf5d" in result.stdout
), result.stderr
assert "Status: ACTIVE" in result.stdout, result.stderr
assert "Status: EXPIRED" in result.stdout, result.stderr
25 changes: 25 additions & 0 deletions python/tests/test_tools/test_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,28 @@ class SomeToolsetExtention(ComposioToolSet, action_name_char_limit=char_limit):
]
)
assert len(t.cast(str, schema.name)) == char_limit


def test_invalid_handle_tool_calls() -> None:
"""Test edge case where the Agent tries to call a tool that wasn't requested from get_tools()."""
toolset = LangchainToolSet()

toolset.get_tools(actions=[Action.GMAIL_FETCH_EMAILS])
with pytest.raises(ComposioSDKError) as exc:
with mock.patch.object(toolset, "_execute_remote"):
toolset.execute_action(Action.HACKERNEWS_GET_FRONTPAGE, {})

assert (
"Action HACKERNEWS_GET_FRONTPAGE is being called, but was never requested by the toolset."
in exc.value.message
)

# Ensure it does NOT fail if a subsequent get_tools added that action
toolset.get_tools(actions=[Action.HACKERNEWS_GET_FRONTPAGE])
with mock.patch.object(toolset, "_execute_remote"):
toolset.execute_action(Action.HACKERNEWS_GET_FRONTPAGE, {})

# Ensure it DOES NOT fail if get_tools is never called
toolset = LangchainToolSet()
with mock.patch.object(toolset, "_execute_remote"):
toolset.execute_action(Action.HACKERNEWS_GET_FRONTPAGE, {})

0 comments on commit f496f7f

Please sign in to comment.