diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 6941ff8fa270c7..9a98b441e2c187 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Any, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator from core.entities.provider_entities import ProviderConfig from core.plugin.entities.parameters import ( @@ -128,12 +128,13 @@ class VariableMessage(BaseModel): variable_value: str = Field(..., description="The value of the variable") stream: bool = Field(default=False, description="Whether the variable is streamed") - @field_validator("variable_value", mode="before") + @model_validator(mode="before") @classmethod - def transform_variable_value(cls, value, values) -> Any: + def transform_variable_value(cls, values) -> Any: """ Only basic types and lists are allowed. """ + value = values.get("variable_value") if not isinstance(value, dict | list | str | int | float | bool): raise ValueError("Only basic types and lists are allowed.") @@ -142,7 +143,7 @@ def transform_variable_value(cls, value, values) -> Any: if not isinstance(value, str): raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - return value + return values @field_validator("variable_name", mode="before") @classmethod