Skip to content

Commit

Permalink
Refine code
Browse files Browse the repository at this point in the history
  • Loading branch information
Lina Tang committed Dec 15, 2023
1 parent 2e33df7 commit 39bd23e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 26 deletions.
39 changes: 16 additions & 23 deletions src/promptflow/promptflow/_core/tools_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,44 +392,37 @@ def load_tool_for_node(self, node: Node) -> Tool:
raise NotImplementedError(f"Tool type {node.type} is not supported yet.")

def load_tool_for_package_node(self, node: Node) -> Tool:
return self.load_tool_for_package(node.source.tool)

def load_tool_for_script_node(self, node: Node) -> Tuple[types.ModuleType, Callable, Tool]:
if node.source.path is None:
raise UserErrorException(f"Node {node.name} does not have source path defined.")
return self._load_tool_for_source_path(node.source.path)

def load_tool_for_llm_node(self, node: Node) -> Tool:
api_name = f"{node.provider}.{node.api}"
return BuiltinsManager._load_llm_api(api_name)

def load_tool_for_package(self, package: str) -> Tool:
if package in self._package_tools:
return Tool.deserialize(self._package_tools[package])
if node.source.tool in self._package_tools:
return Tool.deserialize(self._package_tools[node.source.tool])

# If node source tool is not in package tools, try to find the tool ID in deprecated tools.
# If found, load the tool with the new tool ID for backward compatibility.
if package in self._deprecated_tools:
new_tool_id = self._deprecated_tools[package]
if node.source.tool in self._deprecated_tools:
new_tool_id = self._deprecated_tools[node.source.tool]
# Used to collect deprecated tool usage and warn user to replace the deprecated tool with the new one.
module_logger.warning(
f"Tool ID '{package}' is deprecated. Please use '{new_tool_id}' instead."
)
module_logger.warning(f"Tool ID '{node.source.tool}' is deprecated. Please use '{new_tool_id}' instead.")
return Tool.deserialize(self._package_tools[new_tool_id])

raise PackageToolNotFoundError(
f"Package tool '{package}' is not found in the current environment. "
f"Package tool '{node.source.tool}' is not found in the current environment. "
f"All available package tools are: {list(self._package_tools.keys())}.",
target=ErrorTarget.EXECUTOR,
)

def _load_tool_for_source_path(self, source_path: str) -> Tuple[types.ModuleType, Callable, Tool]:
m = load_python_module_from_file(self._working_dir / source_path)
def load_tool_for_script_node(self, node: Node) -> Tuple[types.ModuleType, Callable, Tool]:
if node.source.path is None:
raise UserErrorException(f"Node {node.name} does not have source path defined.")
path = node.source.path
m = load_python_module_from_file(self._working_dir / path)
if m is None:
raise CustomToolSourceLoadError(f"Cannot load module from {source_path}.")
raise CustomToolSourceLoadError(f"Cannot load module from {path}.")
f, init_inputs = collect_tool_function_in_module(m)
return m, _parse_tool_from_function(f, init_inputs, gen_custom_type_conn=True)

def load_tool_for_llm_node(self, node: Node) -> Tool:
api_name = f"{node.provider}.{node.api}"
return BuiltinsManager._load_llm_api(api_name)


builtins = {}
apis = {}
Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/executor/_tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def _invoke_tool_with_timer(self, node: Node, f: Callable, kwargs):
raise ToolExecutionError(node_name=node_name, module=module) from e

def bypass_node(self, node: Node):
"""Update teh bypassed node run info."""
"""Update the bypassed node run info."""
node_run_id = self._generate_node_run_id(node)
flow_logger.info(f"Bypassing node {node.name}. node run id: {node_run_id}")
parent_run_id = f"{self._run_id}_{self._line_number}" if self._line_number is not None else self._run_id
Expand Down
4 changes: 2 additions & 2 deletions src/promptflow/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ def _stringify_generator_output(self, outputs: dict):

return outputs

def _submit_to_scheduler(self, context: DefaultToolInvoker, inputs, nodes: List[Node]) -> Tuple[dict, dict]:
def _submit_to_scheduler(self, invoker: DefaultToolInvoker, inputs, nodes: List[Node]) -> Tuple[dict, dict]:
if not isinstance(self._node_concurrency, int):
raise UnexpectedError(
message_format=(
Expand All @@ -896,7 +896,7 @@ def _submit_to_scheduler(self, context: DefaultToolInvoker, inputs, nodes: List[
),
current_value=self._node_concurrency,
)
return FlowNodesScheduler(self._tools_manager, inputs, nodes, self._node_concurrency, context).execute()
return FlowNodesScheduler(self._tools_manager, inputs, nodes, self._node_concurrency, invoker).execute()

@staticmethod
def apply_inputs_mapping(
Expand Down

0 comments on commit 39bd23e

Please sign in to comment.