Skip to content

Commit

Permalink
Don't use JsonSchemaRules
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jan 2, 2025
1 parent acd9f11 commit 13e59f5
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 95 deletions.
2 changes: 0 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `PromptTask.prompt_driver` is now serialized.
- `PromptTask` can now do everything a `ToolkitTask` can do.
- Loosten `numpy`s version constraint to `>=1.26.4,<3`.
- `JsonSchemaRule`s can now take a `schema.Schema` instance. Required for using a `JsonSchemaRule` with structured output.
- `JsonSchemaRule`s will now be used for structured output if the Prompt Driver supports it.

### Fixed

Expand Down
13 changes: 2 additions & 11 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,8 @@ The easiest way to get started with structured output is by using a [JsonSchemaR
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py"
```

### Multiple Schemas

If multiple `JsonSchemaRule`s are provided, Griptape will merge them into a single JSON Schema using `anyOf`.

Some LLMs may not support `anyOf` as a top-level JSON Schema. To work around this, you can try using another `native_structured_output_strategy`:

```python
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py"
```

Not every LLM supports `use_native_structured_output` or all `native_structured_output_strategy` options.
!!! warning
Not every LLM supports `use_native_structured_output` or all `native_structured_output_strategy` options.

## Prompt Drivers

Expand Down
3 changes: 1 addition & 2 deletions docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ Handler 2 <class 'griptape.events.finish_structure_run_event.FinishStructureRunE
You can use `Structure.run_stream()` for streaming Events from the `Structure` in the form of an iterator.

!!! tip

Set `stream=True` on your [Prompt Driver](../drivers/prompt-drivers.md) in order to receive completion chunk events.
Set `stream=True` on your [Prompt Driver](../drivers/prompt-drivers.md) in order to receive completion chunk events.

```python
--8<-- "docs/griptape-framework/misc/src/events_streaming.py"
Expand Down
48 changes: 10 additions & 38 deletions griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

import json
import logging
import warnings
from typing import TYPE_CHECKING, Callable, Optional, Union

from attrs import NOTHING, Attribute, Factory, NothingType, define, field
from schema import Or, Schema

from griptape import utils
from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact
Expand All @@ -16,11 +14,13 @@
from griptape.memory.structure import Run
from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin
from griptape.mixins.rule_mixin import RuleMixin
from griptape.rules import JsonSchemaRule, Ruleset
from griptape.rules import Ruleset
from griptape.tasks import ActionsSubtask, BaseTask
from griptape.utils import J2

if TYPE_CHECKING:
from schema import Schema

from griptape.drivers import BasePromptDriver
from griptape.memory import TaskMemory
from griptape.memory.structure.base_conversation_memory import BaseConversationMemory
Expand All @@ -39,6 +39,7 @@ class PromptTask(BaseTask, RuleMixin, ActionsSubtaskOriginMixin):
prompt_driver: BasePromptDriver = field(
default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True, metadata={"serializable": True}
)
output_schema: Optional[Schema] = field(default=None, kw_only=True)
generate_system_template: Callable[[PromptTask], str] = field(
default=Factory(lambda self: self.default_generate_system_template, takes_self=True),
kw_only=True,
Expand Down Expand Up @@ -90,22 +91,12 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask],

@property
def prompt_stack(self) -> PromptStack:
stack = PromptStack(tools=self.tools)
stack = PromptStack(tools=self.tools, output_schema=self.output_schema)
memory = self.structure.conversation_memory if self.structure is not None else None

rulesets = self.rulesets
system_artifacts = [TextArtifact(self.generate_system_template(self))]
if self.prompt_driver.use_native_structured_output:
self._add_native_schema_to_prompt_stack(stack, rulesets)

# Ensure there is at least one Ruleset that has non-empty `rules`.
if any(len(ruleset.rules) for ruleset in rulesets):
system_artifacts.append(TextArtifact(J2("rulesets/rulesets.j2").render(rulesets=rulesets)))

# Ensure there is at least one system Artifact that has a non-empty value.
has_system_artifacts = any(system_artifact.value for system_artifact in system_artifacts)
if has_system_artifacts:
stack.add_system_message(ListArtifact(system_artifacts))
system_template = self.generate_system_template(self)
if system_template:
stack.add_system_message(system_template)

stack.add_user_message(self.input)

Expand All @@ -116,7 +107,7 @@ def prompt_stack(self) -> PromptStack:

if memory is not None:
# inserting at index 1 to place memory right after system prompt
memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if has_system_artifacts else 0)
memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_template else 0)

return stack

Expand Down Expand Up @@ -226,6 +217,7 @@ def default_generate_system_template(self, _: PromptTask) -> str:
schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually.

return J2("tasks/prompt_task/system.j2").render(
rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets),
action_names=str.join(", ", [tool.name for tool in self.tools]),
actions_schema=utils.minify_json(json.dumps(schema)),
meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories),
Expand Down Expand Up @@ -307,26 +299,6 @@ def _process_task_input(
else:
return self._process_task_input(TextArtifact(task_input))

def _add_native_schema_to_prompt_stack(self, stack: PromptStack, rulesets: list[Ruleset]) -> None:
# Need to separate JsonSchemaRules from other rules, removing them in the process
json_schema_rules = [rule for ruleset in rulesets for rule in ruleset.rules if isinstance(rule, JsonSchemaRule)]
non_json_schema_rules = [
[rule for rule in ruleset.rules if not isinstance(rule, JsonSchemaRule)] for ruleset in rulesets
]
for ruleset, non_json_rules in zip(rulesets, non_json_schema_rules):
ruleset.rules = non_json_rules

schemas = [rule.value for rule in json_schema_rules if isinstance(rule.value, Schema)]

if len(json_schema_rules) != len(schemas):
warnings.warn(
"Not all provided `JsonSchemaRule`s include a `schema.Schema` instance. These will be ignored with `use_native_structured_output`.",
stacklevel=2,
)

if schemas:
stack.output_schema = schemas[0] if len(schemas) == 1 else Schema(Or(*schemas))

def _add_subtasks_to_prompt_stack(self, stack: PromptStack) -> None:
for s in self.subtasks:
if self.prompt_driver.use_native_tools:
Expand Down
2 changes: 2 additions & 0 deletions mise.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tools]
python = "3.9"
2 changes: 1 addition & 1 deletion tests/mocks/mock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MockPromptDriver(BasePromptDriver):

def try_run(self, prompt_stack: PromptStack) -> Message:
output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output
if prompt_stack.output_schema and self.use_native_structured_output:
if self.use_native_structured_output and prompt_stack.output_schema:
if self.native_structured_output_strategy == "native":
return Message(
content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))],
Expand Down
42 changes: 1 addition & 41 deletions tests/unit/tasks/test_prompt_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import warnings

import pytest

from griptape.artifacts.image_artifact import ImageArtifact
from griptape.artifacts.json_artifact import JsonArtifact
from griptape.artifacts.list_artifact import ListArtifact
Expand Down Expand Up @@ -188,7 +186,7 @@ def test_prompt_stack_native_schema(self):
use_native_structured_output=True,
mock_structured_output={"baz": "foo"},
),
rules=[JsonSchemaRule(output_schema)],
output_schema=output_schema,
)
output = task.run()

Expand All @@ -204,27 +202,6 @@ def test_prompt_stack_native_schema(self):
warnings.simplefilter("error")
assert task.prompt_stack

def test_prompt_stack_mixed_native_schema(self):
from schema import Schema

output_schema = Schema({"baz": str})
task = PromptTask(
input="foo",
prompt_driver=MockPromptDriver(
use_native_structured_output=True,
),
rules=[Rule("foo"), JsonSchemaRule({"bar": {}}), JsonSchemaRule(output_schema)],
)

assert task.prompt_stack.output_schema is output_schema
assert task.prompt_stack.messages[0].is_system()
assert "foo" in task.prompt_stack.messages[0].to_text()
assert "bar" not in task.prompt_stack.messages[0].to_text()
with pytest.warns(
match="Not all provided `JsonSchemaRule`s include a `schema.Schema` instance. These will be ignored with `use_native_structured_output`."
):
assert task.prompt_stack

def test_prompt_stack_empty_native_schema(self):
task = PromptTask(
input="foo",
Expand All @@ -236,23 +213,6 @@ def test_prompt_stack_empty_native_schema(self):

assert task.prompt_stack.output_schema is None

def test_prompt_stack_multi_native_schema(self):
from schema import Or, Schema

output_schema = Schema({"foo": str})
task = PromptTask(
input="foo",
prompt_driver=MockPromptDriver(
use_native_structured_output=True,
),
rules=[JsonSchemaRule({"foo": {}}), JsonSchemaRule(output_schema), JsonSchemaRule(output_schema)],
)

assert isinstance(task.prompt_stack.output_schema, Schema)
assert task.prompt_stack.output_schema.json_schema("Output") == Schema(
Or(output_schema, output_schema)
).json_schema("Output")

def test_rulesets(self):
pipeline = Pipeline(
rulesets=[Ruleset("Pipeline Ruleset")],
Expand Down

0 comments on commit 13e59f5

Please sign in to comment.