-
Notifications
You must be signed in to change notification settings - Fork 4
/
server.py
133 lines (110 loc) · 3.72 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# -*- coding: utf-8 -*-
import logging.config
import threading
from typing import List, Tuple
from fastapi import APIRouter, Body, FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from langchain.agents import tool
from langchain_community.tools import ShellTool
from langchain_core.agents import AgentAction
from pydantic.v1 import Extra, Field
from sse_starlette.sse import EventSourceResponse
from uvicorn import Config, Server
from zhipuai.core.logs import (
get_config_dict,
get_log_file,
get_timestamp_ms,
)
from langchain_glm.agent_toolkits import BaseToolOutput
from langchain_glm.agents.zhipuai_all_tools import ZhipuAIAllToolsRunnable
from langchain_glm.agents.zhipuai_all_tools.base import OutputType
@tool
def calculate(text: str = Field(description="a math expression")) -> BaseToolOutput:
"""
Useful to answer questions about simple calculations.
translate user question to a math expression that can be evaluated by numexpr.
"""
import numexpr
try:
ret = str(numexpr.evaluate(text))
except Exception as e:
ret = f"wrong: {e}"
return BaseToolOutput(ret)
@tool
def shell(query: str = Field(description="The command to execute")):
"""Use Shell to execute system shell commands"""
tool = ShellTool()
return BaseToolOutput(tool.run(tool_input=query))
intermediate_steps: List[Tuple[AgentAction, BaseToolOutput]] = []
async def chat(
query: str = Body(..., description="用户输入", examples=["帮我计算100+1"]),
message_id: str = Body(None, description="数据库消息ID"),
history: List = Body(
[],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[
[
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "有什么需要帮助的"},
]
],
),
):
"""Agent 对话"""
agent_executor = ZhipuAIAllToolsRunnable.create_agent_executor(
model_name="glm-4-alltools",
history=history,
intermediate_steps=intermediate_steps,
tools=[
{"type": "code_interpreter"},
{"type": "web_browser"},
{"type": "drawing_tool"},
calculate,
],
)
chat_iterator = agent_executor.invoke(chat_input=query)
async def chat_generator():
async for chat_output in chat_iterator:
yield chat_output.to_json()
# if agent_executor.callback.out:
# intermediate_steps.extend(agent_executor.callback.intermediate_steps)
return EventSourceResponse(chat_generator())
if __name__ == "__main__":
logging_conf = get_config_dict(
"debug",
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
1024 * 1024 * 1024 * 3,
1024 * 1024 * 1024 * 3,
)
logging.config.dictConfig(logging_conf) # type: ignore
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
chat_router = APIRouter()
chat_router.add_api_route(
"/chat",
chat,
response_model=OutputType,
status_code=status.HTTP_200_OK,
methods=["POST"],
description="与llm模型对话(通过LLMChain)",
)
app.include_router(chat_router)
config = Config(
app=app,
host="127.0.0.1",
port=10000,
log_config=logging_conf,
)
_server = Server(config)
def run_server():
_server.shutdown_timeout = 2 # 设置为2秒
_server.run()
_server_thread = threading.Thread(target=run_server)
_server_thread.start()
_server_thread.join()