From 984629ed1a050fa16ea047cb295ea4228b772418 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 26 Dec 2024 12:00:02 -0800 Subject: [PATCH] Don't create subtasks when not using tools --- griptape/tasks/prompt_task.py | 29 +++++++++++++++------------- tests/unit/tasks/test_prompt_task.py | 15 ++++++++++++++ 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 4ed313bf0..5086636d0 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -203,21 +203,24 @@ def try_run(self) -> BaseArtifact: self.prompt_driver.tokenizer.stop_sequences.extend([self.response_stop_sequence]) result = self.prompt_driver.run(self.prompt_stack) - subtask = self.add_subtask(ActionsSubtask(result.to_artifact())) - - while True: - if subtask.output is None: - if len(self.subtasks) >= self.max_subtasks: - subtask.output = ErrorArtifact(f"Exceeded tool limit of {self.max_subtasks} subtasks per task") + if self.tools: + subtask = self.add_subtask(ActionsSubtask(result.to_artifact())) + + while True: + if subtask.output is None: + if len(self.subtasks) >= self.max_subtasks: + subtask.output = ErrorArtifact(f"Exceeded tool limit of {self.max_subtasks} subtasks per task") + else: + subtask.run() + + result = self.prompt_driver.run(self.prompt_stack) + subtask = self.add_subtask(ActionsSubtask(result.to_artifact())) else: - subtask.run() - - result = self.prompt_driver.run(self.prompt_stack) - subtask = self.add_subtask(ActionsSubtask(result.to_artifact())) - else: - break + break - self.output = subtask.output + self.output = subtask.output + else: + self.output = result.to_artifact() return self.output diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index a5d4521cf..f457a4b55 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -8,6 +8,7 @@ from griptape.structures import Pipeline from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.mocks.mock_tool.tool import MockTool class TestPromptTask: @@ -212,3 +213,17 @@ def test_conversation_memory(self): task.run() assert len(conversation_memory.runs) == 2 + + def test_subtasks(self): + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver(), + ) + + task.run() + assert len(task.subtasks) == 0 + + task = PromptTask(input="foo", prompt_driver=MockPromptDriver(use_native_tools=True), tools=[MockTool()]) + + task.run() + assert len(task.subtasks) == 2