From 5ac6485b594d4fc2f1c79e3401fddb686ef75d65 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Wed, 17 Jul 2024 16:00:47 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=BC=E5=AE=B9PYDANTIC=5FV2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- langchain_glm/agents/output_parsers/tools.py | 2 +- .../agents/zhipuai_all_tools/base.py | 4 ++-- langchain_glm/embeddings/base.py | 19 ++++++++----------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/langchain_glm/agents/output_parsers/tools.py b/langchain_glm/agents/output_parsers/tools.py index 023830b..bc3758a 100644 --- a/langchain_glm/agents/output_parsers/tools.py +++ b/langchain_glm/agents/output_parsers/tools.py @@ -181,7 +181,7 @@ def parse_ai_message_to_tool_action( actions.append(function_tool_result_stack.popleft()) else: for too_call in tool_calls: - if "function" == too_call["name"]: + if too_call["name"] not in AdapterAllToolStructType.__members__.values(): actions.append(function_tool_result_stack.popleft()) elif too_call["name"] == AdapterAllToolStructType.CODE_INTERPRETER: actions.append(code_interpreter_action_result_stack.popleft()) diff --git a/langchain_glm/agents/zhipuai_all_tools/base.py b/langchain_glm/agents/zhipuai_all_tools/base.py index 0e76cea..a10e69e 100644 --- a/langchain_glm/agents/zhipuai_all_tools/base.py +++ b/langchain_glm/agents/zhipuai_all_tools/base.py @@ -141,7 +141,7 @@ class ZhipuAIAllToolsRunnable(RunnableSerializable[Dict, OutputType]): agent_executor: AgentExecutor """ZhipuAI AgentExecutor.""" - model_name: str = Field(default="tob-alltools-api-dev") + model_name: str = Field(default="glm-4-alltools") """工具模型""" callback: AgentExecutorAsyncIteratorCallbackHandler """ZhipuAI AgentExecutor callback.""" @@ -193,7 +193,7 @@ def create_agent_executor( streaming=True, verbose=True, callbacks=callbacks, - model_name=model_name, + model=model_name, temperature=temperature, **kwargs, ) diff --git a/langchain_glm/embeddings/base.py b/langchain_glm/embeddings/base.py index afb6261..f1c1650 100644 --- a/langchain_glm/embeddings/base.py +++ b/langchain_glm/embeddings/base.py @@ -22,6 +22,7 @@ import zhipuai from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import ( + BaseModel, Extra, Field, SecretStr, @@ -32,8 +33,6 @@ get_from_dict_or_env, get_pydantic_field_names, ) -from typing_extensions import ClassVar -from zhipuai.core import PYDANTIC_V2, BaseModel, ConfigDict logger = logging.getLogger(__name__) @@ -84,16 +83,14 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings): http_client: Union[Any, None] = None """Optional httpx.Client.""" - if PYDANTIC_V2: - model_config: ClassVar[ConfigDict] = ConfigDict( - extra="forbid", populate_by_name=True - ) - else: - class Config: - allow_population_by_field_name = True + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + allow_population_by_field_name = True - @root_validator(pre=True) + @root_validator(pre=True, allow_reuse=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) @@ -119,7 +116,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["model_kwargs"] = extra return values - @root_validator() + @root_validator(allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" zhipuai_api_key = get_from_dict_or_env(