Skip to content

Commit

Permalink
Add Structured Output functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 26, 2024
1 parent af1e5eb commit 896e534
Show file tree
Hide file tree
Showing 46 changed files with 1,237 additions and 178 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for `BranchTask` in `StructureVisualizer`.
- `EvalEngine` for evaluating the performance of an LLM's output against a given input.
- `BaseFileLoader.save()` method for saving an Artifact to a destination.
- `BasePromptDriver.use_native_structured_output` for enabling or disabling structured output.
- `BasePromptDriver.native_structured_output_strategy` for changing the structured output strategy between `native` and `tool`.

### Changed

- Rulesets can now be serialized and deserialized.
- `ToolkitTask` now serializes its `tools` field.
- `PromptTask.prompt_driver` is now serialized.
- `PromptTask` can now do everything a `ToolkitTask` can do.
- `JsonSchemaRule`s can now take a `schema.Schema` instance. Required for using a `JsonSchemaRule` with structured output.
- `JsonSchemaRule`s will now be used for structured output if the Prompt Driver supports it.

### Fixed

Expand Down
31 changes: 31 additions & 0 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,37 @@ You can pass images to the Driver if the model supports it:
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_images.py"
```

## Structured Output

Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems.

Structured output can be enabled or disabled for a Prompt Driver by setting the [use_native_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_native_structured_output).

If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [native_structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.native_structured_output_strategy) to one of:

- `native`: The Driver will use the LLM's structured output functionality provided by the API.
- `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output_tool.md) and try to force the LLM to use a Tool.

### JSON Schema

The easiest way to get started with structured output is by using a [JsonSchemaRule](../structures/rulesets.md#json-schema). If a [schema.Schema](https://pypi.org/project/schema/) instance is provided to the Rule, Griptape will convert it to a JSON Schema and provide it to the LLM using the selected structured output strategy.

```python
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py"
```

### Multiple Schemas

If multiple `JsonSchemaRule`s are provided, Griptape will merge them into a single JSON Schema using `anyOf`.

Some LLMs may not support `anyOf` as a top-level JSON Schema. To work around this, you can try using another `native_structured_output_strategy`:

```python
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py"
```

Not every LLM supports `use_native_structured_output` or all `native_structured_output_strategy` options.

## Prompt Drivers

Griptape offers the following Prompt Drivers for interacting with LLMs.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import json

import schema
from rich.pretty import pprint

from griptape.drivers import OpenAiChatPromptDriver
from griptape.rules import JsonSchemaRule, Rule
from griptape.structures import Pipeline
from griptape.tasks import PromptTask

pipeline = Pipeline(
tasks=[
PromptTask(
prompt_driver=OpenAiChatPromptDriver(
model="gpt-4o",
use_native_structured_output=True,
native_structured_output_strategy="native",
),
rules=[
Rule("You are a helpful math tutor. Guide the user through the solution step by step."),
JsonSchemaRule(
schema.Schema(
{
"steps": [schema.Schema({"explanation": str, "output": str})],
"final_answer": str,
}
)
),
],
)
]
)

output = pipeline.run("How can I solve 8x + 7 = -23").output.value
parsed_output = json.loads(output)


pprint(parsed_output)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import schema
from rich.pretty import pprint

from griptape.drivers import OpenAiChatPromptDriver
from griptape.rules import JsonSchemaRule
from griptape.structures import Pipeline
from griptape.tasks import PromptTask

pipeline = Pipeline(
tasks=[
PromptTask(
prompt_driver=OpenAiChatPromptDriver(
model="gpt-4o",
use_native_structured_output=True,
native_structured_output_strategy="tool",
),
rules=[
JsonSchemaRule(schema.Schema({"color": "red"})),
JsonSchemaRule(schema.Schema({"color": "blue"})),
],
)
]
)

output = pipeline.run("Pick a color").output.value


pprint(output)
7 changes: 7 additions & 0 deletions docs/griptape-framework/structures/rulesets.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Ru
[JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema.
This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types.

If the Prompt Driver supports [Structured Output](../drivers/prompt-drivers.md#structured-output), Griptape will use the schema provided to the `JsonSchemaRule` to ensure JSON output.
If the Prompt Driver does not support Structured Output, Griptape will include the schema in the system prompt using [this template](https://github.com/griptape-ai/griptape/blob/main/griptape/templates/rules/json_schema.j2).

```python
--8<-- "docs/griptape-framework/structures/src/json_schema_rule.py"
```
Expand All @@ -47,6 +50,10 @@ Although Griptape leverages the `schema` library, you're free to use any JSON sc

For example, using `pydantic`:

!!! warning

Griptape does not yet support using `pydantic` schemas for[Structured Output](../drivers/prompt-drivers.md#structured-output). It is recommended to pass a `schema.Schema` instance.

```python
--8<-- "docs/griptape-framework/structures/src/json_schema_rule_pydantic.py"
```
Expand Down
8 changes: 1 addition & 7 deletions docs/griptape-framework/structures/src/json_schema_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@
from griptape.rules.json_schema_rule import JsonSchemaRule
from griptape.structures import Agent

agent = Agent(
rules=[
JsonSchemaRule(
schema.Schema({"answer": str, "relevant_emojis": schema.Schema(["str"])}).json_schema("Output Format")
)
]
)
agent = Agent(rules=[JsonSchemaRule(schema.Schema({"answer": str, "relevant_emojis": schema.Schema(["str"])}))])

output = agent.run("What is the sentiment of this message?: 'I am so happy!'").output

Expand Down
5 changes: 4 additions & 1 deletion griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from attrs import define, field

Expand All @@ -24,13 +24,16 @@
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
from schema import Schema

from griptape.tools import BaseTool


@define
class PromptStack(SerializableMixin):
messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True})
tools: list[BaseTool] = field(factory=list, kw_only=True)
output_schema: Optional[Schema] = field(default=None, kw_only=True)

@property
def system_messages(self) -> list[Message]:
Expand Down
41 changes: 32 additions & 9 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

from attrs import Factory, define, field
from attrs import Attribute, Factory, define, field
from schema import Schema

from griptape.artifacts import (
Expand Down Expand Up @@ -55,9 +55,20 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_strategy: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
if value == "native":
raise ValueError("AmazonBedrockPromptDriver does not support `native` structured output mode.")

return value

@lazy_property()
def client(self) -> Any:
return self.session.client("bedrock-runtime")
Expand Down Expand Up @@ -103,10 +114,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:

def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages]

messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()])

return {
params = {
"modelId": self.model,
"messages": messages,
"system": system_messages,
Expand All @@ -115,14 +125,27 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}),
},
"additionalModelRequestFields": self.additional_model_request_fields,
**(
{"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}}
if prompt_stack.tools and self.use_native_tools
else {}
),
**self.extra_params,
}

if prompt_stack.tools and self.use_native_tools:
params["toolConfig"] = {
"tools": [],
"toolChoice": self.tool_choice,
}

if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_strategy == "tool"
):
self._add_structured_output_tool(prompt_stack)
params["toolConfig"]["toolChoice"] = {"any": {}}

params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools)

return params

def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]:
return [
{
Expand Down
37 changes: 29 additions & 8 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional

from attrs import Factory, define, field
from attrs import Attribute, Factory, define, field
from schema import Schema

from griptape.artifacts import (
Expand Down Expand Up @@ -68,13 +68,24 @@ class AnthropicPromptDriver(BasePromptDriver):
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_strategy: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Client:
return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)

@native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
if value == "native":
raise ValueError("AnthropicPromptDriver does not support `native` structured output mode.")

return value

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
params = self._base_params(prompt_stack)
Expand Down Expand Up @@ -110,23 +121,33 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = prompt_stack.system_messages
system_message = system_messages[0].to_text() if system_messages else None

return {
params = {
"model": self.model,
"temperature": self.temperature,
"stop_sequences": self.tokenizer.stop_sequences,
"top_p": self.top_p,
"top_k": self.top_k,
"max_tokens": self.max_tokens,
"messages": messages,
**(
{"tools": self.__to_anthropic_tools(prompt_stack.tools), "tool_choice": self.tool_choice}
if prompt_stack.tools and self.use_native_tools
else {}
),
**({"system": system_message} if system_message else {}),
**self.extra_params,
}

if prompt_stack.tools and self.use_native_tools:
params["tool_choice"] = self.tool_choice

if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_strategy == "tool"
):
self._add_structured_output_tool(prompt_stack)
params["tool_choice"] = {"type": "any"}

params["tools"] = self.__to_anthropic_tools(prompt_stack.tools)

return params

def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
return [
{"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)}
Expand Down
16 changes: 15 additions & 1 deletion griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional

from attrs import Factory, define, field

Expand Down Expand Up @@ -56,6 +56,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
tokenizer: BaseTokenizer
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True})
native_structured_output_strategy: Literal["native", "tool"] = field(
default="native", kw_only=True, metadata={"serializable": True}
)
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})

def before_run(self, prompt_stack: PromptStack) -> None:
Expand Down Expand Up @@ -122,6 +126,16 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ...
@abstractmethod
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ...

def _add_structured_output_tool(self, prompt_stack: PromptStack) -> None:
from griptape.tools.structured_output.tool import StructuredOutputTool

if prompt_stack.output_schema is None:
raise ValueError("PromptStack must have an output schema to use structured output.")

structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
if structured_output_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_output_tool)

def __process_run(self, prompt_stack: PromptStack) -> Message:
return self.try_run(prompt_stack)

Expand Down
Loading

0 comments on commit 896e534

Please sign in to comment.