Skip to content

Commit

Permalink
Don't create subtasks when not using tools
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 26, 2024
1 parent af1e5eb commit 984629e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
29 changes: 16 additions & 13 deletions griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,21 +203,24 @@ def try_run(self) -> BaseArtifact:
self.prompt_driver.tokenizer.stop_sequences.extend([self.response_stop_sequence])

result = self.prompt_driver.run(self.prompt_stack)
subtask = self.add_subtask(ActionsSubtask(result.to_artifact()))

while True:
if subtask.output is None:
if len(self.subtasks) >= self.max_subtasks:
subtask.output = ErrorArtifact(f"Exceeded tool limit of {self.max_subtasks} subtasks per task")
if self.tools:
subtask = self.add_subtask(ActionsSubtask(result.to_artifact()))

while True:
if subtask.output is None:
if len(self.subtasks) >= self.max_subtasks:
subtask.output = ErrorArtifact(f"Exceeded tool limit of {self.max_subtasks} subtasks per task")
else:
subtask.run()

result = self.prompt_driver.run(self.prompt_stack)
subtask = self.add_subtask(ActionsSubtask(result.to_artifact()))
else:
subtask.run()

result = self.prompt_driver.run(self.prompt_stack)
subtask = self.add_subtask(ActionsSubtask(result.to_artifact()))
else:
break
break

Check warning on line 219 in griptape/tasks/prompt_task.py

View check run for this annotation

Codecov / codecov/patch

griptape/tasks/prompt_task.py#L219

Added line #L219 was not covered by tests

self.output = subtask.output
self.output = subtask.output
else:
self.output = result.to_artifact()

return self.output

Expand Down
15 changes: 15 additions & 0 deletions tests/unit/tasks/test_prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from griptape.structures import Pipeline
from griptape.tasks import PromptTask
from tests.mocks.mock_prompt_driver import MockPromptDriver
from tests.mocks.mock_tool.tool import MockTool


class TestPromptTask:
Expand Down Expand Up @@ -212,3 +213,17 @@ def test_conversation_memory(self):
task.run()

assert len(conversation_memory.runs) == 2

def test_subtasks(self):
task = PromptTask(
input="foo",
prompt_driver=MockPromptDriver(),
)

task.run()
assert len(task.subtasks) == 0

task = PromptTask(input="foo", prompt_driver=MockPromptDriver(use_native_tools=True), tools=[MockTool()])

task.run()
assert len(task.subtasks) == 2

0 comments on commit 984629e

Please sign in to comment.