Skip to content

Commit

Permalink
Feat/stream react (#2498)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly authored Feb 21, 2024
1 parent adf2651 commit edb86f5
Showing 1 changed file with 159 additions and 150 deletions.
309 changes: 159 additions & 150 deletions api/core/features/assistant_cot_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,61 +133,95 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
# recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
llm_result: LLMResult = model_instance.invoke_llm(
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
tools=[],
stop=app_orchestration_config.model_config.stop,
stream=False,
stream=True,
user=self.user_id,
callbacks=[],
)

# check llm result
if not llm_result:
if not chunks:
raise ValueError("failed to invoke llm")

# get scratchpad
scratchpad = self._extract_response_scratchpad(llm_result.message.content)
agent_scratchpad.append(scratchpad)

# get llm usage
if llm_result.usage:
increase_usage(llm_usage, llm_result.usage)

usage_dict = {}
react_chunks = self._handle_stream_react(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response='',
thought='',
action_str='',
observation='',
action=None
)

# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)

for chunk in react_chunks:
if isinstance(chunk, dict):
scratchpad.agent_response += json.dumps(chunk)
try:
if scratchpad.action:
raise Exception("")
scratchpad.action_str = json.dumps(chunk)
scratchpad.action = AgentScratchpadUnit.Action(
action_name=chunk['action'],
action_input=chunk['action_input']
)
except:
scratchpad.thought += json.dumps(chunk)
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint='',
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=json.dumps(chunk)
),
usage=None
)
)
else:
scratchpad.agent_response += chunk
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint='',
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=chunk
),
usage=None
)
)

agent_scratchpad.append(scratchpad)

# get llm usage
if 'usage' in usage_dict:
increase_usage(llm_usage, usage_dict['usage'])
else:
usage_dict['usage'] = LLMUsage.empty_usage()

self.save_agent_thought(agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input=scratchpad.action.action_input if scratchpad.action else '',
thought=scratchpad.thought,
observation='',
answer=llm_result.message.content,
answer=scratchpad.agent_response,
messages_ids=[],
llm_usage=llm_result.usage)
llm_usage=usage_dict['usage'])

if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)

# publish agent thought if it's not empty and there is a action
if scratchpad.thought and scratchpad.action:
# check if final answer
if not scratchpad.action.action_name.lower() == "final answer":
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=scratchpad.thought
),
usage=llm_result.usage,
),
system_fingerprint=''
)

if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = scratchpad.agent_response or ''
Expand Down Expand Up @@ -262,7 +296,6 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):

# save scratchpad
scratchpad.observation = observation
scratchpad.agent_response = llm_result.message.content

# save agent thought
self.save_agent_thought(
Expand All @@ -271,7 +304,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
tool_input=tool_call_args,
thought=None,
observation=observation,
answer=llm_result.message.content,
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
Expand Down Expand Up @@ -318,6 +351,97 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
system_fingerprint=''
), PublishFrom.APPLICATION_MANAGER)

def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
-> Generator[Union[str, dict], None, None]:
def parse_json(json_str):
try:
return json.loads(json_str.strip())
except:
return json_str

def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
if not code_blocks:
return
for block in code_blocks:
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
yield parse_json(json_text)

code_block_cache = ''
code_block_delimiter_count = 0
in_code_block = False
json_cache = ''
json_quote_count = 0
in_json = False
got_json = False

for response in llm_response:
response = response.delta.message.content
if not isinstance(response, str):
continue

# stream
index = 0
while index < len(response):
steps = 1
delta = response[index:index+steps]
if delta == '`':
code_block_cache += delta
code_block_delimiter_count += 1
else:
if not in_code_block:
if code_block_delimiter_count > 0:
yield code_block_cache
code_block_cache = ''
else:
code_block_cache += delta
code_block_delimiter_count = 0

if code_block_delimiter_count == 3:
if in_code_block:
yield from extra_json_from_code_block(code_block_cache)
code_block_cache = ''

in_code_block = not in_code_block
code_block_delimiter_count = 0

if not in_code_block:
# handle single json
if delta == '{':
json_quote_count += 1
in_json = True
json_cache += delta
elif delta == '}':
json_cache += delta
if json_quote_count > 0:
json_quote_count -= 1
if json_quote_count == 0:
in_json = False
got_json = True
index += steps
continue
else:
if in_json:
json_cache += delta

if got_json:
got_json = False
yield parse_json(json_cache)
json_cache = ''
json_quote_count = 0
in_json = False

if not in_code_block and not in_json:
yield delta.replace('`', '')

index += steps

if code_block_cache:
yield code_block_cache

if json_cache:
yield parse_json(json_cache)

def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
"""
fill in inputs from external data tools
Expand Down Expand Up @@ -363,121 +487,6 @@ def _init_agent_scratchpad(self,

return agent_scratchpad

def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
"""
extract response from llm response
"""
def extra_quotes() -> AgentScratchpadUnit:
agent_response = content
# try to extract all quotes
pattern = re.compile(r'```(.*?)```', re.DOTALL)
quotes = pattern.findall(content)

# try to extract action from end to start
for i in range(len(quotes) - 1, 0, -1):
"""
1. use json load to parse action
2. use plain text `Action: xxx` to parse action
"""
try:
action = json.loads(quotes[i].replace('```', ''))
action_name = action.get("action")
action_input = action.get("action_input")
agent_thought = agent_response.replace(quotes[i], '')

if action_name and action_input:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,
action_str=quotes[i],
action=AgentScratchpadUnit.Action(
action_name=action_name,
action_input=action_input,
)
)
except:
# try to parse action from plain text
action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
# delete action from agent response
agent_thought = agent_response.replace(quotes[i], '')
# remove extra quotes
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
# remove Action: xxx from agent thought
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)

if action_name and action_input:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,
action_str=quotes[i],
action=AgentScratchpadUnit.Action(
action_name=action_name[0],
action_input=action_input[0],
)
)

def extra_json():
agent_response = content
# try to extract all json
structures, pair_match_stack = [], []
started_at, end_at = 0, 0
for i in range(len(content)):
if content[i] == '{':
pair_match_stack.append(i)
if len(pair_match_stack) == 1:
started_at = i
elif content[i] == '}':
begin = pair_match_stack.pop()
if not pair_match_stack:
end_at = i + 1
structures.append((content[begin:i+1], (started_at, end_at)))

# handle the last character
if pair_match_stack:
end_at = len(content)
structures.append((content[pair_match_stack[0]:], (started_at, end_at)))

for i in range(len(structures), 0, -1):
try:
json_content, (started_at, end_at) = structures[i - 1]
action = json.loads(json_content)
action_name = action.get("action")
action_input = action.get("action_input")
# delete json content from agent response
agent_thought = agent_response[:started_at] + agent_response[end_at:]
# remove extra quotes like ```(json)*\n\n```
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
# remove Action: xxx from agent thought
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)

if action_name and action_input is not None:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,
action_str=json_content,
action=AgentScratchpadUnit.Action(
action_name=action_name,
action_input=action_input,
)
)
except:
pass

agent_scratchpad = extra_quotes()
if agent_scratchpad:
return agent_scratchpad
agent_scratchpad = extra_json()
if agent_scratchpad:
return agent_scratchpad

return AgentScratchpadUnit(
agent_response=content,
thought=content,
action_str='',
action=None
)

def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
agent_prompt_message: AgentPromptEntity,
):
Expand Down Expand Up @@ -591,15 +600,15 @@ def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
# organize prompt messages
if mode == "chat":
# override system message
overrided = False
overridden = False
prompt_messages = prompt_messages.copy()
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
prompt_message.content = system_message
overrided = True
overridden = True
break

if not overrided:
if not overridden:
prompt_messages.insert(0, SystemPromptMessage(
content=system_message,
))
Expand Down

0 comments on commit edb86f5

Please sign in to comment.