Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jan 2, 2025
1 parent 13e59f5 commit 3762fc7
Show file tree
Hide file tree
Showing 32 changed files with 110 additions and 138 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`.
- Validators to `Agent` initialization.
- `BasePromptDriver.use_native_structured_output` for enabling or disabling structured output.
- `BasePromptDriver.native_structured_output_strategy` for changing the structured output strategy between `native` and `tool`.
- `BasePromptDriver.structured_output_strategy` for changing the structured output strategy between `native` and `tool`.

### Changed

Expand Down
6 changes: 4 additions & 2 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ Some LLMs provide functionality often referred to as "Structured Output". This m

Structured output can be enabled or disabled for a Prompt Driver by setting the [use_native_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_native_structured_output).

If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [native_structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.native_structured_output_strategy) to one of:
If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of:

- `native`: The Driver will use the LLM's structured output functionality provided by the API.
- `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output_tool.md) and try to force the LLM to use a Tool.

Each Driver may have a different default setting depending on the LLM provider's capabilities.

### JSON Schema

The easiest way to get started with structured output is by using a [JsonSchemaRule](../structures/rulesets.md#json-schema). If a [schema.Schema](https://pypi.org/project/schema/) instance is provided to the Rule, Griptape will convert it to a JSON Schema and provide it to the LLM using the selected structured output strategy.
Expand All @@ -45,7 +47,7 @@ The easiest way to get started with structured output is by using a [JsonSchemaR
```

!!! warning
Not every LLM supports `use_native_structured_output` or all `native_structured_output_strategy` options.
Not every LLM supports `use_native_structured_output` or all `structured_output_strategy` options.

## Prompt Drivers

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rich.pretty import pprint

from griptape.drivers import OpenAiChatPromptDriver
from griptape.rules import JsonSchemaRule, Rule
from griptape.rules import Rule
from griptape.structures import Pipeline
from griptape.tasks import PromptTask

Expand All @@ -12,18 +12,16 @@
prompt_driver=OpenAiChatPromptDriver(
model="gpt-4o",
use_native_structured_output=True,
native_structured_output_strategy="native",
structured_output_strategy="native",
),
output_schema=schema.Schema(
{
"steps": [schema.Schema({"explanation": str, "output": str})],
"final_answer": str,
}
),
rules=[
Rule("You are a helpful math tutor. Guide the user through the solution step by step."),
JsonSchemaRule(
schema.Schema(
{
"steps": [schema.Schema({"explanation": str, "output": str})],
"final_answer": str,
}
)
),
],
)
]
Expand Down

This file was deleted.

8 changes: 4 additions & 4 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_strategy: Literal["native", "tool"] = field(
structured_output_strategy: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
if value == "native":
raise ValueError("AmazonBedrockPromptDriver does not support `native` structured output mode.")

Expand Down Expand Up @@ -137,7 +137,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_strategy == "tool"
and self.structured_output_strategy == "tool"
):
self._add_structured_output_tool(prompt_stack)
params["toolConfig"]["toolChoice"] = {"any": {}}
Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class AnthropicPromptDriver(BasePromptDriver):
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_strategy: Literal["native", "tool"] = field(
structured_output_strategy: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})
Expand All @@ -79,8 +79,8 @@ class AnthropicPromptDriver(BasePromptDriver):
def client(self) -> Client:
return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)

@native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
if value == "native":
raise ValueError("AnthropicPromptDriver does not support `native` structured output mode.")

Expand Down Expand Up @@ -139,7 +139,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_strategy == "tool"
and self.structured_output_strategy == "tool"
):
self._add_structured_output_tool(prompt_stack)
params["tool_choice"] = {"type": "any"}
Expand Down
2 changes: 1 addition & 1 deletion griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True})
native_structured_output_strategy: Literal["native", "tool"] = field(
structured_output_strategy: Literal["native", "tool"] = field(
default="native", kw_only=True, metadata={"serializable": True}
)
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
Expand Down
4 changes: 2 additions & 2 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
}

if prompt_stack.output_schema is not None and self.use_native_structured_output:
if self.native_structured_output_strategy == "native":
if self.structured_output_strategy == "native":
params["response_format"] = {
"type": "json_object",
"schema": prompt_stack.output_schema.json_schema("Output"),
}
elif self.native_structured_output_strategy == "tool":
elif self.structured_output_strategy == "tool":
# TODO: Implement tool choice once supported
self._add_structured_output_tool(prompt_stack)

Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ class GooglePromptDriver(BasePromptDriver):
top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_strategy: Literal["native", "tool"] = field(
structured_output_strategy: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True})
_client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
if value == "native":
raise ValueError("GooglePromptDriver does not support `native` structured output mode.")

Expand Down Expand Up @@ -167,7 +167,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_strategy == "tool"
and self.structured_output_strategy == "tool"
):
params["tool_config"]["function_calling_config"]["mode"] = "auto"
self._add_structured_output_tool(prompt_stack)
Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver):
max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_strategy: Literal["native", "tool"] = field(
structured_output_strategy: Literal["native", "tool"] = field(
default="native", kw_only=True, metadata={"serializable": True}
)
tokenizer: HuggingFaceTokenizer = field(
Expand All @@ -55,8 +55,8 @@ def client(self) -> InferenceClient:
token=self.api_token,
)

@native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str:
if value == "tool":
raise ValueError("HuggingFaceHubPromptDriver does not support `tool` structured output mode.")

Expand Down Expand Up @@ -124,7 +124,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if (
prompt_stack.output_schema
and self.use_native_structured_output
and self.native_structured_output_strategy == "native"
and self.structured_output_strategy == "native"
):
# https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding
output_schema = prompt_stack.output_schema.json_schema("Output Schema")
Expand Down
4 changes: 2 additions & 2 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
}

if prompt_stack.output_schema is not None and self.use_native_structured_output:
if self.native_structured_output_strategy == "native":
if self.structured_output_strategy == "native":
params["format"] = prompt_stack.output_schema.json_schema("Output")
elif self.native_structured_output_strategy == "tool":
elif self.structured_output_strategy == "tool":
# TODO: Implement tool choice once supported
self._add_structured_output_tool(prompt_stack)

Expand Down
4 changes: 2 additions & 2 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
params["parallel_tool_calls"] = self.parallel_tool_calls

if prompt_stack.output_schema is not None and self.use_native_structured_output:
if self.native_structured_output_strategy == "native":
if self.structured_output_strategy == "native":
params["response_format"] = {
"type": "json_schema",
"json_schema": {
Expand All @@ -169,7 +169,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"strict": True,
},
}
elif self.native_structured_output_strategy == "tool" and self.use_native_tools:
elif self.structured_output_strategy == "tool" and self.use_native_tools:
params["tool_choice"] = "required"
self._add_structured_output_tool(prompt_stack)

Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def try_run(self) -> BaseArtifact:

if (
self.prompt_driver.use_native_structured_output
and self.prompt_driver.native_structured_output_strategy == "native"
and self.prompt_driver.structured_output_strategy == "native"
):
return JsonArtifact(output.value)
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/mocks/mock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ class MockPromptDriver(BasePromptDriver):
def try_run(self, prompt_stack: PromptStack) -> Message:
output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output
if self.use_native_structured_output and prompt_stack.output_schema:
if self.native_structured_output_strategy == "native":
if self.structured_output_strategy == "native":
return Message(
content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))],
role=Message.ASSISTANT_ROLE,
usage=Message.Usage(input_tokens=100, output_tokens=100),
)
elif self.native_structured_output_strategy == "tool":
elif self.structured_output_strategy == "tool":
self._add_structured_output_tool(prompt_stack)

if self.use_native_tools and prompt_stack.tools:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_to_dict(self, config):
"tool_choice": {"auto": {}},
"use_native_tools": True,
"use_native_structured_output": True,
"native_structured_output_strategy": "tool",
"structured_output_strategy": "tool",
"extra_params": {},
},
"vector_store_driver": {
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_to_dict_with_values(self, config_with_values):
"tool_choice": {"auto": {}},
"use_native_tools": True,
"use_native_structured_output": True,
"native_structured_output_strategy": "tool",
"structured_output_strategy": "tool",
"extra_params": {},
},
"vector_store_driver": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_to_dict(self, config):
"top_p": 0.999,
"top_k": 250,
"use_native_tools": True,
"native_structured_output_strategy": "tool",
"structured_output_strategy": "tool",
"use_native_structured_output": True,
"extra_params": {},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_to_dict(self, config):
"stream": False,
"user": "",
"use_native_tools": True,
"native_structured_output_strategy": "native",
"structured_output_strategy": "native",
"use_native_structured_output": True,
"extra_params": {},
},
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/configs/drivers/test_cohere_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_to_dict(self, config):
"force_single_step": False,
"use_native_tools": True,
"use_native_structured_output": True,
"native_structured_output_strategy": "native",
"structured_output_strategy": "native",
"extra_params": {},
},
"embedding_driver": {
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/configs/drivers/test_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_to_dict(self, config):
"stream": False,
"use_native_tools": False,
"use_native_structured_output": False,
"native_structured_output_strategy": "native",
"structured_output_strategy": "native",
"extra_params": {},
},
"conversation_memory_driver": {
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/configs/drivers/test_google_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_to_dict(self, config):
"tool_choice": "auto",
"use_native_tools": True,
"use_native_structured_output": True,
"native_structured_output_strategy": "tool",
"structured_output_strategy": "tool",
"extra_params": {},
},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/configs/drivers/test_openai_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_to_dict(self, config):
"stream": False,
"user": "",
"use_native_tools": True,
"native_structured_output_strategy": "native",
"structured_output_strategy": "native",
"use_native_structured_output": True,
"extra_params": {},
},
Expand Down
Loading

0 comments on commit 3762fc7

Please sign in to comment.