From 3a074302ee3e9180e69584f428855eeb208d2b17 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 13 Dec 2024 14:50:42 -0800 Subject: [PATCH] Add Structured Output functionality --- CHANGELOG.md | 4 + .../drivers/prompt-drivers.md | 31 ++++ .../src/prompt_drivers_structured_output.py | 35 ++++ .../prompt_drivers_structured_output_multi.py | 28 ++++ .../griptape-framework/structures/rulesets.md | 7 + .../structures/src/json_schema_rule.py | 8 +- griptape/common/prompt_stack/prompt_stack.py | 5 +- .../prompt/amazon_bedrock_prompt_driver.py | 41 ++++- .../drivers/prompt/anthropic_prompt_driver.py | 37 ++++- griptape/drivers/prompt/base_prompt_driver.py | 16 +- .../drivers/prompt/cohere_prompt_driver.py | 23 ++- .../drivers/prompt/google_prompt_driver.py | 40 +++-- .../prompt/huggingface_hub_prompt_driver.py | 45 ++++- .../drivers/prompt/ollama_prompt_driver.py | 25 ++- .../prompt/openai_chat_prompt_driver.py | 31 +++- griptape/rules/json_schema_rule.py | 9 +- griptape/schemas/base_schema.py | 3 + griptape/tasks/actions_subtask.py | 11 +- griptape/tasks/prompt_task.py | 156 +++++++++++------- griptape/tools/__init__.py | 2 + griptape/tools/structured_output/__init__.py | 0 griptape/tools/structured_output/tool.py | 20 +++ pyproject.toml | 2 +- tests/mocks/mock_prompt_driver.py | 12 ++ .../test_amazon_bedrock_drivers_config.py | 4 + .../drivers/test_anthropic_drivers_config.py | 2 + .../test_azure_openai_drivers_config.py | 2 + .../drivers/test_cohere_drivers_config.py | 2 + .../configs/drivers/test_drivers_config.py | 2 + .../drivers/test_google_drivers_config.py | 2 + .../drivers/test_openai_driver_config.py | 2 + .../test_amazon_bedrock_prompt_driver.py | 82 ++++++++- .../prompt/test_anthropic_prompt_driver.py | 73 +++++++- .../test_azure_openai_chat_prompt_driver.py | 78 ++++++++- .../drivers/prompt/test_base_prompt_driver.py | 23 +++ .../prompt/test_cohere_prompt_driver.py | 107 +++++++++++- .../prompt/test_google_prompt_driver.py | 52 +++++- .../test_hugging_face_hub_prompt_driver.py | 42 ++++- .../prompt/test_ollama_prompt_driver.py | 109 +++++++++--- .../prompt/test_openai_chat_prompt_driver.py | 115 ++++++++++++- tests/unit/structures/test_structure.py | 2 + tests/unit/tasks/test_actions_subtask.py | 68 +++++++- tests/unit/tasks/test_prompt_task.py | 81 +++++++++ tests/unit/tasks/test_tool_task.py | 2 + tests/unit/tasks/test_toolkit_task.py | 2 + .../unit/tools/test_structured_output_tool.py | 13 ++ 46 files changed, 1264 insertions(+), 192 deletions(-) create mode 100644 docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py create mode 100644 docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py create mode 100644 griptape/tools/structured_output/__init__.py create mode 100644 griptape/tools/structured_output/tool.py create mode 100644 tests/unit/tools/test_structured_output_tool.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a8cb3296..bb6a26a4d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for `BranchTask` in `StructureVisualizer`. - `EvalEngine` for evaluating the performance of an LLM's output against a given input. - `BaseFileLoader.save()` method for saving an Artifact to a destination. +- `BasePromptDriver.use_native_structured_output` for enabling or disabling structured output. +- `BasePromptDriver.native_structured_output_strategy` for changing the structured output strategy between `native` and `tool`. ### Changed @@ -23,6 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `ToolkitTask` now serializes its `tools` field. - `PromptTask.prompt_driver` is now serialized. - `PromptTask` can now do everything a `ToolkitTask` can do. +- `JsonSchemaRule`s can now take a `schema.Schema` instance. Required for using a `JsonSchemaRule` with structured output. +- `JsonSchemaRule`s will now be used for structured output if the Prompt Driver supports it. ### Fixed diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 6c51d2d01..f05647bb3 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -25,6 +25,37 @@ You can pass images to the Driver if the model supports it: --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_images.py" ``` +## Structured Output + +Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems. + +Structured output can be enabled or disabled for a Prompt Driver by setting the [use_native_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_native_structured_output). + +If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [native_structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.native_structured_output_strategy) to one of: + +- `native`: The Driver will use the LLM's structured output functionality provided by the API. +- `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output_tool.md) and try to force the LLM to use a Tool. + +### JSON Schema + +The easiest way to get started with structured output is by using a [JsonSchemaRule](../structures/rulesets.md#json-schema). If a [schema.Schema](https://pypi.org/project/schema/) instance is provided to the Rule, Griptape will convert it to a JSON Schema and provide it to the LLM using the selected structured output strategy. + +```python +--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py" +``` + +### Multiple Schemas + +If multiple `JsonSchemaRule`s are provided, Griptape will merge them into a single JSON Schema using `anyOf`. + +Some LLMs may not support `anyOf` as a top-level JSON Schema. To work around this, you can try using another `native_structured_output_strategy`: + +```python +--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py" +``` + +Not every LLM supports `use_native_structured_output` or all `native_structured_output_strategy` options. + ## Prompt Drivers Griptape offers the following Prompt Drivers for interacting with LLMs. diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py new file mode 100644 index 000000000..6613c3a3e --- /dev/null +++ b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py @@ -0,0 +1,35 @@ +import schema +from rich.pretty import pprint + +from griptape.drivers import OpenAiChatPromptDriver +from griptape.rules import JsonSchemaRule, Rule +from griptape.structures import Pipeline +from griptape.tasks import PromptTask + +pipeline = Pipeline( + tasks=[ + PromptTask( + prompt_driver=OpenAiChatPromptDriver( + model="gpt-4o", + use_native_structured_output=True, + native_structured_output_strategy="native", + ), + rules=[ + Rule("You are a helpful math tutor. Guide the user through the solution step by step."), + JsonSchemaRule( + schema.Schema( + { + "steps": [schema.Schema({"explanation": str, "output": str})], + "final_answer": str, + } + ) + ), + ], + ) + ] +) + +output = pipeline.run("How can I solve 8x + 7 = -23").output.value + + +pprint(output) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py new file mode 100644 index 000000000..0b85cee94 --- /dev/null +++ b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output_multi.py @@ -0,0 +1,28 @@ +import schema +from rich.pretty import pprint + +from griptape.drivers import OpenAiChatPromptDriver +from griptape.rules import JsonSchemaRule +from griptape.structures import Pipeline +from griptape.tasks import PromptTask + +pipeline = Pipeline( + tasks=[ + PromptTask( + prompt_driver=OpenAiChatPromptDriver( + model="gpt-4o", + use_native_structured_output=True, + native_structured_output_strategy="tool", + ), + rules=[ + JsonSchemaRule(schema.Schema({"color": "red"})), + JsonSchemaRule(schema.Schema({"color": "blue"})), + ], + ) + ] +) + +output = pipeline.run("Pick a color").output.value + + +pprint(output) diff --git a/docs/griptape-framework/structures/rulesets.md b/docs/griptape-framework/structures/rulesets.md index f7a1de482..181ea2478 100644 --- a/docs/griptape-framework/structures/rulesets.md +++ b/docs/griptape-framework/structures/rulesets.md @@ -29,6 +29,9 @@ A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Ru [JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema. This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types. +If the Prompt Driver supports [Structured Output](../drivers/prompt-drivers.md#structured-output), Griptape will use the schema provided to the `JsonSchemaRule` to ensure JSON output. +If the Prompt Driver does not support Structured Output, Griptape will include the schema in the system prompt using [this template](https://github.com/griptape-ai/griptape/blob/main/griptape/templates/rules/json_schema.j2). + ```python --8<-- "docs/griptape-framework/structures/src/json_schema_rule.py" ``` @@ -47,6 +50,10 @@ Although Griptape leverages the `schema` library, you're free to use any JSON sc For example, using `pydantic`: +!!! warning + +Griptape does not yet support using `pydantic` schemas for[Structured Output](../drivers/prompt-drivers.md#structured-output). It is recommended to pass a `schema.Schema` instance. + ```python --8<-- "docs/griptape-framework/structures/src/json_schema_rule_pydantic.py" ``` diff --git a/docs/griptape-framework/structures/src/json_schema_rule.py b/docs/griptape-framework/structures/src/json_schema_rule.py index 1f78de928..ebee0d882 100644 --- a/docs/griptape-framework/structures/src/json_schema_rule.py +++ b/docs/griptape-framework/structures/src/json_schema_rule.py @@ -5,13 +5,7 @@ from griptape.rules.json_schema_rule import JsonSchemaRule from griptape.structures import Agent -agent = Agent( - rules=[ - JsonSchemaRule( - schema.Schema({"answer": str, "relevant_emojis": schema.Schema(["str"])}).json_schema("Output Format") - ) - ] -) +agent = Agent(rules=[JsonSchemaRule(schema.Schema({"answer": str, "relevant_emojis": schema.Schema(["str"])}))]) output = agent.run("What is the sentiment of this message?: 'I am so happy!'").output diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index 3b1b8ef74..752ce8a8d 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from attrs import define, field @@ -24,6 +24,8 @@ from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: + from schema import Schema + from griptape.tools import BaseTool @@ -31,6 +33,7 @@ class PromptStack(SerializableMixin): messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True}) tools: list[BaseTool] = field(factory=list, kw_only=True) + output_schema: Optional[Schema] = field(default=None, kw_only=True) @property def system_messages(self) -> list[Message]: diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index b108180d2..9e467c532 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from schema import Schema from griptape.artifacts import ( @@ -55,9 +55,20 @@ class AmazonBedrockPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_strategy: Literal["native", "tool"] = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True}) _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + @native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + if value == "native": + raise ValueError("AmazonBedrockPromptDriver does not support `native` structured output mode.") + + return value + @lazy_property() def client(self) -> Any: return self.session.client("bedrock-runtime") @@ -103,10 +114,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages] - messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()]) - return { + params = { "modelId": self.model, "messages": messages, "system": system_messages, @@ -115,14 +125,27 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}), }, "additionalModelRequestFields": self.additional_model_request_fields, - **( - {"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}} - if prompt_stack.tools and self.use_native_tools - else {} - ), **self.extra_params, } + if prompt_stack.tools and self.use_native_tools: + params["toolConfig"] = { + "tools": [], + "toolChoice": self.tool_choice, + } + + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_strategy == "tool" + ): + self._add_structured_output_tool(prompt_stack) + params["toolConfig"]["toolChoice"] = {"any": {}} + + params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools) + + return params + def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]: return [ { diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 3341006a1..bc4ff84f0 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from schema import Schema from griptape.artifacts import ( @@ -68,6 +68,10 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_strategy: Literal["native", "tool"] = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @@ -75,6 +79,13 @@ class AnthropicPromptDriver(BasePromptDriver): def client(self) -> Client: return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) + @native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + if value == "native": + raise ValueError("AnthropicPromptDriver does not support `native` structured output mode.") + + return value + @observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) @@ -110,7 +121,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = prompt_stack.system_messages system_message = system_messages[0].to_text() if system_messages else None - return { + params = { "model": self.model, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, @@ -118,15 +129,25 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "top_k": self.top_k, "max_tokens": self.max_tokens, "messages": messages, - **( - {"tools": self.__to_anthropic_tools(prompt_stack.tools), "tool_choice": self.tool_choice} - if prompt_stack.tools and self.use_native_tools - else {} - ), **({"system": system_message} if system_message else {}), **self.extra_params, } + if prompt_stack.tools and self.use_native_tools: + params["tool_choice"] = self.tool_choice + + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_strategy == "tool" + ): + self._add_structured_output_tool(prompt_stack) + params["tool_choice"] = {"type": "any"} + + params["tools"] = self.__to_anthropic_tools(prompt_stack.tools) + + return params + def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]: return [ {"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)} diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 707f67644..19109f55f 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional from attrs import Factory, define, field @@ -56,6 +56,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + native_structured_output_strategy: Literal["native", "tool"] = field( + default="native", kw_only=True, metadata={"serializable": True} + ) extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: @@ -122,6 +126,16 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ... @abstractmethod def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ... + def _add_structured_output_tool(self, prompt_stack: PromptStack) -> None: + from griptape.tools.structured_output.tool import StructuredOutputTool + + if prompt_stack.output_schema is None: + raise ValueError("PromptStack must have an output schema to use structured output.") + + structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) + if structured_output_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_output_tool) + def __process_run(self, prompt_stack: PromptStack) -> Message: return self.try_run(prompt_stack) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 3811db5cd..2695aba09 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -53,6 +53,7 @@ class CoherePromptDriver(BasePromptDriver): model: str = field(metadata={"serializable": True}) force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), @@ -101,21 +102,31 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: messages = self.__to_cohere_messages(prompt_stack.messages) - return { + params = { "model": self.model, "messages": messages, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, "max_tokens": self.max_tokens, **({"tool_results": tool_results} if tool_results else {}), - **( - {"tools": self.__to_cohere_tools(prompt_stack.tools)} - if prompt_stack.tools and self.use_native_tools - else {} - ), **self.extra_params, } + if prompt_stack.output_schema is not None and self.use_native_structured_output: + if self.native_structured_output_strategy == "native": + params["response_format"] = { + "type": "json_object", + "schema": prompt_stack.output_schema.json_schema("Output"), + } + elif self.native_structured_output_strategy == "tool": + # TODO: Implement tool choice once supported + self._add_structured_output_tool(prompt_stack) + + if prompt_stack.tools and self.use_native_tools: + params["tools"] = self.__to_cohere_tools(prompt_stack.tools) + + return params + def __to_cohere_messages(self, messages: list[Message]) -> list[dict]: cohere_messages = [] diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 2a6bdbf6d..23de1e42d 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -2,9 +2,9 @@ import json import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from schema import Schema from griptape.artifacts import ActionArtifact, TextArtifact @@ -63,9 +63,20 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_strategy: Literal["native", "tool"] = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True}) _client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + @native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + if value == "native": + raise ValueError("GooglePromptDriver does not support `native` structured output mode.") + + return value + @lazy_property() def client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") @@ -135,7 +146,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages], ) - return { + params = { "generation_config": types.GenerationConfig( **{ # For some reason, providing stop sequences when streaming breaks native functions @@ -148,16 +159,23 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, }, ), - **( - { - "tools": self.__to_google_tools(prompt_stack.tools), - "tool_config": {"function_calling_config": {"mode": self.tool_choice}}, - } - if prompt_stack.tools and self.use_native_tools - else {} - ), } + if prompt_stack.tools and self.use_native_tools: + params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} + + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_strategy == "tool" + ): + params["tool_config"]["function_calling_config"]["mode"] = "auto" + self._add_structured_output_tool(prompt_stack) + + params["tools"] = self.__to_google_tools(prompt_stack.tools) + + return params + def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: types = import_optional_dependency("google.generativeai.types") diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index c2c45c3ae..f9acdeb1d 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from griptape.common import DeltaMessage, Message, PromptStack, TextDeltaMessageContent, observable from griptape.configs import Defaults @@ -35,6 +35,10 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_strategy: Literal["native", "tool"] = field( + default="native", kw_only=True, metadata={"serializable": True} + ) tokenizer: HuggingFaceTokenizer = field( default=Factory( lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), @@ -51,11 +55,23 @@ def client(self) -> InferenceClient: token=self.api_token, ) + @native_structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_native_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + if value == "tool": + raise ValueError("HuggingFaceHubPromptDriver does not support `tool` structured output mode.") + + return value + @observable def try_run(self, prompt_stack: PromptStack) -> Message: prompt = self.prompt_stack_to_string(prompt_stack) full_params = self._base_params(prompt_stack) - logger.debug((prompt, full_params)) + logger.debug( + { + "prompt": prompt, + **full_params, + } + ) response = self.client.text_generation( prompt, @@ -75,7 +91,12 @@ def try_run(self, prompt_stack: PromptStack) -> Message: def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: prompt = self.prompt_stack_to_string(prompt_stack) full_params = {**self._base_params(prompt_stack), "stream": True} - logger.debug((prompt, full_params)) + logger.debug( + { + "prompt": prompt, + **full_params, + } + ) response = self.client.text_generation(prompt, **full_params) @@ -94,12 +115,26 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) def _base_params(self, prompt_stack: PromptStack) -> dict: - return { + params = { "return_full_text": False, "max_new_tokens": self.max_tokens, **self.extra_params, } + if ( + prompt_stack.output_schema + and self.use_native_structured_output + and self.native_structured_output_strategy == "native" + ): + # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding + output_schema = prompt_stack.output_schema.json_schema("Output Schema") + # Grammar does not support $schema and $id + del output_schema["$schema"] + del output_schema["$id"] + params["grammar"] = {"type": "json", "value": output_schema} + + return params + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 5cbba1fdf..25756cc1c 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -68,6 +68,7 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() @@ -79,7 +80,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) response = self.client.chat(**params) - logger.debug(response) + logger.debug(response.model_dump()) return Message( content=self.__to_prompt_stack_message_content(response), @@ -102,20 +103,26 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: def _base_params(self, prompt_stack: PromptStack) -> dict: messages = self._prompt_stack_to_messages(prompt_stack) - return { + params = { "messages": messages, "model": self.model, "options": self.options, - **( - {"tools": self.__to_ollama_tools(prompt_stack.tools)} - if prompt_stack.tools - and self.use_native_tools - and not self.stream # Tool calling is only supported when not streaming - else {} - ), **self.extra_params, } + if prompt_stack.output_schema is not None and self.use_native_structured_output: + if self.native_structured_output_strategy == "native": + params["format"] = prompt_stack.output_schema.json_schema("Output") + elif self.native_structured_output_strategy == "tool": + # TODO: Implement tool choice once supported + self._add_structured_output_tool(prompt_stack) + + # Tool calling is only supported when not streaming + if prompt_stack.tools and self.use_native_tools and not self.stream: + params["tools"] = self.__to_ollama_tools(prompt_stack.tools) + + return params + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: ollama_messages = [] for message in prompt_stack.messages: diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index eed0e35f0..d8f61a3bf 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -76,6 +76,7 @@ class OpenAiChatPromptDriver(BasePromptDriver): seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory( @@ -148,21 +149,30 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "temperature": self.temperature, "user": self.user, "seed": self.seed, - **( - { - "tools": self.__to_openai_tools(prompt_stack.tools), - "tool_choice": self.tool_choice, - "parallel_tool_calls": self.parallel_tool_calls, - } - if prompt_stack.tools and self.use_native_tools - else {} - ), **({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}), **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}), **({"stream_options": {"include_usage": True}} if self.stream else {}), **self.extra_params, } + if prompt_stack.tools and self.use_native_tools: + params["tool_choice"] = self.tool_choice + params["parallel_tool_calls"] = self.parallel_tool_calls + + if prompt_stack.output_schema is not None and self.use_native_structured_output: + if self.native_structured_output_strategy == "native": + params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": prompt_stack.output_schema.json_schema("Output"), + "strict": True, + }, + } + elif self.native_structured_output_strategy == "tool" and self.use_native_tools: + params["tool_choice"] = "required" + self._add_structured_output_tool(prompt_stack) + if self.response_format is not None: if self.response_format == {"type": "json_object"}: params["response_format"] = self.response_format @@ -171,6 +181,9 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: else: params["response_format"] = self.response_format + if prompt_stack.tools and self.use_native_tools: + params["tools"] = self.__to_openai_tools(prompt_stack.tools) + messages = self.__to_openai_messages(prompt_stack.messages) params["messages"] = messages diff --git a/griptape/rules/json_schema_rule.py b/griptape/rules/json_schema_rule.py index c068eb4a1..84a700ce8 100644 --- a/griptape/rules/json_schema_rule.py +++ b/griptape/rules/json_schema_rule.py @@ -1,17 +1,22 @@ from __future__ import annotations import json +from typing import TYPE_CHECKING, Union from attrs import Factory, define, field from griptape.rules import BaseRule from griptape.utils import J2 +if TYPE_CHECKING: + from schema import Schema + @define() class JsonSchemaRule(BaseRule): - value: dict = field(metadata={"serializable": True}) + value: Union[dict, Schema] = field(metadata={"serializable": True}) generate_template: J2 = field(default=Factory(lambda: J2("rules/json_schema.j2"))) def to_text(self) -> str: - return self.generate_template.render(json_schema=json.dumps(self.value)) + value = self.value if isinstance(self.value, dict) else self.value.json_schema("Output Schema") + return self.generate_template.render(json_schema=json.dumps(value)) diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 7b23c620f..9217f26c2 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -151,6 +151,8 @@ def _resolve_types(cls, attrs_cls: type) -> None: from collections.abc import Sequence from typing import Any + from schema import Schema + from griptape.artifacts import BaseArtifact from griptape.common import ( BaseDeltaMessageContent, @@ -215,6 +217,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: "BaseRule": BaseRule, "Ruleset": Ruleset, # Third party modules + "Schema": Schema, "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any, "ClientV2": import_optional_dependency("cohere").ClientV2 if is_dependency_installed("cohere") else Any, "GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 6f9d70053..c889554fd 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -9,12 +9,13 @@ from attrs import define, field from griptape import utils -from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, JsonArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction from griptape.configs import Defaults from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask +from griptape.tools.structured_output.tool import StructuredOutputTool from griptape.utils import remove_null_values_in_dict_recursively, with_contextvars if TYPE_CHECKING: @@ -87,6 +88,14 @@ def attach_to(self, parent_task: BaseTask) -> None: self.__init_from_prompt(self.input.to_text()) else: self.__init_from_artifacts(self.input) + + structured_outputs = [a for a in self.actions if isinstance(a.tool, StructuredOutputTool)] + if structured_outputs: + output_values = [JsonArtifact(a.input["values"]) for a in structured_outputs] + if len(structured_outputs) > 1: + self.output = ListArtifact(output_values) + else: + self.output = output_values[0] except Exception as e: logger.error("Subtask %s\nError parsing tool action: %s", self.origin_task.id, e) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 4ed313bf0..bb5ca9667 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -2,24 +2,25 @@ import json import logging +import warnings from typing import TYPE_CHECKING, Callable, Optional, Union from attrs import NOTHING, Attribute, Factory, NothingType, define, field +from schema import Or, Schema from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts.json_artifact import JsonArtifact from griptape.common import PromptStack, ToolAction from griptape.configs import Defaults from griptape.memory.structure import Run from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.mixins.rule_mixin import RuleMixin -from griptape.rules import Ruleset +from griptape.rules import JsonSchemaRule, Ruleset from griptape.tasks import ActionsSubtask, BaseTask from griptape.utils import J2 if TYPE_CHECKING: - from schema import Schema - from griptape.drivers import BasePromptDriver from griptape.memory import TaskMemory from griptape.memory.structure.base_conversation_memory import BaseConversationMemory @@ -92,54 +93,30 @@ def prompt_stack(self) -> PromptStack: stack = PromptStack(tools=self.tools) memory = self.structure.conversation_memory if self.structure is not None else None - system_template = self.generate_system_template(self) - if system_template: - stack.add_system_message(system_template) + rulesets = self.rulesets + system_artifacts = [TextArtifact(self.generate_system_template(self))] + if self.prompt_driver.use_native_structured_output: + self._add_native_schema_to_prompt_stack(stack, rulesets) + + # Ensure there is at least one Ruleset that has non-empty `rules`. + if any(len(ruleset.rules) for ruleset in rulesets): + system_artifacts.append(TextArtifact(J2("rulesets/rulesets.j2").render(rulesets=rulesets))) + + # Ensure there is at least one system Artifact that has a non-empty value. + has_system_artifacts = any(system_artifact.value for system_artifact in system_artifacts) + if has_system_artifacts: + stack.add_system_message(ListArtifact(system_artifacts)) stack.add_user_message(self.input) if self.output: stack.add_assistant_message(self.output.to_text()) else: - for s in self.subtasks: - if self.prompt_driver.use_native_tools: - action_calls = [ - ToolAction(name=action.name, path=action.path, tag=action.tag, input=action.input) - for action in s.actions - ] - action_results = [ - ToolAction( - name=action.name, - path=action.path, - tag=action.tag, - output=action.output if action.output is not None else s.output, - ) - for action in s.actions - ] - - stack.add_assistant_message( - ListArtifact( - [ - *([TextArtifact(s.thought)] if s.thought else []), - *[ActionArtifact(a) for a in action_calls], - ], - ), - ) - stack.add_user_message( - ListArtifact( - [ - *[ActionArtifact(a) for a in action_results], - *([] if s.output else [TextArtifact("Please keep going")]), - ], - ), - ) - else: - stack.add_assistant_message(self.generate_assistant_subtask_template(s)) - stack.add_user_message(self.generate_user_subtask_template(s)) + self._add_subtasks_to_prompt_stack(stack) if memory is not None: # inserting at index 1 to place memory right after system prompt - memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_template else 0) + memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if has_system_artifacts else 0) return stack @@ -203,23 +180,32 @@ def try_run(self) -> BaseArtifact: self.prompt_driver.tokenizer.stop_sequences.extend([self.response_stop_sequence]) result = self.prompt_driver.run(self.prompt_stack) - subtask = self.add_subtask(ActionsSubtask(result.to_artifact())) - - while True: - if subtask.output is None: - if len(self.subtasks) >= self.max_subtasks: - subtask.output = ErrorArtifact(f"Exceeded tool limit of {self.max_subtasks} subtasks per task") + if self.tools: + subtask = self.add_subtask(ActionsSubtask(result.to_artifact())) + + while True: + if subtask.output is None: + if len(self.subtasks) >= self.max_subtasks: + subtask.output = ErrorArtifact(f"Exceeded tool limit of {self.max_subtasks} subtasks per task") + else: + subtask.run() + + result = self.prompt_driver.run(self.prompt_stack) + subtask = self.add_subtask(ActionsSubtask(result.to_artifact())) else: - subtask.run() - - result = self.prompt_driver.run(self.prompt_stack) - subtask = self.add_subtask(ActionsSubtask(result.to_artifact())) - else: - break + break - self.output = subtask.output + output = subtask.output + else: + output = result.to_artifact() - return self.output + if ( + self.prompt_driver.use_native_structured_output + and self.prompt_driver.native_structured_output_strategy == "native" + ): + return JsonArtifact(output.value) + else: + return output def preprocess(self, structure: Structure) -> BaseTask: super().preprocess(structure) @@ -240,7 +226,6 @@ def default_generate_system_template(self, _: PromptTask) -> str: schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually. return J2("tasks/prompt_task/system.j2").render( - rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), action_names=str.join(", ", [tool.name for tool in self.tools]), actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), @@ -321,3 +306,60 @@ def _process_task_input( return ListArtifact([self._process_task_input(elem) for elem in task_input]) else: return self._process_task_input(TextArtifact(task_input)) + + def _add_native_schema_to_prompt_stack(self, stack: PromptStack, rulesets: list[Ruleset]) -> None: + # Need to separate JsonSchemaRules from other rules, removing them in the process + json_schema_rules = [rule for ruleset in rulesets for rule in ruleset.rules if isinstance(rule, JsonSchemaRule)] + non_json_schema_rules = [ + [rule for rule in ruleset.rules if not isinstance(rule, JsonSchemaRule)] for ruleset in rulesets + ] + for ruleset, non_json_rules in zip(rulesets, non_json_schema_rules): + ruleset.rules = non_json_rules + + schemas = [rule.value for rule in json_schema_rules if isinstance(rule.value, Schema)] + + if len(json_schema_rules) != len(schemas): + warnings.warn( + "Not all provided `JsonSchemaRule`s include a `schema.Schema` instance. These will be ignored with `use_native_structured_output`.", + stacklevel=2, + ) + + if schemas: + stack.output_schema = schemas[0] if len(schemas) == 1 else Schema(Or(*schemas)) + + def _add_subtasks_to_prompt_stack(self, stack: PromptStack) -> None: + for s in self.subtasks: + if self.prompt_driver.use_native_tools: + action_calls = [ + ToolAction(name=action.name, path=action.path, tag=action.tag, input=action.input) + for action in s.actions + ] + action_results = [ + ToolAction( + name=action.name, + path=action.path, + tag=action.tag, + output=action.output if action.output is not None else s.output, + ) + for action in s.actions + ] + + stack.add_assistant_message( + ListArtifact( + [ + *([TextArtifact(s.thought)] if s.thought else []), + *[ActionArtifact(a) for a in action_calls], + ], + ), + ) + stack.add_user_message( + ListArtifact( + [ + *[ActionArtifact(a) for a in action_results], + *([] if s.output else [TextArtifact("Please keep going")]), + ], + ), + ) + else: + stack.add_assistant_message(self.generate_assistant_subtask_template(s)) + stack.add_user_message(self.generate_user_subtask_template(s)) diff --git a/griptape/tools/__init__.py b/griptape/tools/__init__.py index 67a1712a1..ec9cbd5b7 100644 --- a/griptape/tools/__init__.py +++ b/griptape/tools/__init__.py @@ -23,6 +23,7 @@ from .extraction.tool import ExtractionTool from .prompt_summary.tool import PromptSummaryTool from .query.tool import QueryTool +from .structured_output.tool import StructuredOutputTool __all__ = [ "BaseTool", @@ -50,4 +51,5 @@ "ExtractionTool", "PromptSummaryTool", "QueryTool", + "StructuredOutputTool", ] diff --git a/griptape/tools/structured_output/__init__.py b/griptape/tools/structured_output/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/tools/structured_output/tool.py b/griptape/tools/structured_output/tool.py new file mode 100644 index 000000000..89e638f59 --- /dev/null +++ b/griptape/tools/structured_output/tool.py @@ -0,0 +1,20 @@ +from attrs import define, field +from schema import Schema + +from griptape.artifacts import BaseArtifact, JsonArtifact +from griptape.tools import BaseTool +from griptape.utils.decorators import activity + + +@define +class StructuredOutputTool(BaseTool): + output_schema: Schema = field(kw_only=True) + + @activity( + config={ + "description": "Used to provide the final response which ends this conversation.", + "schema": lambda self: self.output_schema, + } + ) + def provide_output(self, params: dict) -> BaseArtifact: + return JsonArtifact(params["values"]) diff --git a/pyproject.toml b/pyproject.toml index b7e52c6f7..737ee639a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -315,7 +315,7 @@ fixture-parentheses = true "ANN202", # missing-return-type-private-function ] "docs/*" = [ - "T20" # flake8-print + "T20", # flake8-print ] [tool.ruff.lint.flake8-tidy-imports.banned-api] diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index f308c9804..abef72227 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import TYPE_CHECKING, Callable, Union from attrs import define, field @@ -31,9 +32,20 @@ class MockPromptDriver(BasePromptDriver): tokenizer: BaseTokenizer = MockTokenizer(model="test-model", max_input_tokens=4096, max_output_tokens=4096) mock_input: Union[str, Callable[[], str]] = field(default="mock input", kw_only=True) mock_output: Union[str, Callable[[PromptStack], str]] = field(default="mock output", kw_only=True) + mock_structured_output: Union[dict, Callable[[PromptStack], dict]] = field(factory=dict, kw_only=True) def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output + if prompt_stack.output_schema and self.use_native_structured_output: + if self.native_structured_output_strategy == "native": + return Message( + content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=100, output_tokens=100), + ) + elif self.native_structured_output_strategy == "tool": + self._add_structured_output_tool(prompt_stack) + if self.use_native_tools and prompt_stack.tools: # Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer. action_messages = [ diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index 52408922c..59eb4ac61 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -51,6 +51,8 @@ def test_to_dict(self, config): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_strategy": "tool", "extra_params": {}, }, "vector_store_driver": { @@ -106,6 +108,8 @@ def test_to_dict_with_values(self, config_with_values): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_strategy": "tool", "extra_params": {}, }, "vector_store_driver": { diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index 8a6f25ef2..66f987308 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -25,6 +25,8 @@ def test_to_dict(self, config): "top_p": 0.999, "top_k": 250, "use_native_tools": True, + "native_structured_output_strategy": "tool", + "use_native_structured_output": True, "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index 4c44113a0..2281f4c11 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -36,6 +36,8 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, + "native_structured_output_strategy": "native", + "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 65295da52..6f371c5ba 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -26,6 +26,8 @@ def test_to_dict(self, config): "model": "command-r", "force_single_step": False, "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_strategy": "native", "extra_params": {}, }, "embedding_driver": { diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index ca3cea60e..dd2e1736b 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,6 +18,8 @@ def test_to_dict(self, config): "max_tokens": None, "stream": False, "use_native_tools": False, + "use_native_structured_output": False, + "native_structured_output_strategy": "native", "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index c1459a400..569e45561 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -25,6 +25,8 @@ def test_to_dict(self, config): "top_k": None, "tool_choice": "auto", "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_strategy": "tool", "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index c71774b26..603d9867a 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -28,6 +28,8 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, + "native_structured_output_strategy": "native", + "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index ada192aae..f5e3fe67d 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,4 +1,5 @@ import pytest +from schema import Schema from griptape.artifacts import ActionArtifact, ErrorArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction @@ -7,6 +8,29 @@ class TestAmazonBedrockPromptDriver: + BEDROCK_STRUCTURED_OUTPUT_TOOL = { + "toolSpec": { + "description": "Used to provide the final response which ends this conversation.", + "inputSchema": { + "json": { + "$id": "http://json-schema.org/draft-07/schema#", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + }, + "name": "StructuredOutputTool_provide_output", + }, + } BEDROCK_TOOLS = [ { "toolSpec": { @@ -229,6 +253,7 @@ def mock_converse_stream(self, mocker): def prompt_stack(self, request): prompt_stack = PromptStack() prompt_stack.tools = [MockTool()] + prompt_stack.output_schema = Schema({"foo": str}) if request.param: prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -357,10 +382,14 @@ def messages(self): ] @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, use_native_structured_output): # Given driver = AmazonBedrockPromptDriver( - model="ai21.j2", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="ai21.j2", + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, ) # When @@ -377,7 +406,19 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): additionalModelRequestFields={}, **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), **( - {"toolConfig": {"tools": self.BEDROCK_TOOLS, "toolChoice": driver.tool_choice}} + { + "toolConfig": { + "tools": [ + *self.BEDROCK_TOOLS, + *( + [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.native_structured_output_strategy == "tool" + else [] + ), + ], + "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, + } + } if use_native_tools else {} ), @@ -394,10 +435,17 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream_run( + self, mock_converse_stream, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = AmazonBedrockPromptDriver( - model="ai21.j2", stream=True, use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="ai21.j2", + stream=True, + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, ) # When @@ -415,8 +463,20 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_ additionalModelRequestFields={}, **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), **( - {"toolConfig": {"tools": self.BEDROCK_TOOLS, "toolChoice": driver.tool_choice}} - if prompt_stack.tools and use_native_tools + { + "toolConfig": { + "tools": [ + *self.BEDROCK_TOOLS, + *( + [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.native_structured_output_strategy == "tool" + else [] + ), + ], + "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, + } + } + if use_native_tools else {} ), foo="bar", @@ -439,3 +499,11 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_ event = next(stream) assert event.usage.input_tokens == 5 assert event.usage.output_tokens == 10 + + def test_verify_native_structured_output_strategy(self): + assert AmazonBedrockPromptDriver(model="foo", native_structured_output_strategy="tool") + + with pytest.raises( + ValueError, match="AmazonBedrockPromptDriver does not support `native` structured output mode." + ): + AmazonBedrockPromptDriver(model="foo", native_structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 2b84b5a17..cc802481f 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest +from schema import Schema from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact @@ -141,6 +142,24 @@ class TestAnthropicPromptDriver: }, ] + ANTHROPIC_STRUCTURED_OUTPUT_TOOL = { + "description": "Used to provide the final response which ends this conversation.", + "input_schema": { + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + "name": "StructuredOutputTool_provide_output", + } + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("anthropic.Anthropic") @@ -199,6 +218,7 @@ def mock_stream_client(self, mocker): @pytest.fixture(params=[True, False]) def prompt_stack(self, request): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] if request.param: prompt_stack.add_system_message("system-input") @@ -343,10 +363,15 @@ def test_init(self): assert AnthropicPromptDriver(model="claude-3-haiku", api_key="1234") @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, use_native_structured_output): # Given driver = AnthropicPromptDriver( - model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="claude-3-haiku", + api_key="api-key", + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, ) # When @@ -362,7 +387,21 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): top_p=0.999, top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, - **{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + **{ + "tools": [ + *self.ANTHROPIC_TOOLS, + *( + [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.native_structured_output_strategy == "tool" + else [] + ), + ] + if use_native_tools + else {}, + "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, + } + if use_native_tools + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -376,13 +415,17 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream_run( + self, mock_stream_client, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -401,7 +444,21 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na top_p=0.999, top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, - **{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + **{ + "tools": [ + *self.ANTHROPIC_TOOLS, + *( + [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.native_structured_output_strategy == "tool" + else [] + ), + ] + if use_native_tools + else {}, + "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, + } + if use_native_tools + else {}, foo="bar", ) assert event.usage.input_tokens == 5 @@ -426,3 +483,9 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na event = next(stream) assert event.usage.output_tokens == 10 + + def test_verify_native_structured_output_strategy(self): + assert AnthropicPromptDriver(model="foo", native_structured_output_strategy="tool") + + with pytest.raises(ValueError, match="AnthropicPromptDriver does not support `native` structured output mode."): + AnthropicPromptDriver(model="foo", native_structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index c7dff9811..6ca7a423b 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -67,13 +67,25 @@ def test_init(self): assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", model="gpt-4").azure_deployment == "gpt-4" @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool"]) + def test_try_run( + self, + mock_chat_completion_create, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given driver = AzureOpenAiChatPromptDriver( azure_endpoint="endpoint", azure_deployment="deployment-id", model="gpt-4", use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -87,11 +99,32 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_ user=driver.user, messages=messages, **{ - "tools": self.OPENAI_TOOLS, - "tool_choice": driver.tool_choice, + "tools": [ + *self.OPENAI_TOOLS, + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ], + "tool_choice": "required" + if use_native_structured_output and native_structured_output_strategy == "tool" + else driver.tool_choice, } if use_native_tools else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -103,7 +136,17 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_ assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool"]) + def test_try_stream_run( + self, + mock_chat_completion_stream_create, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given driver = AzureOpenAiChatPromptDriver( azure_endpoint="endpoint", @@ -111,6 +154,8 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, model="gpt-4", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -126,11 +171,32 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, stream=True, messages=messages, **{ - "tools": self.OPENAI_TOOLS, - "tool_choice": driver.tool_choice, + "tools": [ + *self.OPENAI_TOOLS, + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ], + "tool_choice": "required" + if use_native_structured_output and native_structured_output_strategy == "tool" + else driver.tool_choice, } if use_native_tools else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 58720bbc5..c57173c66 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,3 +1,5 @@ +import pytest + from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent @@ -65,3 +67,24 @@ def test_run_with_tools_and_stream(self, mock_config): output = pipeline.run().output_task.output assert isinstance(output, TextArtifact) assert output.value == "mock output" + + def test__add_structured_output_tool(self): + from schema import Schema + + from griptape.tools.structured_output.tool import StructuredOutputTool + + mock_prompt_driver = MockPromptDriver() + + prompt_stack = PromptStack() + + with pytest.raises(ValueError, match="PromptStack must have an output schema to use structured output."): + mock_prompt_driver._add_structured_output_tool(prompt_stack) + + prompt_stack.output_schema = Schema({"foo": str}) + + mock_prompt_driver._add_structured_output_tool(prompt_stack) + # Ensure it doesn't get added twice + mock_prompt_driver._add_structured_output_tool(prompt_stack) + assert len(prompt_stack.tools) == 1 + assert isinstance(prompt_stack.tools[0], StructuredOutputTool) + assert prompt_stack.tools[0].output_schema is prompt_stack.output_schema diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 9b7c24a98..bc0c51203 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import pytest +from schema import Schema from griptape.artifacts.action_artifact import ActionArtifact from griptape.artifacts.list_artifact import ListArtifact @@ -12,6 +13,36 @@ class TestCoherePromptDriver: + COHERE_STRUCTURED_OUTPUT_SCHEMA = { + "$id": "Output", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } + COHERE_STRUCTURED_OUTPUT_TOOL = { + "function": { + "description": "Used to provide the final response which ends this conversation.", + "name": "StructuredOutputTool_provide_output", + "parameters": { + "$id": "Parameters Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + }, + "type": "function", + } COHERE_TOOLS = [ { "function": { @@ -242,6 +273,7 @@ def mock_tokenizer(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -306,10 +338,25 @@ def test_init(self): assert CoherePromptDriver(model="command", api_key="foobar") @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + def test_try_run( + self, + mock_client, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given driver = CoherePromptDriver( - model="command", api_key="api-key", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="command", + api_key="api-key", + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, + extra_params={"foo": "bar"}, ) # When @@ -320,7 +367,26 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): model="command", messages=messages, max_tokens=None, - **({"tools": self.COHERE_TOOLS} if use_native_tools else {}), + **{ + "tools": [ + *self.COHERE_TOOLS, + *( + [self.COHERE_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ] + } + if use_native_tools + else {}, + **{ + "response_format": { + "type": "json_object", + "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, + } + } + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, stop_sequences=[], temperature=0.1, foo="bar", @@ -340,13 +406,25 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + def test_try_stream_run( + self, + mock_stream_client, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given driver = CoherePromptDriver( model="command", api_key="api-key", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -359,7 +437,26 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na model="command", messages=messages, max_tokens=None, - **({"tools": self.COHERE_TOOLS} if use_native_tools else {}), + **{ + "tools": [ + *self.COHERE_TOOLS, + *( + [self.COHERE_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ] + } + if use_native_tools + else {}, + **{ + "response_format": { + "type": "json_object", + "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, + } + } + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, stop_sequences=[], temperature=0.1, foo="bar", diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 72cf51d03..e4a71d24e 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -4,6 +4,7 @@ from google.generativeai.protos import FunctionCall, FunctionResponse, Part from google.generativeai.types import ContentDict, GenerationConfig from google.protobuf.json_format import MessageToDict +from schema import Schema from griptape.artifacts import ActionArtifact, GenericArtifact, ImageArtifact, TextArtifact from griptape.artifacts.list_artifact import ListArtifact @@ -13,6 +14,15 @@ class TestGooglePromptDriver: + GOOGLE_STRUCTURED_OUTPUT_TOOL = { + "description": "Used to provide the final response which ends this conversation.", + "name": "StructuredOutputTool_provide_output", + "parameters": { + "properties": {"foo": {"type": "STRING"}}, + "required": ["foo"], + "type": "OBJECT", + }, + } GOOGLE_TOOLS = [ { "name": "MockTool_test", @@ -100,6 +110,7 @@ def mock_stream_generative_model(self, mocker): @pytest.fixture(params=[True, False]) def prompt_stack(self, request): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] if request.param: prompt_stack.add_system_message("system-input") @@ -166,7 +177,10 @@ def test_init(self): assert driver @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run( + self, mock_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -174,6 +188,8 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native top_p=0.5, top_k=50, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy="tool", extra_params={"max_output_tokens": 10}, ) @@ -195,9 +211,14 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] - assert [ - MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations - ] == self.GOOGLE_TOOLS + tools = [ + *self.GOOGLE_TOOLS, + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), + ] + assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools + + if use_native_structured_output: + assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(message.value[0], TextArtifact) assert message.value[0].value == "model-output" @@ -210,7 +231,10 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream( + self, mock_stream_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -219,6 +243,7 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, top_p=0.5, top_k=50, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"max_output_tokens": 10}, ) @@ -242,9 +267,14 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] - assert [ - MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations - ] == self.GOOGLE_TOOLS + tools = [ + *self.GOOGLE_TOOLS, + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), + ] + assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools + + if use_native_structured_output: + assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" assert event.usage.input_tokens == 5 @@ -259,3 +289,9 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, event = next(stream) assert event.usage.output_tokens == 5 + + def test_verify_native_structured_output_strategy(self): + assert GooglePromptDriver(model="foo", native_structured_output_strategy="tool") + + with pytest.raises(ValueError, match="GooglePromptDriver does not support `native` structured output mode."): + GooglePromptDriver(model="foo", native_structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 4b7aa4d13..24a83c07b 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -1,10 +1,18 @@ import pytest +from schema import Schema from griptape.common import PromptStack, TextDeltaMessageContent from griptape.drivers import HuggingFaceHubPromptDriver class TestHuggingFaceHubPromptDriver: + HUGGINGFACE_HUB_OUTPUT_SCHEMA = { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("huggingface_hub.InferenceClient").return_value @@ -31,6 +39,7 @@ def mock_client_stream(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") prompt_stack.add_assistant_message("assistant-input") @@ -45,9 +54,15 @@ def mock_autotokenizer(self, mocker): def test_init(self): assert HuggingFaceHubPromptDriver(api_token="foobar", model="gpt2") - def test_try_run(self, prompt_stack, mock_client): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, prompt_stack, mock_client, use_native_structured_output): # Given - driver = HuggingFaceHubPromptDriver(api_token="api-token", model="repo-id", extra_params={"foo": "bar"}) + driver = HuggingFaceHubPromptDriver( + api_token="api-token", + model="repo-id", + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, + ) # When message = driver.try_run(prompt_stack) @@ -58,15 +73,23 @@ def test_try_run(self, prompt_stack, mock_client): return_full_text=False, max_new_tokens=250, foo="bar", + **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} + if use_native_structured_output + else {}, ) assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, prompt_stack, mock_client_stream): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream(self, prompt_stack, mock_client_stream, use_native_structured_output): # Given driver = HuggingFaceHubPromptDriver( - api_token="api-token", model="repo-id", stream=True, extra_params={"foo": "bar"} + api_token="api-token", + model="repo-id", + stream=True, + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, ) # When @@ -79,6 +102,9 @@ def test_try_stream(self, prompt_stack, mock_client_stream): return_full_text=False, max_new_tokens=250, foo="bar", + **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} + if use_native_structured_output + else {}, stream=True, ) assert isinstance(event.content, TextDeltaMessageContent) @@ -87,3 +113,11 @@ def test_try_stream(self, prompt_stack, mock_client_stream): event = next(stream) assert event.usage.input_tokens == 3 assert event.usage.output_tokens == 3 + + def test_verify_native_structured_output_strategy(self): + assert HuggingFaceHubPromptDriver(model="foo", api_token="bar", native_structured_output_strategy="native") + + with pytest.raises( + ValueError, match="HuggingFaceHubPromptDriver does not support `tool` structured output mode." + ): + HuggingFaceHubPromptDriver(model="foo", api_token="bar", native_structured_output_strategy="tool") diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index 51a3dbb77..bfc1111e0 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -1,4 +1,5 @@ import pytest +from schema import Schema from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.common import PromptStack, TextDeltaMessageContent, ToolAction @@ -7,6 +8,27 @@ class TestOllamaPromptDriver: + OLLAMA_STRUCTURED_OUTPUT_SCHEMA = { + "$id": "Output", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } + OLLAMA_STRUCTURED_OUTPUT_TOOL = { + "function": { + "description": "Used to provide the final response which ends this conversation.", + "name": "StructuredOutputTool_provide_output", + "parameters": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "type": "function", + } OLLAMA_TOOLS = [ { "function": { @@ -112,7 +134,9 @@ class TestOllamaPromptDriver: def mock_client(self, mocker): mock_client = mocker.patch("ollama.Client") - mock_client.return_value.chat.return_value = { + mock_response = mocker.MagicMock() + + data = { "message": { "content": "model-output", "tool_calls": [ @@ -126,6 +150,10 @@ def mock_client(self, mocker): }, } + mock_response.__getitem__.side_effect = lambda key: data[key] + mock_response.model_dump.return_value = data + mock_client.return_value.chat.return_value = mock_response + return mock_client @pytest.fixture() @@ -138,6 +166,7 @@ def mock_stream_client(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -202,10 +231,26 @@ def messages(self): def test_init(self): assert OllamaPromptDriver(model="llama") - @pytest.mark.parametrize("use_native_tools", [True]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + def test_try_run( + self, + mock_client, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given - driver = OllamaPromptDriver(model="llama", extra_params={"foo": "bar"}) + driver = OllamaPromptDriver( + model="llama", + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, + extra_params={"foo": "bar"}, + ) # When message = driver.try_run(prompt_stack) @@ -219,7 +264,21 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): "stop": [], "num_predict": driver.max_tokens, }, - **{"tools": self.OLLAMA_TOOLS} if use_native_tools else {}, + **{ + "tools": [ + *self.OLLAMA_TOOLS, + *( + [self.OLLAMA_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ] + } + if use_native_tools + else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -230,33 +289,39 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.value[1].value.path == "test" assert message.value[1].value.input == {"foo": "bar"} - def test_try_stream_run(self, mock_stream_client): + @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + def test_try_stream_run( + self, + mock_stream_client, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message( - ListArtifact( - [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] - ) + driver = OllamaPromptDriver( + model="llama", + stream=True, + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, + extra_params={"foo": "bar"}, ) - prompt_stack.add_assistant_message("assistant-input") - expected_messages = [ - {"role": "system", "content": "system-input"}, - {"role": "user", "content": "user-input"}, - {"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]}, - {"role": "assistant", "content": "assistant-input"}, - ] - driver = OllamaPromptDriver(model="llama", stream=True, extra_params={"foo": "bar"}) # When text_artifact = next(driver.try_stream(prompt_stack)) # Then mock_stream_client.return_value.chat.assert_called_once_with( - messages=expected_messages, + messages=messages, model=driver.model, options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, stream=True, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index c47c3e9c6..2b9c7e5b9 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -12,6 +12,36 @@ class TestOpenAiChatPromptDriverFixtureMixin: + OPENAI_STRUCTURED_OUTPUT_SCHEMA = { + "$id": "Output", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } + OPENAI_STRUCTURED_OUTPUT_TOOL = { + "function": { + "description": "Used to provide the final response which ends this conversation.", + "name": "StructuredOutputTool_provide_output", + "parameters": { + "$id": "Parameters Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + }, + "type": "function", + } OPENAI_TOOLS = [ { "function": { @@ -239,6 +269,7 @@ def mock_chat_completion_stream_create(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = schema.Schema({"foo": str}) prompt_stack.tools = [MockTool()] prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -340,11 +371,23 @@ def test_init(self): assert OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_4_MODEL) @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + def test_try_run( + self, + mock_chat_completion_create, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -359,12 +402,33 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_ messages=messages, seed=driver.seed, **{ - "tools": self.OPENAI_TOOLS, - "tool_choice": driver.tool_choice, + "tools": [ + *self.OPENAI_TOOLS, + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ], + "tool_choice": "required" + if use_native_structured_output and native_structured_output_strategy == "tool" + else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -445,12 +509,24 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create, assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_strategy", ["native", "tool", "foo"]) + def test_try_stream_run( + self, + mock_chat_completion_stream_create, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_strategy, + ): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_strategy=native_structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -468,12 +544,33 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, seed=driver.seed, stream_options={"include_usage": True}, **{ - "tools": self.OPENAI_TOOLS, - "tool_choice": driver.tool_choice, + "tools": [ + *self.OPENAI_TOOLS, + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_strategy == "tool" + else [] + ), + ], + "tool_choice": "required" + if use_native_structured_output and native_structured_output_strategy == "tool" + else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if use_native_structured_output and native_structured_output_strategy == "native" + else {}, foo="bar", ) @@ -500,7 +597,10 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): # Given driver = OpenAiChatPromptDriver( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1, use_native_tools=False + model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, + max_tokens=1, + use_native_tools=False, + use_native_structured_output=False, ) # When @@ -535,6 +635,7 @@ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messa tokenizer=MockTokenizer(model="mock-model", stop_sequences=["mock-stop"]), max_tokens=1, use_native_tools=False, + use_native_structured_output=False, ) # When diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 3b0e508e0..320641757 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -82,6 +82,8 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "use_native_structured_output": False, + "native_structured_output_strategy": "native", }, } ], diff --git a/tests/unit/tasks/test_actions_subtask.py b/tests/unit/tasks/test_actions_subtask.py index e7d44b5af..764c3440c 100644 --- a/tests/unit/tasks/test_actions_subtask.py +++ b/tests/unit/tasks/test_actions_subtask.py @@ -4,9 +4,10 @@ from griptape.artifacts import ActionArtifact, ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact +from griptape.artifacts.json_artifact import JsonArtifact from griptape.common import ToolAction from griptape.structures import Agent -from griptape.tasks import ActionsSubtask, PromptTask +from griptape.tasks import ActionsSubtask, PromptTask, ToolkitTask from tests.mocks.mock_tool.tool import MockTool @@ -257,3 +258,68 @@ def test_origin_task(self): with pytest.raises(Exception, match="ActionSubtask has no origin task."): assert ActionsSubtask("test").origin_task + + def test_structured_output_tool(self): + import schema + + from griptape.tools.structured_output.tool import StructuredOutputTool + + actions = ListArtifact( + [ + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool", + path="provide_output", + input={"values": {"test": "value"}}, + ) + ), + ] + ) + + task = ToolkitTask(tools=[StructuredOutputTool(output_schema=schema.Schema({"test": str}))]) + Agent().add_task(task) + subtask = task.add_subtask(ActionsSubtask(actions)) + + assert isinstance(subtask.output, JsonArtifact) + assert subtask.output.value == {"test": "value"} + + def test_structured_output_tool_multiple(self): + import schema + + from griptape.tools.structured_output.tool import StructuredOutputTool + + actions = ListArtifact( + [ + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool1", + path="provide_output", + input={"values": {"test1": "value"}}, + ) + ), + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool2", + path="provide_output", + input={"values": {"test2": "value"}}, + ) + ), + ] + ) + + task = ToolkitTask( + tools=[ + StructuredOutputTool(name="StructuredOutputTool1", output_schema=schema.Schema({"test": str})), + StructuredOutputTool(name="StructuredOutputTool2", output_schema=schema.Schema({"test": str})), + ] + ) + Agent().add_task(task) + subtask = task.add_subtask(ActionsSubtask(actions)) + + assert isinstance(subtask.output, ListArtifact) + assert len(subtask.output.value) == 2 + assert subtask.output.value[0].value == {"test1": "value"} + assert subtask.output.value[1].value == {"test2": "value"} diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index a5d4521cf..ee5902f82 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,9 +1,15 @@ +import warnings + +import pytest + from griptape.artifacts.image_artifact import ImageArtifact +from griptape.artifacts.json_artifact import JsonArtifact from griptape.artifacts.list_artifact import ListArtifact from griptape.artifacts.text_artifact import TextArtifact from griptape.memory.structure import ConversationMemory from griptape.memory.structure.run import Run from griptape.rules import Rule +from griptape.rules.json_schema_rule import JsonSchemaRule from griptape.rules.ruleset import Ruleset from griptape.structures import Pipeline from griptape.tasks import PromptTask @@ -171,6 +177,81 @@ def test_prompt_stack_empty_system_content(self): assert task.prompt_stack.messages[2].is_user() assert task.prompt_stack.messages[2].to_text() == "test value" + def test_prompt_stack_native_schema(self): + from schema import Schema + + output_schema = Schema({"baz": str}) + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver( + use_native_structured_output=True, + mock_structured_output={"baz": "foo"}, + ), + rules=[JsonSchemaRule(output_schema)], + ) + output = task.run() + + assert isinstance(output, JsonArtifact) + assert output.value == {"baz": "foo"} + + assert task.prompt_stack.output_schema is output_schema + assert task.prompt_stack.messages[0].is_user() + assert "foo" in task.prompt_stack.messages[0].to_text() + + # Ensure no warnings were raised + with warnings.catch_warnings(): + warnings.simplefilter("error") + assert task.prompt_stack + + def test_prompt_stack_mixed_native_schema(self): + from schema import Schema + + output_schema = Schema({"baz": str}) + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver( + use_native_structured_output=True, + ), + rules=[Rule("foo"), JsonSchemaRule({"bar": {}}), JsonSchemaRule(output_schema)], + ) + + assert task.prompt_stack.output_schema is output_schema + assert task.prompt_stack.messages[0].is_system() + assert "foo" in task.prompt_stack.messages[0].to_text() + assert "bar" not in task.prompt_stack.messages[0].to_text() + with pytest.warns( + match="Not all provided `JsonSchemaRule`s include a `schema.Schema` instance. These will be ignored with `use_native_structured_output`." + ): + assert task.prompt_stack + + def test_prompt_stack_empty_native_schema(self): + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver( + use_native_structured_output=True, + ), + rules=[JsonSchemaRule({"foo": {}})], + ) + + assert task.prompt_stack.output_schema is None + + def test_prompt_stack_multi_native_schema(self): + from schema import Or, Schema + + output_schema = Schema({"foo": str}) + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver( + use_native_structured_output=True, + ), + rules=[JsonSchemaRule({"foo": {}}), JsonSchemaRule(output_schema), JsonSchemaRule(output_schema)], + ) + + assert isinstance(task.prompt_stack.output_schema, Schema) + assert task.prompt_stack.output_schema.json_schema("Output") == Schema( + Or(output_schema, output_schema) + ).json_schema("Output") + def test_rulesets(self): pipeline = Pipeline( rulesets=[Ruleset("Pipeline Ruleset")], diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index ca0576ebe..d7050c8f6 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -257,6 +257,8 @@ def test_to_dict(self): "stream": False, "temperature": 0.1, "type": "MockPromptDriver", + "native_structured_output_strategy": "native", + "use_native_structured_output": False, "use_native_tools": False, }, "tool": { diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 3c17ff479..2503f1174 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -399,6 +399,8 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "use_native_structured_output": False, + "native_structured_output_strategy": "native", }, "tools": [ { diff --git a/tests/unit/tools/test_structured_output_tool.py b/tests/unit/tools/test_structured_output_tool.py new file mode 100644 index 000000000..d310b2f9b --- /dev/null +++ b/tests/unit/tools/test_structured_output_tool.py @@ -0,0 +1,13 @@ +import pytest +import schema + +from griptape.tools import StructuredOutputTool + + +class TestStructuredOutputTool: + @pytest.fixture() + def tool(self): + return StructuredOutputTool(output_schema=schema.Schema({"foo": "bar"})) + + def test_provide_output(self, tool): + assert tool.provide_output({"values": {"foo": "bar"}}).value == {"foo": "bar"}