From edb86f5f5a6f0471a79dece2c3f6e0efa240f793 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:45:59 +0800 Subject: [PATCH] Feat/stream react (#2498) --- api/core/features/assistant_cot_runner.py | 309 +++++++++++----------- 1 file changed, 159 insertions(+), 150 deletions(-) diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py index c8477fb5d98727..aa4a6797cd6b56 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/features/assistant_cot_runner.py @@ -133,61 +133,95 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # recale llm max tokens self.recale_llm_max_tokens(self.model_config, prompt_messages) # invoke model - llm_result: LLMResult = model_instance.invoke_llm( + chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=app_orchestration_config.model_config.parameters, tools=[], stop=app_orchestration_config.model_config.stop, - stream=False, + stream=True, user=self.user_id, callbacks=[], ) # check llm result - if not llm_result: + if not chunks: raise ValueError("failed to invoke llm") - - # get scratchpad - scratchpad = self._extract_response_scratchpad(llm_result.message.content) - agent_scratchpad.append(scratchpad) - - # get llm usage - if llm_result.usage: - increase_usage(llm_usage, llm_result.usage) + usage_dict = {} + react_chunks = self._handle_stream_react(chunks, usage_dict) + scratchpad = AgentScratchpadUnit( + agent_response='', + thought='', + action_str='', + observation='', + action=None + ) + # publish agent thought if it's first iteration if iteration_step == 1: self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + for chunk in react_chunks: + if isinstance(chunk, dict): + scratchpad.agent_response += json.dumps(chunk) + try: + if scratchpad.action: + raise Exception("") + scratchpad.action_str = json.dumps(chunk) + scratchpad.action = AgentScratchpadUnit.Action( + action_name=chunk['action'], + action_input=chunk['action_input'] + ) + except: + scratchpad.thought += json.dumps(chunk) + yield LLMResultChunk( + model=self.model_config.model, + prompt_messages=prompt_messages, + system_fingerprint='', + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=json.dumps(chunk) + ), + usage=None + ) + ) + else: + scratchpad.agent_response += chunk + scratchpad.thought += chunk + yield LLMResultChunk( + model=self.model_config.model, + prompt_messages=prompt_messages, + system_fingerprint='', + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=chunk + ), + usage=None + ) + ) + + agent_scratchpad.append(scratchpad) + + # get llm usage + if 'usage' in usage_dict: + increase_usage(llm_usage, usage_dict['usage']) + else: + usage_dict['usage'] = LLMUsage.empty_usage() + self.save_agent_thought(agent_thought=agent_thought, tool_name=scratchpad.action.action_name if scratchpad.action else '', tool_input=scratchpad.action.action_input if scratchpad.action else '', thought=scratchpad.thought, observation='', - answer=llm_result.message.content, + answer=scratchpad.agent_response, messages_ids=[], - llm_usage=llm_result.usage) + llm_usage=usage_dict['usage']) if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) - # publish agent thought if it's not empty and there is a action - if scratchpad.thought and scratchpad.action: - # check if final answer - if not scratchpad.action.action_name.lower() == "final answer": - yield LLMResultChunk( - model=model_instance.model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=scratchpad.thought - ), - usage=llm_result.usage, - ), - system_fingerprint='' - ) - if not scratchpad.action: # failed to extract action, return final answer directly final_answer = scratchpad.agent_response or '' @@ -262,7 +296,6 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # save scratchpad scratchpad.observation = observation - scratchpad.agent_response = llm_result.message.content # save agent thought self.save_agent_thought( @@ -271,7 +304,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_input=tool_call_args, thought=None, observation=observation, - answer=llm_result.message.content, + answer=scratchpad.agent_response, messages_ids=message_file_ids, ) self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) @@ -318,6 +351,97 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): system_fingerprint='' ), PublishFrom.APPLICATION_MANAGER) + def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \ + -> Generator[Union[str, dict], None, None]: + def parse_json(json_str): + try: + return json.loads(json_str.strip()) + except: + return json_str + + def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: + code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL) + if not code_blocks: + return + for block in code_blocks: + json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE) + yield parse_json(json_text) + + code_block_cache = '' + code_block_delimiter_count = 0 + in_code_block = False + json_cache = '' + json_quote_count = 0 + in_json = False + got_json = False + + for response in llm_response: + response = response.delta.message.content + if not isinstance(response, str): + continue + + # stream + index = 0 + while index < len(response): + steps = 1 + delta = response[index:index+steps] + if delta == '`': + code_block_cache += delta + code_block_delimiter_count += 1 + else: + if not in_code_block: + if code_block_delimiter_count > 0: + yield code_block_cache + code_block_cache = '' + else: + code_block_cache += delta + code_block_delimiter_count = 0 + + if code_block_delimiter_count == 3: + if in_code_block: + yield from extra_json_from_code_block(code_block_cache) + code_block_cache = '' + + in_code_block = not in_code_block + code_block_delimiter_count = 0 + + if not in_code_block: + # handle single json + if delta == '{': + json_quote_count += 1 + in_json = True + json_cache += delta + elif delta == '}': + json_cache += delta + if json_quote_count > 0: + json_quote_count -= 1 + if json_quote_count == 0: + in_json = False + got_json = True + index += steps + continue + else: + if in_json: + json_cache += delta + + if got_json: + got_json = False + yield parse_json(json_cache) + json_cache = '' + json_quote_count = 0 + in_json = False + + if not in_code_block and not in_json: + yield delta.replace('`', '') + + index += steps + + if code_block_cache: + yield code_block_cache + + if json_cache: + yield parse_json(json_cache) + def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: """ fill in inputs from external data tools @@ -363,121 +487,6 @@ def _init_agent_scratchpad(self, return agent_scratchpad - def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit: - """ - extract response from llm response - """ - def extra_quotes() -> AgentScratchpadUnit: - agent_response = content - # try to extract all quotes - pattern = re.compile(r'```(.*?)```', re.DOTALL) - quotes = pattern.findall(content) - - # try to extract action from end to start - for i in range(len(quotes) - 1, 0, -1): - """ - 1. use json load to parse action - 2. use plain text `Action: xxx` to parse action - """ - try: - action = json.loads(quotes[i].replace('```', '')) - action_name = action.get("action") - action_input = action.get("action_input") - agent_thought = agent_response.replace(quotes[i], '') - - if action_name and action_input: - return AgentScratchpadUnit( - agent_response=content, - thought=agent_thought, - action_str=quotes[i], - action=AgentScratchpadUnit.Action( - action_name=action_name, - action_input=action_input, - ) - ) - except: - # try to parse action from plain text - action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE) - action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE) - # delete action from agent response - agent_thought = agent_response.replace(quotes[i], '') - # remove extra quotes - agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL) - # remove Action: xxx from agent thought - agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE) - - if action_name and action_input: - return AgentScratchpadUnit( - agent_response=content, - thought=agent_thought, - action_str=quotes[i], - action=AgentScratchpadUnit.Action( - action_name=action_name[0], - action_input=action_input[0], - ) - ) - - def extra_json(): - agent_response = content - # try to extract all json - structures, pair_match_stack = [], [] - started_at, end_at = 0, 0 - for i in range(len(content)): - if content[i] == '{': - pair_match_stack.append(i) - if len(pair_match_stack) == 1: - started_at = i - elif content[i] == '}': - begin = pair_match_stack.pop() - if not pair_match_stack: - end_at = i + 1 - structures.append((content[begin:i+1], (started_at, end_at))) - - # handle the last character - if pair_match_stack: - end_at = len(content) - structures.append((content[pair_match_stack[0]:], (started_at, end_at))) - - for i in range(len(structures), 0, -1): - try: - json_content, (started_at, end_at) = structures[i - 1] - action = json.loads(json_content) - action_name = action.get("action") - action_input = action.get("action_input") - # delete json content from agent response - agent_thought = agent_response[:started_at] + agent_response[end_at:] - # remove extra quotes like ```(json)*\n\n``` - agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL) - # remove Action: xxx from agent thought - agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE) - - if action_name and action_input is not None: - return AgentScratchpadUnit( - agent_response=content, - thought=agent_thought, - action_str=json_content, - action=AgentScratchpadUnit.Action( - action_name=action_name, - action_input=action_input, - ) - ) - except: - pass - - agent_scratchpad = extra_quotes() - if agent_scratchpad: - return agent_scratchpad - agent_scratchpad = extra_json() - if agent_scratchpad: - return agent_scratchpad - - return AgentScratchpadUnit( - agent_response=content, - thought=content, - action_str='', - action=None - ) - def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], agent_prompt_message: AgentPromptEntity, ): @@ -591,15 +600,15 @@ def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"], # organize prompt messages if mode == "chat": # override system message - overrided = False + overridden = False prompt_messages = prompt_messages.copy() for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): prompt_message.content = system_message - overrided = True + overridden = True break - if not overrided: + if not overridden: prompt_messages.insert(0, SystemPromptMessage( content=system_message, ))