From 1d118a0c4c26676d805a56fbf2dce408496746b5 Mon Sep 17 00:00:00 2001 From: Novice Lee Date: Sun, 8 Dec 2024 14:56:08 +0800 Subject: [PATCH] fix: test cases error --- .../workflow/graph_engine/graph_engine.py | 26 ++++---- .../workflow/nodes/test_continue_on_error.py | 66 +++++++------------ 2 files changed, 37 insertions(+), 55 deletions(-) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 0730a2732008d0..e03d4a7194a11e 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -339,6 +339,7 @@ def _run( next_node_id = edge.target_node_id else: final_node_id = None + if any(edge.run_condition for edge in edge_mappings): # if nodes has run conditions, get node id which branch to take based on the run condition results condition_edge_mappings = {} @@ -701,19 +702,18 @@ def _run_node( run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( parent_parallel_start_node_id ) - event_args = { - "id": node_instance.id, - "node_id": node_instance.node_id, - "node_type": node_instance.node_type, - "node_data": node_instance.node_data, - "route_node_state": route_node_state, - "parallel_id": parallel_id, - "parallel_start_node_id": parallel_start_node_id, - "parent_parallel_id": parent_parallel_id, - "parent_parallel_start_node_id": parent_parallel_start_node_id, - } - event = NodeRunSucceededEvent(**event_args) - yield event + + yield NodeRunSucceededEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) break elif isinstance(item, RunStreamChunkEvent): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index ec6d067a7165ab..30751fc104fdb1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -33,8 +33,25 @@ def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: return node @staticmethod - def get_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + def get_http_node( + error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False + ): """Helper method to create a http node configuration""" + authorization = ( + { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + } + if authorization_success + else { + "type": "api-key", + # missing config field + } + ) node = { "id": "node", "data": { @@ -42,10 +59,7 @@ def get_http_node(error_strategy: str = "fail-branch", default_value: dict | Non "desc": "", "method": "get", "url": "http://example.com", - "authorization": { - "type": "api-key", - # missing config field - }, + "authorization": authorization, "headers": "X-Header:123", "params": "A:b", "body": None, @@ -214,7 +228,7 @@ def main() -> dict: {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, ContinueOnErrorTestHelper.get_code_node( - error_code, "default-value", [{"key": "result", "type": "Number", "value": 132123}] + error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}] ), ], } @@ -259,38 +273,6 @@ def main() -> dict: ) -def test_code_success_branch_continue_on_error(): - success_code = """ - def main() -> dict: - return { - "result": 1 / 1, - } - """ - - graph_config = { - "edges": FAIL_BRANCH_EDGES, - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - ContinueOnErrorTestHelper.get_code_node(success_code), - { - "data": {"title": "success", "type": "answer", "answer": "node node run successfully"}, - "id": "success", - }, - { - "data": {"title": "error", "type": "answer", "answer": "node node run failed"}, - "id": "error", - }, - ], - } - - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - events = list(graph_engine.run()) - assert any( - isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "node node run successfully"} for e in events - ) - assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 - - def test_http_node_default_value_continue_on_error(): """Test HTTP node with default value error strategy""" graph_config = { @@ -299,7 +281,7 @@ def test_http_node_default_value_continue_on_error(): {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"}, ContinueOnErrorTestHelper.get_http_node( - "default-value", [{"key": "response", "type": "String", "value": "http node got error response"}] + "default-value", [{"key": "response", "type": "string", "value": "http node got error response"}] ), ], } @@ -351,7 +333,7 @@ def test_tool_node_default_value_continue_on_error(): {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, ContinueOnErrorTestHelper.get_tool_node( - "default-value", [{"key": "result", "type": "String", "value": "default tool result"}] + "default-value", [{"key": "result", "type": "string", "value": "default tool result"}] ), ], } @@ -402,7 +384,7 @@ def test_llm_node_default_value_continue_on_error(): {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"}, ContinueOnErrorTestHelper.get_llm_node( - "default-value", [{"key": "answer", "type": "String", "value": "default LLM response"}] + "default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}] ), ], } @@ -531,7 +513,7 @@ def main() -> dict: "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, "id": "error", }, - ContinueOnErrorTestHelper.get_code_node(code=success_code), + ContinueOnErrorTestHelper.get_http_node(authorization_success=True), { "id": "code", "data": {