From b582896d23698384e920ab515025c7c6f99cfa1f Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 13 Dec 2024 14:50:42 -0800 Subject: [PATCH 01/11] Add Structured Output functionality --- CHANGELOG.md | 4 + .../drivers/prompt-drivers.md | 31 +++++ .../src/prompt_drivers_structured_output.py | 35 +++++ .../prompt_drivers_structured_output_multi.py | 28 ++++ griptape/common/prompt_stack/prompt_stack.py | 5 +- .../prompt/amazon_bedrock_prompt_driver.py | 41 ++++-- .../drivers/prompt/anthropic_prompt_driver.py | 37 +++-- griptape/drivers/prompt/base_prompt_driver.py | 16 ++- .../drivers/prompt/cohere_prompt_driver.py | 23 ++- .../drivers/prompt/google_prompt_driver.py | 40 ++++-- .../prompt/huggingface_hub_prompt_driver.py | 45 +++++- .../drivers/prompt/ollama_prompt_driver.py | 25 ++-- .../prompt/openai_chat_prompt_driver.py | 31 +++-- griptape/tasks/actions_subtask.py | 11 +- griptape/tasks/prompt_task.py | 131 ++++++++++++------ griptape/tools/__init__.py | 2 + griptape/tools/structured_output/__init__.py | 0 griptape/tools/structured_output/tool.py | 20 +++ pyproject.toml | 2 +- tests/mocks/mock_prompt_driver.py | 12 ++ .../test_amazon_bedrock_drivers_config.py | 4 + .../drivers/test_anthropic_drivers_config.py | 2 + .../test_azure_openai_drivers_config.py | 2 + .../drivers/test_cohere_drivers_config.py | 2 + .../configs/drivers/test_drivers_config.py | 2 + .../drivers/test_google_drivers_config.py | 2 + .../drivers/test_openai_driver_config.py | 2 + .../test_amazon_bedrock_prompt_driver.py | 82 ++++++++++- .../prompt/test_anthropic_prompt_driver.py | 73 +++++++++- .../test_azure_openai_chat_prompt_driver.py | 78 ++++++++++- .../drivers/prompt/test_base_prompt_driver.py | 23 +++ .../prompt/test_cohere_prompt_driver.py | 107 +++++++++++++- .../prompt/test_google_prompt_driver.py | 52 +++++-- .../test_hugging_face_hub_prompt_driver.py | 42 +++++- .../prompt/test_ollama_prompt_driver.py | 109 ++++++++++++--- .../prompt/test_openai_chat_prompt_driver.py | 115 ++++++++++++++- tests/unit/structures/test_structure.py | 2 + tests/unit/tasks/test_actions_subtask.py | 68 ++++++++- tests/unit/tasks/test_prompt_task.py | 81 +++++++++++ tests/unit/tasks/test_tool_task.py | 2 + tests/unit/tasks/test_toolkit_task.py | 2 + .../unit/tools/test_structured_output_tool.py | 13 ++ 42 files changed, 1232 insertions(+), 172 deletions(-) create mode 100644 docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py create mode 100644 docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py create mode 100644 griptape/tools/structured_output/__init__.py create mode 100644 griptape/tools/structured_output/tool.py create mode 100644 tests/unit/tools/test_structured_output_tool.py diff --git a/CHANGELOG.md b/CHANGELOG.md index cb0a4e021..cc51c5124 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Structure.run_stream()` for streaming Events from a Structure as an iterator. - Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`. - Validators to `Agent` initialization. +- `BasePromptDriver.use_native_structured_output` for enabling or disabling structured output. +- `BasePromptDriver.native_structured_output_strategy` for changing the structured output strategy between `native` and `tool`. ### Changed @@ -39,6 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `PromptTask.prompt_driver` is now serialized. - `PromptTask` can now do everything a `ToolkitTask` can do. - Loosten `numpy`s version constraint to `>=1.26.4,<3`. +- `JsonSchemaRule`s can now take a `schema.Schema` instance. Required for using a `JsonSchemaRule` with structured output. +- `JsonSchemaRule`s will now be used for structured output if the Prompt Driver supports it. ### Fixed diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 6c51d2d01..f05647bb3 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -25,6 +25,37 @@ You can pass images to the Driver if the model supports it: --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_images.py" ``` +## Structured Output + +Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems. + +Structured output can be enabled or disabled for a Prompt Driver by setting the [use_native_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_native_structured_output). + +If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [native_structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.native_structured_output_strategy) to one of: + +- `native`: The Driver will use the LLM's structured output functionality provided by the API. +- `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output_tool.md) and try to force the LLM to use a Tool. + +### JSON Schema + +The easiest way to get started with structured output is by using a [JsonSchemaRule](../structures/rulesets.md#json-schema). If a [schema.Schema](https://pypi.org/project/schema/) instance is provided to the Rule, Griptape will convert it to a JSON Schema and provide it to the LLM using the selected structured output strategy. + +```python +--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py" +``` + +### Multiple Schemas + +If multiple `JsonSchemaRule`s are provided, Griptape will merge them into a single JSON Schema using `anyOf`. + +Some LLMs may not support `anyOf` as a top-level JSON Schema. To work around this, you can try using another `native_structured_output_strategy`: + +```python +--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py" +``` + +Not every LLM supports `use_native_structured_output` or all `native_structured_output_strategy` options. + ## Prompt Drivers Griptape offers the following Prompt Drivers for interacting with LLMs. diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py new file mode 100644 index 000000000..6613c3a3e --- /dev/null +++ b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py @@ -0,0 +1,35 @@ +import schema +from rich.pretty import pprint + +from griptape.drivers import OpenAiChatPromptDriver +from griptape.rules import JsonSchemaRule, Rule +from griptape.structures import Pipeline +from griptape.tasks import PromptTask + +pipeline = Pipeline( + tasks=[ + PromptTask( + prompt_driver=OpenAiChatPromptDriver( + model="gpt-4o", + use_native_structured_output=True, + native_structured_output_strategy="native", + ), + rules=[ + Rule("You are a helpful math tutor. Guide the user through the solution step by step."), + JsonSchemaRule( + schema.Schema( + { + "steps": [schema.Schema({"explanation": str, "output": str})], + "final_answer": str, + } + ) + ), + ], + ) + ] +) + +output = pipeline.run("How can I solve 8x + 7 = -23").output.value + + +pprint(output) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py new file mode 100644 index 000000000..0b85cee94 --- /dev/null +++ b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py @@ -0,0 +1,28 @@ +import schema +from rich.pretty import pprint + +from griptape.drivers import OpenAiChatPromptDriver +from griptape.rules import JsonSchemaRule +from griptape.structures import Pipeline +from griptape.tasks import PromptTask + +pipeline = Pipeline( + tasks=[ + PromptTask( + prompt_driver=OpenAiChatPromptDriver( + model="gpt-4o", + use_native_structured_output=True, + native_structured_output_strategy="tool", + ), + rules=[ + JsonSchemaRule(schema.Schema({"color": "red"})), + JsonSchemaRule(schema.Schema({"color": "blue"})), + ], + ) + ] +) + +output = pipeline.run("Pick a color").output.value + + +pprint(output) diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index 3b1b8ef74..752ce8a8d 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from attrs import define, field @@ -24,6 +24,8 @@ from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: + from schema import Schema + from griptape.tools import BaseTool @@ -31,6 +33,7 @@ class PromptStack(SerializableMixin): messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True}) tools: list[BaseTool] = field(factory=list, kw_only=True) + output_schema: Optional[Schema] = field(default=None, kw_only=True) @property def system_messages(self) -> list[Message]: diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 54278c895..9e754f6aa 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from schema import Schema from griptape.artifacts import ( @@ -55,9 +55,20 @@ class AmazonBedrockPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_strategy: Literal["native", "tool"] = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True}) _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + @native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + if value == "native": + raise ValueError("AmazonBedrockPromptDriver does not support `native` structured output mode.") + + return value + @lazy_property() def client(self) -> Any: return self.session.client("bedrock-runtime") @@ -103,10 +114,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages] - messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()]) - return { + params = { "modelId": self.model, "messages": messages, "system": system_messages, @@ -115,14 +125,27 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}), }, "additionalModelRequestFields": self.additional_model_request_fields, - **( - {"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}} - if prompt_stack.tools and self.use_native_tools - else {} - ), **self.extra_params, } + if prompt_stack.tools and self.use_native_tools: + params["toolConfig"] = { + "tools": [], + "toolChoice": self.tool_choice, + } + + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_strategy == "tool" + ): + self._add_structured_output_tool(prompt_stack) + params["toolConfig"]["toolChoice"] = {"any": {}} + + params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools) + + return params + def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]: return [ { diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 060b8151d..a61b69232 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from schema import Schema from griptape.artifacts import ( @@ -68,6 +68,10 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_strategy: Literal["native", "tool"] = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @@ -75,6 +79,13 @@ class AnthropicPromptDriver(BasePromptDriver): def client(self) -> Client: return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) + @native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + if value == "native": + raise ValueError("AnthropicPromptDriver does not support `native` structured output mode.") + + return value + @observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) @@ -110,7 +121,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = prompt_stack.system_messages system_message = system_messages[0].to_text() if system_messages else None - return { + params = { "model": self.model, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, @@ -118,15 +129,25 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "top_k": self.top_k, "max_tokens": self.max_tokens, "messages": messages, - **( - {"tools": self.__to_anthropic_tools(prompt_stack.tools), "tool_choice": self.tool_choice} - if prompt_stack.tools and self.use_native_tools - else {} - ), **({"system": system_message} if system_message else {}), **self.extra_params, } + if prompt_stack.tools and self.use_native_tools: + params["tool_choice"] = self.tool_choice + + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_strategy == "tool" + ): + self._add_structured_output_tool(prompt_stack) + params["tool_choice"] = {"type": "any"} + + params["tools"] = self.__to_anthropic_tools(prompt_stack.tools) + + return params + def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]: return [ {"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)} diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 707f67644..19109f55f 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional from attrs import Factory, define, field @@ -56,6 +56,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + native_structured_output_strategy: Literal["native", "tool"] = field( + default="native", kw_only=True, metadata={"serializable": True} + ) extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: @@ -122,6 +126,16 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ... @abstractmethod def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ... + def _add_structured_output_tool(self, prompt_stack: PromptStack) -> None: + from griptape.tools.structured_output.tool import StructuredOutputTool + + if prompt_stack.output_schema is None: + raise ValueError("PromptStack must have an output schema to use structured output.") + + structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) + if structured_output_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_output_tool) + def __process_run(self, prompt_stack: PromptStack) -> Message: return self.try_run(prompt_stack) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 3811db5cd..2695aba09 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -53,6 +53,7 @@ class CoherePromptDriver(BasePromptDriver): model: str = field(metadata={"serializable": True}) force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), @@ -101,21 +102,31 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: messages = self.__to_cohere_messages(prompt_stack.messages) - return { + params = { "model": self.model, "messages": messages, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, "max_tokens": self.max_tokens, **({"tool_results": tool_results} if tool_results else {}), - **( - {"tools": self.__to_cohere_tools(prompt_stack.tools)} - if prompt_stack.tools and self.use_native_tools - else {} - ), **self.extra_params, } + if prompt_stack.output_schema is not None and self.use_native_structured_output: + if self.native_structured_output_strategy == "native": + params["response_format"] = { + "type": "json_object", + "schema": prompt_stack.output_schema.json_schema("Output"), + } + elif self.native_structured_output_strategy == "tool": + # TODO: Implement tool choice once supported + self._add_structured_output_tool(prompt_stack) + + if prompt_stack.tools and self.use_native_tools: + params["tools"] = self.__to_cohere_tools(prompt_stack.tools) + + return params + def __to_cohere_messages(self, messages: list[Message]) -> list[dict]: cohere_messages = [] diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 2a6bdbf6d..23de1e42d 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -2,9 +2,9 @@ import json import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from schema import Schema from griptape.artifacts import ActionArtifact, TextArtifact @@ -63,9 +63,20 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_strategy: Literal["native", "tool"] = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True}) _client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + @native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + if value == "native": + raise ValueError("GooglePromptDriver does not support `native` structured output mode.") + + return value + @lazy_property() def client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") @@ -135,7 +146,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages], ) - return { + params = { "generation_config": types.GenerationConfig( **{ # For some reason, providing stop sequences when streaming breaks native functions @@ -148,16 +159,23 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, }, ), - **( - { - "tools": self.__to_google_tools(prompt_stack.tools), - "tool_config": {"function_calling_config": {"mode": self.tool_choice}}, - } - if prompt_stack.tools and self.use_native_tools - else {} - ), } + if prompt_stack.tools and self.use_native_tools: + params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} + + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_strategy == "tool" + ): + params["tool_config"]["function_calling_config"]["mode"] = "auto" + self._add_structured_output_tool(prompt_stack) + + params["tools"] = self.__to_google_tools(prompt_stack.tools) + + return params + def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: types = import_optional_dependency("google.generativeai.types") diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index c2c45c3ae..f9acdeb1d 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from griptape.common import DeltaMessage, Message, PromptStack, TextDeltaMessageContent, observable from griptape.configs import Defaults @@ -35,6 +35,10 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_strategy: Literal["native", "tool"] = field( + default="native", kw_only=True, metadata={"serializable": True} + ) tokenizer: HuggingFaceTokenizer = field( default=Factory( lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), @@ -51,11 +55,23 @@ def client(self) -> InferenceClient: token=self.api_token, ) + @native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + if value == "tool": + raise ValueError("HuggingFaceHubPromptDriver does not support `tool` structured output mode.") + + return value + @observable def try_run(self, prompt_stack: PromptStack) -> Message: prompt = self.prompt_stack_to_string(prompt_stack) full_params = self._base_params(prompt_stack) - logger.debug((prompt, full_params)) + logger.debug( + { + "prompt": prompt, + **full_params, + } + ) response = self.client.text_generation( prompt, @@ -75,7 +91,12 @@ def try_run(self, prompt_stack: PromptStack) -> Message: def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: prompt = self.prompt_stack_to_string(prompt_stack) full_params = {**self._base_params(prompt_stack), "stream": True} - logger.debug((prompt, full_params)) + logger.debug( + { + "prompt": prompt, + **full_params, + } + ) response = self.client.text_generation(prompt, **full_params) @@ -94,12 +115,26 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) def _base_params(self, prompt_stack: PromptStack) -> dict: - return { + params = { "return_full_text": False, "max_new_tokens": self.max_tokens, **self.extra_params, } + if ( + prompt_stack.output_schema + and self.use_native_structured_output + and self.native_structured_output_strategy == "native" + ): + # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding + output_schema = prompt_stack.output_schema.json_schema("Output Schema") + # Grammar does not support $schema and $id + del output_schema["$schema"] + del output_schema["$id"] + params["grammar"] = {"type": "json", "value": output_schema} + + return params + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 5cbba1fdf..25756cc1c 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -68,6 +68,7 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() @@ -79,7 +80,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) response = self.client.chat(**params) - logger.debug(response) + logger.debug(response.model_dump()) return Message( content=self.__to_prompt_stack_message_content(response), @@ -102,20 +103,26 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: def _base_params(self, prompt_stack: PromptStack) -> dict: messages = self._prompt_stack_to_messages(prompt_stack) - return { + params = { "messages": messages, "model": self.model, "options": self.options, - **( - {"tools": self.__to_ollama_tools(prompt_stack.tools)} - if prompt_stack.tools - and self.use_native_tools - and not self.stream # Tool calling is only supported when not streaming - else {} - ), **self.extra_params, } + if prompt_stack.output_schema is not None and self.use_native_structured_output: + if self.native_structured_output_strategy == "native": + params["format"] = prompt_stack.output_schema.json_schema("Output") + elif self.native_structured_output_strategy == "tool": + # TODO: Implement tool choice once supported + self._add_structured_output_tool(prompt_stack) + + # Tool calling is only supported when not streaming + if prompt_stack.tools and self.use_native_tools and not self.stream: + params["tools"] = self.__to_ollama_tools(prompt_stack.tools) + + return params + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: ollama_messages = [] for message in prompt_stack.messages: diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index eed0e35f0..d8f61a3bf 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -76,6 +76,7 @@ class OpenAiChatPromptDriver(BasePromptDriver): seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory( @@ -148,21 +149,30 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "temperature": self.temperature, "user": self.user, "seed": self.seed, - **( - { - "tools": self.__to_openai_tools(prompt_stack.tools), - "tool_choice": self.tool_choice, - "parallel_tool_calls": self.parallel_tool_calls, - } - if prompt_stack.tools and self.use_native_tools - else {} - ), **({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}), **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}), **({"stream_options": {"include_usage": True}} if self.stream else {}), **self.extra_params, } + if prompt_stack.tools and self.use_native_tools: + params["tool_choice"] = self.tool_choice + params["parallel_tool_calls"] = self.parallel_tool_calls + + if prompt_stack.output_schema is not None and self.use_native_structured_output: + if self.native_structured_output_strategy == "native": + params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": prompt_stack.output_schema.json_schema("Output"), + "strict": True, + }, + } + elif self.native_structured_output_strategy == "tool" and self.use_native_tools: + params["tool_choice"] = "required" + self._add_structured_output_tool(prompt_stack) + if self.response_format is not None: if self.response_format == {"type": "json_object"}: params["response_format"] = self.response_format @@ -171,6 +181,9 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: else: params["response_format"] = self.response_format + if prompt_stack.tools and self.use_native_tools: + params["tools"] = self.__to_openai_tools(prompt_stack.tools) + messages = self.__to_openai_messages(prompt_stack.messages) params["messages"] = messages diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 6f9d70053..c889554fd 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -9,12 +9,13 @@ from attrs import define, field from griptape import utils -from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, JsonArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction from griptape.configs import Defaults from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask +from griptape.tools.structured_output.tool import StructuredOutputTool from griptape.utils import remove_null_values_in_dict_recursively, with_contextvars if TYPE_CHECKING: @@ -87,6 +88,14 @@ def attach_to(self, parent_task: BaseTask) -> None: self.__init_from_prompt(self.input.to_text()) else: self.__init_from_artifacts(self.input) + + structured_outputs = [a for a in self.actions if isinstance(a.tool, StructuredOutputTool)] + if structured_outputs: + output_values = [JsonArtifact(a.input["values"]) for a in structured_outputs] + if len(structured_outputs) > 1: + self.output = ListArtifact(output_values) + else: + self.output = output_values[0] except Exception as e: logger.error("Subtask %s\nError parsing tool action: %s", self.origin_task.id, e) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 5086636d0..bb5ca9667 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -2,24 +2,25 @@ import json import logging +import warnings from typing import TYPE_CHECKING, Callable, Optional, Union from attrs import NOTHING, Attribute, Factory, NothingType, define, field +from schema import Or, Schema from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts.json_artifact import JsonArtifact from griptape.common import PromptStack, ToolAction from griptape.configs import Defaults from griptape.memory.structure import Run from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.mixins.rule_mixin import RuleMixin -from griptape.rules import Ruleset +from griptape.rules import JsonSchemaRule, Ruleset from griptape.tasks import ActionsSubtask, BaseTask from griptape.utils import J2 if TYPE_CHECKING: - from schema import Schema - from griptape.drivers import BasePromptDriver from griptape.memory import TaskMemory from griptape.memory.structure.base_conversation_memory import BaseConversationMemory @@ -92,54 +93,30 @@ def prompt_stack(self) -> PromptStack: stack = PromptStack(tools=self.tools) memory = self.structure.conversation_memory if self.structure is not None else None - system_template = self.generate_system_template(self) - if system_template: - stack.add_system_message(system_template) + rulesets = self.rulesets + system_artifacts = [TextArtifact(self.generate_system_template(self))] + if self.prompt_driver.use_native_structured_output: + self._add_native_schema_to_prompt_stack(stack, rulesets) + + # Ensure there is at least one Ruleset that has non-empty `rules`. + if any(len(ruleset.rules) for ruleset in rulesets): + system_artifacts.append(TextArtifact(J2("rulesets/rulesets.j2").render(rulesets=rulesets))) + + # Ensure there is at least one system Artifact that has a non-empty value. + has_system_artifacts = any(system_artifact.value for system_artifact in system_artifacts) + if has_system_artifacts: + stack.add_system_message(ListArtifact(system_artifacts)) stack.add_user_message(self.input) if self.output: stack.add_assistant_message(self.output.to_text()) else: - for s in self.subtasks: - if self.prompt_driver.use_native_tools: - action_calls = [ - ToolAction(name=action.name, path=action.path, tag=action.tag, input=action.input) - for action in s.actions - ] - action_results = [ - ToolAction( - name=action.name, - path=action.path, - tag=action.tag, - output=action.output if action.output is not None else s.output, - ) - for action in s.actions - ] - - stack.add_assistant_message( - ListArtifact( - [ - *([TextArtifact(s.thought)] if s.thought else []), - *[ActionArtifact(a) for a in action_calls], - ], - ), - ) - stack.add_user_message( - ListArtifact( - [ - *[ActionArtifact(a) for a in action_results], - *([] if s.output else [TextArtifact("Please keep going")]), - ], - ), - ) - else: - stack.add_assistant_message(self.generate_assistant_subtask_template(s)) - stack.add_user_message(self.generate_user_subtask_template(s)) + self._add_subtasks_to_prompt_stack(stack) if memory is not None: # inserting at index 1 to place memory right after system prompt - memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_template else 0) + memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if has_system_artifacts else 0) return stack @@ -218,11 +195,17 @@ def try_run(self) -> BaseArtifact: else: break - self.output = subtask.output + output = subtask.output else: - self.output = result.to_artifact() + output = result.to_artifact() - return self.output + if ( + self.prompt_driver.use_native_structured_output + and self.prompt_driver.native_structured_output_strategy == "native" + ): + return JsonArtifact(output.value) + else: + return output def preprocess(self, structure: Structure) -> BaseTask: super().preprocess(structure) @@ -243,7 +226,6 @@ def default_generate_system_template(self, _: PromptTask) -> str: schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually. return J2("tasks/prompt_task/system.j2").render( - rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), action_names=str.join(", ", [tool.name for tool in self.tools]), actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), @@ -324,3 +306,60 @@ def _process_task_input( return ListArtifact([self._process_task_input(elem) for elem in task_input]) else: return self._process_task_input(TextArtifact(task_input)) + + def _add_native_schema_to_prompt_stack(self, stack: PromptStack, rulesets: list[Ruleset]) -> None: + # Need to separate JsonSchemaRules from other rules, removing them in the process + json_schema_rules = [rule for ruleset in rulesets for rule in ruleset.rules if isinstance(rule, JsonSchemaRule)] + non_json_schema_rules = [ + [rule for rule in ruleset.rules if not isinstance(rule, JsonSchemaRule)] for ruleset in rulesets + ] + for ruleset, non_json_rules in zip(rulesets, non_json_schema_rules): + ruleset.rules = non_json_rules + + schemas = [rule.value for rule in json_schema_rules if isinstance(rule.value, Schema)] + + if len(json_schema_rules) != len(schemas): + warnings.warn( + "Not all provided `JsonSchemaRule`s include a `schema.Schema` instance. These will be ignored with `use_native_structured_output`.", + stacklevel=2, + ) + + if schemas: + stack.output_schema = schemas[0] if len(schemas) == 1 else Schema(Or(*schemas)) + + def _add_subtasks_to_prompt_stack(self, stack: PromptStack) -> None: + for s in self.subtasks: + if self.prompt_driver.use_native_tools: + action_calls = [ + ToolAction(name=action.name, path=action.path, tag=action.tag, input=action.input) + for action in s.actions + ] + action_results = [ + ToolAction( + name=action.name, + path=action.path, + tag=action.tag, + output=action.output if action.output is not None else s.output, + ) + for action in s.actions + ] + + stack.add_assistant_message( + ListArtifact( + [ + *([TextArtifact(s.thought)] if s.thought else []), + *[ActionArtifact(a) for a in action_calls], + ], + ), + ) + stack.add_user_message( + ListArtifact( + [ + *[ActionArtifact(a) for a in action_results], + *([] if s.output else [TextArtifact("Please keep going")]), + ], + ), + ) + else: + stack.add_assistant_message(self.generate_assistant_subtask_template(s)) + stack.add_user_message(self.generate_user_subtask_template(s)) diff --git a/griptape/tools/__init__.py b/griptape/tools/__init__.py index 67a1712a1..ec9cbd5b7 100644 --- a/griptape/tools/__init__.py +++ b/griptape/tools/__init__.py @@ -23,6 +23,7 @@ from .extraction.tool import ExtractionTool from .prompt_summary.tool import PromptSummaryTool from .query.tool import QueryTool +from .structured_output.tool import StructuredOutputTool __all__ = [ "BaseTool", @@ -50,4 +51,5 @@ "ExtractionTool", "PromptSummaryTool", "QueryTool", + "StructuredOutputTool", ] diff --git a/griptape/tools/structured_output/__init__.py b/griptape/tools/structured_output/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/tools/structured_output/tool.py b/griptape/tools/structured_output/tool.py new file mode 100644 index 000000000..89e638f59 --- /dev/null +++ b/griptape/tools/structured_output/tool.py @@ -0,0 +1,20 @@ +from attrs import define, field +from schema import Schema + +from griptape.artifacts import BaseArtifact, JsonArtifact +from griptape.tools import BaseTool +from griptape.utils.decorators import activity + + +@define +class StructuredOutputTool(BaseTool): + output_schema: Schema = field(kw_only=True) + + @activity( + config={ + "description": "Used to provide the final response which ends this conversation.", + "schema": lambda self: self.output_schema, + } + ) + def provide_output(self, params: dict) -> BaseArtifact: + return JsonArtifact(params["values"]) diff --git a/pyproject.toml b/pyproject.toml index f63dd396d..c45fbefa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -315,7 +315,7 @@ fixture-parentheses = true "ANN202", # missing-return-type-private-function ] "docs/*" = [ - "T20" # flake8-print + "T20", # flake8-print ] [tool.ruff.lint.flake8-tidy-imports.banned-api] diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index f308c9804..abef72227 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import TYPE_CHECKING, Callable, Union from attrs import define, field @@ -31,9 +32,20 @@ class MockPromptDriver(BasePromptDriver): tokenizer: BaseTokenizer = MockTokenizer(model="test-model", max_input_tokens=4096, max_output_tokens=4096) mock_input: Union[str, Callable[[], str]] = field(default="mock input", kw_only=True) mock_output: Union[str, Callable[[PromptStack], str]] = field(default="mock output", kw_only=True) + mock_structured_output: Union[dict, Callable[[PromptStack], dict]] = field(factory=dict, kw_only=True) def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output + if prompt_stack.output_schema and self.use_native_structured_output: + if self.native_structured_output_strategy == "native": + return Message( + content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=100, output_tokens=100), + ) + elif self.native_structured_output_strategy == "tool": + self._add_structured_output_tool(prompt_stack) + if self.use_native_tools and prompt_stack.tools: # Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer. action_messages = [ diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index 52408922c..59eb4ac61 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -51,6 +51,8 @@ def test_to_dict(self, config): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_strategy": "tool", "extra_params": {}, }, "vector_store_driver": { @@ -106,6 +108,8 @@ def test_to_dict_with_values(self, config_with_values): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_strategy": "tool", "extra_params": {}, }, "vector_store_driver": { diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index 8a6f25ef2..66f987308 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -25,6 +25,8 @@ def test_to_dict(self, config): "top_p": 0.999, "top_k": 250, "use_native_tools": True, + "native_structured_output_strategy": "tool", + "use_native_structured_output": True, "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index 4c44113a0..2281f4c11 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -36,6 +36,8 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, + "native_structured_output_strategy": "native", + "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 65295da52..6f371c5ba 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -26,6 +26,8 @@ def test_to_dict(self, config): "model": "command-r", "force_single_step": False, "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_strategy": "native", "extra_params": {}, }, "embedding_driver": { diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index ca3cea60e..dd2e1736b 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,6 +18,8 @@ def test_to_dict(self, config): "max_tokens": None, "stream": False, "use_native_tools": False, + "use_native_structured_output": False, + "native_structured_output_strategy": "native", "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index c1459a400..569e45561 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -25,6 +25,8 @@ def test_to_dict(self, config): "top_k": None, "tool_choice": "auto", "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_strategy": "tool", "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index c71774b26..603d9867a 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -28,6 +28,8 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, + "native_structured_output_strategy": "native", + "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index 939b86c5e..a21690cd3 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,4 +1,5 @@ import pytest +from schema import Schema from griptape.artifacts import ActionArtifact, ErrorArtifact, GenericArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction @@ -7,6 +8,29 @@ class TestAmazonBedrockPromptDriver: + BEDROCK_STRUCTURED_OUTPUT_TOOL = { + "toolSpec": { + "description": "Used to provide the final response which ends this conversation.", + "inputSchema": { + "json": { + "$id": "http://json-schema.org/draft-07/schema#", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + }, + "name": "StructuredOutputTool_provide_output", + }, + } BEDROCK_TOOLS = [ { "toolSpec": { @@ -229,6 +253,7 @@ def mock_converse_stream(self, mocker): def prompt_stack(self, request): prompt_stack = PromptStack() prompt_stack.tools = [MockTool()] + prompt_stack.output_schema = Schema({"foo": str}) if request.param: prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -359,10 +384,14 @@ def messages(self): ] @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, use_native_structured_output): # Given driver = AmazonBedrockPromptDriver( - model="ai21.j2", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="ai21.j2", + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, ) # When @@ -379,7 +408,19 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): additionalModelRequestFields={}, **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), **( - {"toolConfig": {"tools": self.BEDROCK_TOOLS, "toolChoice": driver.tool_choice}} + { + "toolConfig": { + "tools": [ + *self.BEDROCK_TOOLS, + *( + [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.native_structured_output_strategy == "tool" + else [] + ), + ], + "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, + } + } if use_native_tools else {} ), @@ -396,10 +437,17 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream_run( + self, mock_converse_stream, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = AmazonBedrockPromptDriver( - model="ai21.j2", stream=True, use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="ai21.j2", + stream=True, + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, ) # When @@ -417,8 +465,20 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_ additionalModelRequestFields={}, **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), **( - {"toolConfig": {"tools": self.BEDROCK_TOOLS, "toolChoice": driver.tool_choice}} - if prompt_stack.tools and use_native_tools + { + "toolConfig": { + "tools": [ + *self.BEDROCK_TOOLS, + *( + [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.native_structured_output_strategy == "tool" + else [] + ), + ], + "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, + } + } + if use_native_tools else {} ), foo="bar", @@ -441,3 +501,11 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_ event = next(stream) assert event.usage.input_tokens == 5 assert event.usage.output_tokens == 10 + + def test_verify_native_structured_output_strategy(self): + assert AmazonBedrockPromptDriver(model="foo", native_structured_output_strategy="tool") + + with pytest.raises( + ValueError, match="AmazonBedrockPromptDriver does not support `native` structured output mode." + ): + AmazonBedrockPromptDriver(model="foo", native_structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index b611b5e1c..cc9179ae8 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest +from schema import Schema from griptape.artifacts import ActionArtifact, GenericArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact @@ -141,6 +142,24 @@ class TestAnthropicPromptDriver: }, ] + ANTHROPIC_STRUCTURED_OUTPUT_TOOL = { + "description": "Used to provide the final response which ends this conversation.", + "input_schema": { + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + "name": "StructuredOutputTool_provide_output", + } + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("anthropic.Anthropic") @@ -199,6 +218,7 @@ def mock_stream_client(self, mocker): @pytest.fixture(params=[True, False]) def prompt_stack(self, request): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] if request.param: prompt_stack.add_system_message("system-input") @@ -350,10 +370,15 @@ def test_init(self): assert AnthropicPromptDriver(model="claude-3-haiku", api_key="1234") @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, use_native_structured_output): # Given driver = AnthropicPromptDriver( - model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="claude-3-haiku", + api_key="api-key", + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, ) # When @@ -369,7 +394,21 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): top_p=0.999, top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, - **{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + **{ + "tools": [ + *self.ANTHROPIC_TOOLS, + *( + [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.native_structured_output_strategy == "tool" + else [] + ), + ] + if use_native_tools + else {}, + "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, + } + if use_native_tools + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -383,13 +422,17 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream_run( + self, mock_stream_client, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -408,7 +451,21 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na top_p=0.999, top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, - **{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + **{ + "tools": [ + *self.ANTHROPIC_TOOLS, + *( + [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.native_structured_output_strategy == "tool" + else [] + ), + ] + if use_native_tools + else {}, + "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, + } + if use_native_tools + else {}, foo="bar", ) assert event.usage.input_tokens == 5 @@ -433,3 +490,9 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na event = next(stream) assert event.usage.output_tokens == 10 + + def test_verify_native_structured_output_strategy(self): + assert AnthropicPromptDriver(model="foo", native_structured_output_strategy="tool") + + with pytest.raises(ValueError, match="AnthropicPromptDriver does not support `native` structured output mode."): + AnthropicPromptDriver(model="foo", native_structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index c7dff9811..6ca7a423b 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -67,13 +67,25 @@ def test_init(self): assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", model="gpt-4").azure_deployment == "gpt-4" @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool"]) + def test_try_run( + self, + mock_chat_completion_create, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given driver = AzureOpenAiChatPromptDriver( azure_endpoint="endpoint", azure_deployment="deployment-id", model="gpt-4", use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -87,11 +99,32 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_ user=driver.user, messages=messages, **{ - "tools": self.OPENAI_TOOLS, - "tool_choice": driver.tool_choice, + "tools": [ + *self.OPENAI_TOOLS, + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ], + "tool_choice": "required" + if use_native_structured_output and native_structured_output_strategy == "tool" + else driver.tool_choice, } if use_native_tools else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -103,7 +136,17 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_ assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool"]) + def test_try_stream_run( + self, + mock_chat_completion_stream_create, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given driver = AzureOpenAiChatPromptDriver( azure_endpoint="endpoint", @@ -111,6 +154,8 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, model="gpt-4", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -126,11 +171,32 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, stream=True, messages=messages, **{ - "tools": self.OPENAI_TOOLS, - "tool_choice": driver.tool_choice, + "tools": [ + *self.OPENAI_TOOLS, + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ], + "tool_choice": "required" + if use_native_structured_output and native_structured_output_strategy == "tool" + else driver.tool_choice, } if use_native_tools else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 58720bbc5..c57173c66 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,3 +1,5 @@ +import pytest + from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent @@ -65,3 +67,24 @@ def test_run_with_tools_and_stream(self, mock_config): output = pipeline.run().output_task.output assert isinstance(output, TextArtifact) assert output.value == "mock output" + + def test__add_structured_output_tool(self): + from schema import Schema + + from griptape.tools.structured_output.tool import StructuredOutputTool + + mock_prompt_driver = MockPromptDriver() + + prompt_stack = PromptStack() + + with pytest.raises(ValueError, match="PromptStack must have an output schema to use structured output."): + mock_prompt_driver._add_structured_output_tool(prompt_stack) + + prompt_stack.output_schema = Schema({"foo": str}) + + mock_prompt_driver._add_structured_output_tool(prompt_stack) + # Ensure it doesn't get added twice + mock_prompt_driver._add_structured_output_tool(prompt_stack) + assert len(prompt_stack.tools) == 1 + assert isinstance(prompt_stack.tools[0], StructuredOutputTool) + assert prompt_stack.tools[0].output_schema is prompt_stack.output_schema diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 9b7c24a98..bc0c51203 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import pytest +from schema import Schema from griptape.artifacts.action_artifact import ActionArtifact from griptape.artifacts.list_artifact import ListArtifact @@ -12,6 +13,36 @@ class TestCoherePromptDriver: + COHERE_STRUCTURED_OUTPUT_SCHEMA = { + "$id": "Output", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } + COHERE_STRUCTURED_OUTPUT_TOOL = { + "function": { + "description": "Used to provide the final response which ends this conversation.", + "name": "StructuredOutputTool_provide_output", + "parameters": { + "$id": "Parameters Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + }, + "type": "function", + } COHERE_TOOLS = [ { "function": { @@ -242,6 +273,7 @@ def mock_tokenizer(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -306,10 +338,25 @@ def test_init(self): assert CoherePromptDriver(model="command", api_key="foobar") @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + def test_try_run( + self, + mock_client, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given driver = CoherePromptDriver( - model="command", api_key="api-key", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="command", + api_key="api-key", + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, + extra_params={"foo": "bar"}, ) # When @@ -320,7 +367,26 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): model="command", messages=messages, max_tokens=None, - **({"tools": self.COHERE_TOOLS} if use_native_tools else {}), + **{ + "tools": [ + *self.COHERE_TOOLS, + *( + [self.COHERE_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ] + } + if use_native_tools + else {}, + **{ + "response_format": { + "type": "json_object", + "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, + } + } + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, stop_sequences=[], temperature=0.1, foo="bar", @@ -340,13 +406,25 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + def test_try_stream_run( + self, + mock_stream_client, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given driver = CoherePromptDriver( model="command", api_key="api-key", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -359,7 +437,26 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na model="command", messages=messages, max_tokens=None, - **({"tools": self.COHERE_TOOLS} if use_native_tools else {}), + **{ + "tools": [ + *self.COHERE_TOOLS, + *( + [self.COHERE_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ] + } + if use_native_tools + else {}, + **{ + "response_format": { + "type": "json_object", + "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, + } + } + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, stop_sequences=[], temperature=0.1, foo="bar", diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 72cf51d03..e4a71d24e 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -4,6 +4,7 @@ from google.generativeai.protos import FunctionCall, FunctionResponse, Part from google.generativeai.types import ContentDict, GenerationConfig from google.protobuf.json_format import MessageToDict +from schema import Schema from griptape.artifacts import ActionArtifact, GenericArtifact, ImageArtifact, TextArtifact from griptape.artifacts.list_artifact import ListArtifact @@ -13,6 +14,15 @@ class TestGooglePromptDriver: + GOOGLE_STRUCTURED_OUTPUT_TOOL = { + "description": "Used to provide the final response which ends this conversation.", + "name": "StructuredOutputTool_provide_output", + "parameters": { + "properties": {"foo": {"type": "STRING"}}, + "required": ["foo"], + "type": "OBJECT", + }, + } GOOGLE_TOOLS = [ { "name": "MockTool_test", @@ -100,6 +110,7 @@ def mock_stream_generative_model(self, mocker): @pytest.fixture(params=[True, False]) def prompt_stack(self, request): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] if request.param: prompt_stack.add_system_message("system-input") @@ -166,7 +177,10 @@ def test_init(self): assert driver @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run( + self, mock_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -174,6 +188,8 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native top_p=0.5, top_k=50, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy="tool", extra_params={"max_output_tokens": 10}, ) @@ -195,9 +211,14 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] - assert [ - MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations - ] == self.GOOGLE_TOOLS + tools = [ + *self.GOOGLE_TOOLS, + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), + ] + assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools + + if use_native_structured_output: + assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(message.value[0], TextArtifact) assert message.value[0].value == "model-output" @@ -210,7 +231,10 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream( + self, mock_stream_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -219,6 +243,7 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, top_p=0.5, top_k=50, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"max_output_tokens": 10}, ) @@ -242,9 +267,14 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] - assert [ - MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations - ] == self.GOOGLE_TOOLS + tools = [ + *self.GOOGLE_TOOLS, + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), + ] + assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools + + if use_native_structured_output: + assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" assert event.usage.input_tokens == 5 @@ -259,3 +289,9 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, event = next(stream) assert event.usage.output_tokens == 5 + + def test_verify_native_structured_output_strategy(self): + assert GooglePromptDriver(model="foo", native_structured_output_strategy="tool") + + with pytest.raises(ValueError, match="GooglePromptDriver does not support `native` structured output mode."): + GooglePromptDriver(model="foo", native_structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 4b7aa4d13..24a83c07b 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -1,10 +1,18 @@ import pytest +from schema import Schema from griptape.common import PromptStack, TextDeltaMessageContent from griptape.drivers import HuggingFaceHubPromptDriver class TestHuggingFaceHubPromptDriver: + HUGGINGFACE_HUB_OUTPUT_SCHEMA = { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("huggingface_hub.InferenceClient").return_value @@ -31,6 +39,7 @@ def mock_client_stream(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") prompt_stack.add_assistant_message("assistant-input") @@ -45,9 +54,15 @@ def mock_autotokenizer(self, mocker): def test_init(self): assert HuggingFaceHubPromptDriver(api_token="foobar", model="gpt2") - def test_try_run(self, prompt_stack, mock_client): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, prompt_stack, mock_client, use_native_structured_output): # Given - driver = HuggingFaceHubPromptDriver(api_token="api-token", model="repo-id", extra_params={"foo": "bar"}) + driver = HuggingFaceHubPromptDriver( + api_token="api-token", + model="repo-id", + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, + ) # When message = driver.try_run(prompt_stack) @@ -58,15 +73,23 @@ def test_try_run(self, prompt_stack, mock_client): return_full_text=False, max_new_tokens=250, foo="bar", + **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} + if use_native_structured_output + else {}, ) assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, prompt_stack, mock_client_stream): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream(self, prompt_stack, mock_client_stream, use_native_structured_output): # Given driver = HuggingFaceHubPromptDriver( - api_token="api-token", model="repo-id", stream=True, extra_params={"foo": "bar"} + api_token="api-token", + model="repo-id", + stream=True, + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, ) # When @@ -79,6 +102,9 @@ def test_try_stream(self, prompt_stack, mock_client_stream): return_full_text=False, max_new_tokens=250, foo="bar", + **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} + if use_native_structured_output + else {}, stream=True, ) assert isinstance(event.content, TextDeltaMessageContent) @@ -87,3 +113,11 @@ def test_try_stream(self, prompt_stack, mock_client_stream): event = next(stream) assert event.usage.input_tokens == 3 assert event.usage.output_tokens == 3 + + def test_verify_native_structured_output_strategy(self): + assert HuggingFaceHubPromptDriver(model="foo", api_token="bar", native_structured_output_strategy="native") + + with pytest.raises( + ValueError, match="HuggingFaceHubPromptDriver does not support `tool` structured output mode." + ): + HuggingFaceHubPromptDriver(model="foo", api_token="bar", native_structured_output_strategy="tool") diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index 51a3dbb77..bfc1111e0 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -1,4 +1,5 @@ import pytest +from schema import Schema from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.common import PromptStack, TextDeltaMessageContent, ToolAction @@ -7,6 +8,27 @@ class TestOllamaPromptDriver: + OLLAMA_STRUCTURED_OUTPUT_SCHEMA = { + "$id": "Output", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } + OLLAMA_STRUCTURED_OUTPUT_TOOL = { + "function": { + "description": "Used to provide the final response which ends this conversation.", + "name": "StructuredOutputTool_provide_output", + "parameters": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "type": "function", + } OLLAMA_TOOLS = [ { "function": { @@ -112,7 +134,9 @@ class TestOllamaPromptDriver: def mock_client(self, mocker): mock_client = mocker.patch("ollama.Client") - mock_client.return_value.chat.return_value = { + mock_response = mocker.MagicMock() + + data = { "message": { "content": "model-output", "tool_calls": [ @@ -126,6 +150,10 @@ def mock_client(self, mocker): }, } + mock_response.__getitem__.side_effect = lambda key: data[key] + mock_response.model_dump.return_value = data + mock_client.return_value.chat.return_value = mock_response + return mock_client @pytest.fixture() @@ -138,6 +166,7 @@ def mock_stream_client(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -202,10 +231,26 @@ def messages(self): def test_init(self): assert OllamaPromptDriver(model="llama") - @pytest.mark.parametrize("use_native_tools", [True]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + def test_try_run( + self, + mock_client, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given - driver = OllamaPromptDriver(model="llama", extra_params={"foo": "bar"}) + driver = OllamaPromptDriver( + model="llama", + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, + extra_params={"foo": "bar"}, + ) # When message = driver.try_run(prompt_stack) @@ -219,7 +264,21 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): "stop": [], "num_predict": driver.max_tokens, }, - **{"tools": self.OLLAMA_TOOLS} if use_native_tools else {}, + **{ + "tools": [ + *self.OLLAMA_TOOLS, + *( + [self.OLLAMA_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ] + } + if use_native_tools + else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -230,33 +289,39 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.value[1].value.path == "test" assert message.value[1].value.input == {"foo": "bar"} - def test_try_stream_run(self, mock_stream_client): + @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + def test_try_stream_run( + self, + mock_stream_client, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message( - ListArtifact( - [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] - ) + driver = OllamaPromptDriver( + model="llama", + stream=True, + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, + extra_params={"foo": "bar"}, ) - prompt_stack.add_assistant_message("assistant-input") - expected_messages = [ - {"role": "system", "content": "system-input"}, - {"role": "user", "content": "user-input"}, - {"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]}, - {"role": "assistant", "content": "assistant-input"}, - ] - driver = OllamaPromptDriver(model="llama", stream=True, extra_params={"foo": "bar"}) # When text_artifact = next(driver.try_stream(prompt_stack)) # Then mock_stream_client.return_value.chat.assert_called_once_with( - messages=expected_messages, + messages=messages, model=driver.model, options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, stream=True, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index c47c3e9c6..2b9c7e5b9 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -12,6 +12,36 @@ class TestOpenAiChatPromptDriverFixtureMixin: + OPENAI_STRUCTURED_OUTPUT_SCHEMA = { + "$id": "Output", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } + OPENAI_STRUCTURED_OUTPUT_TOOL = { + "function": { + "description": "Used to provide the final response which ends this conversation.", + "name": "StructuredOutputTool_provide_output", + "parameters": { + "$id": "Parameters Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + }, + "type": "function", + } OPENAI_TOOLS = [ { "function": { @@ -239,6 +269,7 @@ def mock_chat_completion_stream_create(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = schema.Schema({"foo": str}) prompt_stack.tools = [MockTool()] prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -340,11 +371,23 @@ def test_init(self): assert OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_4_MODEL) @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + def test_try_run( + self, + mock_chat_completion_create, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -359,12 +402,33 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_ messages=messages, seed=driver.seed, **{ - "tools": self.OPENAI_TOOLS, - "tool_choice": driver.tool_choice, + "tools": [ + *self.OPENAI_TOOLS, + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ], + "tool_choice": "required" + if use_native_structured_output and native_structured_output_strategy == "tool" + else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -445,12 +509,24 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create, assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + def test_try_stream_run( + self, + mock_chat_completion_stream_create, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -468,12 +544,33 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, seed=driver.seed, stream_options={"include_usage": True}, **{ - "tools": self.OPENAI_TOOLS, - "tool_choice": driver.tool_choice, + "tools": [ + *self.OPENAI_TOOLS, + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ], + "tool_choice": "required" + if use_native_structured_output and native_structured_output_strategy == "tool" + else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, foo="bar", ) @@ -500,7 +597,10 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): # Given driver = OpenAiChatPromptDriver( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1, use_native_tools=False + model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, + max_tokens=1, + use_native_tools=False, + use_native_structured_output=False, ) # When @@ -535,6 +635,7 @@ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messa tokenizer=MockTokenizer(model="mock-model", stop_sequences=["mock-stop"]), max_tokens=1, use_native_tools=False, + use_native_structured_output=False, ) # When diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 21a637ff6..5921d9e28 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -83,6 +83,8 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "use_native_structured_output": False, + "native_structured_output_strategy": "native", }, } ], diff --git a/tests/unit/tasks/test_actions_subtask.py b/tests/unit/tasks/test_actions_subtask.py index e7d44b5af..764c3440c 100644 --- a/tests/unit/tasks/test_actions_subtask.py +++ b/tests/unit/tasks/test_actions_subtask.py @@ -4,9 +4,10 @@ from griptape.artifacts import ActionArtifact, ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact +from griptape.artifacts.json_artifact import JsonArtifact from griptape.common import ToolAction from griptape.structures import Agent -from griptape.tasks import ActionsSubtask, PromptTask +from griptape.tasks import ActionsSubtask, PromptTask, ToolkitTask from tests.mocks.mock_tool.tool import MockTool @@ -257,3 +258,68 @@ def test_origin_task(self): with pytest.raises(Exception, match="ActionSubtask has no origin task."): assert ActionsSubtask("test").origin_task + + def test_structured_output_tool(self): + import schema + + from griptape.tools.structured_output.tool import StructuredOutputTool + + actions = ListArtifact( + [ + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool", + path="provide_output", + input={"values": {"test": "value"}}, + ) + ), + ] + ) + + task = ToolkitTask(tools=[StructuredOutputTool(output_schema=schema.Schema({"test": str}))]) + Agent().add_task(task) + subtask = task.add_subtask(ActionsSubtask(actions)) + + assert isinstance(subtask.output, JsonArtifact) + assert subtask.output.value == {"test": "value"} + + def test_structured_output_tool_multiple(self): + import schema + + from griptape.tools.structured_output.tool import StructuredOutputTool + + actions = ListArtifact( + [ + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool1", + path="provide_output", + input={"values": {"test1": "value"}}, + ) + ), + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool2", + path="provide_output", + input={"values": {"test2": "value"}}, + ) + ), + ] + ) + + task = ToolkitTask( + tools=[ + StructuredOutputTool(name="StructuredOutputTool1", output_schema=schema.Schema({"test": str})), + StructuredOutputTool(name="StructuredOutputTool2", output_schema=schema.Schema({"test": str})), + ] + ) + Agent().add_task(task) + subtask = task.add_subtask(ActionsSubtask(actions)) + + assert isinstance(subtask.output, ListArtifact) + assert len(subtask.output.value) == 2 + assert subtask.output.value[0].value == {"test1": "value"} + assert subtask.output.value[1].value == {"test2": "value"} diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index f457a4b55..30a7001f9 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,9 +1,15 @@ +import warnings + +import pytest + from griptape.artifacts.image_artifact import ImageArtifact +from griptape.artifacts.json_artifact import JsonArtifact from griptape.artifacts.list_artifact import ListArtifact from griptape.artifacts.text_artifact import TextArtifact from griptape.memory.structure import ConversationMemory from griptape.memory.structure.run import Run from griptape.rules import Rule +from griptape.rules.json_schema_rule import JsonSchemaRule from griptape.rules.ruleset import Ruleset from griptape.structures import Pipeline from griptape.tasks import PromptTask @@ -172,6 +178,81 @@ def test_prompt_stack_empty_system_content(self): assert task.prompt_stack.messages[2].is_user() assert task.prompt_stack.messages[2].to_text() == "test value" + def test_prompt_stack_native_schema(self): + from schema import Schema + + output_schema = Schema({"baz": str}) + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver( + use_native_structured_output=True, + mock_structured_output={"baz": "foo"}, + ), + rules=[JsonSchemaRule(output_schema)], + ) + output = task.run() + + assert isinstance(output, JsonArtifact) + assert output.value == {"baz": "foo"} + + assert task.prompt_stack.output_schema is output_schema + assert task.prompt_stack.messages[0].is_user() + assert "foo" in task.prompt_stack.messages[0].to_text() + + # Ensure no warnings were raised + with warnings.catch_warnings(): + warnings.simplefilter("error") + assert task.prompt_stack + + def test_prompt_stack_mixed_native_schema(self): + from schema import Schema + + output_schema = Schema({"baz": str}) + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver( + use_native_structured_output=True, + ), + rules=[Rule("foo"), JsonSchemaRule({"bar": {}}), JsonSchemaRule(output_schema)], + ) + + assert task.prompt_stack.output_schema is output_schema + assert task.prompt_stack.messages[0].is_system() + assert "foo" in task.prompt_stack.messages[0].to_text() + assert "bar" not in task.prompt_stack.messages[0].to_text() + with pytest.warns( + match="Not all provided `JsonSchemaRule`s include a `schema.Schema` instance. These will be ignored with `use_native_structured_output`." + ): + assert task.prompt_stack + + def test_prompt_stack_empty_native_schema(self): + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver( + use_native_structured_output=True, + ), + rules=[JsonSchemaRule({"foo": {}})], + ) + + assert task.prompt_stack.output_schema is None + + def test_prompt_stack_multi_native_schema(self): + from schema import Or, Schema + + output_schema = Schema({"foo": str}) + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver( + use_native_structured_output=True, + ), + rules=[JsonSchemaRule({"foo": {}}), JsonSchemaRule(output_schema), JsonSchemaRule(output_schema)], + ) + + assert isinstance(task.prompt_stack.output_schema, Schema) + assert task.prompt_stack.output_schema.json_schema("Output") == Schema( + Or(output_schema, output_schema) + ).json_schema("Output") + def test_rulesets(self): pipeline = Pipeline( rulesets=[Ruleset("Pipeline Ruleset")], diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index ca0576ebe..d7050c8f6 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -257,6 +257,8 @@ def test_to_dict(self): "stream": False, "temperature": 0.1, "type": "MockPromptDriver", + "native_structured_output_strategy": "native", + "use_native_structured_output": False, "use_native_tools": False, }, "tool": { diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 3c17ff479..2503f1174 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -399,6 +399,8 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "use_native_structured_output": False, + "native_structured_output_strategy": "native", }, "tools": [ { diff --git a/tests/unit/tools/test_structured_output_tool.py b/tests/unit/tools/test_structured_output_tool.py new file mode 100644 index 000000000..d310b2f9b --- /dev/null +++ b/tests/unit/tools/test_structured_output_tool.py @@ -0,0 +1,13 @@ +import pytest +import schema + +from griptape.tools import StructuredOutputTool + + +class TestStructuredOutputTool: + @pytest.fixture() + def tool(self): + return StructuredOutputTool(output_schema=schema.Schema({"foo": "bar"})) + + def test_provide_output(self, tool): + assert tool.provide_output({"values": {"foo": "bar"}}).value == {"foo": "bar"} From 1599246f53804e19c92109172869d2450670dca8 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 2 Jan 2025 14:32:12 -0800 Subject: [PATCH 02/11] Don't use JsonSchemaRules --- CHANGELOG.md | 2 - .../drivers/prompt-drivers.md | 13 +---- griptape/tasks/prompt_task.py | 48 ++++--------------- mise.toml | 2 + tests/mocks/mock_prompt_driver.py | 2 +- tests/unit/tasks/test_prompt_task.py | 42 +--------------- 6 files changed, 16 insertions(+), 93 deletions(-) create mode 100644 mise.toml diff --git a/CHANGELOG.md b/CHANGELOG.md index cc51c5124..4a85afb23 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,8 +41,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `PromptTask.prompt_driver` is now serialized. - `PromptTask` can now do everything a `ToolkitTask` can do. - Loosten `numpy`s version constraint to `>=1.26.4,<3`. -- `JsonSchemaRule`s can now take a `schema.Schema` instance. Required for using a `JsonSchemaRule` with structured output. -- `JsonSchemaRule`s will now be used for structured output if the Prompt Driver supports it. ### Fixed diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index f05647bb3..67b24026e 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -44,17 +44,8 @@ The easiest way to get started with structured output is by using a [JsonSchemaR --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py" ``` -### Multiple Schemas - -If multiple `JsonSchemaRule`s are provided, Griptape will merge them into a single JSON Schema using `anyOf`. - -Some LLMs may not support `anyOf` as a top-level JSON Schema. To work around this, you can try using another `native_structured_output_strategy`: - -```python ---8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py" -``` - -Not every LLM supports `use_native_structured_output` or all `native_structured_output_strategy` options. +!!! warning + Not every LLM supports `use_native_structured_output` or all `native_structured_output_strategy` options. ## Prompt Drivers diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index bb5ca9667..f786b6d33 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -2,11 +2,9 @@ import json import logging -import warnings from typing import TYPE_CHECKING, Callable, Optional, Union from attrs import NOTHING, Attribute, Factory, NothingType, define, field -from schema import Or, Schema from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact @@ -16,11 +14,13 @@ from griptape.memory.structure import Run from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.mixins.rule_mixin import RuleMixin -from griptape.rules import JsonSchemaRule, Ruleset +from griptape.rules import Ruleset from griptape.tasks import ActionsSubtask, BaseTask from griptape.utils import J2 if TYPE_CHECKING: + from schema import Schema + from griptape.drivers import BasePromptDriver from griptape.memory import TaskMemory from griptape.memory.structure.base_conversation_memory import BaseConversationMemory @@ -39,6 +39,7 @@ class PromptTask(BaseTask, RuleMixin, ActionsSubtaskOriginMixin): prompt_driver: BasePromptDriver = field( default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True, metadata={"serializable": True} ) + output_schema: Optional[Schema] = field(default=None, kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_generate_system_template, takes_self=True), kw_only=True, @@ -90,22 +91,12 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], @property def prompt_stack(self) -> PromptStack: - stack = PromptStack(tools=self.tools) + stack = PromptStack(tools=self.tools, output_schema=self.output_schema) memory = self.structure.conversation_memory if self.structure is not None else None - rulesets = self.rulesets - system_artifacts = [TextArtifact(self.generate_system_template(self))] - if self.prompt_driver.use_native_structured_output: - self._add_native_schema_to_prompt_stack(stack, rulesets) - - # Ensure there is at least one Ruleset that has non-empty `rules`. - if any(len(ruleset.rules) for ruleset in rulesets): - system_artifacts.append(TextArtifact(J2("rulesets/rulesets.j2").render(rulesets=rulesets))) - - # Ensure there is at least one system Artifact that has a non-empty value. - has_system_artifacts = any(system_artifact.value for system_artifact in system_artifacts) - if has_system_artifacts: - stack.add_system_message(ListArtifact(system_artifacts)) + system_template = self.generate_system_template(self) + if system_template: + stack.add_system_message(system_template) stack.add_user_message(self.input) @@ -116,7 +107,7 @@ def prompt_stack(self) -> PromptStack: if memory is not None: # inserting at index 1 to place memory right after system prompt - memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if has_system_artifacts else 0) + memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_template else 0) return stack @@ -226,6 +217,7 @@ def default_generate_system_template(self, _: PromptTask) -> str: schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually. return J2("tasks/prompt_task/system.j2").render( + rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), action_names=str.join(", ", [tool.name for tool in self.tools]), actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), @@ -307,26 +299,6 @@ def _process_task_input( else: return self._process_task_input(TextArtifact(task_input)) - def _add_native_schema_to_prompt_stack(self, stack: PromptStack, rulesets: list[Ruleset]) -> None: - # Need to separate JsonSchemaRules from other rules, removing them in the process - json_schema_rules = [rule for ruleset in rulesets for rule in ruleset.rules if isinstance(rule, JsonSchemaRule)] - non_json_schema_rules = [ - [rule for rule in ruleset.rules if not isinstance(rule, JsonSchemaRule)] for ruleset in rulesets - ] - for ruleset, non_json_rules in zip(rulesets, non_json_schema_rules): - ruleset.rules = non_json_rules - - schemas = [rule.value for rule in json_schema_rules if isinstance(rule.value, Schema)] - - if len(json_schema_rules) != len(schemas): - warnings.warn( - "Not all provided `JsonSchemaRule`s include a `schema.Schema` instance. These will be ignored with `use_native_structured_output`.", - stacklevel=2, - ) - - if schemas: - stack.output_schema = schemas[0] if len(schemas) == 1 else Schema(Or(*schemas)) - def _add_subtasks_to_prompt_stack(self, stack: PromptStack) -> None: for s in self.subtasks: if self.prompt_driver.use_native_tools: diff --git a/mise.toml b/mise.toml new file mode 100644 index 000000000..e01d6ae46 --- /dev/null +++ b/mise.toml @@ -0,0 +1,2 @@ +[tools] +python = "3.9" diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index abef72227..af4b2c79a 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -36,7 +36,7 @@ class MockPromptDriver(BasePromptDriver): def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if prompt_stack.output_schema and self.use_native_structured_output: + if self.use_native_structured_output and prompt_stack.output_schema: if self.native_structured_output_strategy == "native": return Message( content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 30a7001f9..e4d3060a5 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,7 +1,5 @@ import warnings -import pytest - from griptape.artifacts.image_artifact import ImageArtifact from griptape.artifacts.json_artifact import JsonArtifact from griptape.artifacts.list_artifact import ListArtifact @@ -188,7 +186,7 @@ def test_prompt_stack_native_schema(self): use_native_structured_output=True, mock_structured_output={"baz": "foo"}, ), - rules=[JsonSchemaRule(output_schema)], + output_schema=output_schema, ) output = task.run() @@ -204,27 +202,6 @@ def test_prompt_stack_native_schema(self): warnings.simplefilter("error") assert task.prompt_stack - def test_prompt_stack_mixed_native_schema(self): - from schema import Schema - - output_schema = Schema({"baz": str}) - task = PromptTask( - input="foo", - prompt_driver=MockPromptDriver( - use_native_structured_output=True, - ), - rules=[Rule("foo"), JsonSchemaRule({"bar": {}}), JsonSchemaRule(output_schema)], - ) - - assert task.prompt_stack.output_schema is output_schema - assert task.prompt_stack.messages[0].is_system() - assert "foo" in task.prompt_stack.messages[0].to_text() - assert "bar" not in task.prompt_stack.messages[0].to_text() - with pytest.warns( - match="Not all provided `JsonSchemaRule`s include a `schema.Schema` instance. These will be ignored with `use_native_structured_output`." - ): - assert task.prompt_stack - def test_prompt_stack_empty_native_schema(self): task = PromptTask( input="foo", @@ -236,23 +213,6 @@ def test_prompt_stack_empty_native_schema(self): assert task.prompt_stack.output_schema is None - def test_prompt_stack_multi_native_schema(self): - from schema import Or, Schema - - output_schema = Schema({"foo": str}) - task = PromptTask( - input="foo", - prompt_driver=MockPromptDriver( - use_native_structured_output=True, - ), - rules=[JsonSchemaRule({"foo": {}}), JsonSchemaRule(output_schema), JsonSchemaRule(output_schema)], - ) - - assert isinstance(task.prompt_stack.output_schema, Schema) - assert task.prompt_stack.output_schema.json_schema("Output") == Schema( - Or(output_schema, output_schema) - ).json_schema("Output") - def test_rulesets(self): pipeline = Pipeline( rulesets=[Ruleset("Pipeline Ruleset")], From f49cda020275071e0ab8d1b5798692994627ed5a Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 2 Jan 2025 14:55:40 -0800 Subject: [PATCH 03/11] PR feedback --- CHANGELOG.md | 2 +- .../drivers/prompt-drivers.md | 6 ++-- .../src/prompt_drivers_structured_output.py | 18 ++++++------ .../prompt_drivers_structured_output_multi.py | 28 ------------------- .../prompt/amazon_bedrock_prompt_driver.py | 8 +++--- .../drivers/prompt/anthropic_prompt_driver.py | 8 +++--- griptape/drivers/prompt/base_prompt_driver.py | 2 +- .../drivers/prompt/cohere_prompt_driver.py | 4 +-- .../drivers/prompt/google_prompt_driver.py | 8 +++--- .../prompt/huggingface_hub_prompt_driver.py | 8 +++--- .../drivers/prompt/ollama_prompt_driver.py | 4 +-- .../prompt/openai_chat_prompt_driver.py | 4 +-- griptape/tasks/prompt_task.py | 2 +- tests/mocks/mock_prompt_driver.py | 4 +-- .../test_amazon_bedrock_drivers_config.py | 4 +-- .../drivers/test_anthropic_drivers_config.py | 2 +- .../test_azure_openai_drivers_config.py | 2 +- .../drivers/test_cohere_drivers_config.py | 2 +- .../configs/drivers/test_drivers_config.py | 2 +- .../drivers/test_google_drivers_config.py | 2 +- .../drivers/test_openai_driver_config.py | 2 +- .../test_amazon_bedrock_prompt_driver.py | 10 +++---- .../prompt/test_anthropic_prompt_driver.py | 10 +++---- .../test_azure_openai_chat_prompt_driver.py | 24 ++++++++-------- .../prompt/test_cohere_prompt_driver.py | 20 ++++++------- .../prompt/test_google_prompt_driver.py | 8 +++--- .../test_hugging_face_hub_prompt_driver.py | 6 ++-- .../prompt/test_ollama_prompt_driver.py | 18 ++++++------ .../prompt/test_openai_chat_prompt_driver.py | 24 ++++++++-------- tests/unit/structures/test_structure.py | 2 +- tests/unit/tasks/test_tool_task.py | 2 +- tests/unit/tasks/test_toolkit_task.py | 2 +- 32 files changed, 110 insertions(+), 138 deletions(-) delete mode 100644 docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a85afb23..6e9defcc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`. - Validators to `Agent` initialization. - `BasePromptDriver.use_native_structured_output` for enabling or disabling structured output. -- `BasePromptDriver.native_structured_output_strategy` for changing the structured output strategy between `native` and `tool`. +- `BasePromptDriver.structured_output_strategy` for changing the structured output strategy between `native` and `tool`. ### Changed diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 67b24026e..c890ef179 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -31,11 +31,13 @@ Some LLMs provide functionality often referred to as "Structured Output". This m Structured output can be enabled or disabled for a Prompt Driver by setting the [use_native_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_native_structured_output). -If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [native_structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.native_structured_output_strategy) to one of: +If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of: - `native`: The Driver will use the LLM's structured output functionality provided by the API. - `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output_tool.md) and try to force the LLM to use a Tool. +Each Driver may have a different default setting depending on the LLM provider's capabilities. + ### JSON Schema The easiest way to get started with structured output is by using a [JsonSchemaRule](../structures/rulesets.md#json-schema). If a [schema.Schema](https://pypi.org/project/schema/) instance is provided to the Rule, Griptape will convert it to a JSON Schema and provide it to the LLM using the selected structured output strategy. @@ -45,7 +47,7 @@ The easiest way to get started with structured output is by using a [JsonSchemaR ``` !!! warning - Not every LLM supports `use_native_structured_output` or all `native_structured_output_strategy` options. + Not every LLM supports `use_native_structured_output` or all `structured_output_strategy` options. ## Prompt Drivers diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py index 6613c3a3e..b1f801341 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py @@ -2,7 +2,7 @@ from rich.pretty import pprint from griptape.drivers import OpenAiChatPromptDriver -from griptape.rules import JsonSchemaRule, Rule +from griptape.rules import Rule from griptape.structures import Pipeline from griptape.tasks import PromptTask @@ -12,18 +12,16 @@ prompt_driver=OpenAiChatPromptDriver( model="gpt-4o", use_native_structured_output=True, - native_structured_output_strategy="native", + structured_output_strategy="native", + ), + output_schema=schema.Schema( + { + "steps": [schema.Schema({"explanation": str, "output": str})], + "final_answer": str, + } ), rules=[ Rule("You are a helpful math tutor. Guide the user through the solution step by step."), - JsonSchemaRule( - schema.Schema( - { - "steps": [schema.Schema({"explanation": str, "output": str})], - "final_answer": str, - } - ) - ), ], ) ] diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py deleted file mode 100644 index 0b85cee94..000000000 --- a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py +++ /dev/null @@ -1,28 +0,0 @@ -import schema -from rich.pretty import pprint - -from griptape.drivers import OpenAiChatPromptDriver -from griptape.rules import JsonSchemaRule -from griptape.structures import Pipeline -from griptape.tasks import PromptTask - -pipeline = Pipeline( - tasks=[ - PromptTask( - prompt_driver=OpenAiChatPromptDriver( - model="gpt-4o", - use_native_structured_output=True, - native_structured_output_strategy="tool", - ), - rules=[ - JsonSchemaRule(schema.Schema({"color": "red"})), - JsonSchemaRule(schema.Schema({"color": "blue"})), - ], - ) - ] -) - -output = pipeline.run("Pick a color").output.value - - -pprint(output) diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 9e754f6aa..eefee0ff2 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -56,14 +56,14 @@ class AmazonBedrockPromptDriver(BasePromptDriver): ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - native_structured_output_strategy: Literal["native", "tool"] = field( + structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True}) _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) - @native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str: if value == "native": raise ValueError("AmazonBedrockPromptDriver does not support `native` structured output mode.") @@ -137,7 +137,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if ( prompt_stack.output_schema is not None and self.use_native_structured_output - and self.native_structured_output_strategy == "tool" + and self.structured_output_strategy == "tool" ): self._add_structured_output_tool(prompt_stack) params["toolConfig"]["toolChoice"] = {"any": {}} diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index a61b69232..99053713a 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -69,7 +69,7 @@ class AnthropicPromptDriver(BasePromptDriver): tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - native_structured_output_strategy: Literal["native", "tool"] = field( + structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) @@ -79,8 +79,8 @@ class AnthropicPromptDriver(BasePromptDriver): def client(self) -> Client: return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) - @native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str: if value == "native": raise ValueError("AnthropicPromptDriver does not support `native` structured output mode.") @@ -139,7 +139,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if ( prompt_stack.output_schema is not None and self.use_native_structured_output - and self.native_structured_output_strategy == "tool" + and self.structured_output_strategy == "tool" ): self._add_structured_output_tool(prompt_stack) params["tool_choice"] = {"type": "any"} diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 19109f55f..950c80cf8 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -57,7 +57,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - native_structured_output_strategy: Literal["native", "tool"] = field( + structured_output_strategy: Literal["native", "tool"] = field( default="native", kw_only=True, metadata={"serializable": True} ) extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 2695aba09..a7121b440 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -113,12 +113,12 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: } if prompt_stack.output_schema is not None and self.use_native_structured_output: - if self.native_structured_output_strategy == "native": + if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_object", "schema": prompt_stack.output_schema.json_schema("Output"), } - elif self.native_structured_output_strategy == "tool": + elif self.structured_output_strategy == "tool": # TODO: Implement tool choice once supported self._add_structured_output_tool(prompt_stack) diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 23de1e42d..29c43a91e 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -64,14 +64,14 @@ class GooglePromptDriver(BasePromptDriver): top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - native_structured_output_strategy: Literal["native", "tool"] = field( + structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True}) _client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) - @native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str: if value == "native": raise ValueError("GooglePromptDriver does not support `native` structured output mode.") @@ -167,7 +167,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if ( prompt_stack.output_schema is not None and self.use_native_structured_output - and self.native_structured_output_strategy == "tool" + and self.structured_output_strategy == "tool" ): params["tool_config"]["function_calling_config"]["mode"] = "auto" self._add_structured_output_tool(prompt_stack) diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index f9acdeb1d..5b24f083b 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -36,7 +36,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - native_structured_output_strategy: Literal["native", "tool"] = field( + structured_output_strategy: Literal["native", "tool"] = field( default="native", kw_only=True, metadata={"serializable": True} ) tokenizer: HuggingFaceTokenizer = field( @@ -55,8 +55,8 @@ def client(self) -> InferenceClient: token=self.api_token, ) - @native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str: if value == "tool": raise ValueError("HuggingFaceHubPromptDriver does not support `tool` structured output mode.") @@ -124,7 +124,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if ( prompt_stack.output_schema and self.use_native_structured_output - and self.native_structured_output_strategy == "native" + and self.structured_output_strategy == "native" ): # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding output_schema = prompt_stack.output_schema.json_schema("Output Schema") diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 25756cc1c..295d926d1 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -111,9 +111,9 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: } if prompt_stack.output_schema is not None and self.use_native_structured_output: - if self.native_structured_output_strategy == "native": + if self.structured_output_strategy == "native": params["format"] = prompt_stack.output_schema.json_schema("Output") - elif self.native_structured_output_strategy == "tool": + elif self.structured_output_strategy == "tool": # TODO: Implement tool choice once supported self._add_structured_output_tool(prompt_stack) diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index d8f61a3bf..69e615585 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -160,7 +160,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: params["parallel_tool_calls"] = self.parallel_tool_calls if prompt_stack.output_schema is not None and self.use_native_structured_output: - if self.native_structured_output_strategy == "native": + if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_schema", "json_schema": { @@ -169,7 +169,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "strict": True, }, } - elif self.native_structured_output_strategy == "tool" and self.use_native_tools: + elif self.structured_output_strategy == "tool" and self.use_native_tools: params["tool_choice"] = "required" self._add_structured_output_tool(prompt_stack) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index f786b6d33..4af20a6f9 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -192,7 +192,7 @@ def try_run(self) -> BaseArtifact: if ( self.prompt_driver.use_native_structured_output - and self.prompt_driver.native_structured_output_strategy == "native" + and self.prompt_driver.structured_output_strategy == "native" ): return JsonArtifact(output.value) else: diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index af4b2c79a..1bdbfac7e 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -37,13 +37,13 @@ class MockPromptDriver(BasePromptDriver): def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output if self.use_native_structured_output and prompt_stack.output_schema: - if self.native_structured_output_strategy == "native": + if self.structured_output_strategy == "native": return Message( content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=100, output_tokens=100), ) - elif self.native_structured_output_strategy == "tool": + elif self.structured_output_strategy == "tool": self._add_structured_output_tool(prompt_stack) if self.use_native_tools and prompt_stack.tools: diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index 59eb4ac61..77c2631f3 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -52,7 +52,7 @@ def test_to_dict(self, config): "tool_choice": {"auto": {}}, "use_native_tools": True, "use_native_structured_output": True, - "native_structured_output_strategy": "tool", + "structured_output_strategy": "tool", "extra_params": {}, }, "vector_store_driver": { @@ -109,7 +109,7 @@ def test_to_dict_with_values(self, config_with_values): "tool_choice": {"auto": {}}, "use_native_tools": True, "use_native_structured_output": True, - "native_structured_output_strategy": "tool", + "structured_output_strategy": "tool", "extra_params": {}, }, "vector_store_driver": { diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index 66f987308..f412e10cb 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -25,7 +25,7 @@ def test_to_dict(self, config): "top_p": 0.999, "top_k": 250, "use_native_tools": True, - "native_structured_output_strategy": "tool", + "structured_output_strategy": "tool", "use_native_structured_output": True, "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index 2281f4c11..45fbfd6ab 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -36,7 +36,7 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, - "native_structured_output_strategy": "native", + "structured_output_strategy": "native", "use_native_structured_output": True, "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 6f371c5ba..0c2e665a6 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -27,7 +27,7 @@ def test_to_dict(self, config): "force_single_step": False, "use_native_tools": True, "use_native_structured_output": True, - "native_structured_output_strategy": "native", + "structured_output_strategy": "native", "extra_params": {}, }, "embedding_driver": { diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index dd2e1736b..f425913b5 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -19,7 +19,7 @@ def test_to_dict(self, config): "stream": False, "use_native_tools": False, "use_native_structured_output": False, - "native_structured_output_strategy": "native", + "structured_output_strategy": "native", "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index 569e45561..3c8ef0e0e 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -26,7 +26,7 @@ def test_to_dict(self, config): "tool_choice": "auto", "use_native_tools": True, "use_native_structured_output": True, - "native_structured_output_strategy": "tool", + "structured_output_strategy": "tool", "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index 603d9867a..bc9b02cd3 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -28,7 +28,7 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, - "native_structured_output_strategy": "native", + "structured_output_strategy": "native", "use_native_structured_output": True, "extra_params": {}, }, diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index a21690cd3..81c642814 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -414,7 +414,7 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, *self.BEDROCK_TOOLS, *( [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and driver.native_structured_output_strategy == "tool" + if use_native_structured_output and driver.structured_output_strategy == "tool" else [] ), ], @@ -471,7 +471,7 @@ def test_try_stream_run( *self.BEDROCK_TOOLS, *( [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and driver.native_structured_output_strategy == "tool" + if use_native_structured_output and driver.structured_output_strategy == "tool" else [] ), ], @@ -502,10 +502,10 @@ def test_try_stream_run( assert event.usage.input_tokens == 5 assert event.usage.output_tokens == 10 - def test_verify_native_structured_output_strategy(self): - assert AmazonBedrockPromptDriver(model="foo", native_structured_output_strategy="tool") + def test_verify_structured_output_strategy(self): + assert AmazonBedrockPromptDriver(model="foo", structured_output_strategy="tool") with pytest.raises( ValueError, match="AmazonBedrockPromptDriver does not support `native` structured output mode." ): - AmazonBedrockPromptDriver(model="foo", native_structured_output_strategy="native") + AmazonBedrockPromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index cc9179ae8..687db3b68 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -399,7 +399,7 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, us *self.ANTHROPIC_TOOLS, *( [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and driver.native_structured_output_strategy == "tool" + if use_native_structured_output and driver.structured_output_strategy == "tool" else [] ), ] @@ -456,7 +456,7 @@ def test_try_stream_run( *self.ANTHROPIC_TOOLS, *( [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and driver.native_structured_output_strategy == "tool" + if use_native_structured_output and driver.structured_output_strategy == "tool" else [] ), ] @@ -491,8 +491,8 @@ def test_try_stream_run( event = next(stream) assert event.usage.output_tokens == 10 - def test_verify_native_structured_output_strategy(self): - assert AnthropicPromptDriver(model="foo", native_structured_output_strategy="tool") + def test_verify_structured_output_strategy(self): + assert AnthropicPromptDriver(model="foo", structured_output_strategy="tool") with pytest.raises(ValueError, match="AnthropicPromptDriver does not support `native` structured output mode."): - AnthropicPromptDriver(model="foo", native_structured_output_strategy="native") + AnthropicPromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index 6ca7a423b..f7f153dd0 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -68,7 +68,7 @@ def test_init(self): @pytest.mark.parametrize("use_native_tools", [True, False]) @pytest.mark.parametrize("use_native_structured_output", [True, False]) - @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_run( self, mock_chat_completion_create, @@ -76,7 +76,7 @@ def test_try_run( messages, use_native_tools, use_native_structured_output, - native_structured_output_strategy, + structured_output_strategy, ): # Given driver = AzureOpenAiChatPromptDriver( @@ -85,7 +85,7 @@ def test_try_run( model="gpt-4", use_native_tools=use_native_tools, use_native_structured_output=use_native_structured_output, - native_structured_output_strategy=native_structured_output_strategy, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -103,12 +103,12 @@ def test_try_run( *self.OPENAI_TOOLS, *( [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and native_structured_output_strategy == "tool" + if use_native_structured_output and structured_output_strategy == "tool" else [] ), ], "tool_choice": "required" - if use_native_structured_output and native_structured_output_strategy == "tool" + if use_native_structured_output and structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools @@ -123,7 +123,7 @@ def test_try_run( }, } } - if use_native_structured_output and native_structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) @@ -137,7 +137,7 @@ def test_try_run( @pytest.mark.parametrize("use_native_tools", [True, False]) @pytest.mark.parametrize("use_native_structured_output", [True, False]) - @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_stream_run( self, mock_chat_completion_stream_create, @@ -145,7 +145,7 @@ def test_try_stream_run( messages, use_native_tools, use_native_structured_output, - native_structured_output_strategy, + structured_output_strategy, ): # Given driver = AzureOpenAiChatPromptDriver( @@ -155,7 +155,7 @@ def test_try_stream_run( stream=True, use_native_tools=use_native_tools, use_native_structured_output=use_native_structured_output, - native_structured_output_strategy=native_structured_output_strategy, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -175,12 +175,12 @@ def test_try_stream_run( *self.OPENAI_TOOLS, *( [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and native_structured_output_strategy == "tool" + if use_native_structured_output and structured_output_strategy == "tool" else [] ), ], "tool_choice": "required" - if use_native_structured_output and native_structured_output_strategy == "tool" + if use_native_structured_output and structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools @@ -195,7 +195,7 @@ def test_try_stream_run( }, } } - if use_native_structured_output and native_structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index bc0c51203..ad417cac5 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -339,7 +339,7 @@ def test_init(self): @pytest.mark.parametrize("use_native_tools", [True, False]) @pytest.mark.parametrize("use_native_structured_output", [True, False]) - @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, mock_client, @@ -347,7 +347,7 @@ def test_try_run( messages, use_native_tools, use_native_structured_output, - native_structured_output_strategy, + structured_output_strategy, ): # Given driver = CoherePromptDriver( @@ -355,7 +355,7 @@ def test_try_run( api_key="api-key", use_native_tools=use_native_tools, use_native_structured_output=use_native_structured_output, - native_structured_output_strategy=native_structured_output_strategy, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -372,7 +372,7 @@ def test_try_run( *self.COHERE_TOOLS, *( [self.COHERE_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and native_structured_output_strategy == "tool" + if use_native_structured_output and structured_output_strategy == "tool" else [] ), ] @@ -385,7 +385,7 @@ def test_try_run( "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if use_native_structured_output and native_structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, @@ -407,7 +407,7 @@ def test_try_run( @pytest.mark.parametrize("use_native_tools", [True, False]) @pytest.mark.parametrize("use_native_structured_output", [True, False]) - @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, mock_stream_client, @@ -415,7 +415,7 @@ def test_try_stream_run( messages, use_native_tools, use_native_structured_output, - native_structured_output_strategy, + structured_output_strategy, ): # Given driver = CoherePromptDriver( @@ -424,7 +424,7 @@ def test_try_stream_run( stream=True, use_native_tools=use_native_tools, use_native_structured_output=use_native_structured_output, - native_structured_output_strategy=native_structured_output_strategy, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -442,7 +442,7 @@ def test_try_stream_run( *self.COHERE_TOOLS, *( [self.COHERE_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and native_structured_output_strategy == "tool" + if use_native_structured_output and structured_output_strategy == "tool" else [] ), ] @@ -455,7 +455,7 @@ def test_try_stream_run( "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if use_native_structured_output and native_structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index e4a71d24e..a0b68a6af 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -189,7 +189,7 @@ def test_try_run( top_k=50, use_native_tools=use_native_tools, use_native_structured_output=use_native_structured_output, - native_structured_output_strategy="tool", + structured_output_strategy="tool", extra_params={"max_output_tokens": 10}, ) @@ -290,8 +290,8 @@ def test_try_stream( event = next(stream) assert event.usage.output_tokens == 5 - def test_verify_native_structured_output_strategy(self): - assert GooglePromptDriver(model="foo", native_structured_output_strategy="tool") + def test_verify_structured_output_strategy(self): + assert GooglePromptDriver(model="foo", structured_output_strategy="tool") with pytest.raises(ValueError, match="GooglePromptDriver does not support `native` structured output mode."): - GooglePromptDriver(model="foo", native_structured_output_strategy="native") + GooglePromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 24a83c07b..334c1649e 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -114,10 +114,10 @@ def test_try_stream(self, prompt_stack, mock_client_stream, use_native_structure assert event.usage.input_tokens == 3 assert event.usage.output_tokens == 3 - def test_verify_native_structured_output_strategy(self): - assert HuggingFaceHubPromptDriver(model="foo", api_token="bar", native_structured_output_strategy="native") + def test_verify_structured_output_strategy(self): + assert HuggingFaceHubPromptDriver(model="foo", api_token="bar", structured_output_strategy="native") with pytest.raises( ValueError, match="HuggingFaceHubPromptDriver does not support `tool` structured output mode." ): - HuggingFaceHubPromptDriver(model="foo", api_token="bar", native_structured_output_strategy="tool") + HuggingFaceHubPromptDriver(model="foo", api_token="bar", structured_output_strategy="tool") diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index bfc1111e0..cffcd3954 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -233,7 +233,7 @@ def test_init(self): @pytest.mark.parametrize("use_native_tools", [True, False]) @pytest.mark.parametrize("use_native_structured_output", [True, False]) - @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, mock_client, @@ -241,14 +241,14 @@ def test_try_run( messages, use_native_tools, use_native_structured_output, - native_structured_output_strategy, + structured_output_strategy, ): # Given driver = OllamaPromptDriver( model="llama", use_native_tools=use_native_tools, use_native_structured_output=use_native_structured_output, - native_structured_output_strategy=native_structured_output_strategy, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -269,7 +269,7 @@ def test_try_run( *self.OLLAMA_TOOLS, *( [self.OLLAMA_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and native_structured_output_strategy == "tool" + if use_native_structured_output and structured_output_strategy == "tool" else [] ), ] @@ -277,7 +277,7 @@ def test_try_run( if use_native_tools else {}, **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} - if use_native_structured_output and native_structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) @@ -291,7 +291,7 @@ def test_try_run( @pytest.mark.parametrize("use_native_tools", [True, False]) @pytest.mark.parametrize("use_native_structured_output", [True, False]) - @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, mock_stream_client, @@ -299,7 +299,7 @@ def test_try_stream_run( messages, use_native_tools, use_native_structured_output, - native_structured_output_strategy, + structured_output_strategy, ): # Given driver = OllamaPromptDriver( @@ -307,7 +307,7 @@ def test_try_stream_run( stream=True, use_native_tools=use_native_tools, use_native_structured_output=use_native_structured_output, - native_structured_output_strategy=native_structured_output_strategy, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -320,7 +320,7 @@ def test_try_stream_run( model=driver.model, options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} - if use_native_structured_output and native_structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, stream=True, foo="bar", diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 2b9c7e5b9..ed6085538 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -372,7 +372,7 @@ def test_init(self): @pytest.mark.parametrize("use_native_tools", [True, False]) @pytest.mark.parametrize("use_native_structured_output", [True, False]) - @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, mock_chat_completion_create, @@ -380,14 +380,14 @@ def test_try_run( messages, use_native_tools, use_native_structured_output, - native_structured_output_strategy, + structured_output_strategy, ): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, use_native_tools=use_native_tools, use_native_structured_output=use_native_structured_output, - native_structured_output_strategy=native_structured_output_strategy, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -406,12 +406,12 @@ def test_try_run( *self.OPENAI_TOOLS, *( [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and native_structured_output_strategy == "tool" + if use_native_structured_output and structured_output_strategy == "tool" else [] ), ], "tool_choice": "required" - if use_native_structured_output and native_structured_output_strategy == "tool" + if use_native_structured_output and structured_output_strategy == "tool" else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } @@ -427,7 +427,7 @@ def test_try_run( }, } } - if use_native_structured_output and native_structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) @@ -510,7 +510,7 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create, @pytest.mark.parametrize("use_native_tools", [True, False]) @pytest.mark.parametrize("use_native_structured_output", [True, False]) - @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, mock_chat_completion_stream_create, @@ -518,7 +518,7 @@ def test_try_stream_run( messages, use_native_tools, use_native_structured_output, - native_structured_output_strategy, + structured_output_strategy, ): # Given driver = OpenAiChatPromptDriver( @@ -526,7 +526,7 @@ def test_try_stream_run( stream=True, use_native_tools=use_native_tools, use_native_structured_output=use_native_structured_output, - native_structured_output_strategy=native_structured_output_strategy, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -548,12 +548,12 @@ def test_try_stream_run( *self.OPENAI_TOOLS, *( [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and native_structured_output_strategy == "tool" + if use_native_structured_output and structured_output_strategy == "tool" else [] ), ], "tool_choice": "required" - if use_native_structured_output and native_structured_output_strategy == "tool" + if use_native_structured_output and structured_output_strategy == "tool" else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } @@ -569,7 +569,7 @@ def test_try_stream_run( }, } } - if use_native_structured_output and native_structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 5921d9e28..3344644a3 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -84,7 +84,7 @@ def test_to_dict(self): "type": "MockPromptDriver", "use_native_tools": False, "use_native_structured_output": False, - "native_structured_output_strategy": "native", + "structured_output_strategy": "native", }, } ], diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index d7050c8f6..f3a18b1e2 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -257,7 +257,7 @@ def test_to_dict(self): "stream": False, "temperature": 0.1, "type": "MockPromptDriver", - "native_structured_output_strategy": "native", + "structured_output_strategy": "native", "use_native_structured_output": False, "use_native_tools": False, }, diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 2503f1174..082ccc466 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -400,7 +400,7 @@ def test_to_dict(self): "type": "MockPromptDriver", "use_native_tools": False, "use_native_structured_output": False, - "native_structured_output_strategy": "native", + "structured_output_strategy": "native", }, "tools": [ { From b526d251a4417d2a3870b02883ae3531741186f4 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 2 Jan 2025 15:37:57 -0800 Subject: [PATCH 04/11] Remove use_native_structured_output toggle --- CHANGELOG.md | 1 - .../drivers/prompt-drivers.md | 6 ++-- .../src/prompt_drivers_structured_output.py | 1 - .../prompt/amazon_bedrock_prompt_driver.py | 7 +--- .../drivers/prompt/anthropic_prompt_driver.py | 7 +--- griptape/drivers/prompt/base_prompt_driver.py | 1 - .../drivers/prompt/cohere_prompt_driver.py | 3 +- .../drivers/prompt/google_prompt_driver.py | 7 +--- .../prompt/huggingface_hub_prompt_driver.py | 7 +--- .../drivers/prompt/ollama_prompt_driver.py | 3 +- .../prompt/openai_chat_prompt_driver.py | 3 +- griptape/schemas/base_schema.py | 3 ++ griptape/tasks/prompt_task.py | 5 +-- tests/mocks/mock_prompt_driver.py | 12 ++++++- .../test_amazon_bedrock_drivers_config.py | 2 -- .../drivers/test_anthropic_drivers_config.py | 1 - .../test_azure_openai_drivers_config.py | 1 - .../drivers/test_cohere_drivers_config.py | 1 - .../configs/drivers/test_drivers_config.py | 1 - .../drivers/test_google_drivers_config.py | 1 - .../drivers/test_openai_driver_config.py | 1 - .../test_amazon_bedrock_prompt_driver.py | 22 ++++++------ .../prompt/test_anthropic_prompt_driver.py | 26 ++++---------- .../test_azure_openai_chat_prompt_driver.py | 30 ++++------------ .../prompt/test_cohere_prompt_driver.py | 22 +++--------- .../prompt/test_google_prompt_driver.py | 20 ++++------- .../test_hugging_face_hub_prompt_driver.py | 16 +++------ .../prompt/test_ollama_prompt_driver.py | 20 ++--------- .../prompt/test_openai_chat_prompt_driver.py | 34 +++++-------------- tests/unit/structures/test_structure.py | 1 - tests/unit/tasks/test_prompt_task.py | 5 +-- tests/unit/tasks/test_tool_task.py | 1 - tests/unit/tasks/test_toolkit_task.py | 1 - 33 files changed, 72 insertions(+), 200 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e9defcc5..c906896ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Structure.run_stream()` for streaming Events from a Structure as an iterator. - Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`. - Validators to `Agent` initialization. -- `BasePromptDriver.use_native_structured_output` for enabling or disabling structured output. - `BasePromptDriver.structured_output_strategy` for changing the structured output strategy between `native` and `tool`. ### Changed diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index c890ef179..0e8b8b9b9 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -29,9 +29,7 @@ You can pass images to the Driver if the model supports it: Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems. -Structured output can be enabled or disabled for a Prompt Driver by setting the [use_native_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_native_structured_output). - -If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of: +If an [output_schema](../../reference/griptape/tasks.md#griptape.tasks.PromptTask.output_schema) is provided to the Task, you can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of: - `native`: The Driver will use the LLM's structured output functionality provided by the API. - `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output_tool.md) and try to force the LLM to use a Tool. @@ -47,7 +45,7 @@ The easiest way to get started with structured output is by using a [JsonSchemaR ``` !!! warning - Not every LLM supports `use_native_structured_output` or all `structured_output_strategy` options. + Not every LLM supports all `structured_output_strategy` options. ## Prompt Drivers diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py index b1f801341..918725210 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py @@ -11,7 +11,6 @@ PromptTask( prompt_driver=OpenAiChatPromptDriver( model="gpt-4o", - use_native_structured_output=True, structured_output_strategy="native", ), output_schema=schema.Schema( diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index eefee0ff2..7a8c1b470 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -55,7 +55,6 @@ class AmazonBedrockPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -134,11 +133,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "toolChoice": self.tool_choice, } - if ( - prompt_stack.output_schema is not None - and self.use_native_structured_output - and self.structured_output_strategy == "tool" - ): + if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": self._add_structured_output_tool(prompt_stack) params["toolConfig"]["toolChoice"] = {"any": {}} diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 99053713a..48e8ac18b 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -68,7 +68,6 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -136,11 +135,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if prompt_stack.tools and self.use_native_tools: params["tool_choice"] = self.tool_choice - if ( - prompt_stack.output_schema is not None - and self.use_native_structured_output - and self.structured_output_strategy == "tool" - ): + if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": self._add_structured_output_tool(prompt_stack) params["tool_choice"] = {"type": "any"} diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 950c80cf8..d13a045c3 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -56,7 +56,6 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="native", kw_only=True, metadata={"serializable": True} ) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index a7121b440..c7438aa99 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -53,7 +53,6 @@ class CoherePromptDriver(BasePromptDriver): model: str = field(metadata={"serializable": True}) force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), @@ -112,7 +111,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema is not None and self.use_native_structured_output: + if prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_object", diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 29c43a91e..ff486167b 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -63,7 +63,6 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -164,11 +163,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if prompt_stack.tools and self.use_native_tools: params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} - if ( - prompt_stack.output_schema is not None - and self.use_native_structured_output - and self.structured_output_strategy == "tool" - ): + if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["tool_config"]["function_calling_config"]["mode"] = "auto" self._add_structured_output_tool(prompt_stack) diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 5b24f083b..62f463a1b 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -35,7 +35,6 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="native", kw_only=True, metadata={"serializable": True} ) @@ -121,11 +120,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if ( - prompt_stack.output_schema - and self.use_native_structured_output - and self.structured_output_strategy == "native" - ): + if prompt_stack.output_schema and self.structured_output_strategy == "native": # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding output_schema = prompt_stack.output_schema.json_schema("Output Schema") # Grammar does not support $schema and $id diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 295d926d1..734a73308 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -68,7 +68,6 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() @@ -110,7 +109,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema is not None and self.use_native_structured_output: + if prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": params["format"] = prompt_stack.output_schema.json_schema("Output") elif self.structured_output_strategy == "tool": diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 69e615585..56b1b3405 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -76,7 +76,6 @@ class OpenAiChatPromptDriver(BasePromptDriver): seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory( @@ -159,7 +158,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: params["tool_choice"] = self.tool_choice params["parallel_tool_calls"] = self.parallel_tool_calls - if prompt_stack.output_schema is not None and self.use_native_structured_output: + if prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_schema", diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 7b23c620f..4432c1080 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -151,6 +151,8 @@ def _resolve_types(cls, attrs_cls: type) -> None: from collections.abc import Sequence from typing import Any + from schema import Schema + from griptape.artifacts import BaseArtifact from griptape.common import ( BaseDeltaMessageContent, @@ -228,6 +230,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: if is_dependency_installed("mypy_boto3_bedrock") else Any, "voyageai": import_optional_dependency("voyageai") if is_dependency_installed("voyageai") else Any, + "Schema": Schema, }, ) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 4af20a6f9..276c2c229 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -190,10 +190,7 @@ def try_run(self) -> BaseArtifact: else: output = result.to_artifact() - if ( - self.prompt_driver.use_native_structured_output - and self.prompt_driver.structured_output_strategy == "native" - ): + if self.output_schema is not None and self.prompt_driver.structured_output_strategy == "native": return JsonArtifact(output.value) else: return output diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 1bdbfac7e..782c8ecd4 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -36,7 +36,7 @@ class MockPromptDriver(BasePromptDriver): def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if self.use_native_structured_output and prompt_stack.output_schema: + if prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": return Message( content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], @@ -84,6 +84,16 @@ def try_run(self, prompt_stack: PromptStack) -> Message: def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output + if prompt_stack.output_schema is not None: + if self.structured_output_strategy == "native": + yield DeltaMessage( + content=TextDeltaMessageContent(json.dumps(self.mock_structured_output)), + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=100, output_tokens=100), + ) + elif self.structured_output_strategy == "tool": + self._add_structured_output_tool(prompt_stack) + if self.use_native_tools and prompt_stack.tools: # Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer. action_messages = [ diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index 77c2631f3..b2fd51d24 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -51,7 +51,6 @@ def test_to_dict(self, config): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, - "use_native_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, @@ -108,7 +107,6 @@ def test_to_dict_with_values(self, config_with_values): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, - "use_native_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index f412e10cb..fa13480c1 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -26,7 +26,6 @@ def test_to_dict(self, config): "top_k": 250, "use_native_tools": True, "structured_output_strategy": "tool", - "use_native_structured_output": True, "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index 45fbfd6ab..a30cea001 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -37,7 +37,6 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, "structured_output_strategy": "native", - "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 0c2e665a6..94e258e36 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -26,7 +26,6 @@ def test_to_dict(self, config): "model": "command-r", "force_single_step": False, "use_native_tools": True, - "use_native_structured_output": True, "structured_output_strategy": "native", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index f425913b5..15646cc1d 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,7 +18,6 @@ def test_to_dict(self, config): "max_tokens": None, "stream": False, "use_native_tools": False, - "use_native_structured_output": False, "structured_output_strategy": "native", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index 3c8ef0e0e..910ae3240 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -25,7 +25,6 @@ def test_to_dict(self, config): "top_k": None, "tool_choice": "auto", "use_native_tools": True, - "use_native_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index bc9b02cd3..344d14d99 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -29,7 +29,6 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, "structured_output_strategy": "native", - "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index 81c642814..b31776f63 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -384,13 +384,11 @@ def messages(self): ] @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, use_native_structured_output): + def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): # Given driver = AmazonBedrockPromptDriver( model="ai21.j2", use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -414,11 +412,13 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, *self.BEDROCK_TOOLS, *( [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and driver.structured_output_strategy == "tool" + if driver.structured_output_strategy == "tool" else [] ), ], - "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, + "toolChoice": {"any": {}} + if driver.structured_output_strategy == "tool" + else driver.tool_choice, } } if use_native_tools @@ -437,16 +437,12 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_stream_run( - self, mock_converse_stream, prompt_stack, messages, use_native_tools, use_native_structured_output - ): + def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_native_tools): # Given driver = AmazonBedrockPromptDriver( model="ai21.j2", stream=True, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -471,11 +467,13 @@ def test_try_stream_run( *self.BEDROCK_TOOLS, *( [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and driver.structured_output_strategy == "tool" + if driver.structured_output_strategy == "tool" else [] ), ], - "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, + "toolChoice": {"any": {}} + if driver.structured_output_strategy == "tool" + else driver.tool_choice, } } if use_native_tools diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 687db3b68..147c69103 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -370,14 +370,12 @@ def test_init(self): assert AnthropicPromptDriver(model="claude-3-haiku", api_key="1234") @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, use_native_structured_output): + def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -397,15 +395,11 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, us **{ "tools": [ *self.ANTHROPIC_TOOLS, - *( - [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and driver.structured_output_strategy == "tool" - else [] - ), + *([self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] if driver.structured_output_strategy == "tool" else []), ] if use_native_tools else {}, - "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, + "tool_choice": {"type": "any"} if driver.structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, @@ -422,17 +416,13 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, us assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_stream_run( - self, mock_stream_client, prompt_stack, messages, use_native_tools, use_native_structured_output - ): + def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -454,15 +444,11 @@ def test_try_stream_run( **{ "tools": [ *self.ANTHROPIC_TOOLS, - *( - [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and driver.structured_output_strategy == "tool" - else [] - ), + *([self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] if driver.structured_output_strategy == "tool" else []), ] if use_native_tools else {}, - "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, + "tool_choice": {"type": "any"} if driver.structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index f7f153dd0..3c8d39475 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -67,7 +67,6 @@ def test_init(self): assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", model="gpt-4").azure_deployment == "gpt-4" @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_run( self, @@ -75,7 +74,6 @@ def test_try_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, structured_output_strategy, ): # Given @@ -84,7 +82,6 @@ def test_try_run( azure_deployment="deployment-id", model="gpt-4", use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -101,15 +98,9 @@ def test_try_run( **{ "tools": [ *self.OPENAI_TOOLS, - *( - [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" - else [] - ), + *([self.OPENAI_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), ], - "tool_choice": "required" - if use_native_structured_output and structured_output_strategy == "tool" - else driver.tool_choice, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, @@ -123,7 +114,7 @@ def test_try_run( }, } } - if use_native_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, foo="bar", ) @@ -136,7 +127,6 @@ def test_try_run( assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_stream_run( self, @@ -144,7 +134,6 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, structured_output_strategy, ): # Given @@ -154,7 +143,6 @@ def test_try_stream_run( model="gpt-4", stream=True, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -173,15 +161,9 @@ def test_try_stream_run( **{ "tools": [ *self.OPENAI_TOOLS, - *( - [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" - else [] - ), + *([self.OPENAI_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), ], - "tool_choice": "required" - if use_native_structured_output and structured_output_strategy == "tool" - else driver.tool_choice, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, @@ -195,7 +177,7 @@ def test_try_stream_run( }, } } - if use_native_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index ad417cac5..17e9251d3 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -338,7 +338,6 @@ def test_init(self): assert CoherePromptDriver(model="command", api_key="foobar") @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -346,7 +345,6 @@ def test_try_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, structured_output_strategy, ): # Given @@ -354,7 +352,6 @@ def test_try_run( model="command", api_key="api-key", use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -370,11 +367,7 @@ def test_try_run( **{ "tools": [ *self.COHERE_TOOLS, - *( - [self.COHERE_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" - else [] - ), + *([self.COHERE_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), ] } if use_native_tools @@ -385,7 +378,7 @@ def test_try_run( "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if use_native_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, @@ -406,7 +399,6 @@ def test_try_run( assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -414,7 +406,6 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, structured_output_strategy, ): # Given @@ -423,7 +414,6 @@ def test_try_stream_run( api_key="api-key", stream=True, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -440,11 +430,7 @@ def test_try_stream_run( **{ "tools": [ *self.COHERE_TOOLS, - *( - [self.COHERE_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" - else [] - ), + *([self.COHERE_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), ] } if use_native_tools @@ -455,7 +441,7 @@ def test_try_stream_run( "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if use_native_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index a0b68a6af..cc17de3c1 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -177,10 +177,7 @@ def test_init(self): assert driver @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_run( - self, mock_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output - ): + def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -188,7 +185,6 @@ def test_try_run( top_p=0.5, top_k=50, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, structured_output_strategy="tool", extra_params={"max_output_tokens": 10}, ) @@ -213,11 +209,11 @@ def test_try_run( tool_declarations = call_args.kwargs["tools"] tools = [ *self.GOOGLE_TOOLS, - *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if driver.structured_output_strategy == "tool" else []), ] assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools - if use_native_structured_output: + if driver.structured_output_strategy == "tool": assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(message.value[0], TextArtifact) @@ -231,10 +227,7 @@ def test_try_run( assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_stream( - self, mock_stream_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output - ): + def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, use_native_tools): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -243,7 +236,6 @@ def test_try_stream( top_p=0.5, top_k=50, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, extra_params={"max_output_tokens": 10}, ) @@ -269,11 +261,11 @@ def test_try_stream( tool_declarations = call_args.kwargs["tools"] tools = [ *self.GOOGLE_TOOLS, - *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if driver.structured_output_strategy == "tool" else []), ] assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools - if use_native_structured_output: + if driver.structured_output_strategy == "tool": assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 334c1649e..a65befbce 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -54,13 +54,11 @@ def mock_autotokenizer(self, mocker): def test_init(self): assert HuggingFaceHubPromptDriver(api_token="foobar", model="gpt2") - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_run(self, prompt_stack, mock_client, use_native_structured_output): + def test_try_run(self, prompt_stack, mock_client): # Given driver = HuggingFaceHubPromptDriver( api_token="api-token", model="repo-id", - use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -73,22 +71,18 @@ def test_try_run(self, prompt_stack, mock_client, use_native_structured_output): return_full_text=False, max_new_tokens=250, foo="bar", - **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} - if use_native_structured_output - else {}, + grammar={"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}, ) assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_stream(self, prompt_stack, mock_client_stream, use_native_structured_output): + def test_try_stream(self, prompt_stack, mock_client_stream): # Given driver = HuggingFaceHubPromptDriver( api_token="api-token", model="repo-id", stream=True, - use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -102,9 +96,7 @@ def test_try_stream(self, prompt_stack, mock_client_stream, use_native_structure return_full_text=False, max_new_tokens=250, foo="bar", - **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} - if use_native_structured_output - else {}, + grammar={"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}, stream=True, ) assert isinstance(event.content, TextDeltaMessageContent) diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index cffcd3954..46c3ef4af 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -232,7 +232,6 @@ def test_init(self): assert OllamaPromptDriver(model="llama") @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -240,14 +239,12 @@ def test_try_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, structured_output_strategy, ): # Given driver = OllamaPromptDriver( model="llama", use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -267,18 +264,12 @@ def test_try_run( **{ "tools": [ *self.OLLAMA_TOOLS, - *( - [self.OLLAMA_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" - else [] - ), + *([self.OLLAMA_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), ] } if use_native_tools else {}, - **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} - if use_native_structured_output and structured_output_strategy == "native" - else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} if structured_output_strategy == "native" else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -290,7 +281,6 @@ def test_try_run( assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -298,7 +288,6 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, structured_output_strategy, ): # Given @@ -306,7 +295,6 @@ def test_try_stream_run( model="llama", stream=True, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -319,9 +307,7 @@ def test_try_stream_run( messages=messages, model=driver.model, options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, - **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} - if use_native_structured_output and structured_output_strategy == "native" - else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} if structured_output_strategy == "native" else {}, stream=True, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index ed6085538..44c3ecba4 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -371,7 +371,6 @@ def test_init(self): assert OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_4_MODEL) @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -379,14 +378,12 @@ def test_try_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, structured_output_strategy, ): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -404,15 +401,9 @@ def test_try_run( **{ "tools": [ *self.OPENAI_TOOLS, - *( - [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" - else [] - ), + *([self.OPENAI_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), ], - "tool_choice": "required" - if use_native_structured_output and structured_output_strategy == "tool" - else driver.tool_choice, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools @@ -427,7 +418,7 @@ def test_try_run( }, } } - if use_native_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, foo="bar", ) @@ -509,7 +500,6 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create, assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -517,7 +507,6 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, structured_output_strategy, ): # Given @@ -525,7 +514,6 @@ def test_try_stream_run( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -546,15 +534,9 @@ def test_try_stream_run( **{ "tools": [ *self.OPENAI_TOOLS, - *( - [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" - else [] - ), + *([self.OPENAI_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), ], - "tool_choice": "required" - if use_native_structured_output and structured_output_strategy == "tool" - else driver.tool_choice, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools @@ -569,7 +551,7 @@ def test_try_stream_run( }, } } - if use_native_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, foo="bar", ) @@ -596,11 +578,11 @@ def test_try_stream_run( def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): # Given + prompt_stack.output_schema = None driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1, use_native_tools=False, - use_native_structured_output=False, ) # When @@ -630,12 +612,12 @@ def test_try_run_throws_when_multiple_choices_returned(self, mock_chat_completio assert e.value.args[0] == "Completion with more than one choice is not supported yet." def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): + prompt_stack.output_schema = None driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=MockTokenizer(model="mock-model", stop_sequences=["mock-stop"]), max_tokens=1, use_native_tools=False, - use_native_structured_output=False, ) # When diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 3344644a3..34471fb39 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -83,7 +83,6 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, - "use_native_structured_output": False, "structured_output_strategy": "native", }, } diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index e4d3060a5..fba790470 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -183,7 +183,6 @@ def test_prompt_stack_native_schema(self): task = PromptTask( input="foo", prompt_driver=MockPromptDriver( - use_native_structured_output=True, mock_structured_output={"baz": "foo"}, ), output_schema=output_schema, @@ -205,9 +204,7 @@ def test_prompt_stack_native_schema(self): def test_prompt_stack_empty_native_schema(self): task = PromptTask( input="foo", - prompt_driver=MockPromptDriver( - use_native_structured_output=True, - ), + prompt_driver=MockPromptDriver(), rules=[JsonSchemaRule({"foo": {}})], ) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index f3a18b1e2..ba419480d 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -258,7 +258,6 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "structured_output_strategy": "native", - "use_native_structured_output": False, "use_native_tools": False, }, "tool": { diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 082ccc466..3a7476596 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -399,7 +399,6 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, - "use_native_structured_output": False, "structured_output_strategy": "native", }, "tools": [ From b7137ac52dc34e1e74acd2cbd33e62d7f4afecba Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 2 Jan 2025 15:41:27 -0800 Subject: [PATCH 05/11] Update dogs --- docs/griptape-framework/drivers/prompt-drivers.md | 8 ++++---- .../drivers/src/prompt_drivers_structured_output.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 0e8b8b9b9..dfc2c8b56 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -29,16 +29,16 @@ You can pass images to the Driver if the model supports it: Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems. -If an [output_schema](../../reference/griptape/tasks.md#griptape.tasks.PromptTask.output_schema) is provided to the Task, you can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of: +You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of: - `native`: The Driver will use the LLM's structured output functionality provided by the API. -- `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output_tool.md) and try to force the LLM to use a Tool. +- `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md) and try to force the LLM to use a Tool. Each Driver may have a different default setting depending on the LLM provider's capabilities. -### JSON Schema +### Prompt Task -The easiest way to get started with structured output is by using a [JsonSchemaRule](../structures/rulesets.md#json-schema). If a [schema.Schema](https://pypi.org/project/schema/) instance is provided to the Rule, Griptape will convert it to a JSON Schema and provide it to the LLM using the selected structured output strategy. +The easiest way to get started with structured output is by using a [PromptTask](../structures/tasks.md#prompt)'s [output_schema](../../reference/griptape/tasks/prompt_task.md#griptape.tasks.PromptTask.output_schema) parameter. ```python --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py" diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py index 918725210..cb7eb5ceb 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py @@ -11,7 +11,7 @@ PromptTask( prompt_driver=OpenAiChatPromptDriver( model="gpt-4o", - structured_output_strategy="native", + structured_output_strategy="native", # optional ), output_schema=schema.Schema( { From 1f5ab443d00957c28c43811b5530b75044a1c432 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 2 Jan 2025 16:05:38 -0800 Subject: [PATCH 06/11] Revert removal of use_native_structured_output, add fallback to JsonSchemaRule --- CHANGELOG.md | 1 + .../drivers/prompt-drivers.md | 8 +++-- .../src/prompt_drivers_structured_output.py | 1 + .../griptape-framework/structures/rulesets.md | 4 +++ .../prompt/amazon_bedrock_prompt_driver.py | 7 +++- .../drivers/prompt/anthropic_prompt_driver.py | 7 +++- griptape/drivers/prompt/base_prompt_driver.py | 1 + .../drivers/prompt/cohere_prompt_driver.py | 3 +- .../drivers/prompt/google_prompt_driver.py | 7 +++- .../prompt/huggingface_hub_prompt_driver.py | 7 +++- .../drivers/prompt/ollama_prompt_driver.py | 3 +- .../prompt/openai_chat_prompt_driver.py | 3 +- griptape/tasks/prompt_task.py | 11 +++++- .../templates/tasks/prompt_task/system.j2 | 4 +++ tests/mocks/mock_prompt_driver.py | 4 +-- .../test_amazon_bedrock_drivers_config.py | 2 ++ .../drivers/test_anthropic_drivers_config.py | 1 + .../test_azure_openai_drivers_config.py | 1 + .../drivers/test_cohere_drivers_config.py | 1 + .../configs/drivers/test_drivers_config.py | 1 + .../drivers/test_google_drivers_config.py | 1 + .../drivers/test_openai_driver_config.py | 1 + .../test_amazon_bedrock_prompt_driver.py | 22 ++++++------ .../prompt/test_anthropic_prompt_driver.py | 26 ++++++++++---- .../test_azure_openai_chat_prompt_driver.py | 30 ++++++++++++---- .../prompt/test_cohere_prompt_driver.py | 22 +++++++++--- .../prompt/test_google_prompt_driver.py | 20 +++++++---- .../test_hugging_face_hub_prompt_driver.py | 16 ++++++--- .../prompt/test_ollama_prompt_driver.py | 20 +++++++++-- .../prompt/test_openai_chat_prompt_driver.py | 34 ++++++++++++++----- tests/unit/structures/test_structure.py | 1 + tests/unit/tasks/test_prompt_task.py | 5 ++- tests/unit/tasks/test_tool_task.py | 1 + tests/unit/tasks/test_toolkit_task.py | 1 + 34 files changed, 217 insertions(+), 60 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c906896ed..6e9defcc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Structure.run_stream()` for streaming Events from a Structure as an iterator. - Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`. - Validators to `Agent` initialization. +- `BasePromptDriver.use_native_structured_output` for enabling or disabling structured output. - `BasePromptDriver.structured_output_strategy` for changing the structured output strategy between `native` and `tool`. ### Changed diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index dfc2c8b56..aede7fe01 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -29,7 +29,9 @@ You can pass images to the Driver if the model supports it: Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems. -You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of: +Structured output can be enabled or disabled for a Prompt Driver by setting the [use_native_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_native_structured_output). + +If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of: - `native`: The Driver will use the LLM's structured output functionality provided by the API. - `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md) and try to force the LLM to use a Tool. @@ -44,8 +46,10 @@ The easiest way to get started with structured output is by using a [PromptTask] --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py" ``` +If `use_native_structured_output=False`, the Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. + !!! warning - Not every LLM supports all `structured_output_strategy` options. + Not every LLM supports `use_native_structured_output` or all `structured_output_strategy` options. ## Prompt Drivers diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py index cb7eb5ceb..adc7ea7ad 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py @@ -11,6 +11,7 @@ PromptTask( prompt_driver=OpenAiChatPromptDriver( model="gpt-4o", + use_native_structured_output=True, # optional structured_output_strategy="native", # optional ), output_schema=schema.Schema( diff --git a/docs/griptape-framework/structures/rulesets.md b/docs/griptape-framework/structures/rulesets.md index f7a1de482..93e5a4c2b 100644 --- a/docs/griptape-framework/structures/rulesets.md +++ b/docs/griptape-framework/structures/rulesets.md @@ -26,6 +26,10 @@ A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Ru ### Json Schema +!!! tip + [Structured Output](../drivers/prompt-drivers.md#structured-output) provides a more robust solution for having the LLM generate structured output. + And if an LLM does not natively support structured output, a `JsonSchemaRule` will automatically be added. + [JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema. This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types. diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 7a8c1b470..eefee0ff2 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -55,6 +55,7 @@ class AmazonBedrockPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -133,7 +134,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "toolChoice": self.tool_choice, } - if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.structured_output_strategy == "tool" + ): self._add_structured_output_tool(prompt_stack) params["toolConfig"]["toolChoice"] = {"any": {}} diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 48e8ac18b..99053713a 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -68,6 +68,7 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -135,7 +136,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if prompt_stack.tools and self.use_native_tools: params["tool_choice"] = self.tool_choice - if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.structured_output_strategy == "tool" + ): self._add_structured_output_tool(prompt_stack) params["tool_choice"] = {"type": "any"} diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index d13a045c3..950c80cf8 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -56,6 +56,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="native", kw_only=True, metadata={"serializable": True} ) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index c7438aa99..a7121b440 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -53,6 +53,7 @@ class CoherePromptDriver(BasePromptDriver): model: str = field(metadata={"serializable": True}) force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), @@ -111,7 +112,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema is not None: + if prompt_stack.output_schema is not None and self.use_native_structured_output: if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_object", diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index ff486167b..29c43a91e 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -63,6 +63,7 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -163,7 +164,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if prompt_stack.tools and self.use_native_tools: params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} - if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.structured_output_strategy == "tool" + ): params["tool_config"]["function_calling_config"]["mode"] = "auto" self._add_structured_output_tool(prompt_stack) diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 62f463a1b..5b24f083b 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -35,6 +35,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="native", kw_only=True, metadata={"serializable": True} ) @@ -120,7 +121,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema and self.structured_output_strategy == "native": + if ( + prompt_stack.output_schema + and self.use_native_structured_output + and self.structured_output_strategy == "native" + ): # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding output_schema = prompt_stack.output_schema.json_schema("Output Schema") # Grammar does not support $schema and $id diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 734a73308..295d926d1 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -68,6 +68,7 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() @@ -109,7 +110,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema is not None: + if prompt_stack.output_schema is not None and self.use_native_structured_output: if self.structured_output_strategy == "native": params["format"] = prompt_stack.output_schema.json_schema("Output") elif self.structured_output_strategy == "tool": diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 56b1b3405..69e615585 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -76,6 +76,7 @@ class OpenAiChatPromptDriver(BasePromptDriver): seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory( @@ -158,7 +159,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: params["tool_choice"] = self.tool_choice params["parallel_tool_calls"] = self.parallel_tool_calls - if prompt_stack.output_schema is not None: + if prompt_stack.output_schema is not None and self.use_native_structured_output: if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_schema", diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 276c2c229..cd00ec574 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -190,7 +190,10 @@ def try_run(self) -> BaseArtifact: else: output = result.to_artifact() - if self.output_schema is not None and self.prompt_driver.structured_output_strategy == "native": + if ( + self.prompt_driver.use_native_structured_output + and self.prompt_driver.structured_output_strategy == "native" + ): return JsonArtifact(output.value) else: return output @@ -210,6 +213,8 @@ def preprocess(self, structure: Structure) -> BaseTask: return self def default_generate_system_template(self, _: PromptTask) -> str: + from griptape.rules import JsonSchemaRule + schema = self.actions_schema().json_schema("Actions Schema") schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually. @@ -219,6 +224,10 @@ def default_generate_system_template(self, _: PromptTask) -> str: actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), use_native_tools=self.prompt_driver.use_native_tools, + use_native_structured_output=self.prompt_driver.use_native_structured_output, + json_schema_rule=JsonSchemaRule(self.output_schema.json_schema("Output Schema")) + if self.output_schema is not None + else None, stop_sequence=self.response_stop_sequence, ) diff --git a/griptape/templates/tasks/prompt_task/system.j2 b/griptape/templates/tasks/prompt_task/system.j2 index b262e7c72..4dcd34ee5 100644 --- a/griptape/templates/tasks/prompt_task/system.j2 +++ b/griptape/templates/tasks/prompt_task/system.j2 @@ -26,3 +26,7 @@ NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER {{ rulesets }} {% endif %} +{% if not use_native_structured_output and json_schema_rule %} + +{{ json_schema_rule }} +{% endif %} diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 782c8ecd4..01824af06 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -36,7 +36,7 @@ class MockPromptDriver(BasePromptDriver): def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if prompt_stack.output_schema is not None: + if self.use_native_structured_output and prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": return Message( content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], @@ -84,7 +84,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if prompt_stack.output_schema is not None: + if self.use_native_structured_output and prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": yield DeltaMessage( content=TextDeltaMessageContent(json.dumps(self.mock_structured_output)), diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index b2fd51d24..77c2631f3 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -51,6 +51,7 @@ def test_to_dict(self, config): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "use_native_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, @@ -107,6 +108,7 @@ def test_to_dict_with_values(self, config_with_values): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "use_native_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index fa13480c1..f412e10cb 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -26,6 +26,7 @@ def test_to_dict(self, config): "top_k": 250, "use_native_tools": True, "structured_output_strategy": "tool", + "use_native_structured_output": True, "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index a30cea001..45fbfd6ab 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -37,6 +37,7 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, "structured_output_strategy": "native", + "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 94e258e36..0c2e665a6 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -26,6 +26,7 @@ def test_to_dict(self, config): "model": "command-r", "force_single_step": False, "use_native_tools": True, + "use_native_structured_output": True, "structured_output_strategy": "native", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index 15646cc1d..f425913b5 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,6 +18,7 @@ def test_to_dict(self, config): "max_tokens": None, "stream": False, "use_native_tools": False, + "use_native_structured_output": False, "structured_output_strategy": "native", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index 910ae3240..3c8ef0e0e 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -25,6 +25,7 @@ def test_to_dict(self, config): "top_k": None, "tool_choice": "auto", "use_native_tools": True, + "use_native_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index 344d14d99..bc9b02cd3 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -29,6 +29,7 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, "structured_output_strategy": "native", + "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index b31776f63..81c642814 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -384,11 +384,13 @@ def messages(self): ] @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, use_native_structured_output): # Given driver = AmazonBedrockPromptDriver( model="ai21.j2", use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -412,13 +414,11 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): *self.BEDROCK_TOOLS, *( [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if driver.structured_output_strategy == "tool" + if use_native_structured_output and driver.structured_output_strategy == "tool" else [] ), ], - "toolChoice": {"any": {}} - if driver.structured_output_strategy == "tool" - else driver.tool_choice, + "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, } } if use_native_tools @@ -437,12 +437,16 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream_run( + self, mock_converse_stream, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = AmazonBedrockPromptDriver( model="ai21.j2", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -467,13 +471,11 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_ *self.BEDROCK_TOOLS, *( [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if driver.structured_output_strategy == "tool" + if use_native_structured_output and driver.structured_output_strategy == "tool" else [] ), ], - "toolChoice": {"any": {}} - if driver.structured_output_strategy == "tool" - else driver.tool_choice, + "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, } } if use_native_tools diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 147c69103..687db3b68 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -370,12 +370,14 @@ def test_init(self): assert AnthropicPromptDriver(model="claude-3-haiku", api_key="1234") @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, use_native_structured_output): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -395,11 +397,15 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): **{ "tools": [ *self.ANTHROPIC_TOOLS, - *([self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] if driver.structured_output_strategy == "tool" else []), + *( + [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.structured_output_strategy == "tool" + else [] + ), ] if use_native_tools else {}, - "tool_choice": {"type": "any"} if driver.structured_output_strategy == "tool" else driver.tool_choice, + "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, } if use_native_tools else {}, @@ -416,13 +422,17 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream_run( + self, mock_stream_client, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -444,11 +454,15 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na **{ "tools": [ *self.ANTHROPIC_TOOLS, - *([self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] if driver.structured_output_strategy == "tool" else []), + *( + [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.structured_output_strategy == "tool" + else [] + ), ] if use_native_tools else {}, - "tool_choice": {"type": "any"} if driver.structured_output_strategy == "tool" else driver.tool_choice, + "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, } if use_native_tools else {}, diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index 3c8d39475..f7f153dd0 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -67,6 +67,7 @@ def test_init(self): assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", model="gpt-4").azure_deployment == "gpt-4" @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_run( self, @@ -74,6 +75,7 @@ def test_try_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given @@ -82,6 +84,7 @@ def test_try_run( azure_deployment="deployment-id", model="gpt-4", use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -98,9 +101,15 @@ def test_try_run( **{ "tools": [ *self.OPENAI_TOOLS, - *([self.OPENAI_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ], - "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, + "tool_choice": "required" + if use_native_structured_output and structured_output_strategy == "tool" + else driver.tool_choice, } if use_native_tools else {}, @@ -114,7 +123,7 @@ def test_try_run( }, } } - if structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) @@ -127,6 +136,7 @@ def test_try_run( assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_stream_run( self, @@ -134,6 +144,7 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given @@ -143,6 +154,7 @@ def test_try_stream_run( model="gpt-4", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -161,9 +173,15 @@ def test_try_stream_run( **{ "tools": [ *self.OPENAI_TOOLS, - *([self.OPENAI_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ], - "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, + "tool_choice": "required" + if use_native_structured_output and structured_output_strategy == "tool" + else driver.tool_choice, } if use_native_tools else {}, @@ -177,7 +195,7 @@ def test_try_stream_run( }, } } - if structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 17e9251d3..ad417cac5 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -338,6 +338,7 @@ def test_init(self): assert CoherePromptDriver(model="command", api_key="foobar") @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -345,6 +346,7 @@ def test_try_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given @@ -352,6 +354,7 @@ def test_try_run( model="command", api_key="api-key", use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -367,7 +370,11 @@ def test_try_run( **{ "tools": [ *self.COHERE_TOOLS, - *([self.COHERE_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.COHERE_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ] } if use_native_tools @@ -378,7 +385,7 @@ def test_try_run( "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, @@ -399,6 +406,7 @@ def test_try_run( assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -406,6 +414,7 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given @@ -414,6 +423,7 @@ def test_try_stream_run( api_key="api-key", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -430,7 +440,11 @@ def test_try_stream_run( **{ "tools": [ *self.COHERE_TOOLS, - *([self.COHERE_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.COHERE_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ] } if use_native_tools @@ -441,7 +455,7 @@ def test_try_stream_run( "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index cc17de3c1..a0b68a6af 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -177,7 +177,10 @@ def test_init(self): assert driver @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run( + self, mock_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -185,6 +188,7 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native top_p=0.5, top_k=50, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy="tool", extra_params={"max_output_tokens": 10}, ) @@ -209,11 +213,11 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native tool_declarations = call_args.kwargs["tools"] tools = [ *self.GOOGLE_TOOLS, - *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if driver.structured_output_strategy == "tool" else []), + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), ] assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools - if driver.structured_output_strategy == "tool": + if use_native_structured_output: assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(message.value[0], TextArtifact) @@ -227,7 +231,10 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream( + self, mock_stream_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -236,6 +243,7 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, top_p=0.5, top_k=50, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"max_output_tokens": 10}, ) @@ -261,11 +269,11 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, tool_declarations = call_args.kwargs["tools"] tools = [ *self.GOOGLE_TOOLS, - *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if driver.structured_output_strategy == "tool" else []), + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), ] assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools - if driver.structured_output_strategy == "tool": + if use_native_structured_output: assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index a65befbce..334c1649e 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -54,11 +54,13 @@ def mock_autotokenizer(self, mocker): def test_init(self): assert HuggingFaceHubPromptDriver(api_token="foobar", model="gpt2") - def test_try_run(self, prompt_stack, mock_client): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, prompt_stack, mock_client, use_native_structured_output): # Given driver = HuggingFaceHubPromptDriver( api_token="api-token", model="repo-id", + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -71,18 +73,22 @@ def test_try_run(self, prompt_stack, mock_client): return_full_text=False, max_new_tokens=250, foo="bar", - grammar={"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}, + **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} + if use_native_structured_output + else {}, ) assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, prompt_stack, mock_client_stream): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream(self, prompt_stack, mock_client_stream, use_native_structured_output): # Given driver = HuggingFaceHubPromptDriver( api_token="api-token", model="repo-id", stream=True, + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -96,7 +102,9 @@ def test_try_stream(self, prompt_stack, mock_client_stream): return_full_text=False, max_new_tokens=250, foo="bar", - grammar={"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}, + **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} + if use_native_structured_output + else {}, stream=True, ) assert isinstance(event.content, TextDeltaMessageContent) diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index 46c3ef4af..cffcd3954 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -232,6 +232,7 @@ def test_init(self): assert OllamaPromptDriver(model="llama") @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -239,12 +240,14 @@ def test_try_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given driver = OllamaPromptDriver( model="llama", use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -264,12 +267,18 @@ def test_try_run( **{ "tools": [ *self.OLLAMA_TOOLS, - *([self.OLLAMA_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.OLLAMA_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ] } if use_native_tools else {}, - **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} if structured_output_strategy == "native" else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} + if use_native_structured_output and structured_output_strategy == "native" + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -281,6 +290,7 @@ def test_try_run( assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -288,6 +298,7 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given @@ -295,6 +306,7 @@ def test_try_stream_run( model="llama", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -307,7 +319,9 @@ def test_try_stream_run( messages=messages, model=driver.model, options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, - **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} if structured_output_strategy == "native" else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} + if use_native_structured_output and structured_output_strategy == "native" + else {}, stream=True, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 44c3ecba4..ed6085538 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -371,6 +371,7 @@ def test_init(self): assert OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_4_MODEL) @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -378,12 +379,14 @@ def test_try_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -401,9 +404,15 @@ def test_try_run( **{ "tools": [ *self.OPENAI_TOOLS, - *([self.OPENAI_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ], - "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, + "tool_choice": "required" + if use_native_structured_output and structured_output_strategy == "tool" + else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools @@ -418,7 +427,7 @@ def test_try_run( }, } } - if structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) @@ -500,6 +509,7 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create, assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -507,6 +517,7 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given @@ -514,6 +525,7 @@ def test_try_stream_run( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -534,9 +546,15 @@ def test_try_stream_run( **{ "tools": [ *self.OPENAI_TOOLS, - *([self.OPENAI_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ], - "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, + "tool_choice": "required" + if use_native_structured_output and structured_output_strategy == "tool" + else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools @@ -551,7 +569,7 @@ def test_try_stream_run( }, } } - if structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) @@ -578,11 +596,11 @@ def test_try_stream_run( def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): # Given - prompt_stack.output_schema = None driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1, use_native_tools=False, + use_native_structured_output=False, ) # When @@ -612,12 +630,12 @@ def test_try_run_throws_when_multiple_choices_returned(self, mock_chat_completio assert e.value.args[0] == "Completion with more than one choice is not supported yet." def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): - prompt_stack.output_schema = None driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=MockTokenizer(model="mock-model", stop_sequences=["mock-stop"]), max_tokens=1, use_native_tools=False, + use_native_structured_output=False, ) # When diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 34471fb39..3344644a3 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -83,6 +83,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "use_native_structured_output": False, "structured_output_strategy": "native", }, } diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index fba790470..e4d3060a5 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -183,6 +183,7 @@ def test_prompt_stack_native_schema(self): task = PromptTask( input="foo", prompt_driver=MockPromptDriver( + use_native_structured_output=True, mock_structured_output={"baz": "foo"}, ), output_schema=output_schema, @@ -204,7 +205,9 @@ def test_prompt_stack_native_schema(self): def test_prompt_stack_empty_native_schema(self): task = PromptTask( input="foo", - prompt_driver=MockPromptDriver(), + prompt_driver=MockPromptDriver( + use_native_structured_output=True, + ), rules=[JsonSchemaRule({"foo": {}})], ) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index ba419480d..f3a18b1e2 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -258,6 +258,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "structured_output_strategy": "native", + "use_native_structured_output": False, "use_native_tools": False, }, "tool": { diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 3a7476596..082ccc466 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -399,6 +399,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "use_native_structured_output": False, "structured_output_strategy": "native", }, "tools": [ From 06e0e4cf92ae2ea71b157a9ee95d1515fa4ac413 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 3 Jan 2025 10:06:57 -0800 Subject: [PATCH 07/11] Drop "native" from structured output fields --- CHANGELOG.md | 2 +- .../drivers/prompt-drivers.md | 8 +++--- .../src/prompt_drivers_structured_output.py | 2 +- .../prompt/amazon_bedrock_prompt_driver.py | 4 +-- .../drivers/prompt/anthropic_prompt_driver.py | 4 +-- griptape/drivers/prompt/base_prompt_driver.py | 2 +- .../drivers/prompt/cohere_prompt_driver.py | 4 +-- .../drivers/prompt/google_prompt_driver.py | 4 +-- .../prompt/huggingface_hub_prompt_driver.py | 8 ++---- .../drivers/prompt/ollama_prompt_driver.py | 4 +-- .../prompt/openai_chat_prompt_driver.py | 4 +-- griptape/tasks/prompt_task.py | 7 ++--- .../templates/tasks/prompt_task/system.j2 | 2 +- tests/mocks/mock_prompt_driver.py | 4 +-- .../test_amazon_bedrock_drivers_config.py | 4 +-- .../drivers/test_anthropic_drivers_config.py | 2 +- .../test_azure_openai_drivers_config.py | 2 +- .../drivers/test_cohere_drivers_config.py | 2 +- .../configs/drivers/test_drivers_config.py | 2 +- .../drivers/test_google_drivers_config.py | 2 +- .../drivers/test_openai_driver_config.py | 2 +- .../test_amazon_bedrock_prompt_driver.py | 20 ++++++------- .../prompt/test_anthropic_prompt_driver.py | 22 +++++++-------- .../test_azure_openai_chat_prompt_driver.py | 24 ++++++++-------- .../prompt/test_cohere_prompt_driver.py | 20 ++++++------- .../prompt/test_google_prompt_driver.py | 22 +++++++-------- .../test_hugging_face_hub_prompt_driver.py | 16 +++++------ .../prompt/test_ollama_prompt_driver.py | 18 ++++++------ .../prompt/test_openai_chat_prompt_driver.py | 28 +++++++++---------- tests/unit/structures/test_structure.py | 2 +- tests/unit/tasks/test_prompt_task.py | 4 +-- tests/unit/tasks/test_tool_task.py | 2 +- tests/unit/tasks/test_toolkit_task.py | 2 +- 33 files changed, 122 insertions(+), 133 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e9defcc5..b20d4c280 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Structure.run_stream()` for streaming Events from a Structure as an iterator. - Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`. - Validators to `Agent` initialization. -- `BasePromptDriver.use_native_structured_output` for enabling or disabling structured output. +- `BasePromptDriver.use_structured_output` for enabling or disabling structured output. - `BasePromptDriver.structured_output_strategy` for changing the structured output strategy between `native` and `tool`. ### Changed diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index aede7fe01..22c3dd4ff 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -29,9 +29,9 @@ You can pass images to the Driver if the model supports it: Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems. -Structured output can be enabled or disabled for a Prompt Driver by setting the [use_native_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_native_structured_output). +Structured output can be enabled or disabled for a Prompt Driver by setting the [use_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_structured_output). -If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of: +If `use_structured_output=True`, you can change _how_ the output is structured by setting the [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of: - `native`: The Driver will use the LLM's structured output functionality provided by the API. - `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md) and try to force the LLM to use a Tool. @@ -46,10 +46,10 @@ The easiest way to get started with structured output is by using a [PromptTask] --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py" ``` -If `use_native_structured_output=False`, the Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. +If `use_structured_output=False`, the Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. !!! warning - Not every LLM supports `use_native_structured_output` or all `structured_output_strategy` options. + Not every LLM supports `use_structured_output` or all `structured_output_strategy` options. ## Prompt Drivers diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py index adc7ea7ad..8f5d0b77b 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py @@ -11,7 +11,7 @@ PromptTask( prompt_driver=OpenAiChatPromptDriver( model="gpt-4o", - use_native_structured_output=True, # optional + use_structured_output=True, # optional structured_output_strategy="native", # optional ), output_schema=schema.Schema( diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index eefee0ff2..ff370e2f9 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -55,7 +55,7 @@ class AmazonBedrockPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -136,7 +136,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if ( prompt_stack.output_schema is not None - and self.use_native_structured_output + and self.use_structured_output and self.structured_output_strategy == "tool" ): self._add_structured_output_tool(prompt_stack) diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 99053713a..17492e8d0 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -68,7 +68,7 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -138,7 +138,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if ( prompt_stack.output_schema is not None - and self.use_native_structured_output + and self.use_structured_output and self.structured_output_strategy == "tool" ): self._add_structured_output_tool(prompt_stack) diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 950c80cf8..b46be4822 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -56,7 +56,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + use_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="native", kw_only=True, metadata={"serializable": True} ) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index a7121b440..b7421381c 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -53,7 +53,7 @@ class CoherePromptDriver(BasePromptDriver): model: str = field(metadata={"serializable": True}) force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), @@ -112,7 +112,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema is not None and self.use_native_structured_output: + if prompt_stack.output_schema is not None and self.use_structured_output: if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_object", diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 29c43a91e..bf91a5b30 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -63,7 +63,7 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -166,7 +166,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if ( prompt_stack.output_schema is not None - and self.use_native_structured_output + and self.use_structured_output and self.structured_output_strategy == "tool" ): params["tool_config"]["function_calling_config"]["mode"] = "auto" diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 5b24f083b..e0a35048f 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -35,7 +35,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="native", kw_only=True, metadata={"serializable": True} ) @@ -121,11 +121,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if ( - prompt_stack.output_schema - and self.use_native_structured_output - and self.structured_output_strategy == "native" - ): + if prompt_stack.output_schema and self.use_structured_output and self.structured_output_strategy == "native": # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding output_schema = prompt_stack.output_schema.json_schema("Output Schema") # Grammar does not support $schema and $id diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 295d926d1..da7c51a9a 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -68,7 +68,7 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() @@ -110,7 +110,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema is not None and self.use_native_structured_output: + if prompt_stack.output_schema is not None and self.use_structured_output: if self.structured_output_strategy == "native": params["format"] = prompt_stack.output_schema.json_schema("Output") elif self.structured_output_strategy == "tool": diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 69e615585..aaf954da0 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -76,7 +76,7 @@ class OpenAiChatPromptDriver(BasePromptDriver): seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory( @@ -159,7 +159,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: params["tool_choice"] = self.tool_choice params["parallel_tool_calls"] = self.parallel_tool_calls - if prompt_stack.output_schema is not None and self.use_native_structured_output: + if prompt_stack.output_schema is not None and self.use_structured_output: if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_schema", diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index cd00ec574..ae80effcb 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -190,10 +190,7 @@ def try_run(self) -> BaseArtifact: else: output = result.to_artifact() - if ( - self.prompt_driver.use_native_structured_output - and self.prompt_driver.structured_output_strategy == "native" - ): + if self.prompt_driver.use_structured_output and self.prompt_driver.structured_output_strategy == "native": return JsonArtifact(output.value) else: return output @@ -224,7 +221,7 @@ def default_generate_system_template(self, _: PromptTask) -> str: actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), use_native_tools=self.prompt_driver.use_native_tools, - use_native_structured_output=self.prompt_driver.use_native_structured_output, + use_structured_output=self.prompt_driver.use_structured_output, json_schema_rule=JsonSchemaRule(self.output_schema.json_schema("Output Schema")) if self.output_schema is not None else None, diff --git a/griptape/templates/tasks/prompt_task/system.j2 b/griptape/templates/tasks/prompt_task/system.j2 index 4dcd34ee5..e1a8bb21b 100644 --- a/griptape/templates/tasks/prompt_task/system.j2 +++ b/griptape/templates/tasks/prompt_task/system.j2 @@ -26,7 +26,7 @@ NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER {{ rulesets }} {% endif %} -{% if not use_native_structured_output and json_schema_rule %} +{% if not use_structured_output and json_schema_rule %} {{ json_schema_rule }} {% endif %} diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 01824af06..243b29281 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -36,7 +36,7 @@ class MockPromptDriver(BasePromptDriver): def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if self.use_native_structured_output and prompt_stack.output_schema is not None: + if self.use_structured_output and prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": return Message( content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], @@ -84,7 +84,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if self.use_native_structured_output and prompt_stack.output_schema is not None: + if self.use_structured_output and prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": yield DeltaMessage( content=TextDeltaMessageContent(json.dumps(self.mock_structured_output)), diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index 77c2631f3..d9a4f4cb3 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -51,7 +51,7 @@ def test_to_dict(self, config): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, - "use_native_structured_output": True, + "use_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, @@ -108,7 +108,7 @@ def test_to_dict_with_values(self, config_with_values): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, - "use_native_structured_output": True, + "use_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index f412e10cb..1df66b534 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -26,7 +26,7 @@ def test_to_dict(self, config): "top_k": 250, "use_native_tools": True, "structured_output_strategy": "tool", - "use_native_structured_output": True, + "use_structured_output": True, "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index 45fbfd6ab..c63f8bdbc 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -37,7 +37,7 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, "structured_output_strategy": "native", - "use_native_structured_output": True, + "use_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 0c2e665a6..11a39ba4c 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -26,7 +26,7 @@ def test_to_dict(self, config): "model": "command-r", "force_single_step": False, "use_native_tools": True, - "use_native_structured_output": True, + "use_structured_output": True, "structured_output_strategy": "native", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index f425913b5..fa8c07c8c 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,7 +18,7 @@ def test_to_dict(self, config): "max_tokens": None, "stream": False, "use_native_tools": False, - "use_native_structured_output": False, + "use_structured_output": False, "structured_output_strategy": "native", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index 3c8ef0e0e..1f53ae59f 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -25,7 +25,7 @@ def test_to_dict(self, config): "top_k": None, "tool_choice": "auto", "use_native_tools": True, - "use_native_structured_output": True, + "use_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index bc9b02cd3..a77f9ab46 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -29,7 +29,7 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, "structured_output_strategy": "native", - "use_native_structured_output": True, + "use_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index 81c642814..d7e642b39 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -384,13 +384,13 @@ def messages(self): ] @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, use_native_structured_output): + @pytest.mark.parametrize("use_structured_output", [True, False]) + def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, use_structured_output): # Given driver = AmazonBedrockPromptDriver( model="ai21.j2", use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, extra_params={"foo": "bar"}, ) @@ -414,11 +414,11 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, *self.BEDROCK_TOOLS, *( [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and driver.structured_output_strategy == "tool" + if use_structured_output and driver.structured_output_strategy == "tool" else [] ), ], - "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, + "toolChoice": {"any": {}} if use_structured_output else driver.tool_choice, } } if use_native_tools @@ -437,16 +437,16 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("use_structured_output", [True, False]) def test_try_stream_run( - self, mock_converse_stream, prompt_stack, messages, use_native_tools, use_native_structured_output + self, mock_converse_stream, prompt_stack, messages, use_native_tools, use_structured_output ): # Given driver = AmazonBedrockPromptDriver( model="ai21.j2", stream=True, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, extra_params={"foo": "bar"}, ) @@ -471,11 +471,11 @@ def test_try_stream_run( *self.BEDROCK_TOOLS, *( [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and driver.structured_output_strategy == "tool" + if use_structured_output and driver.structured_output_strategy == "tool" else [] ), ], - "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, + "toolChoice": {"any": {}} if use_structured_output else driver.tool_choice, } } if use_native_tools diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 687db3b68..38b8c8bbb 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -370,14 +370,14 @@ def test_init(self): assert AnthropicPromptDriver(model="claude-3-haiku", api_key="1234") @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, use_native_structured_output): + @pytest.mark.parametrize("use_structured_output", [True, False]) + def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, use_structured_output): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, extra_params={"foo": "bar"}, ) @@ -399,13 +399,13 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, us *self.ANTHROPIC_TOOLS, *( [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and driver.structured_output_strategy == "tool" + if use_structured_output and driver.structured_output_strategy == "tool" else [] ), ] if use_native_tools else {}, - "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, + "tool_choice": {"type": "any"} if use_structured_output else driver.tool_choice, } if use_native_tools else {}, @@ -422,17 +422,15 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, us assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_stream_run( - self, mock_stream_client, prompt_stack, messages, use_native_tools, use_native_structured_output - ): + @pytest.mark.parametrize("use_structured_output", [True, False]) + def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools, use_structured_output): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, extra_params={"foo": "bar"}, ) @@ -456,13 +454,13 @@ def test_try_stream_run( *self.ANTHROPIC_TOOLS, *( [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and driver.structured_output_strategy == "tool" + if use_structured_output and driver.structured_output_strategy == "tool" else [] ), ] if use_native_tools else {}, - "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, + "tool_choice": {"type": "any"} if use_structured_output else driver.tool_choice, } if use_native_tools else {}, diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index f7f153dd0..d97c16ba3 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -67,7 +67,7 @@ def test_init(self): assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", model="gpt-4").azure_deployment == "gpt-4" @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_run( self, @@ -75,7 +75,7 @@ def test_try_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, + use_structured_output, structured_output_strategy, ): # Given @@ -84,7 +84,7 @@ def test_try_run( azure_deployment="deployment-id", model="gpt-4", use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -103,12 +103,12 @@ def test_try_run( *self.OPENAI_TOOLS, *( [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" + if use_structured_output and structured_output_strategy == "tool" else [] ), ], "tool_choice": "required" - if use_native_structured_output and structured_output_strategy == "tool" + if use_structured_output and structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools @@ -123,7 +123,7 @@ def test_try_run( }, } } - if use_native_structured_output and structured_output_strategy == "native" + if use_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) @@ -136,7 +136,7 @@ def test_try_run( assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_stream_run( self, @@ -144,7 +144,7 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, + use_structured_output, structured_output_strategy, ): # Given @@ -154,7 +154,7 @@ def test_try_stream_run( model="gpt-4", stream=True, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -175,12 +175,12 @@ def test_try_stream_run( *self.OPENAI_TOOLS, *( [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" + if use_structured_output and structured_output_strategy == "tool" else [] ), ], "tool_choice": "required" - if use_native_structured_output and structured_output_strategy == "tool" + if use_structured_output and structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools @@ -195,7 +195,7 @@ def test_try_stream_run( }, } } - if use_native_structured_output and structured_output_strategy == "native" + if use_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index ad417cac5..858aa5bee 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -338,7 +338,7 @@ def test_init(self): assert CoherePromptDriver(model="command", api_key="foobar") @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -346,7 +346,7 @@ def test_try_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, + use_structured_output, structured_output_strategy, ): # Given @@ -354,7 +354,7 @@ def test_try_run( model="command", api_key="api-key", use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -372,7 +372,7 @@ def test_try_run( *self.COHERE_TOOLS, *( [self.COHERE_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" + if use_structured_output and structured_output_strategy == "tool" else [] ), ] @@ -385,7 +385,7 @@ def test_try_run( "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if use_native_structured_output and structured_output_strategy == "native" + if use_structured_output and structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, @@ -406,7 +406,7 @@ def test_try_run( assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -414,7 +414,7 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, + use_structured_output, structured_output_strategy, ): # Given @@ -423,7 +423,7 @@ def test_try_stream_run( api_key="api-key", stream=True, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -442,7 +442,7 @@ def test_try_stream_run( *self.COHERE_TOOLS, *( [self.COHERE_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" + if use_structured_output and structured_output_strategy == "tool" else [] ), ] @@ -455,7 +455,7 @@ def test_try_stream_run( "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if use_native_structured_output and structured_output_strategy == "native" + if use_structured_output and structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index a0b68a6af..53c33735e 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -177,10 +177,8 @@ def test_init(self): assert driver @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_run( - self, mock_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output - ): + @pytest.mark.parametrize("use_structured_output", [True, False]) + def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools, use_structured_output): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -188,7 +186,7 @@ def test_try_run( top_p=0.5, top_k=50, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, structured_output_strategy="tool", extra_params={"max_output_tokens": 10}, ) @@ -213,11 +211,11 @@ def test_try_run( tool_declarations = call_args.kwargs["tools"] tools = [ *self.GOOGLE_TOOLS, - *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_structured_output else []), ] assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools - if use_native_structured_output: + if use_structured_output: assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(message.value[0], TextArtifact) @@ -231,9 +229,9 @@ def test_try_run( assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("use_structured_output", [True, False]) def test_try_stream( - self, mock_stream_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output + self, mock_stream_generative_model, prompt_stack, messages, use_native_tools, use_structured_output ): # Given driver = GooglePromptDriver( @@ -243,7 +241,7 @@ def test_try_stream( top_p=0.5, top_k=50, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, extra_params={"max_output_tokens": 10}, ) @@ -269,11 +267,11 @@ def test_try_stream( tool_declarations = call_args.kwargs["tools"] tools = [ *self.GOOGLE_TOOLS, - *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_structured_output else []), ] assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools - if use_native_structured_output: + if use_structured_output: assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 334c1649e..763a4f7b1 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -54,13 +54,13 @@ def mock_autotokenizer(self, mocker): def test_init(self): assert HuggingFaceHubPromptDriver(api_token="foobar", model="gpt2") - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_run(self, prompt_stack, mock_client, use_native_structured_output): + @pytest.mark.parametrize("use_structured_output", [True, False]) + def test_try_run(self, prompt_stack, mock_client, use_structured_output): # Given driver = HuggingFaceHubPromptDriver( api_token="api-token", model="repo-id", - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, extra_params={"foo": "bar"}, ) @@ -74,21 +74,21 @@ def test_try_run(self, prompt_stack, mock_client, use_native_structured_output): max_new_tokens=250, foo="bar", **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} - if use_native_structured_output + if use_structured_output else {}, ) assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - @pytest.mark.parametrize("use_native_structured_output", [True, False]) - def test_try_stream(self, prompt_stack, mock_client_stream, use_native_structured_output): + @pytest.mark.parametrize("use_structured_output", [True, False]) + def test_try_stream(self, prompt_stack, mock_client_stream, use_structured_output): # Given driver = HuggingFaceHubPromptDriver( api_token="api-token", model="repo-id", stream=True, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, extra_params={"foo": "bar"}, ) @@ -103,7 +103,7 @@ def test_try_stream(self, prompt_stack, mock_client_stream, use_native_structure max_new_tokens=250, foo="bar", **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} - if use_native_structured_output + if use_structured_output else {}, stream=True, ) diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index cffcd3954..d638e84e2 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -232,7 +232,7 @@ def test_init(self): assert OllamaPromptDriver(model="llama") @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -240,14 +240,14 @@ def test_try_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, + use_structured_output, structured_output_strategy, ): # Given driver = OllamaPromptDriver( model="llama", use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -269,7 +269,7 @@ def test_try_run( *self.OLLAMA_TOOLS, *( [self.OLLAMA_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" + if use_structured_output and structured_output_strategy == "tool" else [] ), ] @@ -277,7 +277,7 @@ def test_try_run( if use_native_tools else {}, **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} - if use_native_structured_output and structured_output_strategy == "native" + if use_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) @@ -290,7 +290,7 @@ def test_try_run( assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -298,7 +298,7 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, + use_structured_output, structured_output_strategy, ): # Given @@ -306,7 +306,7 @@ def test_try_stream_run( model="llama", stream=True, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -320,7 +320,7 @@ def test_try_stream_run( model=driver.model, options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} - if use_native_structured_output and structured_output_strategy == "native" + if use_structured_output and structured_output_strategy == "native" else {}, stream=True, foo="bar", diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index ed6085538..eff9fda66 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -371,7 +371,7 @@ def test_init(self): assert OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_4_MODEL) @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -379,14 +379,14 @@ def test_try_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, + use_structured_output, structured_output_strategy, ): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -406,12 +406,12 @@ def test_try_run( *self.OPENAI_TOOLS, *( [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" + if use_structured_output and structured_output_strategy == "tool" else [] ), ], "tool_choice": "required" - if use_native_structured_output and structured_output_strategy == "tool" + if use_structured_output and structured_output_strategy == "tool" else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } @@ -427,7 +427,7 @@ def test_try_run( }, } } - if use_native_structured_output and structured_output_strategy == "native" + if use_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) @@ -509,7 +509,7 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create, assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -517,7 +517,7 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, - use_native_structured_output, + use_structured_output, structured_output_strategy, ): # Given @@ -525,7 +525,7 @@ def test_try_stream_run( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True, use_native_tools=use_native_tools, - use_native_structured_output=use_native_structured_output, + use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -548,12 +548,12 @@ def test_try_stream_run( *self.OPENAI_TOOLS, *( [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_native_structured_output and structured_output_strategy == "tool" + if use_structured_output and structured_output_strategy == "tool" else [] ), ], "tool_choice": "required" - if use_native_structured_output and structured_output_strategy == "tool" + if use_structured_output and structured_output_strategy == "tool" else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } @@ -569,7 +569,7 @@ def test_try_stream_run( }, } } - if use_native_structured_output and structured_output_strategy == "native" + if use_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) @@ -600,7 +600,7 @@ def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1, use_native_tools=False, - use_native_structured_output=False, + use_structured_output=False, ) # When @@ -635,7 +635,7 @@ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messa tokenizer=MockTokenizer(model="mock-model", stop_sequences=["mock-stop"]), max_tokens=1, use_native_tools=False, - use_native_structured_output=False, + use_structured_output=False, ) # When diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 3344644a3..807e78f0b 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -83,7 +83,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, - "use_native_structured_output": False, + "use_structured_output": False, "structured_output_strategy": "native", }, } diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index e4d3060a5..60a10f1a4 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -183,7 +183,7 @@ def test_prompt_stack_native_schema(self): task = PromptTask( input="foo", prompt_driver=MockPromptDriver( - use_native_structured_output=True, + use_structured_output=True, mock_structured_output={"baz": "foo"}, ), output_schema=output_schema, @@ -206,7 +206,7 @@ def test_prompt_stack_empty_native_schema(self): task = PromptTask( input="foo", prompt_driver=MockPromptDriver( - use_native_structured_output=True, + use_structured_output=True, ), rules=[JsonSchemaRule({"foo": {}})], ) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index f3a18b1e2..00bbadc45 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -258,7 +258,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "structured_output_strategy": "native", - "use_native_structured_output": False, + "use_structured_output": False, "use_native_tools": False, }, "tool": { diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 082ccc466..70c59e1f8 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -399,7 +399,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, - "use_native_structured_output": False, + "use_structured_output": False, "structured_output_strategy": "native", }, "tools": [ From 9832874390fe5b176beda608de7fae9dcbe54e5b Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 3 Jan 2025 10:18:28 -0800 Subject: [PATCH 08/11] Remove redundant doc --- docs/griptape-framework/structures/rulesets.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/griptape-framework/structures/rulesets.md b/docs/griptape-framework/structures/rulesets.md index 93e5a4c2b..0104a94d3 100644 --- a/docs/griptape-framework/structures/rulesets.md +++ b/docs/griptape-framework/structures/rulesets.md @@ -28,7 +28,6 @@ A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Ru !!! tip [Structured Output](../drivers/prompt-drivers.md#structured-output) provides a more robust solution for having the LLM generate structured output. - And if an LLM does not natively support structured output, a `JsonSchemaRule` will automatically be added. [JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema. This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types. From 0925b38a4b9963a509a8cb3a17429e1b33f7b076 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 3 Jan 2025 10:19:50 -0800 Subject: [PATCH 09/11] Rename method for clarity --- griptape/drivers/prompt/amazon_bedrock_prompt_driver.py | 2 +- griptape/drivers/prompt/anthropic_prompt_driver.py | 2 +- griptape/drivers/prompt/base_prompt_driver.py | 2 +- griptape/drivers/prompt/cohere_prompt_driver.py | 2 +- griptape/drivers/prompt/google_prompt_driver.py | 2 +- griptape/drivers/prompt/ollama_prompt_driver.py | 2 +- griptape/drivers/prompt/openai_chat_prompt_driver.py | 2 +- tests/unit/drivers/prompt/test_base_prompt_driver.py | 6 +++--- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index ff370e2f9..f4837bdeb 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -139,7 +139,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: and self.use_structured_output and self.structured_output_strategy == "tool" ): - self._add_structured_output_tool(prompt_stack) + self._add_structured_output_tool_if_absent(prompt_stack) params["toolConfig"]["toolChoice"] = {"any": {}} params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools) diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 17492e8d0..22eaf0d30 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -141,7 +141,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: and self.use_structured_output and self.structured_output_strategy == "tool" ): - self._add_structured_output_tool(prompt_stack) + self._add_structured_output_tool_if_absent(prompt_stack) params["tool_choice"] = {"type": "any"} params["tools"] = self.__to_anthropic_tools(prompt_stack.tools) diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index b46be4822..eb00adee4 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -126,7 +126,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ... @abstractmethod def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ... - def _add_structured_output_tool(self, prompt_stack: PromptStack) -> None: + def _add_structured_output_tool_if_absent(self, prompt_stack: PromptStack) -> None: from griptape.tools.structured_output.tool import StructuredOutputTool if prompt_stack.output_schema is None: diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index b7421381c..4810aad65 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -120,7 +120,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: } elif self.structured_output_strategy == "tool": # TODO: Implement tool choice once supported - self._add_structured_output_tool(prompt_stack) + self._add_structured_output_tool_if_absent(prompt_stack) if prompt_stack.tools and self.use_native_tools: params["tools"] = self.__to_cohere_tools(prompt_stack.tools) diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index bf91a5b30..cb7ac47b5 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -170,7 +170,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: and self.structured_output_strategy == "tool" ): params["tool_config"]["function_calling_config"]["mode"] = "auto" - self._add_structured_output_tool(prompt_stack) + self._add_structured_output_tool_if_absent(prompt_stack) params["tools"] = self.__to_google_tools(prompt_stack.tools) diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index da7c51a9a..fd3b24524 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -115,7 +115,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: params["format"] = prompt_stack.output_schema.json_schema("Output") elif self.structured_output_strategy == "tool": # TODO: Implement tool choice once supported - self._add_structured_output_tool(prompt_stack) + self._add_structured_output_tool_if_absent(prompt_stack) # Tool calling is only supported when not streaming if prompt_stack.tools and self.use_native_tools and not self.stream: diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index aaf954da0..5a1029eee 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -171,7 +171,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: } elif self.structured_output_strategy == "tool" and self.use_native_tools: params["tool_choice"] = "required" - self._add_structured_output_tool(prompt_stack) + self._add_structured_output_tool_if_absent(prompt_stack) if self.response_format is not None: if self.response_format == {"type": "json_object"}: diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index c57173c66..985cc3d31 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -78,13 +78,13 @@ def test__add_structured_output_tool(self): prompt_stack = PromptStack() with pytest.raises(ValueError, match="PromptStack must have an output schema to use structured output."): - mock_prompt_driver._add_structured_output_tool(prompt_stack) + mock_prompt_driver._add_structured_output_tool_if_absent(prompt_stack) prompt_stack.output_schema = Schema({"foo": str}) - mock_prompt_driver._add_structured_output_tool(prompt_stack) + mock_prompt_driver._add_structured_output_tool_if_absent(prompt_stack) # Ensure it doesn't get added twice - mock_prompt_driver._add_structured_output_tool(prompt_stack) + mock_prompt_driver._add_structured_output_tool_if_absent(prompt_stack) assert len(prompt_stack.tools) == 1 assert isinstance(prompt_stack.tools[0], StructuredOutputTool) assert prompt_stack.tools[0].output_schema is prompt_stack.output_schema From c9bcefa881dea502e42c4b885547c8443d2e2d69 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 3 Jan 2025 12:28:25 -0800 Subject: [PATCH 10/11] Move logic from driver to task, remove flag --- CHANGELOG.md | 6 +- .../drivers/prompt-drivers.md | 25 ++--- .../src/prompt_drivers_structured_output.py | 1 - .../prompt/amazon_bedrock_prompt_driver.py | 17 +-- ...mazon_sagemaker_jumpstart_prompt_driver.py | 11 ++ .../drivers/prompt/anthropic_prompt_driver.py | 17 +-- griptape/drivers/prompt/base_prompt_driver.py | 17 +-- .../drivers/prompt/cohere_prompt_driver.py | 15 +-- .../drivers/prompt/google_prompt_driver.py | 17 +-- .../prompt/huggingface_hub_prompt_driver.py | 13 ++- .../huggingface_pipeline_prompt_driver.py | 14 ++- .../drivers/prompt/ollama_prompt_driver.py | 9 +- .../prompt/openai_chat_prompt_driver.py | 8 +- griptape/schemas/base_schema.py | 2 + griptape/structures/agent.py | 4 + griptape/tasks/prompt_task.py | 18 ++- .../templates/tasks/prompt_task/system.j2 | 2 +- tests/mocks/mock_prompt_driver.py | 103 ++++++++++-------- .../test_amazon_bedrock_drivers_config.py | 2 - .../drivers/test_anthropic_drivers_config.py | 1 - .../test_azure_openai_drivers_config.py | 1 - .../drivers/test_cohere_drivers_config.py | 3 +- .../configs/drivers/test_drivers_config.py | 3 +- .../drivers/test_google_drivers_config.py | 1 - .../drivers/test_openai_driver_config.py | 1 - .../test_amazon_bedrock_prompt_driver.py | 63 +++-------- ...mazon_sagemaker_jumpstart_prompt_driver.py | 9 ++ .../prompt/test_anthropic_prompt_driver.py | 62 +++-------- .../test_azure_openai_chat_prompt_driver.py | 36 +----- .../drivers/prompt/test_base_prompt_driver.py | 23 ---- .../prompt/test_cohere_prompt_driver.py | 58 +--------- .../prompt/test_google_prompt_driver.py | 40 +++---- .../test_hugging_face_hub_prompt_driver.py | 34 +++--- ...est_hugging_face_pipeline_prompt_driver.py | 24 +++- .../prompt/test_ollama_prompt_driver.py | 40 +------ .../prompt/test_openai_chat_prompt_driver.py | 66 ++--------- tests/unit/structures/test_agent.py | 12 ++ tests/unit/structures/test_structure.py | 3 +- tests/unit/tasks/test_prompt_task.py | 32 ++++-- tests/unit/tasks/test_tool_task.py | 3 +- tests/unit/tasks/test_toolkit_task.py | 3 +- 41 files changed, 315 insertions(+), 504 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b20d4c280..62b069cdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `BaseVectorStoreDriver.query_vector` for querying vector stores with vectors. +- Structured Output support for all Prompt Drivers. +- `PromptTask.output_schema` for setting an output schema to be used with Structured Output. +- `Agent.output_schema` for setting an output schema to be used on the Agent's Prompt Task. +- `BasePromptDriver.structured_output_strategy` for changing the Structured Output strategy between `native`, `tool`, and `rule`. ## [1.1.1] - 2025-01-03 @@ -31,8 +35,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Structure.run_stream()` for streaming Events from a Structure as an iterator. - Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`. - Validators to `Agent` initialization. -- `BasePromptDriver.use_structured_output` for enabling or disabling structured output. -- `BasePromptDriver.structured_output_strategy` for changing the structured output strategy between `native` and `tool`. ### Changed diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 22c3dd4ff..a6694726b 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -27,30 +27,27 @@ You can pass images to the Driver if the model supports it: ## Structured Output -Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems. +Some LLMs provide functionality often referred to as "Structured Output". +This means instructing the LLM to output data in a particular format, usually JSON. +This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems. -Structured output can be enabled or disabled for a Prompt Driver by setting the [use_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_structured_output). - -If `use_structured_output=True`, you can change _how_ the output is structured by setting the [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of: - -- `native`: The Driver will use the LLM's structured output functionality provided by the API. -- `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md) and try to force the LLM to use a Tool. - -Each Driver may have a different default setting depending on the LLM provider's capabilities. +!!! warning + Each Driver may have a different default setting depending on the LLM provider's capabilities. ### Prompt Task The easiest way to get started with structured output is by using a [PromptTask](../structures/tasks.md#prompt)'s [output_schema](../../reference/griptape/tasks/prompt_task.md#griptape.tasks.PromptTask.output_schema) parameter. +You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of: + +- `native`: The Driver will use the LLM's structured output functionality provided by the API. +- `tool`: The Task will add a special tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md), and the Driver will try to force the LLM to use the Tool. +- `rule`: The Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. This strategy does not guarantee that the LLM will output JSON and should only be used as a last resort. + ```python --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py" ``` -If `use_structured_output=False`, the Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. - -!!! warning - Not every LLM supports `use_structured_output` or all `structured_output_strategy` options. - ## Prompt Drivers Griptape offers the following Prompt Drivers for interacting with LLMs. diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py index 8f5d0b77b..cb7eb5ceb 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py @@ -11,7 +11,6 @@ PromptTask( prompt_driver=OpenAiChatPromptDriver( model="gpt-4o", - use_structured_output=True, # optional structured_output_strategy="native", # optional ), output_schema=schema.Schema( diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index f4837bdeb..12ea13ad5 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any from attrs import Attribute, Factory, define, field from schema import Schema @@ -41,6 +41,7 @@ import boto3 from griptape.common import PromptStack + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.tools import BaseTool logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -55,17 +56,16 @@ class AmazonBedrockPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - structured_output_strategy: Literal["native", "tool"] = field( + structured_output_strategy: StructuredOutputStrategy = field( default="tool", kw_only=True, metadata={"serializable": True} ) tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True}) _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": - raise ValueError("AmazonBedrockPromptDriver does not support `native` structured output mode.") + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @@ -134,12 +134,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "toolChoice": self.tool_choice, } - if ( - prompt_stack.output_schema is not None - and self.use_structured_output - and self.structured_output_strategy == "tool" - ): - self._add_structured_output_tool_if_absent(prompt_stack) + if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["toolConfig"]["toolChoice"] = {"any": {}} params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools) diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index d98ac9fd4..bc0e28266 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -20,6 +20,7 @@ import boto3 from griptape.common import PromptStack + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -39,8 +40,18 @@ class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver): ), kw_only=True, ) + structured_output_strategy: StructuredOutputStrategy = field( + default="rule", kw_only=True, metadata={"serializable": True} + ) _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: + if value != "rule": + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") + + return value + @lazy_property() def client(self) -> Any: return self.session.client("sagemaker-runtime") diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 22eaf0d30..9a558e7cf 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Optional from attrs import Attribute, Factory, define, field from schema import Schema @@ -42,6 +42,7 @@ from anthropic import Client from anthropic.types import ContentBlock, ContentBlockDeltaEvent, ContentBlockStartEvent + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.tools.base_tool import BaseTool @@ -68,8 +69,7 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - structured_output_strategy: Literal["native", "tool"] = field( + structured_output_strategy: StructuredOutputStrategy = field( default="tool", kw_only=True, metadata={"serializable": True} ) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) @@ -80,9 +80,9 @@ def client(self) -> Client: return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": - raise ValueError("AnthropicPromptDriver does not support `native` structured output mode.") + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @@ -136,12 +136,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if prompt_stack.tools and self.use_native_tools: params["tool_choice"] = self.tool_choice - if ( - prompt_stack.output_schema is not None - and self.use_structured_output - and self.structured_output_strategy == "tool" - ): - self._add_structured_output_tool_if_absent(prompt_stack) + if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["tool_choice"] = {"type": "any"} params["tools"] = self.__to_anthropic_tools(prompt_stack.tools) diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index eb00adee4..c5ffb7259 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -32,6 +32,8 @@ from griptape.tokenizers import BaseTokenizer +StructuredOutputStrategy = Literal["native", "tool", "rule"] + @define(kw_only=True) class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): @@ -56,9 +58,8 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - structured_output_strategy: Literal["native", "tool"] = field( - default="native", kw_only=True, metadata={"serializable": True} + structured_output_strategy: StructuredOutputStrategy = field( + default="rule", kw_only=True, metadata={"serializable": True} ) extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) @@ -126,16 +127,6 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ... @abstractmethod def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ... - def _add_structured_output_tool_if_absent(self, prompt_stack: PromptStack) -> None: - from griptape.tools.structured_output.tool import StructuredOutputTool - - if prompt_stack.output_schema is None: - raise ValueError("PromptStack must have an output schema to use structured output.") - - structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) - if structured_output_tool not in prompt_stack.tools: - prompt_stack.tools.append(structured_output_tool) - def __process_run(self, prompt_stack: PromptStack) -> Message: return self.try_run(prompt_stack) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 4810aad65..9158c4ad1 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -53,7 +53,6 @@ class CoherePromptDriver(BasePromptDriver): model: str = field(metadata={"serializable": True}) force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), @@ -112,15 +111,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema is not None and self.use_structured_output: - if self.structured_output_strategy == "native": - params["response_format"] = { - "type": "json_object", - "schema": prompt_stack.output_schema.json_schema("Output"), - } - elif self.structured_output_strategy == "tool": - # TODO: Implement tool choice once supported - self._add_structured_output_tool_if_absent(prompt_stack) + if prompt_stack.output_schema is not None and self.structured_output_strategy == "native": + params["response_format"] = { + "type": "json_object", + "schema": prompt_stack.output_schema.json_schema("Output"), + } if prompt_stack.tools and self.use_native_tools: params["tools"] = self.__to_cohere_tools(prompt_stack.tools) diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index cb7ac47b5..46a721b08 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -2,7 +2,7 @@ import json import logging -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Optional from attrs import Attribute, Factory, define, field from schema import Schema @@ -37,6 +37,7 @@ from google.generativeai.protos import Part from google.generativeai.types import ContentDict, ContentsType, GenerateContentResponse + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.tools import BaseTool logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -63,17 +64,16 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - structured_output_strategy: Literal["native", "tool"] = field( + structured_output_strategy: StructuredOutputStrategy = field( default="tool", kw_only=True, metadata={"serializable": True} ) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True}) _client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": - raise ValueError("GooglePromptDriver does not support `native` structured output mode.") + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @@ -164,13 +164,8 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if prompt_stack.tools and self.use_native_tools: params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} - if ( - prompt_stack.output_schema is not None - and self.use_structured_output - and self.structured_output_strategy == "tool" - ): + if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["tool_config"]["function_calling_config"]["mode"] = "auto" - self._add_structured_output_tool_if_absent(prompt_stack) params["tools"] = self.__to_google_tools(prompt_stack.tools) diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index e0a35048f..57a487450 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING from attrs import Attribute, Factory, define, field @@ -17,6 +17,8 @@ from huggingface_hub import InferenceClient + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy + logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -35,8 +37,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - structured_output_strategy: Literal["native", "tool"] = field( + structured_output_strategy: StructuredOutputStrategy = field( default="native", kw_only=True, metadata={"serializable": True} ) tokenizer: HuggingFaceTokenizer = field( @@ -56,9 +57,9 @@ def client(self) -> InferenceClient: ) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "tool": - raise ValueError("HuggingFaceHubPromptDriver does not support `tool` structured output mode.") + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @@ -121,7 +122,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema and self.use_structured_output and self.structured_output_strategy == "native": + if prompt_stack.output_schema and self.structured_output_strategy == "native": # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding output_schema = prompt_stack.output_schema.json_schema("Output Schema") # Grammar does not support $schema and $id diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index a197523df..866f033ec 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -3,7 +3,7 @@ import logging from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from griptape.artifacts import TextArtifact from griptape.common import DeltaMessage, Message, PromptStack, TextMessageContent, observable @@ -18,6 +18,8 @@ from transformers import TextGenerationPipeline + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy + logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -38,10 +40,20 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): ), kw_only=True, ) + structured_output_strategy: StructuredOutputStrategy = field( + default="rule", kw_only=True, metadata={"serializable": True} + ) _pipeline: TextGenerationPipeline = field( default=None, kw_only=True, alias="pipeline", metadata={"serializable": False} ) + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: + if value in ("native", "tool"): + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") + + return value + @lazy_property() def pipeline(self) -> TextGenerationPipeline: return import_optional_dependency("transformers").pipeline( diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index fd3b24524..1c4ae3fd1 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -68,7 +68,6 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() @@ -110,12 +109,8 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema is not None and self.use_structured_output: - if self.structured_output_strategy == "native": - params["format"] = prompt_stack.output_schema.json_schema("Output") - elif self.structured_output_strategy == "tool": - # TODO: Implement tool choice once supported - self._add_structured_output_tool_if_absent(prompt_stack) + if prompt_stack.output_schema is not None and self.structured_output_strategy == "native": + params["format"] = prompt_stack.output_schema.json_schema("Output") # Tool calling is only supported when not streaming if prompt_stack.tools and self.use_native_tools and not self.stream: diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 5a1029eee..03390d687 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -35,6 +35,7 @@ from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.chat.chat_completion_message import ChatCompletionMessage + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.tools import BaseTool @@ -76,7 +77,9 @@ class OpenAiChatPromptDriver(BasePromptDriver): seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + structured_output_strategy: StructuredOutputStrategy = field( + default="native", kw_only=True, metadata={"serializable": True} + ) parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory( @@ -159,7 +162,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: params["tool_choice"] = self.tool_choice params["parallel_tool_calls"] = self.parallel_tool_calls - if prompt_stack.output_schema is not None and self.use_structured_output: + if prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_schema", @@ -171,7 +174,6 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: } elif self.structured_output_strategy == "tool" and self.use_native_tools: params["tool_choice"] = "required" - self._add_structured_output_tool_if_absent(prompt_stack) if self.response_format is not None: if self.response_format == {"type": "json_object"}: diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 4432c1080..fa622bd05 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -172,6 +172,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: BaseTextToSpeechDriver, BaseVectorStoreDriver, ) + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.events import EventListener from griptape.memory import TaskMemory from griptape.memory.structure import BaseConversationMemory, Run @@ -216,6 +217,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: "BaseArtifactStorage": BaseArtifactStorage, "BaseRule": BaseRule, "Ruleset": Ruleset, + "StructuredOutputStrategy": StructuredOutputStrategy, # Third party modules "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any, "ClientV2": import_optional_dependency("cohere").ClientV2 if is_dependency_installed("cohere") else Any, diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index baf36108f..9b70b7fb1 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -12,6 +12,8 @@ from griptape.tasks import PromptTask if TYPE_CHECKING: + from schema import Schema + from griptape.artifacts import BaseArtifact from griptape.drivers import BasePromptDriver from griptape.tasks import BaseTask @@ -25,6 +27,7 @@ class Agent(Structure): ) stream: bool = field(default=None, kw_only=True) prompt_driver: BasePromptDriver = field(default=None, kw_only=True) + output_schema: Optional[Schema] = field(default=None, kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True) @@ -98,6 +101,7 @@ def _init_task(self) -> None: self.input, prompt_driver=self.prompt_driver, tools=self.tools, + output_schema=self.output_schema, max_meta_memory_entries=self.max_meta_memory_entries, ) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index ae80effcb..15c0f7457 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -15,6 +15,7 @@ from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.mixins.rule_mixin import RuleMixin from griptape.rules import Ruleset +from griptape.rules.json_schema_rule import JsonSchemaRule from griptape.tasks import ActionsSubtask, BaseTask from griptape.utils import J2 @@ -91,9 +92,16 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], @property def prompt_stack(self) -> PromptStack: - stack = PromptStack(tools=self.tools, output_schema=self.output_schema) + from griptape.tools.structured_output.tool import StructuredOutputTool + + stack = PromptStack(tools=self.tools) memory = self.structure.conversation_memory if self.structure is not None else None + if self.output_schema is not None: + stack.output_schema = self.output_schema + if self.prompt_driver.structured_output_strategy == "tool": + stack.tools.append(StructuredOutputTool(output_schema=stack.output_schema)) + system_template = self.generate_system_template(self) if system_template: stack.add_system_message(system_template) @@ -190,7 +198,7 @@ def try_run(self) -> BaseArtifact: else: output = result.to_artifact() - if self.prompt_driver.use_structured_output and self.prompt_driver.structured_output_strategy == "native": + if self.output_schema is not None and self.prompt_driver.structured_output_strategy in ("native", "rule"): return JsonArtifact(output.value) else: return output @@ -210,8 +218,6 @@ def preprocess(self, structure: Structure) -> BaseTask: return self def default_generate_system_template(self, _: PromptTask) -> str: - from griptape.rules import JsonSchemaRule - schema = self.actions_schema().json_schema("Actions Schema") schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually. @@ -221,8 +227,8 @@ def default_generate_system_template(self, _: PromptTask) -> str: actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), use_native_tools=self.prompt_driver.use_native_tools, - use_structured_output=self.prompt_driver.use_structured_output, - json_schema_rule=JsonSchemaRule(self.output_schema.json_schema("Output Schema")) + structured_output_strategy=self.prompt_driver.structured_output_strategy, + json_schema_rule=JsonSchemaRule(self.output_schema.json_schema("Output")) if self.output_schema is not None else None, stop_sequence=self.response_stop_sequence, diff --git a/griptape/templates/tasks/prompt_task/system.j2 b/griptape/templates/tasks/prompt_task/system.j2 index e1a8bb21b..8e89e13c7 100644 --- a/griptape/templates/tasks/prompt_task/system.j2 +++ b/griptape/templates/tasks/prompt_task/system.j2 @@ -26,7 +26,7 @@ NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER {{ rulesets }} {% endif %} -{% if not use_structured_output and json_schema_rule %} +{% if json_schema_rule and structured_output_strategy == 'rule' %} {{ json_schema_rule }} {% endif %} diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 243b29281..3310a952e 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -36,15 +36,6 @@ class MockPromptDriver(BasePromptDriver): def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if self.use_structured_output and prompt_stack.output_schema is not None: - if self.structured_output_strategy == "native": - return Message( - content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], - role=Message.ASSISTANT_ROLE, - usage=Message.Usage(input_tokens=100, output_tokens=100), - ) - elif self.structured_output_strategy == "tool": - self._add_structured_output_tool(prompt_stack) if self.use_native_tools and prompt_stack.tools: # Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer. @@ -58,41 +49,42 @@ def try_run(self, prompt_stack: PromptStack) -> Message: usage=Message.Usage(input_tokens=100, output_tokens=100), ) else: + if self.structured_output_strategy == "tool": + tool_action = ToolAction( + tag="mock-tag", + name="StructuredOutputTool", + path="provide_output", + input={"values": self.mock_structured_output}, + ) + else: + tool_action = ToolAction( + tag="mock-tag", + name="MockTool", + path="test", + input={"values": {"test": "test-value"}}, + ) + return Message( - content=[ - ActionCallMessageContent( - ActionArtifact( - ToolAction( - tag="mock-tag", - name="MockTool", - path="test", - input={"values": {"test": "test-value"}}, - ) - ) - ) - ], + content=[ActionCallMessageContent(ActionArtifact(tool_action))], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=100, output_tokens=100), ) else: - return Message( - content=[TextMessageContent(TextArtifact(output))], - role=Message.ASSISTANT_ROLE, - usage=Message.Usage(input_tokens=100, output_tokens=100), - ) - - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: - output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - - if self.use_structured_output and prompt_stack.output_schema is not None: - if self.structured_output_strategy == "native": - yield DeltaMessage( - content=TextDeltaMessageContent(json.dumps(self.mock_structured_output)), + if prompt_stack.output_schema is not None: + return Message( + content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=100, output_tokens=100), + ) + else: + return Message( + content=[TextMessageContent(TextArtifact(output))], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=100, output_tokens=100), ) - elif self.structured_output_strategy == "tool": - self._add_structured_output_tool(prompt_stack) + + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: + output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output if self.use_native_tools and prompt_stack.tools: # Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer. @@ -103,15 +95,36 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: yield DeltaMessage(content=TextDeltaMessageContent(f"Answer: {output}")) yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=100, output_tokens=100)) else: - yield DeltaMessage( - content=ActionCallDeltaMessageContent( - tag="mock-tag", - name="MockTool", - path="test", + if self.structured_output_strategy == "tool": + yield DeltaMessage( + content=ActionCallDeltaMessageContent( + tag="mock-tag", + name="StructuredOutputTool", + path="provide_output", + ) ) - ) + yield DeltaMessage( + content=ActionCallDeltaMessageContent( + partial_input=json.dumps({"values": self.mock_structured_output}) + ) + ) + else: + yield DeltaMessage( + content=ActionCallDeltaMessageContent( + tag="mock-tag", + name="MockTool", + path="test", + ) + ) + yield DeltaMessage( + content=ActionCallDeltaMessageContent(partial_input='{ "values": { "test": "test-value" } }') + ) + else: + if prompt_stack.output_schema is not None: yield DeltaMessage( - content=ActionCallDeltaMessageContent(partial_input='{ "values": { "test": "test-value" } }') + content=TextDeltaMessageContent(json.dumps(self.mock_structured_output)), + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=100, output_tokens=100), ) - else: - yield DeltaMessage(content=TextDeltaMessageContent(output)) + else: + yield DeltaMessage(content=TextDeltaMessageContent(output)) diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index d9a4f4cb3..b2fd51d24 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -51,7 +51,6 @@ def test_to_dict(self, config): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, - "use_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, @@ -108,7 +107,6 @@ def test_to_dict_with_values(self, config_with_values): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, - "use_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index 1df66b534..fa13480c1 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -26,7 +26,6 @@ def test_to_dict(self, config): "top_k": 250, "use_native_tools": True, "structured_output_strategy": "tool", - "use_structured_output": True, "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index c63f8bdbc..a30cea001 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -37,7 +37,6 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, "structured_output_strategy": "native", - "use_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 11a39ba4c..d5e05c9bd 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -26,8 +26,7 @@ def test_to_dict(self, config): "model": "command-r", "force_single_step": False, "use_native_tools": True, - "use_structured_output": True, - "structured_output_strategy": "native", + "structured_output_strategy": "rule", "extra_params": {}, }, "embedding_driver": { diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index fa8c07c8c..5adec7c6d 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,8 +18,7 @@ def test_to_dict(self, config): "max_tokens": None, "stream": False, "use_native_tools": False, - "use_structured_output": False, - "structured_output_strategy": "native", + "structured_output_strategy": "rule", "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index 1f53ae59f..910ae3240 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -25,7 +25,6 @@ def test_to_dict(self, config): "top_k": None, "tool_choice": "auto", "use_native_tools": True, - "use_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index a77f9ab46..344d14d99 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -29,7 +29,6 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, "structured_output_strategy": "native", - "use_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index d7e642b39..2dcb4bf02 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -8,29 +8,6 @@ class TestAmazonBedrockPromptDriver: - BEDROCK_STRUCTURED_OUTPUT_TOOL = { - "toolSpec": { - "description": "Used to provide the final response which ends this conversation.", - "inputSchema": { - "json": { - "$id": "http://json-schema.org/draft-07/schema#", - "$schema": "http://json-schema.org/draft-07/schema#", - "additionalProperties": False, - "properties": { - "values": { - "additionalProperties": False, - "properties": {"foo": {"type": "string"}}, - "required": ["foo"], - "type": "object", - }, - }, - "required": ["values"], - "type": "object", - }, - }, - "name": "StructuredOutputTool_provide_output", - }, - } BEDROCK_TOOLS = [ { "toolSpec": { @@ -384,13 +361,13 @@ def messages(self): ] @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, use_structured_output): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, structured_output_strategy): # Given driver = AmazonBedrockPromptDriver( model="ai21.j2", use_native_tools=use_native_tools, - use_structured_output=use_structured_output, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -410,15 +387,10 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, **( { "toolConfig": { - "tools": [ - *self.BEDROCK_TOOLS, - *( - [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and driver.structured_output_strategy == "tool" - else [] - ), - ], - "toolChoice": {"any": {}} if use_structured_output else driver.tool_choice, + "tools": self.BEDROCK_TOOLS, + "toolChoice": {"any": {}} + if driver.structured_output_strategy == "tool" + else driver.tool_choice, } } if use_native_tools @@ -437,16 +409,16 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) def test_try_stream_run( - self, mock_converse_stream, prompt_stack, messages, use_native_tools, use_structured_output + self, mock_converse_stream, prompt_stack, messages, use_native_tools, structured_output_strategy ): # Given driver = AmazonBedrockPromptDriver( model="ai21.j2", stream=True, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -467,15 +439,10 @@ def test_try_stream_run( **( { "toolConfig": { - "tools": [ - *self.BEDROCK_TOOLS, - *( - [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and driver.structured_output_strategy == "tool" - else [] - ), - ], - "toolChoice": {"any": {}} if use_structured_output else driver.tool_choice, + "tools": self.BEDROCK_TOOLS, + "toolChoice": {"any": {}} + if driver.structured_output_strategy == "tool" + else driver.tool_choice, } } if use_native_tools @@ -506,6 +473,6 @@ def test_verify_structured_output_strategy(self): assert AmazonBedrockPromptDriver(model="foo", structured_output_strategy="tool") with pytest.raises( - ValueError, match="AmazonBedrockPromptDriver does not support `native` structured output mode." + ValueError, match="AmazonBedrockPromptDriver does not support `native` structured output strategy." ): AmazonBedrockPromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py index c7b0682c2..7b2d38398 100644 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py @@ -138,3 +138,12 @@ def test_try_run_throws_on_empty_response(self, mock_client): # Then assert e.value.args[0] == "model response is empty" + + def test_verify_structured_output_strategy(self): + assert AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="foo", structured_output_strategy="rule") + + with pytest.raises( + ValueError, + match="AmazonSageMakerJumpstartPromptDriver does not support `native` structured output strategy.", + ): + AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 38b8c8bbb..fbdf1e55d 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -142,24 +142,6 @@ class TestAnthropicPromptDriver: }, ] - ANTHROPIC_STRUCTURED_OUTPUT_TOOL = { - "description": "Used to provide the final response which ends this conversation.", - "input_schema": { - "additionalProperties": False, - "properties": { - "values": { - "additionalProperties": False, - "properties": {"foo": {"type": "string"}}, - "required": ["foo"], - "type": "object", - }, - }, - "required": ["values"], - "type": "object", - }, - "name": "StructuredOutputTool_provide_output", - } - @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("anthropic.Anthropic") @@ -370,14 +352,14 @@ def test_init(self): assert AnthropicPromptDriver(model="claude-3-haiku", api_key="1234") @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, use_structured_output): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, structured_output_strategy): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools, - use_structured_output=use_structured_output, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -395,17 +377,8 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, us top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, **{ - "tools": [ - *self.ANTHROPIC_TOOLS, - *( - [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and driver.structured_output_strategy == "tool" - else [] - ), - ] - if use_native_tools - else {}, - "tool_choice": {"type": "any"} if use_structured_output else driver.tool_choice, + "tools": self.ANTHROPIC_TOOLS if use_native_tools else {}, + "tool_choice": {"type": "any"} if driver.structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, @@ -422,15 +395,17 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, us assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools, use_structured_output): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_stream_run( + self, mock_stream_client, prompt_stack, messages, use_native_tools, structured_output_strategy + ): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -450,17 +425,8 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, **{ - "tools": [ - *self.ANTHROPIC_TOOLS, - *( - [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and driver.structured_output_strategy == "tool" - else [] - ), - ] - if use_native_tools - else {}, - "tool_choice": {"type": "any"} if use_structured_output else driver.tool_choice, + "tools": self.ANTHROPIC_TOOLS if use_native_tools else {}, + "tool_choice": {"type": "any"} if driver.structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, @@ -492,5 +458,7 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na def test_verify_structured_output_strategy(self): assert AnthropicPromptDriver(model="foo", structured_output_strategy="tool") - with pytest.raises(ValueError, match="AnthropicPromptDriver does not support `native` structured output mode."): + with pytest.raises( + ValueError, match="AnthropicPromptDriver does not support `native` structured output strategy." + ): AnthropicPromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index d97c16ba3..8f0da735a 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -67,7 +67,6 @@ def test_init(self): assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", model="gpt-4").azure_deployment == "gpt-4" @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_run( self, @@ -75,7 +74,6 @@ def test_try_run( prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given @@ -84,7 +82,6 @@ def test_try_run( azure_deployment="deployment-id", model="gpt-4", use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -99,17 +96,8 @@ def test_try_run( user=driver.user, messages=messages, **{ - "tools": [ - *self.OPENAI_TOOLS, - *( - [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ], - "tool_choice": "required" - if use_structured_output and structured_output_strategy == "tool" - else driver.tool_choice, + "tools": self.OPENAI_TOOLS, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, @@ -123,7 +111,7 @@ def test_try_run( }, } } - if use_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, foo="bar", ) @@ -136,7 +124,6 @@ def test_try_run( assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_stream_run( self, @@ -144,7 +131,6 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given @@ -154,7 +140,6 @@ def test_try_stream_run( model="gpt-4", stream=True, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -171,17 +156,8 @@ def test_try_stream_run( stream=True, messages=messages, **{ - "tools": [ - *self.OPENAI_TOOLS, - *( - [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ], - "tool_choice": "required" - if use_structured_output and structured_output_strategy == "tool" - else driver.tool_choice, + "tools": self.OPENAI_TOOLS, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, @@ -195,7 +171,7 @@ def test_try_stream_run( }, } } - if use_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 985cc3d31..58720bbc5 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,5 +1,3 @@ -import pytest - from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent @@ -67,24 +65,3 @@ def test_run_with_tools_and_stream(self, mock_config): output = pipeline.run().output_task.output assert isinstance(output, TextArtifact) assert output.value == "mock output" - - def test__add_structured_output_tool(self): - from schema import Schema - - from griptape.tools.structured_output.tool import StructuredOutputTool - - mock_prompt_driver = MockPromptDriver() - - prompt_stack = PromptStack() - - with pytest.raises(ValueError, match="PromptStack must have an output schema to use structured output."): - mock_prompt_driver._add_structured_output_tool_if_absent(prompt_stack) - - prompt_stack.output_schema = Schema({"foo": str}) - - mock_prompt_driver._add_structured_output_tool_if_absent(prompt_stack) - # Ensure it doesn't get added twice - mock_prompt_driver._add_structured_output_tool_if_absent(prompt_stack) - assert len(prompt_stack.tools) == 1 - assert isinstance(prompt_stack.tools[0], StructuredOutputTool) - assert prompt_stack.tools[0].output_schema is prompt_stack.output_schema diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 858aa5bee..8b51940c8 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -21,28 +21,6 @@ class TestCoherePromptDriver: "required": ["foo"], "type": "object", } - COHERE_STRUCTURED_OUTPUT_TOOL = { - "function": { - "description": "Used to provide the final response which ends this conversation.", - "name": "StructuredOutputTool_provide_output", - "parameters": { - "$id": "Parameters Schema", - "$schema": "http://json-schema.org/draft-07/schema#", - "additionalProperties": False, - "properties": { - "values": { - "additionalProperties": False, - "properties": {"foo": {"type": "string"}}, - "required": ["foo"], - "type": "object", - }, - }, - "required": ["values"], - "type": "object", - }, - }, - "type": "function", - } COHERE_TOOLS = [ { "function": { @@ -338,7 +316,6 @@ def test_init(self): assert CoherePromptDriver(model="command", api_key="foobar") @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -346,7 +323,6 @@ def test_try_run( prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given @@ -354,7 +330,6 @@ def test_try_run( model="command", api_key="api-key", use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -367,25 +342,14 @@ def test_try_run( model="command", messages=messages, max_tokens=None, - **{ - "tools": [ - *self.COHERE_TOOLS, - *( - [self.COHERE_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ] - } - if use_native_tools - else {}, + **{"tools": self.COHERE_TOOLS} if use_native_tools else {}, **{ "response_format": { "type": "json_object", "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if use_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, @@ -406,7 +370,6 @@ def test_try_run( assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -414,7 +377,6 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given @@ -423,7 +385,6 @@ def test_try_stream_run( api_key="api-key", stream=True, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -437,25 +398,14 @@ def test_try_stream_run( model="command", messages=messages, max_tokens=None, - **{ - "tools": [ - *self.COHERE_TOOLS, - *( - [self.COHERE_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ] - } - if use_native_tools - else {}, + **{"tools": self.COHERE_TOOLS} if use_native_tools else {}, **{ "response_format": { "type": "json_object", "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if use_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 53c33735e..aacc207b9 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -14,15 +14,6 @@ class TestGooglePromptDriver: - GOOGLE_STRUCTURED_OUTPUT_TOOL = { - "description": "Used to provide the final response which ends this conversation.", - "name": "StructuredOutputTool_provide_output", - "parameters": { - "properties": {"foo": {"type": "STRING"}}, - "required": ["foo"], - "type": "OBJECT", - }, - } GOOGLE_TOOLS = [ { "name": "MockTool_test", @@ -177,8 +168,8 @@ def test_init(self): assert driver @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools, use_structured_output): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools, structured_output_strategy): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -186,8 +177,7 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native top_p=0.5, top_k=50, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, - structured_output_strategy="tool", + structured_output_strategy=structured_output_strategy, extra_params={"max_output_tokens": 10}, ) @@ -209,13 +199,10 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] - tools = [ - *self.GOOGLE_TOOLS, - *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_structured_output else []), - ] + tools = self.GOOGLE_TOOLS assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools - if use_structured_output: + if driver.structured_output_strategy == "tool": assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(message.value[0], TextArtifact) @@ -229,9 +216,9 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) def test_try_stream( - self, mock_stream_generative_model, prompt_stack, messages, use_native_tools, use_structured_output + self, mock_stream_generative_model, prompt_stack, messages, use_native_tools, structured_output_strategy ): # Given driver = GooglePromptDriver( @@ -241,7 +228,7 @@ def test_try_stream( top_p=0.5, top_k=50, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, + structured_output_strategy=structured_output_strategy, extra_params={"max_output_tokens": 10}, ) @@ -265,13 +252,10 @@ def test_try_stream( ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] - tools = [ - *self.GOOGLE_TOOLS, - *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_structured_output else []), - ] + tools = self.GOOGLE_TOOLS assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools - if use_structured_output: + if driver.structured_output_strategy == "tool": assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" @@ -291,5 +275,7 @@ def test_try_stream( def test_verify_structured_output_strategy(self): assert GooglePromptDriver(model="foo", structured_output_strategy="tool") - with pytest.raises(ValueError, match="GooglePromptDriver does not support `native` structured output mode."): + with pytest.raises( + ValueError, match="GooglePromptDriver does not support `native` structured output strategy." + ): GooglePromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 763a4f7b1..b757dbcea 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -54,14 +54,14 @@ def mock_autotokenizer(self, mocker): def test_init(self): assert HuggingFaceHubPromptDriver(api_token="foobar", model="gpt2") - @pytest.mark.parametrize("use_structured_output", [True, False]) - def test_try_run(self, prompt_stack, mock_client, use_structured_output): + @pytest.mark.parametrize("structured_output_strategy", ["native", "rule", "foo"]) + def test_try_run(self, prompt_stack, mock_client, structured_output_strategy): # Given driver = HuggingFaceHubPromptDriver( api_token="api-token", model="repo-id", - use_structured_output=use_structured_output, extra_params={"foo": "bar"}, + structured_output_strategy=structured_output_strategy, ) # When @@ -73,23 +73,27 @@ def test_try_run(self, prompt_stack, mock_client, use_structured_output): return_full_text=False, max_new_tokens=250, foo="bar", - **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} - if use_structured_output - else {}, + **( + { + "grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}, + } + if structured_output_strategy == "native" + else {} + ), ) assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - @pytest.mark.parametrize("use_structured_output", [True, False]) - def test_try_stream(self, prompt_stack, mock_client_stream, use_structured_output): + @pytest.mark.parametrize("structured_output_strategy", ["native", "rule", "foo"]) + def test_try_stream(self, prompt_stack, mock_client_stream, structured_output_strategy): # Given driver = HuggingFaceHubPromptDriver( api_token="api-token", model="repo-id", stream=True, - use_structured_output=use_structured_output, extra_params={"foo": "bar"}, + structured_output_strategy=structured_output_strategy, ) # When @@ -102,9 +106,13 @@ def test_try_stream(self, prompt_stack, mock_client_stream, use_structured_outpu return_full_text=False, max_new_tokens=250, foo="bar", - **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} - if use_structured_output - else {}, + **( + { + "grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}, + } + if structured_output_strategy == "native" + else {} + ), stream=True, ) assert isinstance(event.content, TextDeltaMessageContent) @@ -118,6 +126,6 @@ def test_verify_structured_output_strategy(self): assert HuggingFaceHubPromptDriver(model="foo", api_token="bar", structured_output_strategy="native") with pytest.raises( - ValueError, match="HuggingFaceHubPromptDriver does not support `tool` structured output mode." + ValueError, match="HuggingFaceHubPromptDriver does not support `tool` structured output strategy." ): HuggingFaceHubPromptDriver(model="foo", api_token="bar", structured_output_strategy="tool") diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index af52ca4e9..e03604aaf 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -42,10 +42,15 @@ def messages(self): def test_init(self, mock_pipeline): assert HuggingFacePipelinePromptDriver(model="gpt2", max_tokens=42, pipeline=mock_pipeline) - def test_try_run(self, prompt_stack, messages, mock_pipeline): + @pytest.mark.parametrize("structured_output_strategy", ["rule", "foo"]) + def test_try_run(self, prompt_stack, messages, mock_pipeline, structured_output_strategy): # Given driver = HuggingFacePipelinePromptDriver( - model="foo", max_tokens=42, extra_params={"foo": "bar"}, pipeline=mock_pipeline + model="foo", + max_tokens=42, + extra_params={"foo": "bar"}, + pipeline=mock_pipeline, + structured_output_strategy=structured_output_strategy, ) # When @@ -57,9 +62,12 @@ def test_try_run(self, prompt_stack, messages, mock_pipeline): assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, prompt_stack, mock_pipeline): + @pytest.mark.parametrize("structured_output_strategy", ["rule", "foo"]) + def test_try_stream(self, prompt_stack, mock_pipeline, structured_output_strategy): # Given - driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, pipeline=mock_pipeline) + driver = HuggingFacePipelinePromptDriver( + model="foo", max_tokens=42, pipeline=mock_pipeline, structured_output_strategy=structured_output_strategy + ) # When with pytest.raises(Exception) as e: @@ -101,3 +109,11 @@ def test_prompt_stack_to_string(self, prompt_stack, mock_pipeline): # Then assert result == "model-output" + + def test_verify_structured_output_strategy(self): + assert HuggingFacePipelinePromptDriver(model="foo", structured_output_strategy="rule") + + with pytest.raises( + ValueError, match="HuggingFacePipelinePromptDriver does not support `native` structured output strategy." + ): + HuggingFacePipelinePromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index d638e84e2..02f284b76 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -16,19 +16,6 @@ class TestOllamaPromptDriver: "required": ["foo"], "type": "object", } - OLLAMA_STRUCTURED_OUTPUT_TOOL = { - "function": { - "description": "Used to provide the final response which ends this conversation.", - "name": "StructuredOutputTool_provide_output", - "parameters": { - "additionalProperties": False, - "properties": {"foo": {"type": "string"}}, - "required": ["foo"], - "type": "object", - }, - }, - "type": "function", - } OLLAMA_TOOLS = [ { "function": { @@ -232,22 +219,19 @@ def test_init(self): assert OllamaPromptDriver(model="llama") @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) def test_try_run( self, mock_client, prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given driver = OllamaPromptDriver( model="llama", use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -265,20 +249,11 @@ def test_try_run( "num_predict": driver.max_tokens, }, **{ - "tools": [ - *self.OLLAMA_TOOLS, - *( - [self.OLLAMA_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ] + "tools": self.OLLAMA_TOOLS, } if use_native_tools else {}, - **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} - if use_structured_output and structured_output_strategy == "native" - else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} if structured_output_strategy == "native" else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -290,15 +265,13 @@ def test_try_run( assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) def test_try_stream_run( self, mock_stream_client, prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given @@ -306,7 +279,6 @@ def test_try_stream_run( model="llama", stream=True, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -319,9 +291,7 @@ def test_try_stream_run( messages=messages, model=driver.model, options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, - **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} - if use_structured_output and structured_output_strategy == "native" - else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} if structured_output_strategy == "native" else {}, stream=True, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index eff9fda66..496560529 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -20,28 +20,6 @@ class TestOpenAiChatPromptDriverFixtureMixin: "required": ["foo"], "type": "object", } - OPENAI_STRUCTURED_OUTPUT_TOOL = { - "function": { - "description": "Used to provide the final response which ends this conversation.", - "name": "StructuredOutputTool_provide_output", - "parameters": { - "$id": "Parameters Schema", - "$schema": "http://json-schema.org/draft-07/schema#", - "additionalProperties": False, - "properties": { - "values": { - "additionalProperties": False, - "properties": {"foo": {"type": "string"}}, - "required": ["foo"], - "type": "object", - }, - }, - "required": ["values"], - "type": "object", - }, - }, - "type": "function", - } OPENAI_TOOLS = [ { "function": { @@ -371,22 +349,19 @@ def test_init(self): assert OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_4_MODEL) @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) def test_try_run( self, mock_chat_completion_create, prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -402,17 +377,8 @@ def test_try_run( messages=messages, seed=driver.seed, **{ - "tools": [ - *self.OPENAI_TOOLS, - *( - [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ], - "tool_choice": "required" - if use_structured_output and structured_output_strategy == "tool" - else driver.tool_choice, + "tools": self.OPENAI_TOOLS, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools @@ -427,7 +393,7 @@ def test_try_run( }, } } - if use_structured_output and structured_output_strategy == "native" + if prompt_stack.output_schema is not None and structured_output_strategy == "native" else {}, foo="bar", ) @@ -509,15 +475,13 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create, assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) def test_try_stream_run( self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given @@ -525,7 +489,6 @@ def test_try_stream_run( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -544,17 +507,8 @@ def test_try_stream_run( seed=driver.seed, stream_options={"include_usage": True}, **{ - "tools": [ - *self.OPENAI_TOOLS, - *( - [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ], - "tool_choice": "required" - if use_structured_output and structured_output_strategy == "tool" - else driver.tool_choice, + "tools": self.OPENAI_TOOLS, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools @@ -569,7 +523,7 @@ def test_try_stream_run( }, } } - if use_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, foo="bar", ) @@ -596,11 +550,11 @@ def test_try_stream_run( def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): # Given + prompt_stack.output_schema = None driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1, use_native_tools=False, - use_structured_output=False, ) # When @@ -630,12 +584,12 @@ def test_try_run_throws_when_multiple_choices_returned(self, mock_chat_completio assert e.value.args[0] == "Completion with more than one choice is not supported yet." def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): + prompt_stack.output_schema = None driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=MockTokenizer(model="mock-model", stop_sequences=["mock-stop"]), max_tokens=1, use_native_tools=False, - use_structured_output=False, ) # When diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 809d174b5..442f654d5 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest +import schema from griptape.memory import TaskMemory from griptape.memory.structure import ConversationMemory @@ -316,3 +317,14 @@ def test_field_hierarchy(self): assert isinstance(agent.tasks[0], PromptTask) assert agent.tasks[0].prompt_driver.stream is True + + def test_output_schema(self): + agent = Agent() + + assert isinstance(agent.tasks[0], PromptTask) + assert agent.tasks[0].output_schema is None + + agent = Agent(output_schema=schema.Schema({"foo": str})) + + assert isinstance(agent.tasks[0], PromptTask) + assert agent.tasks[0].output_schema is agent.output_schema diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 807e78f0b..da277e81e 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -83,8 +83,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, - "use_structured_output": False, - "structured_output_strategy": "native", + "structured_output_strategy": "rule", }, } ], diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 60a10f1a4..2cd102bf8 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,5 +1,3 @@ -import warnings - from griptape.artifacts.image_artifact import ImageArtifact from griptape.artifacts.json_artifact import JsonArtifact from griptape.artifacts.list_artifact import ListArtifact @@ -183,8 +181,8 @@ def test_prompt_stack_native_schema(self): task = PromptTask( input="foo", prompt_driver=MockPromptDriver( - use_structured_output=True, mock_structured_output={"baz": "foo"}, + structured_output_strategy="native", ), output_schema=output_schema, ) @@ -197,17 +195,33 @@ def test_prompt_stack_native_schema(self): assert task.prompt_stack.messages[0].is_user() assert "foo" in task.prompt_stack.messages[0].to_text() - # Ensure no warnings were raised - with warnings.catch_warnings(): - warnings.simplefilter("error") - assert task.prompt_stack + def test_prompt_stack_tool_schema(self): + from schema import Schema - def test_prompt_stack_empty_native_schema(self): + output_schema = Schema({"baz": str}) task = PromptTask( input="foo", prompt_driver=MockPromptDriver( - use_structured_output=True, + mock_structured_output={"baz": "foo"}, + structured_output_strategy="tool", + use_native_tools=True, ), + output_schema=output_schema, + ) + output = task.run() + + assert isinstance(output, JsonArtifact) + assert output.value == {"baz": "foo"} + + assert task.prompt_stack.output_schema is output_schema + assert task.prompt_stack.messages[0].is_system() + assert task.prompt_stack.messages[1].is_user() + assert "foo" in task.prompt_stack.messages[1].to_text() + + def test_prompt_stack_empty_native_schema(self): + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver(), rules=[JsonSchemaRule({"foo": {}})], ) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index 00bbadc45..5c7f6b394 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -257,8 +257,7 @@ def test_to_dict(self): "stream": False, "temperature": 0.1, "type": "MockPromptDriver", - "structured_output_strategy": "native", - "use_structured_output": False, + "structured_output_strategy": "rule", "use_native_tools": False, }, "tool": { diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 70c59e1f8..a5e95f4d1 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -399,8 +399,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, - "use_structured_output": False, - "structured_output_strategy": "native", + "structured_output_strategy": "rule", }, "tools": [ { From d0689677ec36819dec3bcd73a4ff50bd515311b5 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 3 Jan 2025 14:57:20 -0800 Subject: [PATCH 11/11] Move logic from task to base prompt driver --- .../drivers/prompt-drivers.md | 4 +- griptape/drivers/prompt/base_prompt_driver.py | 32 +++++++- griptape/tasks/prompt_task.py | 14 +--- .../templates/tasks/prompt_task/system.j2 | 4 - tests/mocks/mock_prompt_driver.py | 2 - .../drivers/prompt/test_base_prompt_driver.py | 76 ++++++++++++++++++- tests/unit/tasks/test_prompt_task.py | 64 +++++----------- 7 files changed, 128 insertions(+), 68 deletions(-) diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index a6694726b..e45d28d71 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -41,8 +41,8 @@ The easiest way to get started with structured output is by using a [PromptTask] You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of: - `native`: The Driver will use the LLM's structured output functionality provided by the API. -- `tool`: The Task will add a special tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md), and the Driver will try to force the LLM to use the Tool. -- `rule`: The Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. This strategy does not guarantee that the LLM will output JSON and should only be used as a last resort. +- `tool`: The Driver will add a special tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md), and will try to force the LLM to use the Tool. +- `rule`: The Driver will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. This strategy does not guarantee that the LLM will output JSON and should only be used as a last resort. ```python --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py" diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index c5ffb7259..a6d769021 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field -from griptape.artifacts.base_artifact import BaseArtifact +from griptape.artifacts import BaseArtifact, TextArtifact from griptape.common import ( ActionCallDeltaMessageContent, ActionCallMessageContent, @@ -26,6 +26,7 @@ ) from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin from griptape.mixins.serializable_mixin import SerializableMixin +from griptape.rules.json_schema_rule import JsonSchemaRule if TYPE_CHECKING: from collections.abc import Iterator @@ -64,6 +65,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: + self._init_structured_output(prompt_stack) EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: @@ -127,6 +129,34 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ... @abstractmethod def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ... + def _init_structured_output(self, prompt_stack: PromptStack) -> None: + from griptape.tools import StructuredOutputTool + + if (output_schema := prompt_stack.output_schema) is not None: + if self.structured_output_strategy == "tool": + structured_output_tool = StructuredOutputTool(output_schema=output_schema) + if structured_output_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_output_tool) + elif self.structured_output_strategy == "rule": + output_artifact = TextArtifact(JsonSchemaRule(output_schema.json_schema("Output Schema")).to_text()) + system_messages = prompt_stack.system_messages + if system_messages: + last_system_message = prompt_stack.system_messages[-1] + last_system_message.content.extend( + [ + TextMessageContent(TextArtifact("\n\n")), + TextMessageContent(output_artifact), + ] + ) + else: + prompt_stack.messages.insert( + 0, + Message( + content=[TextMessageContent(output_artifact)], + role=Message.SYSTEM_ROLE, + ), + ) + def __process_run(self, prompt_stack: PromptStack) -> Message: return self.try_run(prompt_stack) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 15c0f7457..b70b1eac7 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -15,7 +15,6 @@ from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.mixins.rule_mixin import RuleMixin from griptape.rules import Ruleset -from griptape.rules.json_schema_rule import JsonSchemaRule from griptape.tasks import ActionsSubtask, BaseTask from griptape.utils import J2 @@ -92,16 +91,9 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], @property def prompt_stack(self) -> PromptStack: - from griptape.tools.structured_output.tool import StructuredOutputTool - - stack = PromptStack(tools=self.tools) + stack = PromptStack(tools=self.tools, output_schema=self.output_schema) memory = self.structure.conversation_memory if self.structure is not None else None - if self.output_schema is not None: - stack.output_schema = self.output_schema - if self.prompt_driver.structured_output_strategy == "tool": - stack.tools.append(StructuredOutputTool(output_schema=stack.output_schema)) - system_template = self.generate_system_template(self) if system_template: stack.add_system_message(system_template) @@ -227,10 +219,6 @@ def default_generate_system_template(self, _: PromptTask) -> str: actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), use_native_tools=self.prompt_driver.use_native_tools, - structured_output_strategy=self.prompt_driver.structured_output_strategy, - json_schema_rule=JsonSchemaRule(self.output_schema.json_schema("Output")) - if self.output_schema is not None - else None, stop_sequence=self.response_stop_sequence, ) diff --git a/griptape/templates/tasks/prompt_task/system.j2 b/griptape/templates/tasks/prompt_task/system.j2 index 8e89e13c7..b262e7c72 100644 --- a/griptape/templates/tasks/prompt_task/system.j2 +++ b/griptape/templates/tasks/prompt_task/system.j2 @@ -26,7 +26,3 @@ NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER {{ rulesets }} {% endif %} -{% if json_schema_rule and structured_output_strategy == 'rule' %} - -{{ json_schema_rule }} -{% endif %} diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 3310a952e..1b481067b 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -36,7 +36,6 @@ class MockPromptDriver(BasePromptDriver): def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if self.use_native_tools and prompt_stack.tools: # Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer. action_messages = [ @@ -85,7 +84,6 @@ def try_run(self, prompt_stack: PromptStack) -> Message: def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if self.use_native_tools and prompt_stack.tools: # Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer. action_messages = [ diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 58720bbc5..3ffcebce4 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,9 +1,12 @@ -from griptape.artifacts import ErrorArtifact, TextArtifact +import json + +from griptape.artifacts import ActionArtifact, ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent from griptape.events.event_bus import _EventBus from griptape.structures import Pipeline from griptape.tasks import PromptTask +from griptape.tools.structured_output.tool import StructuredOutputTool from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -65,3 +68,74 @@ def test_run_with_tools_and_stream(self, mock_config): output = pipeline.run().output_task.output assert isinstance(output, TextArtifact) assert output.value == "mock output" + + def test_native_structured_output_strategy(self): + from schema import Schema + + prompt_driver = MockPromptDriver( + mock_structured_output={"baz": "foo"}, + structured_output_strategy="native", + ) + + output_schema = Schema({"baz": str}) + output = prompt_driver.run(PromptStack(messages=[], output_schema=output_schema)).to_artifact() + + assert isinstance(output, TextArtifact) + assert output.value == json.dumps({"baz": "foo"}) + + def test_tool_structured_output_strategy(self): + from schema import Schema + + output_schema = Schema({"baz": str}) + prompt_driver = MockPromptDriver( + mock_structured_output={"baz": "foo"}, + structured_output_strategy="tool", + use_native_tools=True, + ) + prompt_stack = PromptStack(messages=[], output_schema=output_schema) + output = prompt_driver.run(prompt_stack).to_artifact() + output = prompt_driver.run(prompt_stack).to_artifact() + + assert isinstance(output, ActionArtifact) + assert isinstance(prompt_stack.tools[0], StructuredOutputTool) + assert prompt_stack.tools[0].output_schema == output_schema + assert output.value.input == {"values": {"baz": "foo"}} + + def test_rule_structured_output_strategy_empty(self): + from schema import Schema + + output_schema = Schema({"baz": str}) + prompt_driver = MockPromptDriver( + mock_structured_output={"baz": "foo"}, + structured_output_strategy="rule", + ) + prompt_stack = PromptStack(messages=[], output_schema=output_schema) + output = prompt_driver.run(prompt_stack).to_artifact() + + assert len(prompt_stack.system_messages) == 1 + assert prompt_stack.messages[0].is_system() + assert "baz" in prompt_stack.messages[0].content[0].to_text() + assert isinstance(output, TextArtifact) + assert output.value == json.dumps({"baz": "foo"}) + + def test_rule_structured_output_strategy_populated(self): + from schema import Schema + + output_schema = Schema({"baz": str}) + prompt_driver = MockPromptDriver( + mock_structured_output={"baz": "foo"}, + structured_output_strategy="rule", + ) + prompt_stack = PromptStack( + messages=[ + Message(content="foo", role=Message.SYSTEM_ROLE), + ], + output_schema=output_schema, + ) + output = prompt_driver.run(prompt_stack).to_artifact() + assert len(prompt_stack.system_messages) == 1 + assert prompt_stack.messages[0].is_system() + assert prompt_stack.messages[0].content[1].to_text() == "\n\n" + assert "baz" in prompt_stack.messages[0].content[2].to_text() + assert isinstance(output, TextArtifact) + assert output.value == json.dumps({"baz": "foo"}) diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 2cd102bf8..d146d2249 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,5 +1,7 @@ +import pytest +import schema + from griptape.artifacts.image_artifact import ImageArtifact -from griptape.artifacts.json_artifact import JsonArtifact from griptape.artifacts.list_artifact import ListArtifact from griptape.artifacts.text_artifact import TextArtifact from griptape.memory.structure import ConversationMemory @@ -174,50 +176,6 @@ def test_prompt_stack_empty_system_content(self): assert task.prompt_stack.messages[2].is_user() assert task.prompt_stack.messages[2].to_text() == "test value" - def test_prompt_stack_native_schema(self): - from schema import Schema - - output_schema = Schema({"baz": str}) - task = PromptTask( - input="foo", - prompt_driver=MockPromptDriver( - mock_structured_output={"baz": "foo"}, - structured_output_strategy="native", - ), - output_schema=output_schema, - ) - output = task.run() - - assert isinstance(output, JsonArtifact) - assert output.value == {"baz": "foo"} - - assert task.prompt_stack.output_schema is output_schema - assert task.prompt_stack.messages[0].is_user() - assert "foo" in task.prompt_stack.messages[0].to_text() - - def test_prompt_stack_tool_schema(self): - from schema import Schema - - output_schema = Schema({"baz": str}) - task = PromptTask( - input="foo", - prompt_driver=MockPromptDriver( - mock_structured_output={"baz": "foo"}, - structured_output_strategy="tool", - use_native_tools=True, - ), - output_schema=output_schema, - ) - output = task.run() - - assert isinstance(output, JsonArtifact) - assert output.value == {"baz": "foo"} - - assert task.prompt_stack.output_schema is output_schema - assert task.prompt_stack.messages[0].is_system() - assert task.prompt_stack.messages[1].is_user() - assert "foo" in task.prompt_stack.messages[1].to_text() - def test_prompt_stack_empty_native_schema(self): task = PromptTask( input="foo", @@ -282,3 +240,19 @@ def test_subtasks(self): task.run() assert len(task.subtasks) == 2 + + @pytest.mark.parametrize("structured_output_strategy", ["native", "rule"]) + def test_parse_output(self, structured_output_strategy): + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver( + structured_output_strategy=structured_output_strategy, + mock_structured_output={"foo": "bar"}, + ), + output_schema=schema.Schema({"foo": str}), + ) + + task.run() + + assert task.output is not None + assert task.output.value == {"foo": "bar"}