Skip to content

Commit

Permalink
Refine code
Browse files Browse the repository at this point in the history
  • Loading branch information
Lina Tang committed Dec 15, 2023
1 parent 2e33df7 commit 76e7f7a
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 44 deletions.
39 changes: 16 additions & 23 deletions src/promptflow/promptflow/_core/tools_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,44 +392,37 @@ def load_tool_for_node(self, node: Node) -> Tool:
raise NotImplementedError(f"Tool type {node.type} is not supported yet.")

def load_tool_for_package_node(self, node: Node) -> Tool:
return self.load_tool_for_package(node.source.tool)

def load_tool_for_script_node(self, node: Node) -> Tuple[types.ModuleType, Callable, Tool]:
if node.source.path is None:
raise UserErrorException(f"Node {node.name} does not have source path defined.")
return self._load_tool_for_source_path(node.source.path)

def load_tool_for_llm_node(self, node: Node) -> Tool:
api_name = f"{node.provider}.{node.api}"
return BuiltinsManager._load_llm_api(api_name)

def load_tool_for_package(self, package: str) -> Tool:
if package in self._package_tools:
return Tool.deserialize(self._package_tools[package])
if node.source.tool in self._package_tools:
return Tool.deserialize(self._package_tools[node.source.tool])

# If node source tool is not in package tools, try to find the tool ID in deprecated tools.
# If found, load the tool with the new tool ID for backward compatibility.
if package in self._deprecated_tools:
new_tool_id = self._deprecated_tools[package]
if node.source.tool in self._deprecated_tools:
new_tool_id = self._deprecated_tools[node.source.tool]
# Used to collect deprecated tool usage and warn user to replace the deprecated tool with the new one.
module_logger.warning(
f"Tool ID '{package}' is deprecated. Please use '{new_tool_id}' instead."
)
module_logger.warning(f"Tool ID '{node.source.tool}' is deprecated. Please use '{new_tool_id}' instead.")
return Tool.deserialize(self._package_tools[new_tool_id])

raise PackageToolNotFoundError(
f"Package tool '{package}' is not found in the current environment. "
f"Package tool '{node.source.tool}' is not found in the current environment. "
f"All available package tools are: {list(self._package_tools.keys())}.",
target=ErrorTarget.EXECUTOR,
)

def _load_tool_for_source_path(self, source_path: str) -> Tuple[types.ModuleType, Callable, Tool]:
m = load_python_module_from_file(self._working_dir / source_path)
def load_tool_for_script_node(self, node: Node) -> Tuple[types.ModuleType, Callable, Tool]:
if node.source.path is None:
raise UserErrorException(f"Node {node.name} does not have source path defined.")
path = node.source.path
m = load_python_module_from_file(self._working_dir / path)
if m is None:
raise CustomToolSourceLoadError(f"Cannot load module from {source_path}.")
raise CustomToolSourceLoadError(f"Cannot load module from {path}.")
f, init_inputs = collect_tool_function_in_module(m)
return m, _parse_tool_from_function(f, init_inputs, gen_custom_type_conn=True)

def load_tool_for_llm_node(self, node: Node) -> Tool:
api_name = f"{node.provider}.{node.api}"
return BuiltinsManager._load_llm_api(api_name)


builtins = {}
apis = {}
Expand Down
16 changes: 12 additions & 4 deletions src/promptflow/promptflow/contracts/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,18 @@ class FilePath(str):
class AssistantDefinition:
"""This class is used to hint a parameter is an assistant override."""

def __init__(self, value: dict):
self.model = value["module"]
self.instructions = value["instructions"]
self.tools = value["tools"]
def __init__(self, model: str, instructions: str, tools: list):
self.model = model
self.instructions =instructions
self.tools = tools

@staticmethod
def deserialize(data: dict) -> "AssistantDefinition":
return AssistantDefinition(
model=data.get("module"),
instructions=data.get("instructions"),
tools=data.get("tools")
)

def serialize(self):
return {
Expand Down
27 changes: 13 additions & 14 deletions src/promptflow/promptflow/executor/_tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
name,
run_tracker: RunTracker,
cache_manager: AbstractCacheManager,
working_dir: Optional[Path] = None,
connections: Optional[dict] = None,
run_id=None,
flow_id=None,
Expand All @@ -50,6 +51,7 @@ def __init__(
self._name = name
self._run_tracker = run_tracker
self._cache_manager = cache_manager
self._working_dir = working_dir
self._connections = connections or {}
self._run_id = run_id or str(uuid.uuid4())
self._flow_id = flow_id or self._run_id
Expand All @@ -63,13 +65,14 @@ def start_invoker(
name,
run_tracker: RunTracker,
cache_manager: AbstractCacheManager,
working_dir: Optional[Path] = None,
connections: Optional[dict] = None,
run_id=None,
flow_id=None,
line_number=None,
variant_id=None
):
invoker = cls(name, run_tracker, cache_manager, connections, run_id, flow_id, line_number, variant_id)
invoker = cls(name, run_tracker, cache_manager, working_dir, connections, run_id, flow_id, line_number, variant_id)
active_invoker = cls.active_instance()
if active_invoker:
active_invoker._deactivate_in_context()
Expand All @@ -92,7 +95,7 @@ def load_assistant_tools(cls, tools: list):
inputs=updated_inputs,
source=ToolSource.deserialize(tool["source"])
)
tool_resolver = ToolResolver(working_dir=Path(os.getcwd()), connections=invoker._connections)
tool_resolver = ToolResolver(working_dir=invoker._working_dir, connections=invoker._connections)
resolved_tool = tool_resolver._resolve_script_node(node, convert_input_types=True)
if resolved_tool.node.inputs:
inputs = {name: value.value for name, value in resolved_tool.node.inputs.items()}
Expand All @@ -101,18 +104,20 @@ def load_assistant_tools(cls, tools: list):
invoker._assistant_tools[resolved_tool.definition.function] = resolved_tool
return invoker

def invoke_assistant_tool(self, func_name, kwargs):
return self._assistant_tools[func_name].callable(**kwargs)

def to_openai_tools(self):
openai_tools = []
for name, tool in self._assistant_tools.items():
preset_inputs = [name for name, _ in tool.node.inputs.items()]
description = self._get_tool_description(name, tool.definition.description, preset_inputs)
description = self._get_openai_tool_description(name, tool.definition.description, preset_inputs)
openai_tools.append(description)
return openai_tools

def invoke_assistant_tool(self, func_name, kwargs):
return self._assistant_tools[func_name].callable(**kwargs)
def _get_openai_tool_description(self, func_name: str, docstring: str, preset_inputs: Optional[list] = None):
to_openai_type = {"str": "string", "int": "number"}

def _get_tool_description(self, func_name: str, docstring: str, preset_inputs: Optional[list] = None):
doctree = publish_doctree(docstring)
params = {}

Expand All @@ -133,7 +138,7 @@ def _get_tool_description(self, func_name: str, docstring: str, preset_inputs: O
continue
if param_name not in params:
params[param_name] = {}
params[param_name]["type"] = self._convert_type(field_body)
params[param_name]["type"] = to_openai_type[field_body] if field_body in to_openai_type else field_body

return {
"type": "function",
Expand All @@ -148,12 +153,6 @@ def _get_tool_description(self, func_name: str, docstring: str, preset_inputs: O
}
}

def _convert_type(self, type: str):
if type == "str":
return "string"
if type == "int":
return "number"

def _update_operation_context(self):
flow_context_info = {"flow-id": self._flow_id, "root-run-id": self._run_id}
OperationContext.get_instance().update(flow_context_info)
Expand Down Expand Up @@ -288,7 +287,7 @@ def _invoke_tool_with_timer(self, node: Node, f: Callable, kwargs):
raise ToolExecutionError(node_name=node_name, module=module) from e

def bypass_node(self, node: Node):
"""Update teh bypassed node run info."""
"""Update the bypassed node run info."""
node_run_id = self._generate_node_run_id(node)
flow_logger.info(f"Bypassing node {node.name}. node run id: {node_run_id}")
parent_run_id = f"{self._run_id}_{self._line_number}" if self._line_number is not None else self._run_id
Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _convert_node_literal_input_types(self, node: Node, tool: Tool, module: type
updated_inputs[k].value = create_image(v.value)
elif value_type == ValueType.ASSISTANT_DEFINITION:
definition = self._load_json_from_file(v.value, k, node.name)
updated_inputs[k].value = AssistantDefinition(definition)
updated_inputs[k].value = AssistantDefinition.deserialize(definition)
elif isinstance(value_type, ValueType):
try:
updated_inputs[k].value = value_type.parse(v.value)
Expand Down
7 changes: 5 additions & 2 deletions src/promptflow/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def load_and_exec_node(
name=flow.name,
run_tracker=run_tracker,
cache_manager=AbstractCacheManager.init_from_env(),
working_dir=working_dir,
connections=connections
)

Expand Down Expand Up @@ -598,6 +599,7 @@ def _exec_aggregation(
name=self._flow.name,
run_tracker=run_tracker,
cache_manager=self._cache_manager,
working_dir=self._working_dir,
connections=self._connections,
run_id=run_id,
flow_id=self._flow_id
Expand Down Expand Up @@ -773,6 +775,7 @@ def _exec(
name=self._flow.name,
run_tracker=run_tracker,
cache_manager=self._cache_manager,
working_dir=self._working_dir,
connections=self._connections,
run_id=run_id,
flow_id=self._flow_id,
Expand Down Expand Up @@ -887,7 +890,7 @@ def _stringify_generator_output(self, outputs: dict):

return outputs

def _submit_to_scheduler(self, context: DefaultToolInvoker, inputs, nodes: List[Node]) -> Tuple[dict, dict]:
def _submit_to_scheduler(self, invoker: DefaultToolInvoker, inputs, nodes: List[Node]) -> Tuple[dict, dict]:
if not isinstance(self._node_concurrency, int):
raise UnexpectedError(
message_format=(
Expand All @@ -896,7 +899,7 @@ def _submit_to_scheduler(self, context: DefaultToolInvoker, inputs, nodes: List[
),
current_value=self._node_concurrency,
)
return FlowNodesScheduler(self._tools_manager, inputs, nodes, self._node_concurrency, context).execute()
return FlowNodesScheduler(self._tools_manager, inputs, nodes, self._node_concurrency, invoker).execute()

@staticmethod
def apply_inputs_mapping(
Expand Down

0 comments on commit 76e7f7a

Please sign in to comment.