From 9d22f0eff12ef3f7e12e258187ddfe56fc07edc8 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Fri, 3 Jan 2025 16:40:02 +0100 Subject: [PATCH] test(product-assistant): add full trends and funnel flow tests (#27214) --- ee/hogai/test/test_assistant.py | 120 +++++++++++++++++++++++++++++++- 1 file changed, 119 insertions(+), 1 deletion(-) diff --git a/ee/hogai/test/test_assistant.py b/ee/hogai/test/test_assistant.py index 48bb9b05d9b7e..d9a95660a4846 100644 --- a/ee/hogai/test/test_assistant.py +++ b/ee/hogai/test/test_assistant.py @@ -10,8 +10,21 @@ from langgraph.types import StateSnapshot from pydantic import BaseModel +from ee.hogai.funnels.nodes import FunnelsSchemaGeneratorOutput +from ee.hogai.router.nodes import RouterOutput +from ee.hogai.trends.nodes import TrendsSchemaGeneratorOutput from ee.models.assistant import Conversation -from posthog.schema import AssistantMessage, FailureMessage, HumanMessage, ReasoningMessage +from posthog.schema import ( + AssistantFunnelsEventsNode, + AssistantFunnelsQuery, + AssistantMessage, + AssistantTrendsQuery, + FailureMessage, + HumanMessage, + ReasoningMessage, + RouterMessage, + VisualizationMessage, +) from posthog.test.base import NonAtomicBaseTest from ..assistant import Assistant @@ -408,3 +421,108 @@ def node_handler(state): async for message in assistant._astream(): actual_output.append(self._parse_stringified_message(message)) self.assertConversationEqual(actual_output, expected_output) + + @patch("ee.hogai.summarizer.nodes.SummarizerNode._model") + @patch("ee.hogai.schema_generator.nodes.SchemaGeneratorNode._model") + @patch("ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model") + @patch("ee.hogai.router.nodes.RouterNode._model") + def test_full_trends_flow(self, router_mock, planner_mock, generator_mock, summarizer_mock): + router_mock.return_value = RunnableLambda(lambda _: RouterOutput(visualization_type="trends")) + planner_mock.return_value = RunnableLambda( + lambda _: messages.AIMessage( + content=""" + Thought: Done. + Action: + ``` + { + "action": "final_answer", + "action_input": "Plan" + } + ``` + """ + ) + ) + query = AssistantTrendsQuery(series=[]) + generator_mock.return_value = RunnableLambda(lambda _: TrendsSchemaGeneratorOutput(query=query)) + summarizer_mock.return_value = RunnableLambda(lambda _: AssistantMessage(content="Summary")) + + # First run + actual_output = self._run_assistant_graph(is_new_conversation=True) + expected_output = [ + ("conversation", {"id": str(self.conversation.id)}), + ("message", HumanMessage(content="Hello")), + ("message", ReasoningMessage(content="Identifying type of analysis")), + ("message", RouterMessage(content="trends")), + ("message", ReasoningMessage(content="Picking relevant events and properties", substeps=[])), + ("message", ReasoningMessage(content="Picking relevant events and properties", substeps=[])), + ("message", ReasoningMessage(content="Creating trends query")), + ("message", VisualizationMessage(answer=query, plan="Plan")), + ("message", AssistantMessage(content="Summary")), + ] + self.assertConversationEqual(actual_output, expected_output) + self.assertEqual(actual_output[1][1]["id"], actual_output[7][1]["initiator"]) + + # Second run + actual_output = self._run_assistant_graph(is_new_conversation=False) + self.assertConversationEqual(actual_output, expected_output[1:]) + self.assertEqual(actual_output[0][1]["id"], actual_output[6][1]["initiator"]) + + # Third run + actual_output = self._run_assistant_graph(is_new_conversation=False) + self.assertConversationEqual(actual_output, expected_output[1:]) + self.assertEqual(actual_output[0][1]["id"], actual_output[6][1]["initiator"]) + + @patch("ee.hogai.summarizer.nodes.SummarizerNode._model") + @patch("ee.hogai.schema_generator.nodes.SchemaGeneratorNode._model") + @patch("ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model") + @patch("ee.hogai.router.nodes.RouterNode._model") + def test_full_funnel_flow(self, router_mock, planner_mock, generator_mock, summarizer_mock): + router_mock.return_value = RunnableLambda(lambda _: RouterOutput(visualization_type="funnel")) + planner_mock.return_value = RunnableLambda( + lambda _: messages.AIMessage( + content=""" + Thought: Done. + Action: + ``` + { + "action": "final_answer", + "action_input": "Plan" + } + ``` + """ + ) + ) + query = AssistantFunnelsQuery( + series=[ + AssistantFunnelsEventsNode(event="$pageview"), + AssistantFunnelsEventsNode(event="$pageleave"), + ] + ) + generator_mock.return_value = RunnableLambda(lambda _: FunnelsSchemaGeneratorOutput(query=query)) + summarizer_mock.return_value = RunnableLambda(lambda _: AssistantMessage(content="Summary")) + + # First run + actual_output = self._run_assistant_graph(is_new_conversation=True) + expected_output = [ + ("conversation", {"id": str(self.conversation.id)}), + ("message", HumanMessage(content="Hello")), + ("message", ReasoningMessage(content="Identifying type of analysis")), + ("message", RouterMessage(content="funnel")), + ("message", ReasoningMessage(content="Picking relevant events and properties", substeps=[])), + ("message", ReasoningMessage(content="Picking relevant events and properties", substeps=[])), + ("message", ReasoningMessage(content="Creating funnel query")), + ("message", VisualizationMessage(answer=query, plan="Plan")), + ("message", AssistantMessage(content="Summary")), + ] + self.assertConversationEqual(actual_output, expected_output) + self.assertEqual(actual_output[1][1]["id"], actual_output[7][1]["initiator"]) + + # Second run + actual_output = self._run_assistant_graph(is_new_conversation=False) + self.assertConversationEqual(actual_output, expected_output[1:]) + self.assertEqual(actual_output[0][1]["id"], actual_output[6][1]["initiator"]) + + # Third run + actual_output = self._run_assistant_graph(is_new_conversation=False) + self.assertConversationEqual(actual_output, expected_output[1:]) + self.assertEqual(actual_output[0][1]["id"], actual_output[6][1]["initiator"])