diff --git a/CHANGELOG.md b/CHANGELOG.md index cb0a4e021..62b069cdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `BaseVectorStoreDriver.query_vector` for querying vector stores with vectors. +- Structured Output support for all Prompt Drivers. +- `PromptTask.output_schema` for setting an output schema to be used with Structured Output. +- `Agent.output_schema` for setting an output schema to be used on the Agent's Prompt Task. +- `BasePromptDriver.structured_output_strategy` for changing the Structured Output strategy between `native`, `tool`, and `rule`. ## [1.1.1] - 2025-01-03 diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 6c51d2d01..e45d28d71 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -25,6 +25,29 @@ You can pass images to the Driver if the model supports it: --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_images.py" ``` +## Structured Output + +Some LLMs provide functionality often referred to as "Structured Output". +This means instructing the LLM to output data in a particular format, usually JSON. +This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems. + +!!! warning + Each Driver may have a different default setting depending on the LLM provider's capabilities. + +### Prompt Task + +The easiest way to get started with structured output is by using a [PromptTask](../structures/tasks.md#prompt)'s [output_schema](../../reference/griptape/tasks/prompt_task.md#griptape.tasks.PromptTask.output_schema) parameter. + +You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of: + +- `native`: The Driver will use the LLM's structured output functionality provided by the API. +- `tool`: The Driver will add a special tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md), and will try to force the LLM to use the Tool. +- `rule`: The Driver will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. This strategy does not guarantee that the LLM will output JSON and should only be used as a last resort. + +```python +--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py" +``` + ## Prompt Drivers Griptape offers the following Prompt Drivers for interacting with LLMs. diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py new file mode 100644 index 000000000..cb7eb5ceb --- /dev/null +++ b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py @@ -0,0 +1,32 @@ +import schema +from rich.pretty import pprint + +from griptape.drivers import OpenAiChatPromptDriver +from griptape.rules import Rule +from griptape.structures import Pipeline +from griptape.tasks import PromptTask + +pipeline = Pipeline( + tasks=[ + PromptTask( + prompt_driver=OpenAiChatPromptDriver( + model="gpt-4o", + structured_output_strategy="native", # optional + ), + output_schema=schema.Schema( + { + "steps": [schema.Schema({"explanation": str, "output": str})], + "final_answer": str, + } + ), + rules=[ + Rule("You are a helpful math tutor. Guide the user through the solution step by step."), + ], + ) + ] +) + +output = pipeline.run("How can I solve 8x + 7 = -23").output.value + + +pprint(output) diff --git a/docs/griptape-framework/structures/rulesets.md b/docs/griptape-framework/structures/rulesets.md index f7a1de482..0104a94d3 100644 --- a/docs/griptape-framework/structures/rulesets.md +++ b/docs/griptape-framework/structures/rulesets.md @@ -26,6 +26,9 @@ A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Ru ### Json Schema +!!! tip + [Structured Output](../drivers/prompt-drivers.md#structured-output) provides a more robust solution for having the LLM generate structured output. + [JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema. This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types. diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index 3b1b8ef74..752ce8a8d 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from attrs import define, field @@ -24,6 +24,8 @@ from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: + from schema import Schema + from griptape.tools import BaseTool @@ -31,6 +33,7 @@ class PromptStack(SerializableMixin): messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True}) tools: list[BaseTool] = field(factory=list, kw_only=True) + output_schema: Optional[Schema] = field(default=None, kw_only=True) @property def system_messages(self) -> list[Message]: diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 54278c895..12ea13ad5 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -3,7 +3,7 @@ import logging from typing import TYPE_CHECKING, Any -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from schema import Schema from griptape.artifacts import ( @@ -41,6 +41,7 @@ import boto3 from griptape.common import PromptStack + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.tools import BaseTool logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -55,9 +56,19 @@ class AmazonBedrockPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + structured_output_strategy: StructuredOutputStrategy = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True}) _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: + if value == "native": + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") + + return value + @lazy_property() def client(self) -> Any: return self.session.client("bedrock-runtime") @@ -103,10 +114,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages] - messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()]) - return { + params = { "modelId": self.model, "messages": messages, "system": system_messages, @@ -115,14 +125,22 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}), }, "additionalModelRequestFields": self.additional_model_request_fields, - **( - {"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}} - if prompt_stack.tools and self.use_native_tools - else {} - ), **self.extra_params, } + if prompt_stack.tools and self.use_native_tools: + params["toolConfig"] = { + "tools": [], + "toolChoice": self.tool_choice, + } + + if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": + params["toolConfig"]["toolChoice"] = {"any": {}} + + params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools) + + return params + def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]: return [ { diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index d98ac9fd4..bc0e28266 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -20,6 +20,7 @@ import boto3 from griptape.common import PromptStack + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -39,8 +40,18 @@ class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver): ), kw_only=True, ) + structured_output_strategy: StructuredOutputStrategy = field( + default="rule", kw_only=True, metadata={"serializable": True} + ) _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: + if value != "rule": + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") + + return value + @lazy_property() def client(self) -> Any: return self.session.client("sagemaker-runtime") diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 060b8151d..9a558e7cf 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -3,7 +3,7 @@ import logging from typing import TYPE_CHECKING, Optional -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from schema import Schema from griptape.artifacts import ( @@ -42,6 +42,7 @@ from anthropic import Client from anthropic.types import ContentBlock, ContentBlockDeltaEvent, ContentBlockStartEvent + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.tools.base_tool import BaseTool @@ -68,6 +69,9 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + structured_output_strategy: StructuredOutputStrategy = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @@ -75,6 +79,13 @@ class AnthropicPromptDriver(BasePromptDriver): def client(self) -> Client: return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: + if value == "native": + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") + + return value + @observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) @@ -110,7 +121,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = prompt_stack.system_messages system_message = system_messages[0].to_text() if system_messages else None - return { + params = { "model": self.model, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, @@ -118,15 +129,20 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "top_k": self.top_k, "max_tokens": self.max_tokens, "messages": messages, - **( - {"tools": self.__to_anthropic_tools(prompt_stack.tools), "tool_choice": self.tool_choice} - if prompt_stack.tools and self.use_native_tools - else {} - ), **({"system": system_message} if system_message else {}), **self.extra_params, } + if prompt_stack.tools and self.use_native_tools: + params["tool_choice"] = self.tool_choice + + if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": + params["tool_choice"] = {"type": "any"} + + params["tools"] = self.__to_anthropic_tools(prompt_stack.tools) + + return params + def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]: return [ {"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)} diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 707f67644..a6d769021 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -1,11 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional from attrs import Factory, define, field -from griptape.artifacts.base_artifact import BaseArtifact +from griptape.artifacts import BaseArtifact, TextArtifact from griptape.common import ( ActionCallDeltaMessageContent, ActionCallMessageContent, @@ -26,12 +26,15 @@ ) from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin from griptape.mixins.serializable_mixin import SerializableMixin +from griptape.rules.json_schema_rule import JsonSchemaRule if TYPE_CHECKING: from collections.abc import Iterator from griptape.tokenizers import BaseTokenizer +StructuredOutputStrategy = Literal["native", "tool", "rule"] + @define(kw_only=True) class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): @@ -56,9 +59,13 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + structured_output_strategy: StructuredOutputStrategy = field( + default="rule", kw_only=True, metadata={"serializable": True} + ) extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: + self._init_structured_output(prompt_stack) EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: @@ -122,6 +129,34 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ... @abstractmethod def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ... + def _init_structured_output(self, prompt_stack: PromptStack) -> None: + from griptape.tools import StructuredOutputTool + + if (output_schema := prompt_stack.output_schema) is not None: + if self.structured_output_strategy == "tool": + structured_output_tool = StructuredOutputTool(output_schema=output_schema) + if structured_output_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_output_tool) + elif self.structured_output_strategy == "rule": + output_artifact = TextArtifact(JsonSchemaRule(output_schema.json_schema("Output Schema")).to_text()) + system_messages = prompt_stack.system_messages + if system_messages: + last_system_message = prompt_stack.system_messages[-1] + last_system_message.content.extend( + [ + TextMessageContent(TextArtifact("\n\n")), + TextMessageContent(output_artifact), + ] + ) + else: + prompt_stack.messages.insert( + 0, + Message( + content=[TextMessageContent(output_artifact)], + role=Message.SYSTEM_ROLE, + ), + ) + def __process_run(self, prompt_stack: PromptStack) -> Message: return self.try_run(prompt_stack) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 3811db5cd..9158c4ad1 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -101,21 +101,27 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: messages = self.__to_cohere_messages(prompt_stack.messages) - return { + params = { "model": self.model, "messages": messages, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, "max_tokens": self.max_tokens, **({"tool_results": tool_results} if tool_results else {}), - **( - {"tools": self.__to_cohere_tools(prompt_stack.tools)} - if prompt_stack.tools and self.use_native_tools - else {} - ), **self.extra_params, } + if prompt_stack.output_schema is not None and self.structured_output_strategy == "native": + params["response_format"] = { + "type": "json_object", + "schema": prompt_stack.output_schema.json_schema("Output"), + } + + if prompt_stack.tools and self.use_native_tools: + params["tools"] = self.__to_cohere_tools(prompt_stack.tools) + + return params + def __to_cohere_messages(self, messages: list[Message]) -> list[dict]: cohere_messages = [] diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 2a6bdbf6d..46a721b08 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -4,7 +4,7 @@ import logging from typing import TYPE_CHECKING, Optional -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from schema import Schema from griptape.artifacts import ActionArtifact, TextArtifact @@ -37,6 +37,7 @@ from google.generativeai.protos import Part from google.generativeai.types import ContentDict, ContentsType, GenerateContentResponse + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.tools import BaseTool logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -63,9 +64,19 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + structured_output_strategy: StructuredOutputStrategy = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True}) _client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: + if value == "native": + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") + + return value + @lazy_property() def client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") @@ -135,7 +146,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages], ) - return { + params = { "generation_config": types.GenerationConfig( **{ # For some reason, providing stop sequences when streaming breaks native functions @@ -148,16 +159,18 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, }, ), - **( - { - "tools": self.__to_google_tools(prompt_stack.tools), - "tool_config": {"function_calling_config": {"mode": self.tool_choice}}, - } - if prompt_stack.tools and self.use_native_tools - else {} - ), } + if prompt_stack.tools and self.use_native_tools: + params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} + + if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": + params["tool_config"]["function_calling_config"]["mode"] = "auto" + + params["tools"] = self.__to_google_tools(prompt_stack.tools) + + return params + def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: types = import_optional_dependency("google.generativeai.types") diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index c2c45c3ae..57a487450 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -3,7 +3,7 @@ import logging from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from griptape.common import DeltaMessage, Message, PromptStack, TextDeltaMessageContent, observable from griptape.configs import Defaults @@ -17,6 +17,8 @@ from huggingface_hub import InferenceClient + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy + logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -35,6 +37,9 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) + structured_output_strategy: StructuredOutputStrategy = field( + default="native", kw_only=True, metadata={"serializable": True} + ) tokenizer: HuggingFaceTokenizer = field( default=Factory( lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), @@ -51,11 +56,23 @@ def client(self) -> InferenceClient: token=self.api_token, ) + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: + if value == "tool": + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") + + return value + @observable def try_run(self, prompt_stack: PromptStack) -> Message: prompt = self.prompt_stack_to_string(prompt_stack) full_params = self._base_params(prompt_stack) - logger.debug((prompt, full_params)) + logger.debug( + { + "prompt": prompt, + **full_params, + } + ) response = self.client.text_generation( prompt, @@ -75,7 +92,12 @@ def try_run(self, prompt_stack: PromptStack) -> Message: def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: prompt = self.prompt_stack_to_string(prompt_stack) full_params = {**self._base_params(prompt_stack), "stream": True} - logger.debug((prompt, full_params)) + logger.debug( + { + "prompt": prompt, + **full_params, + } + ) response = self.client.text_generation(prompt, **full_params) @@ -94,12 +116,22 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) def _base_params(self, prompt_stack: PromptStack) -> dict: - return { + params = { "return_full_text": False, "max_new_tokens": self.max_tokens, **self.extra_params, } + if prompt_stack.output_schema and self.structured_output_strategy == "native": + # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding + output_schema = prompt_stack.output_schema.json_schema("Output Schema") + # Grammar does not support $schema and $id + del output_schema["$schema"] + del output_schema["$id"] + params["grammar"] = {"type": "json", "value": output_schema} + + return params + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index a197523df..866f033ec 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -3,7 +3,7 @@ import logging from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from griptape.artifacts import TextArtifact from griptape.common import DeltaMessage, Message, PromptStack, TextMessageContent, observable @@ -18,6 +18,8 @@ from transformers import TextGenerationPipeline + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy + logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -38,10 +40,20 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): ), kw_only=True, ) + structured_output_strategy: StructuredOutputStrategy = field( + default="rule", kw_only=True, metadata={"serializable": True} + ) _pipeline: TextGenerationPipeline = field( default=None, kw_only=True, alias="pipeline", metadata={"serializable": False} ) + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: + if value in ("native", "tool"): + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") + + return value + @lazy_property() def pipeline(self) -> TextGenerationPipeline: return import_optional_dependency("transformers").pipeline( diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 5cbba1fdf..1c4ae3fd1 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -79,7 +79,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) response = self.client.chat(**params) - logger.debug(response) + logger.debug(response.model_dump()) return Message( content=self.__to_prompt_stack_message_content(response), @@ -102,20 +102,22 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: def _base_params(self, prompt_stack: PromptStack) -> dict: messages = self._prompt_stack_to_messages(prompt_stack) - return { + params = { "messages": messages, "model": self.model, "options": self.options, - **( - {"tools": self.__to_ollama_tools(prompt_stack.tools)} - if prompt_stack.tools - and self.use_native_tools - and not self.stream # Tool calling is only supported when not streaming - else {} - ), **self.extra_params, } + if prompt_stack.output_schema is not None and self.structured_output_strategy == "native": + params["format"] = prompt_stack.output_schema.json_schema("Output") + + # Tool calling is only supported when not streaming + if prompt_stack.tools and self.use_native_tools and not self.stream: + params["tools"] = self.__to_ollama_tools(prompt_stack.tools) + + return params + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: ollama_messages = [] for message in prompt_stack.messages: diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index eed0e35f0..03390d687 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -35,6 +35,7 @@ from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.chat.chat_completion_message import ChatCompletionMessage + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.tools import BaseTool @@ -76,6 +77,9 @@ class OpenAiChatPromptDriver(BasePromptDriver): seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + structured_output_strategy: StructuredOutputStrategy = field( + default="native", kw_only=True, metadata={"serializable": True} + ) parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory( @@ -148,21 +152,29 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "temperature": self.temperature, "user": self.user, "seed": self.seed, - **( - { - "tools": self.__to_openai_tools(prompt_stack.tools), - "tool_choice": self.tool_choice, - "parallel_tool_calls": self.parallel_tool_calls, - } - if prompt_stack.tools and self.use_native_tools - else {} - ), **({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}), **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}), **({"stream_options": {"include_usage": True}} if self.stream else {}), **self.extra_params, } + if prompt_stack.tools and self.use_native_tools: + params["tool_choice"] = self.tool_choice + params["parallel_tool_calls"] = self.parallel_tool_calls + + if prompt_stack.output_schema is not None: + if self.structured_output_strategy == "native": + params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": prompt_stack.output_schema.json_schema("Output"), + "strict": True, + }, + } + elif self.structured_output_strategy == "tool" and self.use_native_tools: + params["tool_choice"] = "required" + if self.response_format is not None: if self.response_format == {"type": "json_object"}: params["response_format"] = self.response_format @@ -171,6 +183,9 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: else: params["response_format"] = self.response_format + if prompt_stack.tools and self.use_native_tools: + params["tools"] = self.__to_openai_tools(prompt_stack.tools) + messages = self.__to_openai_messages(prompt_stack.messages) params["messages"] = messages diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 7b23c620f..fa622bd05 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -151,6 +151,8 @@ def _resolve_types(cls, attrs_cls: type) -> None: from collections.abc import Sequence from typing import Any + from schema import Schema + from griptape.artifacts import BaseArtifact from griptape.common import ( BaseDeltaMessageContent, @@ -170,6 +172,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: BaseTextToSpeechDriver, BaseVectorStoreDriver, ) + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.events import EventListener from griptape.memory import TaskMemory from griptape.memory.structure import BaseConversationMemory, Run @@ -214,6 +217,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: "BaseArtifactStorage": BaseArtifactStorage, "BaseRule": BaseRule, "Ruleset": Ruleset, + "StructuredOutputStrategy": StructuredOutputStrategy, # Third party modules "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any, "ClientV2": import_optional_dependency("cohere").ClientV2 if is_dependency_installed("cohere") else Any, @@ -228,6 +232,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: if is_dependency_installed("mypy_boto3_bedrock") else Any, "voyageai": import_optional_dependency("voyageai") if is_dependency_installed("voyageai") else Any, + "Schema": Schema, }, ) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index baf36108f..9b70b7fb1 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -12,6 +12,8 @@ from griptape.tasks import PromptTask if TYPE_CHECKING: + from schema import Schema + from griptape.artifacts import BaseArtifact from griptape.drivers import BasePromptDriver from griptape.tasks import BaseTask @@ -25,6 +27,7 @@ class Agent(Structure): ) stream: bool = field(default=None, kw_only=True) prompt_driver: BasePromptDriver = field(default=None, kw_only=True) + output_schema: Optional[Schema] = field(default=None, kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True) @@ -98,6 +101,7 @@ def _init_task(self) -> None: self.input, prompt_driver=self.prompt_driver, tools=self.tools, + output_schema=self.output_schema, max_meta_memory_entries=self.max_meta_memory_entries, ) diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 6f9d70053..c889554fd 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -9,12 +9,13 @@ from attrs import define, field from griptape import utils -from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, JsonArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction from griptape.configs import Defaults from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask +from griptape.tools.structured_output.tool import StructuredOutputTool from griptape.utils import remove_null_values_in_dict_recursively, with_contextvars if TYPE_CHECKING: @@ -87,6 +88,14 @@ def attach_to(self, parent_task: BaseTask) -> None: self.__init_from_prompt(self.input.to_text()) else: self.__init_from_artifacts(self.input) + + structured_outputs = [a for a in self.actions if isinstance(a.tool, StructuredOutputTool)] + if structured_outputs: + output_values = [JsonArtifact(a.input["values"]) for a in structured_outputs] + if len(structured_outputs) > 1: + self.output = ListArtifact(output_values) + else: + self.output = output_values[0] except Exception as e: logger.error("Subtask %s\nError parsing tool action: %s", self.origin_task.id, e) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 5086636d0..b70b1eac7 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -8,6 +8,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts.json_artifact import JsonArtifact from griptape.common import PromptStack, ToolAction from griptape.configs import Defaults from griptape.memory.structure import Run @@ -38,6 +39,7 @@ class PromptTask(BaseTask, RuleMixin, ActionsSubtaskOriginMixin): prompt_driver: BasePromptDriver = field( default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True, metadata={"serializable": True} ) + output_schema: Optional[Schema] = field(default=None, kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_generate_system_template, takes_self=True), kw_only=True, @@ -89,7 +91,7 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], @property def prompt_stack(self) -> PromptStack: - stack = PromptStack(tools=self.tools) + stack = PromptStack(tools=self.tools, output_schema=self.output_schema) memory = self.structure.conversation_memory if self.structure is not None else None system_template = self.generate_system_template(self) @@ -101,41 +103,7 @@ def prompt_stack(self) -> PromptStack: if self.output: stack.add_assistant_message(self.output.to_text()) else: - for s in self.subtasks: - if self.prompt_driver.use_native_tools: - action_calls = [ - ToolAction(name=action.name, path=action.path, tag=action.tag, input=action.input) - for action in s.actions - ] - action_results = [ - ToolAction( - name=action.name, - path=action.path, - tag=action.tag, - output=action.output if action.output is not None else s.output, - ) - for action in s.actions - ] - - stack.add_assistant_message( - ListArtifact( - [ - *([TextArtifact(s.thought)] if s.thought else []), - *[ActionArtifact(a) for a in action_calls], - ], - ), - ) - stack.add_user_message( - ListArtifact( - [ - *[ActionArtifact(a) for a in action_results], - *([] if s.output else [TextArtifact("Please keep going")]), - ], - ), - ) - else: - stack.add_assistant_message(self.generate_assistant_subtask_template(s)) - stack.add_user_message(self.generate_user_subtask_template(s)) + self._add_subtasks_to_prompt_stack(stack) if memory is not None: # inserting at index 1 to place memory right after system prompt @@ -218,11 +186,14 @@ def try_run(self) -> BaseArtifact: else: break - self.output = subtask.output + output = subtask.output else: - self.output = result.to_artifact() + output = result.to_artifact() - return self.output + if self.output_schema is not None and self.prompt_driver.structured_output_strategy in ("native", "rule"): + return JsonArtifact(output.value) + else: + return output def preprocess(self, structure: Structure) -> BaseTask: super().preprocess(structure) @@ -324,3 +295,40 @@ def _process_task_input( return ListArtifact([self._process_task_input(elem) for elem in task_input]) else: return self._process_task_input(TextArtifact(task_input)) + + def _add_subtasks_to_prompt_stack(self, stack: PromptStack) -> None: + for s in self.subtasks: + if self.prompt_driver.use_native_tools: + action_calls = [ + ToolAction(name=action.name, path=action.path, tag=action.tag, input=action.input) + for action in s.actions + ] + action_results = [ + ToolAction( + name=action.name, + path=action.path, + tag=action.tag, + output=action.output if action.output is not None else s.output, + ) + for action in s.actions + ] + + stack.add_assistant_message( + ListArtifact( + [ + *([TextArtifact(s.thought)] if s.thought else []), + *[ActionArtifact(a) for a in action_calls], + ], + ), + ) + stack.add_user_message( + ListArtifact( + [ + *[ActionArtifact(a) for a in action_results], + *([] if s.output else [TextArtifact("Please keep going")]), + ], + ), + ) + else: + stack.add_assistant_message(self.generate_assistant_subtask_template(s)) + stack.add_user_message(self.generate_user_subtask_template(s)) diff --git a/griptape/tools/__init__.py b/griptape/tools/__init__.py index 67a1712a1..ec9cbd5b7 100644 --- a/griptape/tools/__init__.py +++ b/griptape/tools/__init__.py @@ -23,6 +23,7 @@ from .extraction.tool import ExtractionTool from .prompt_summary.tool import PromptSummaryTool from .query.tool import QueryTool +from .structured_output.tool import StructuredOutputTool __all__ = [ "BaseTool", @@ -50,4 +51,5 @@ "ExtractionTool", "PromptSummaryTool", "QueryTool", + "StructuredOutputTool", ] diff --git a/griptape/tools/structured_output/__init__.py b/griptape/tools/structured_output/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/tools/structured_output/tool.py b/griptape/tools/structured_output/tool.py new file mode 100644 index 000000000..89e638f59 --- /dev/null +++ b/griptape/tools/structured_output/tool.py @@ -0,0 +1,20 @@ +from attrs import define, field +from schema import Schema + +from griptape.artifacts import BaseArtifact, JsonArtifact +from griptape.tools import BaseTool +from griptape.utils.decorators import activity + + +@define +class StructuredOutputTool(BaseTool): + output_schema: Schema = field(kw_only=True) + + @activity( + config={ + "description": "Used to provide the final response which ends this conversation.", + "schema": lambda self: self.output_schema, + } + ) + def provide_output(self, params: dict) -> BaseArtifact: + return JsonArtifact(params["values"]) diff --git a/mise.toml b/mise.toml new file mode 100644 index 000000000..e01d6ae46 --- /dev/null +++ b/mise.toml @@ -0,0 +1,2 @@ +[tools] +python = "3.9" diff --git a/pyproject.toml b/pyproject.toml index f63dd396d..c45fbefa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -315,7 +315,7 @@ fixture-parentheses = true "ANN202", # missing-return-type-private-function ] "docs/*" = [ - "T20" # flake8-print + "T20", # flake8-print ] [tool.ruff.lint.flake8-tidy-imports.banned-api] diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index f308c9804..1b481067b 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import TYPE_CHECKING, Callable, Union from attrs import define, field @@ -31,6 +32,7 @@ class MockPromptDriver(BasePromptDriver): tokenizer: BaseTokenizer = MockTokenizer(model="test-model", max_input_tokens=4096, max_output_tokens=4096) mock_input: Union[str, Callable[[], str]] = field(default="mock input", kw_only=True) mock_output: Union[str, Callable[[PromptStack], str]] = field(default="mock output", kw_only=True) + mock_structured_output: Union[dict, Callable[[PromptStack], dict]] = field(factory=dict, kw_only=True) def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output @@ -46,32 +48,42 @@ def try_run(self, prompt_stack: PromptStack) -> Message: usage=Message.Usage(input_tokens=100, output_tokens=100), ) else: + if self.structured_output_strategy == "tool": + tool_action = ToolAction( + tag="mock-tag", + name="StructuredOutputTool", + path="provide_output", + input={"values": self.mock_structured_output}, + ) + else: + tool_action = ToolAction( + tag="mock-tag", + name="MockTool", + path="test", + input={"values": {"test": "test-value"}}, + ) + return Message( - content=[ - ActionCallMessageContent( - ActionArtifact( - ToolAction( - tag="mock-tag", - name="MockTool", - path="test", - input={"values": {"test": "test-value"}}, - ) - ) - ) - ], + content=[ActionCallMessageContent(ActionArtifact(tool_action))], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=100, output_tokens=100), ) else: - return Message( - content=[TextMessageContent(TextArtifact(output))], - role=Message.ASSISTANT_ROLE, - usage=Message.Usage(input_tokens=100, output_tokens=100), - ) + if prompt_stack.output_schema is not None: + return Message( + content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=100, output_tokens=100), + ) + else: + return Message( + content=[TextMessageContent(TextArtifact(output))], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=100, output_tokens=100), + ) def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if self.use_native_tools and prompt_stack.tools: # Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer. action_messages = [ @@ -81,15 +93,36 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: yield DeltaMessage(content=TextDeltaMessageContent(f"Answer: {output}")) yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=100, output_tokens=100)) else: - yield DeltaMessage( - content=ActionCallDeltaMessageContent( - tag="mock-tag", - name="MockTool", - path="test", + if self.structured_output_strategy == "tool": + yield DeltaMessage( + content=ActionCallDeltaMessageContent( + tag="mock-tag", + name="StructuredOutputTool", + path="provide_output", + ) ) - ) + yield DeltaMessage( + content=ActionCallDeltaMessageContent( + partial_input=json.dumps({"values": self.mock_structured_output}) + ) + ) + else: + yield DeltaMessage( + content=ActionCallDeltaMessageContent( + tag="mock-tag", + name="MockTool", + path="test", + ) + ) + yield DeltaMessage( + content=ActionCallDeltaMessageContent(partial_input='{ "values": { "test": "test-value" } }') + ) + else: + if prompt_stack.output_schema is not None: yield DeltaMessage( - content=ActionCallDeltaMessageContent(partial_input='{ "values": { "test": "test-value" } }') + content=TextDeltaMessageContent(json.dumps(self.mock_structured_output)), + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=100, output_tokens=100), ) - else: - yield DeltaMessage(content=TextDeltaMessageContent(output)) + else: + yield DeltaMessage(content=TextDeltaMessageContent(output)) diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index 52408922c..b2fd51d24 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -51,6 +51,7 @@ def test_to_dict(self, config): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "structured_output_strategy": "tool", "extra_params": {}, }, "vector_store_driver": { @@ -106,6 +107,7 @@ def test_to_dict_with_values(self, config_with_values): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "structured_output_strategy": "tool", "extra_params": {}, }, "vector_store_driver": { diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index 8a6f25ef2..fa13480c1 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -25,6 +25,7 @@ def test_to_dict(self, config): "top_p": 0.999, "top_k": 250, "use_native_tools": True, + "structured_output_strategy": "tool", "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index 4c44113a0..a30cea001 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -36,6 +36,7 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, + "structured_output_strategy": "native", "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 65295da52..d5e05c9bd 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -26,6 +26,7 @@ def test_to_dict(self, config): "model": "command-r", "force_single_step": False, "use_native_tools": True, + "structured_output_strategy": "rule", "extra_params": {}, }, "embedding_driver": { diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index ca3cea60e..5adec7c6d 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,6 +18,7 @@ def test_to_dict(self, config): "max_tokens": None, "stream": False, "use_native_tools": False, + "structured_output_strategy": "rule", "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index c1459a400..910ae3240 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -25,6 +25,7 @@ def test_to_dict(self, config): "top_k": None, "tool_choice": "auto", "use_native_tools": True, + "structured_output_strategy": "tool", "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index c71774b26..344d14d99 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -28,6 +28,7 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, + "structured_output_strategy": "native", "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index 939b86c5e..2dcb4bf02 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,4 +1,5 @@ import pytest +from schema import Schema from griptape.artifacts import ActionArtifact, ErrorArtifact, GenericArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction @@ -229,6 +230,7 @@ def mock_converse_stream(self, mocker): def prompt_stack(self, request): prompt_stack = PromptStack() prompt_stack.tools = [MockTool()] + prompt_stack.output_schema = Schema({"foo": str}) if request.param: prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -359,10 +361,14 @@ def messages(self): ] @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, structured_output_strategy): # Given driver = AmazonBedrockPromptDriver( - model="ai21.j2", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="ai21.j2", + use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, + extra_params={"foo": "bar"}, ) # When @@ -379,7 +385,14 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): additionalModelRequestFields={}, **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), **( - {"toolConfig": {"tools": self.BEDROCK_TOOLS, "toolChoice": driver.tool_choice}} + { + "toolConfig": { + "tools": self.BEDROCK_TOOLS, + "toolChoice": {"any": {}} + if driver.structured_output_strategy == "tool" + else driver.tool_choice, + } + } if use_native_tools else {} ), @@ -396,10 +409,17 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_stream_run( + self, mock_converse_stream, prompt_stack, messages, use_native_tools, structured_output_strategy + ): # Given driver = AmazonBedrockPromptDriver( - model="ai21.j2", stream=True, use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="ai21.j2", + stream=True, + use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, + extra_params={"foo": "bar"}, ) # When @@ -417,8 +437,15 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_ additionalModelRequestFields={}, **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), **( - {"toolConfig": {"tools": self.BEDROCK_TOOLS, "toolChoice": driver.tool_choice}} - if prompt_stack.tools and use_native_tools + { + "toolConfig": { + "tools": self.BEDROCK_TOOLS, + "toolChoice": {"any": {}} + if driver.structured_output_strategy == "tool" + else driver.tool_choice, + } + } + if use_native_tools else {} ), foo="bar", @@ -441,3 +468,11 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_ event = next(stream) assert event.usage.input_tokens == 5 assert event.usage.output_tokens == 10 + + def test_verify_structured_output_strategy(self): + assert AmazonBedrockPromptDriver(model="foo", structured_output_strategy="tool") + + with pytest.raises( + ValueError, match="AmazonBedrockPromptDriver does not support `native` structured output strategy." + ): + AmazonBedrockPromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py index c7b0682c2..7b2d38398 100644 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py @@ -138,3 +138,12 @@ def test_try_run_throws_on_empty_response(self, mock_client): # Then assert e.value.args[0] == "model response is empty" + + def test_verify_structured_output_strategy(self): + assert AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="foo", structured_output_strategy="rule") + + with pytest.raises( + ValueError, + match="AmazonSageMakerJumpstartPromptDriver does not support `native` structured output strategy.", + ): + AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index b611b5e1c..fbdf1e55d 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest +from schema import Schema from griptape.artifacts import ActionArtifact, GenericArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact @@ -199,6 +200,7 @@ def mock_stream_client(self, mocker): @pytest.fixture(params=[True, False]) def prompt_stack(self, request): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] if request.param: prompt_stack.add_system_message("system-input") @@ -350,10 +352,15 @@ def test_init(self): assert AnthropicPromptDriver(model="claude-3-haiku", api_key="1234") @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, structured_output_strategy): # Given driver = AnthropicPromptDriver( - model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="claude-3-haiku", + api_key="api-key", + use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, + extra_params={"foo": "bar"}, ) # When @@ -369,7 +376,12 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): top_p=0.999, top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, - **{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + **{ + "tools": self.ANTHROPIC_TOOLS if use_native_tools else {}, + "tool_choice": {"type": "any"} if driver.structured_output_strategy == "tool" else driver.tool_choice, + } + if use_native_tools + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -383,13 +395,17 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_stream_run( + self, mock_stream_client, prompt_stack, messages, use_native_tools, structured_output_strategy + ): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -408,7 +424,12 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na top_p=0.999, top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, - **{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + **{ + "tools": self.ANTHROPIC_TOOLS if use_native_tools else {}, + "tool_choice": {"type": "any"} if driver.structured_output_strategy == "tool" else driver.tool_choice, + } + if use_native_tools + else {}, foo="bar", ) assert event.usage.input_tokens == 5 @@ -433,3 +454,11 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na event = next(stream) assert event.usage.output_tokens == 10 + + def test_verify_structured_output_strategy(self): + assert AnthropicPromptDriver(model="foo", structured_output_strategy="tool") + + with pytest.raises( + ValueError, match="AnthropicPromptDriver does not support `native` structured output strategy." + ): + AnthropicPromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index c7dff9811..8f0da735a 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -67,13 +67,22 @@ def test_init(self): assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", model="gpt-4").azure_deployment == "gpt-4" @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) + def test_try_run( + self, + mock_chat_completion_create, + prompt_stack, + messages, + use_native_tools, + structured_output_strategy, + ): # Given driver = AzureOpenAiChatPromptDriver( azure_endpoint="endpoint", azure_deployment="deployment-id", model="gpt-4", use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -88,10 +97,22 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_ messages=messages, **{ "tools": self.OPENAI_TOOLS, - "tool_choice": driver.tool_choice, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if structured_output_strategy == "native" + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -103,7 +124,15 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_ assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) + def test_try_stream_run( + self, + mock_chat_completion_stream_create, + prompt_stack, + messages, + use_native_tools, + structured_output_strategy, + ): # Given driver = AzureOpenAiChatPromptDriver( azure_endpoint="endpoint", @@ -111,6 +140,7 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, model="gpt-4", stream=True, use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -127,10 +157,22 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages=messages, **{ "tools": self.OPENAI_TOOLS, - "tool_choice": driver.tool_choice, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if structured_output_strategy == "native" + else {}, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 58720bbc5..3ffcebce4 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,9 +1,12 @@ -from griptape.artifacts import ErrorArtifact, TextArtifact +import json + +from griptape.artifacts import ActionArtifact, ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent from griptape.events.event_bus import _EventBus from griptape.structures import Pipeline from griptape.tasks import PromptTask +from griptape.tools.structured_output.tool import StructuredOutputTool from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -65,3 +68,74 @@ def test_run_with_tools_and_stream(self, mock_config): output = pipeline.run().output_task.output assert isinstance(output, TextArtifact) assert output.value == "mock output" + + def test_native_structured_output_strategy(self): + from schema import Schema + + prompt_driver = MockPromptDriver( + mock_structured_output={"baz": "foo"}, + structured_output_strategy="native", + ) + + output_schema = Schema({"baz": str}) + output = prompt_driver.run(PromptStack(messages=[], output_schema=output_schema)).to_artifact() + + assert isinstance(output, TextArtifact) + assert output.value == json.dumps({"baz": "foo"}) + + def test_tool_structured_output_strategy(self): + from schema import Schema + + output_schema = Schema({"baz": str}) + prompt_driver = MockPromptDriver( + mock_structured_output={"baz": "foo"}, + structured_output_strategy="tool", + use_native_tools=True, + ) + prompt_stack = PromptStack(messages=[], output_schema=output_schema) + output = prompt_driver.run(prompt_stack).to_artifact() + output = prompt_driver.run(prompt_stack).to_artifact() + + assert isinstance(output, ActionArtifact) + assert isinstance(prompt_stack.tools[0], StructuredOutputTool) + assert prompt_stack.tools[0].output_schema == output_schema + assert output.value.input == {"values": {"baz": "foo"}} + + def test_rule_structured_output_strategy_empty(self): + from schema import Schema + + output_schema = Schema({"baz": str}) + prompt_driver = MockPromptDriver( + mock_structured_output={"baz": "foo"}, + structured_output_strategy="rule", + ) + prompt_stack = PromptStack(messages=[], output_schema=output_schema) + output = prompt_driver.run(prompt_stack).to_artifact() + + assert len(prompt_stack.system_messages) == 1 + assert prompt_stack.messages[0].is_system() + assert "baz" in prompt_stack.messages[0].content[0].to_text() + assert isinstance(output, TextArtifact) + assert output.value == json.dumps({"baz": "foo"}) + + def test_rule_structured_output_strategy_populated(self): + from schema import Schema + + output_schema = Schema({"baz": str}) + prompt_driver = MockPromptDriver( + mock_structured_output={"baz": "foo"}, + structured_output_strategy="rule", + ) + prompt_stack = PromptStack( + messages=[ + Message(content="foo", role=Message.SYSTEM_ROLE), + ], + output_schema=output_schema, + ) + output = prompt_driver.run(prompt_stack).to_artifact() + assert len(prompt_stack.system_messages) == 1 + assert prompt_stack.messages[0].is_system() + assert prompt_stack.messages[0].content[1].to_text() == "\n\n" + assert "baz" in prompt_stack.messages[0].content[2].to_text() + assert isinstance(output, TextArtifact) + assert output.value == json.dumps({"baz": "foo"}) diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 9b7c24a98..8b51940c8 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import pytest +from schema import Schema from griptape.artifacts.action_artifact import ActionArtifact from griptape.artifacts.list_artifact import ListArtifact @@ -12,6 +13,14 @@ class TestCoherePromptDriver: + COHERE_STRUCTURED_OUTPUT_SCHEMA = { + "$id": "Output", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } COHERE_TOOLS = [ { "function": { @@ -242,6 +251,7 @@ def mock_tokenizer(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -306,10 +316,22 @@ def test_init(self): assert CoherePromptDriver(model="command", api_key="foobar") @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) + def test_try_run( + self, + mock_client, + prompt_stack, + messages, + use_native_tools, + structured_output_strategy, + ): # Given driver = CoherePromptDriver( - model="command", api_key="api-key", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="command", + api_key="api-key", + use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, + extra_params={"foo": "bar"}, ) # When @@ -320,7 +342,15 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): model="command", messages=messages, max_tokens=None, - **({"tools": self.COHERE_TOOLS} if use_native_tools else {}), + **{"tools": self.COHERE_TOOLS} if use_native_tools else {}, + **{ + "response_format": { + "type": "json_object", + "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, + } + } + if structured_output_strategy == "native" + else {}, stop_sequences=[], temperature=0.1, foo="bar", @@ -340,13 +370,22 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) + def test_try_stream_run( + self, + mock_stream_client, + prompt_stack, + messages, + use_native_tools, + structured_output_strategy, + ): # Given driver = CoherePromptDriver( model="command", api_key="api-key", stream=True, use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -359,7 +398,15 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na model="command", messages=messages, max_tokens=None, - **({"tools": self.COHERE_TOOLS} if use_native_tools else {}), + **{"tools": self.COHERE_TOOLS} if use_native_tools else {}, + **{ + "response_format": { + "type": "json_object", + "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, + } + } + if structured_output_strategy == "native" + else {}, stop_sequences=[], temperature=0.1, foo="bar", diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 72cf51d03..aacc207b9 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -4,6 +4,7 @@ from google.generativeai.protos import FunctionCall, FunctionResponse, Part from google.generativeai.types import ContentDict, GenerationConfig from google.protobuf.json_format import MessageToDict +from schema import Schema from griptape.artifacts import ActionArtifact, GenericArtifact, ImageArtifact, TextArtifact from griptape.artifacts.list_artifact import ListArtifact @@ -100,6 +101,7 @@ def mock_stream_generative_model(self, mocker): @pytest.fixture(params=[True, False]) def prompt_stack(self, request): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] if request.param: prompt_stack.add_system_message("system-input") @@ -166,7 +168,8 @@ def test_init(self): assert driver @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools, structured_output_strategy): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -174,6 +177,7 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native top_p=0.5, top_k=50, use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, extra_params={"max_output_tokens": 10}, ) @@ -195,9 +199,11 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] - assert [ - MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations - ] == self.GOOGLE_TOOLS + tools = self.GOOGLE_TOOLS + assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools + + if driver.structured_output_strategy == "tool": + assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(message.value[0], TextArtifact) assert message.value[0].value == "model-output" @@ -210,7 +216,10 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_stream( + self, mock_stream_generative_model, prompt_stack, messages, use_native_tools, structured_output_strategy + ): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -219,6 +228,7 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, top_p=0.5, top_k=50, use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, extra_params={"max_output_tokens": 10}, ) @@ -242,9 +252,11 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] - assert [ - MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations - ] == self.GOOGLE_TOOLS + tools = self.GOOGLE_TOOLS + assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools + + if driver.structured_output_strategy == "tool": + assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" assert event.usage.input_tokens == 5 @@ -259,3 +271,11 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, event = next(stream) assert event.usage.output_tokens == 5 + + def test_verify_structured_output_strategy(self): + assert GooglePromptDriver(model="foo", structured_output_strategy="tool") + + with pytest.raises( + ValueError, match="GooglePromptDriver does not support `native` structured output strategy." + ): + GooglePromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 4b7aa4d13..b757dbcea 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -1,10 +1,18 @@ import pytest +from schema import Schema from griptape.common import PromptStack, TextDeltaMessageContent from griptape.drivers import HuggingFaceHubPromptDriver class TestHuggingFaceHubPromptDriver: + HUGGINGFACE_HUB_OUTPUT_SCHEMA = { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("huggingface_hub.InferenceClient").return_value @@ -31,6 +39,7 @@ def mock_client_stream(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") prompt_stack.add_assistant_message("assistant-input") @@ -45,9 +54,15 @@ def mock_autotokenizer(self, mocker): def test_init(self): assert HuggingFaceHubPromptDriver(api_token="foobar", model="gpt2") - def test_try_run(self, prompt_stack, mock_client): + @pytest.mark.parametrize("structured_output_strategy", ["native", "rule", "foo"]) + def test_try_run(self, prompt_stack, mock_client, structured_output_strategy): # Given - driver = HuggingFaceHubPromptDriver(api_token="api-token", model="repo-id", extra_params={"foo": "bar"}) + driver = HuggingFaceHubPromptDriver( + api_token="api-token", + model="repo-id", + extra_params={"foo": "bar"}, + structured_output_strategy=structured_output_strategy, + ) # When message = driver.try_run(prompt_stack) @@ -58,15 +73,27 @@ def test_try_run(self, prompt_stack, mock_client): return_full_text=False, max_new_tokens=250, foo="bar", + **( + { + "grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}, + } + if structured_output_strategy == "native" + else {} + ), ) assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, prompt_stack, mock_client_stream): + @pytest.mark.parametrize("structured_output_strategy", ["native", "rule", "foo"]) + def test_try_stream(self, prompt_stack, mock_client_stream, structured_output_strategy): # Given driver = HuggingFaceHubPromptDriver( - api_token="api-token", model="repo-id", stream=True, extra_params={"foo": "bar"} + api_token="api-token", + model="repo-id", + stream=True, + extra_params={"foo": "bar"}, + structured_output_strategy=structured_output_strategy, ) # When @@ -79,6 +106,13 @@ def test_try_stream(self, prompt_stack, mock_client_stream): return_full_text=False, max_new_tokens=250, foo="bar", + **( + { + "grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}, + } + if structured_output_strategy == "native" + else {} + ), stream=True, ) assert isinstance(event.content, TextDeltaMessageContent) @@ -87,3 +121,11 @@ def test_try_stream(self, prompt_stack, mock_client_stream): event = next(stream) assert event.usage.input_tokens == 3 assert event.usage.output_tokens == 3 + + def test_verify_structured_output_strategy(self): + assert HuggingFaceHubPromptDriver(model="foo", api_token="bar", structured_output_strategy="native") + + with pytest.raises( + ValueError, match="HuggingFaceHubPromptDriver does not support `tool` structured output strategy." + ): + HuggingFaceHubPromptDriver(model="foo", api_token="bar", structured_output_strategy="tool") diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index af52ca4e9..e03604aaf 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -42,10 +42,15 @@ def messages(self): def test_init(self, mock_pipeline): assert HuggingFacePipelinePromptDriver(model="gpt2", max_tokens=42, pipeline=mock_pipeline) - def test_try_run(self, prompt_stack, messages, mock_pipeline): + @pytest.mark.parametrize("structured_output_strategy", ["rule", "foo"]) + def test_try_run(self, prompt_stack, messages, mock_pipeline, structured_output_strategy): # Given driver = HuggingFacePipelinePromptDriver( - model="foo", max_tokens=42, extra_params={"foo": "bar"}, pipeline=mock_pipeline + model="foo", + max_tokens=42, + extra_params={"foo": "bar"}, + pipeline=mock_pipeline, + structured_output_strategy=structured_output_strategy, ) # When @@ -57,9 +62,12 @@ def test_try_run(self, prompt_stack, messages, mock_pipeline): assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, prompt_stack, mock_pipeline): + @pytest.mark.parametrize("structured_output_strategy", ["rule", "foo"]) + def test_try_stream(self, prompt_stack, mock_pipeline, structured_output_strategy): # Given - driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, pipeline=mock_pipeline) + driver = HuggingFacePipelinePromptDriver( + model="foo", max_tokens=42, pipeline=mock_pipeline, structured_output_strategy=structured_output_strategy + ) # When with pytest.raises(Exception) as e: @@ -101,3 +109,11 @@ def test_prompt_stack_to_string(self, prompt_stack, mock_pipeline): # Then assert result == "model-output" + + def test_verify_structured_output_strategy(self): + assert HuggingFacePipelinePromptDriver(model="foo", structured_output_strategy="rule") + + with pytest.raises( + ValueError, match="HuggingFacePipelinePromptDriver does not support `native` structured output strategy." + ): + HuggingFacePipelinePromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index 51a3dbb77..02f284b76 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -1,4 +1,5 @@ import pytest +from schema import Schema from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.common import PromptStack, TextDeltaMessageContent, ToolAction @@ -7,6 +8,14 @@ class TestOllamaPromptDriver: + OLLAMA_STRUCTURED_OUTPUT_SCHEMA = { + "$id": "Output", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } OLLAMA_TOOLS = [ { "function": { @@ -112,7 +121,9 @@ class TestOllamaPromptDriver: def mock_client(self, mocker): mock_client = mocker.patch("ollama.Client") - mock_client.return_value.chat.return_value = { + mock_response = mocker.MagicMock() + + data = { "message": { "content": "model-output", "tool_calls": [ @@ -126,6 +137,10 @@ def mock_client(self, mocker): }, } + mock_response.__getitem__.side_effect = lambda key: data[key] + mock_response.model_dump.return_value = data + mock_client.return_value.chat.return_value = mock_response + return mock_client @pytest.fixture() @@ -138,6 +153,7 @@ def mock_stream_client(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -202,10 +218,23 @@ def messages(self): def test_init(self): assert OllamaPromptDriver(model="llama") - @pytest.mark.parametrize("use_native_tools", [True]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) + def test_try_run( + self, + mock_client, + prompt_stack, + messages, + use_native_tools, + structured_output_strategy, + ): # Given - driver = OllamaPromptDriver(model="llama", extra_params={"foo": "bar"}) + driver = OllamaPromptDriver( + model="llama", + use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, + extra_params={"foo": "bar"}, + ) # When message = driver.try_run(prompt_stack) @@ -219,7 +248,12 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): "stop": [], "num_predict": driver.max_tokens, }, - **{"tools": self.OLLAMA_TOOLS} if use_native_tools else {}, + **{ + "tools": self.OLLAMA_TOOLS, + } + if use_native_tools + else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} if structured_output_strategy == "native" else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -230,33 +264,34 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.value[1].value.path == "test" assert message.value[1].value.input == {"foo": "bar"} - def test_try_stream_run(self, mock_stream_client): + @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) + def test_try_stream_run( + self, + mock_stream_client, + prompt_stack, + messages, + use_native_tools, + structured_output_strategy, + ): # Given - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message( - ListArtifact( - [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] - ) + driver = OllamaPromptDriver( + model="llama", + stream=True, + use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, + extra_params={"foo": "bar"}, ) - prompt_stack.add_assistant_message("assistant-input") - expected_messages = [ - {"role": "system", "content": "system-input"}, - {"role": "user", "content": "user-input"}, - {"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]}, - {"role": "assistant", "content": "assistant-input"}, - ] - driver = OllamaPromptDriver(model="llama", stream=True, extra_params={"foo": "bar"}) # When text_artifact = next(driver.try_stream(prompt_stack)) # Then mock_stream_client.return_value.chat.assert_called_once_with( - messages=expected_messages, + messages=messages, model=driver.model, options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} if structured_output_strategy == "native" else {}, stream=True, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index c47c3e9c6..496560529 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -12,6 +12,14 @@ class TestOpenAiChatPromptDriverFixtureMixin: + OPENAI_STRUCTURED_OUTPUT_SCHEMA = { + "$id": "Output", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } OPENAI_TOOLS = [ { "function": { @@ -239,6 +247,7 @@ def mock_chat_completion_stream_create(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = schema.Schema({"foo": str}) prompt_stack.tools = [MockTool()] prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -340,11 +349,20 @@ def test_init(self): assert OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_4_MODEL) @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) + def test_try_run( + self, + mock_chat_completion_create, + prompt_stack, + messages, + use_native_tools, + structured_output_strategy, + ): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -360,11 +378,23 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_ seed=driver.seed, **{ "tools": self.OPENAI_TOOLS, - "tool_choice": driver.tool_choice, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if prompt_stack.output_schema is not None and structured_output_strategy == "native" + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -445,12 +475,21 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create, assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) + def test_try_stream_run( + self, + mock_chat_completion_stream_create, + prompt_stack, + messages, + use_native_tools, + structured_output_strategy, + ): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True, use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -469,11 +508,23 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, stream_options={"include_usage": True}, **{ "tools": self.OPENAI_TOOLS, - "tool_choice": driver.tool_choice, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools else {}, + **{ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": self.OPENAI_STRUCTURED_OUTPUT_SCHEMA, + "strict": True, + }, + } + } + if structured_output_strategy == "native" + else {}, foo="bar", ) @@ -499,8 +550,11 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): # Given + prompt_stack.output_schema = None driver = OpenAiChatPromptDriver( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1, use_native_tools=False + model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, + max_tokens=1, + use_native_tools=False, ) # When @@ -530,6 +584,7 @@ def test_try_run_throws_when_multiple_choices_returned(self, mock_chat_completio assert e.value.args[0] == "Completion with more than one choice is not supported yet." def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): + prompt_stack.output_schema = None driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=MockTokenizer(model="mock-model", stop_sequences=["mock-stop"]), diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 809d174b5..442f654d5 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest +import schema from griptape.memory import TaskMemory from griptape.memory.structure import ConversationMemory @@ -316,3 +317,14 @@ def test_field_hierarchy(self): assert isinstance(agent.tasks[0], PromptTask) assert agent.tasks[0].prompt_driver.stream is True + + def test_output_schema(self): + agent = Agent() + + assert isinstance(agent.tasks[0], PromptTask) + assert agent.tasks[0].output_schema is None + + agent = Agent(output_schema=schema.Schema({"foo": str})) + + assert isinstance(agent.tasks[0], PromptTask) + assert agent.tasks[0].output_schema is agent.output_schema diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 21a637ff6..da277e81e 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -83,6 +83,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "structured_output_strategy": "rule", }, } ], diff --git a/tests/unit/tasks/test_actions_subtask.py b/tests/unit/tasks/test_actions_subtask.py index e7d44b5af..764c3440c 100644 --- a/tests/unit/tasks/test_actions_subtask.py +++ b/tests/unit/tasks/test_actions_subtask.py @@ -4,9 +4,10 @@ from griptape.artifacts import ActionArtifact, ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact +from griptape.artifacts.json_artifact import JsonArtifact from griptape.common import ToolAction from griptape.structures import Agent -from griptape.tasks import ActionsSubtask, PromptTask +from griptape.tasks import ActionsSubtask, PromptTask, ToolkitTask from tests.mocks.mock_tool.tool import MockTool @@ -257,3 +258,68 @@ def test_origin_task(self): with pytest.raises(Exception, match="ActionSubtask has no origin task."): assert ActionsSubtask("test").origin_task + + def test_structured_output_tool(self): + import schema + + from griptape.tools.structured_output.tool import StructuredOutputTool + + actions = ListArtifact( + [ + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool", + path="provide_output", + input={"values": {"test": "value"}}, + ) + ), + ] + ) + + task = ToolkitTask(tools=[StructuredOutputTool(output_schema=schema.Schema({"test": str}))]) + Agent().add_task(task) + subtask = task.add_subtask(ActionsSubtask(actions)) + + assert isinstance(subtask.output, JsonArtifact) + assert subtask.output.value == {"test": "value"} + + def test_structured_output_tool_multiple(self): + import schema + + from griptape.tools.structured_output.tool import StructuredOutputTool + + actions = ListArtifact( + [ + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool1", + path="provide_output", + input={"values": {"test1": "value"}}, + ) + ), + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool2", + path="provide_output", + input={"values": {"test2": "value"}}, + ) + ), + ] + ) + + task = ToolkitTask( + tools=[ + StructuredOutputTool(name="StructuredOutputTool1", output_schema=schema.Schema({"test": str})), + StructuredOutputTool(name="StructuredOutputTool2", output_schema=schema.Schema({"test": str})), + ] + ) + Agent().add_task(task) + subtask = task.add_subtask(ActionsSubtask(actions)) + + assert isinstance(subtask.output, ListArtifact) + assert len(subtask.output.value) == 2 + assert subtask.output.value[0].value == {"test1": "value"} + assert subtask.output.value[1].value == {"test2": "value"} diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index f457a4b55..d146d2249 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,9 +1,13 @@ +import pytest +import schema + from griptape.artifacts.image_artifact import ImageArtifact from griptape.artifacts.list_artifact import ListArtifact from griptape.artifacts.text_artifact import TextArtifact from griptape.memory.structure import ConversationMemory from griptape.memory.structure.run import Run from griptape.rules import Rule +from griptape.rules.json_schema_rule import JsonSchemaRule from griptape.rules.ruleset import Ruleset from griptape.structures import Pipeline from griptape.tasks import PromptTask @@ -172,6 +176,15 @@ def test_prompt_stack_empty_system_content(self): assert task.prompt_stack.messages[2].is_user() assert task.prompt_stack.messages[2].to_text() == "test value" + def test_prompt_stack_empty_native_schema(self): + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver(), + rules=[JsonSchemaRule({"foo": {}})], + ) + + assert task.prompt_stack.output_schema is None + def test_rulesets(self): pipeline = Pipeline( rulesets=[Ruleset("Pipeline Ruleset")], @@ -227,3 +240,19 @@ def test_subtasks(self): task.run() assert len(task.subtasks) == 2 + + @pytest.mark.parametrize("structured_output_strategy", ["native", "rule"]) + def test_parse_output(self, structured_output_strategy): + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver( + structured_output_strategy=structured_output_strategy, + mock_structured_output={"foo": "bar"}, + ), + output_schema=schema.Schema({"foo": str}), + ) + + task.run() + + assert task.output is not None + assert task.output.value == {"foo": "bar"} diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index ca0576ebe..5c7f6b394 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -257,6 +257,7 @@ def test_to_dict(self): "stream": False, "temperature": 0.1, "type": "MockPromptDriver", + "structured_output_strategy": "rule", "use_native_tools": False, }, "tool": { diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 3c17ff479..a5e95f4d1 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -399,6 +399,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "structured_output_strategy": "rule", }, "tools": [ { diff --git a/tests/unit/tools/test_structured_output_tool.py b/tests/unit/tools/test_structured_output_tool.py new file mode 100644 index 000000000..d310b2f9b --- /dev/null +++ b/tests/unit/tools/test_structured_output_tool.py @@ -0,0 +1,13 @@ +import pytest +import schema + +from griptape.tools import StructuredOutputTool + + +class TestStructuredOutputTool: + @pytest.fixture() + def tool(self): + return StructuredOutputTool(output_schema=schema.Schema({"foo": "bar"})) + + def test_provide_output(self, tool): + assert tool.provide_output({"values": {"foo": "bar"}}).value == {"foo": "bar"}