Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Structured Output #1443

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Structure.run_stream()` for streaming Events from a Structure as an iterator.
- Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`.
- Validators to `Agent` initialization.
- `BasePromptDriver.use_native_structured_output` for enabling or disabling structured output.
- `BasePromptDriver.native_structured_output_strategy` for changing the structured output strategy between `native` and `tool`.

### Changed

Expand All @@ -29,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `PromptTask.prompt_driver` is now serialized.
- `PromptTask` can now do everything a `ToolkitTask` can do.
- Loosten `numpy`s version constraint to `>=1.26.4,<3`.
- `JsonSchemaRule`s can now take a `schema.Schema` instance. Required for using a `JsonSchemaRule` with structured output.
- `JsonSchemaRule`s will now be used for structured output if the Prompt Driver supports it.

### Fixed

Expand Down
31 changes: 31 additions & 0 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,37 @@ You can pass images to the Driver if the model supports it:
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_images.py"
```

## Structured Output

Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems.

Structured output can be enabled or disabled for a Prompt Driver by setting the [use_native_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_native_structured_output).

If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [native_structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.native_structured_output_strategy) to one of:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the default strategy? (Maybe mention it here?)


- `native`: The Driver will use the LLM's structured output functionality provided by the API.
- `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output_tool.md) and try to force the LLM to use a Tool.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what cases would you use the tool strategy? Can it result in multiple "tool" calls/outputs for a single LLM call?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually why do I need to pass the boolean in addition to the strategy name? If I just pass the strategy name is that not sufficient to infer that it should be used? In other words can you just remove the use_native_structured_output output flag and rely on native_structured_output_strategy exclusively?


### JSON Schema

The easiest way to get started with structured output is by using a [JsonSchemaRule](../structures/rulesets.md#json-schema). If a [schema.Schema](https://pypi.org/project/schema/) instance is provided to the Rule, Griptape will convert it to a JSON Schema and provide it to the LLM using the selected structured output strategy.

```python
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py"
```

### Multiple Schemas

If multiple `JsonSchemaRule`s are provided, Griptape will merge them into a single JSON Schema using `anyOf`.

Some LLMs may not support `anyOf` as a top-level JSON Schema. To work around this, you can try using another `native_structured_output_strategy`:

```python
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py"
```

Not every LLM supports `use_native_structured_output` or all `native_structured_output_strategy` options.

## Prompt Drivers

Griptape offers the following Prompt Drivers for interacting with LLMs.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import schema
from rich.pretty import pprint

from griptape.drivers import OpenAiChatPromptDriver
from griptape.rules import JsonSchemaRule, Rule
from griptape.structures import Pipeline
from griptape.tasks import PromptTask

pipeline = Pipeline(
tasks=[
PromptTask(
prompt_driver=OpenAiChatPromptDriver(
model="gpt-4o",
use_native_structured_output=True,
native_structured_output_strategy="native",
Comment on lines +14 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems weird to see use_native_... then native...strategy="native". Maybe just remove native from the param names?

),
rules=[
Rule("You are a helpful math tutor. Guide the user through the solution step by step."),
JsonSchemaRule(
schema.Schema(
{
"steps": [schema.Schema({"explanation": str, "output": str})],
"final_answer": str,
}
)
),
],
)
]
)

output = pipeline.run("How can I solve 8x + 7 = -23").output.value


pprint(output)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import schema
from rich.pretty import pprint

from griptape.drivers import OpenAiChatPromptDriver
from griptape.rules import JsonSchemaRule
from griptape.structures import Pipeline
from griptape.tasks import PromptTask

pipeline = Pipeline(
tasks=[
PromptTask(
prompt_driver=OpenAiChatPromptDriver(
model="gpt-4o",
use_native_structured_output=True,
native_structured_output_strategy="tool",
),
rules=[
JsonSchemaRule(schema.Schema({"color": "red"})),
JsonSchemaRule(schema.Schema({"color": "blue"})),
],
Comment on lines +12 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really hard for me to imagine what this would do (let alone decide to use it). How do the rules interact with the tool native output strategy?

)
]
)

output = pipeline.run("Pick a color").output.value


pprint(output)
3 changes: 2 additions & 1 deletion docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ Handler 2 <class 'griptape.events.finish_structure_run_event.FinishStructureRunE
You can use `Structure.run_stream()` for streaming Events from the `Structure` in the form of an iterator.

!!! tip
Set `stream=True` on your [Prompt Driver](../drivers/prompt-drivers.md) in order to receive completion chunk events.

Set `stream=True` on your [Prompt Driver](../drivers/prompt-drivers.md) in order to receive completion chunk events.

```python
--8<-- "docs/griptape-framework/misc/src/events_streaming.py"
Expand Down
7 changes: 7 additions & 0 deletions docs/griptape-framework/structures/rulesets.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Ru
[JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema.
This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types.

If the Prompt Driver supports [Structured Output](../drivers/prompt-drivers.md#structured-output), Griptape will use the schema provided to the `JsonSchemaRule` to ensure JSON output.
If the Prompt Driver does not support Structured Output, Griptape will include the schema in the system prompt using [this template](https://github.com/griptape-ai/griptape/blob/main/griptape/templates/rules/json_schema.j2).
Comment on lines +32 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest not using Griptape to like to refer to the thing performing the logic. Instead how about getting more specific about what is responsible? Or if only the effect is important, then maybe use a passive voice, something like the schema will be used to.....


```python
--8<-- "docs/griptape-framework/structures/src/json_schema_rule.py"
```
Expand All @@ -47,6 +50,10 @@ Although Griptape leverages the `schema` library, you're free to use any JSON sc

For example, using `pydantic`:

!!! warning

Griptape does not yet support using `pydantic` schemas for[Structured Output](../drivers/prompt-drivers.md#structured-output). It is recommended to pass a `schema.Schema` instance.

```python
--8<-- "docs/griptape-framework/structures/src/json_schema_rule_pydantic.py"
```
Expand Down
8 changes: 1 addition & 7 deletions docs/griptape-framework/structures/src/json_schema_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@
from griptape.rules.json_schema_rule import JsonSchemaRule
from griptape.structures import Agent

agent = Agent(
rules=[
JsonSchemaRule(
schema.Schema({"answer": str, "relevant_emojis": schema.Schema(["str"])}).json_schema("Output Format")
)
]
)
agent = Agent(rules=[JsonSchemaRule(schema.Schema({"answer": str, "relevant_emojis": schema.Schema(["str"])}))])

output = agent.run("What is the sentiment of this message?: 'I am so happy!'").output

Expand Down
3 changes: 2 additions & 1 deletion griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional

from attrs import define, field

Expand Down Expand Up @@ -31,6 +31,7 @@
class PromptStack(SerializableMixin):
messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True})
tools: list[BaseTool] = field(factory=list, kw_only=True)
output_schema: Optional[Any] = field(default=None, kw_only=True)

@property
def system_messages(self) -> list[Message]:
Expand Down
7 changes: 7 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .schema.base_schema_driver import BaseSchemaDriver
from .schema.schema_schema_driver import SchemaSchemaDriver
from .schema.pydantic_schema_driver import PydanticSchemaDriver

from .prompt.base_prompt_driver import BasePromptDriver
from .prompt.openai_chat_prompt_driver import OpenAiChatPromptDriver
from .prompt.azure_openai_chat_prompt_driver import AzureOpenAiChatPromptDriver
Expand Down Expand Up @@ -240,4 +244,7 @@
"BaseAssistantDriver",
"GriptapeCloudAssistantDriver",
"OpenAiAssistantDriver",
"BaseSchemaDriver",
"SchemaSchemaDriver",
"PydanticSchemaDriver",
]
45 changes: 33 additions & 12 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

from attrs import Factory, define, field
from attrs import Attribute, Factory, define, field
from schema import Schema

from griptape.artifacts import (
Expand Down Expand Up @@ -55,9 +55,20 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_strategy: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
if value == "native":
raise ValueError("AmazonBedrockPromptDriver does not support `native` structured output mode.")

return value

@lazy_property()
def client(self) -> Any:
return self.session.client("bedrock-runtime")
Expand Down Expand Up @@ -103,10 +114,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:

def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages]

messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()])

return {
params = {
"modelId": self.model,
"messages": messages,
"system": system_messages,
Expand All @@ -115,14 +125,27 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}),
},
"additionalModelRequestFields": self.additional_model_request_fields,
**(
{"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}}
if prompt_stack.tools and self.use_native_tools
else {}
),
**self.extra_params,
}

if prompt_stack.tools and self.use_native_tools:
params["toolConfig"] = {
"tools": [],
"toolChoice": self.tool_choice,
}

if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_strategy == "tool"
):
self._add_structured_output_tool(prompt_stack)
params["toolConfig"]["toolChoice"] = {"any": {}}

params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools)

return params

def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]:
return [
{
Expand All @@ -145,9 +168,7 @@ def __to_bedrock_tools(self, tools: list[BaseTool]) -> list[dict]:
"name": tool.to_native_tool_name(activity),
"description": tool.activity_description(activity),
"inputSchema": {
"json": (tool.activity_schema(activity) or Schema({})).json_schema(
"http://json-schema.org/draft-07/schema#",
),
"json": self.schema_driver.to_json_schema(tool.activity_schema(activity) or Schema({})),
},
},
}
Expand Down
37 changes: 29 additions & 8 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional

from attrs import Factory, define, field
from attrs import Attribute, Factory, define, field
from schema import Schema

from griptape.artifacts import (
Expand Down Expand Up @@ -68,13 +68,24 @@ class AnthropicPromptDriver(BasePromptDriver):
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_strategy: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Client:
return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)

@native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
if value == "native":
raise ValueError("AnthropicPromptDriver does not support `native` structured output mode.")

return value

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
params = self._base_params(prompt_stack)
Expand Down Expand Up @@ -110,23 +121,33 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = prompt_stack.system_messages
system_message = system_messages[0].to_text() if system_messages else None

return {
params = {
"model": self.model,
"temperature": self.temperature,
"stop_sequences": self.tokenizer.stop_sequences,
"top_p": self.top_p,
"top_k": self.top_k,
"max_tokens": self.max_tokens,
"messages": messages,
**(
{"tools": self.__to_anthropic_tools(prompt_stack.tools), "tool_choice": self.tool_choice}
if prompt_stack.tools and self.use_native_tools
else {}
),
**({"system": system_message} if system_message else {}),
**self.extra_params,
}

if prompt_stack.tools and self.use_native_tools:
params["tool_choice"] = self.tool_choice

if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_strategy == "tool"
):
self._add_structured_output_tool(prompt_stack)
params["tool_choice"] = {"type": "any"}

params["tools"] = self.__to_anthropic_tools(prompt_stack.tools)

return params

def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
return [
{"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)}
Expand Down
Loading
Loading