diff --git a/backend-python/routes/completion.py b/backend-python/routes/completion.py index 5ee7af25..8a5750ae 100644 --- a/backend-python/routes/completion.py +++ b/backend-python/routes/completion.py @@ -123,7 +123,7 @@ class CompletionBody(ModelConfigBody): async def eval_rwkv( model: AbstractRWKV, request: Request, - body: ModelConfigBody | ChatCompletionBody, + body: ModelConfigBody, prompt: str, stream: bool, stop: Union[str, List[str], None], @@ -160,41 +160,44 @@ async def eval_rwkv( response, prompt_tokens, completion_tokens = "", 0, 0 completion_start_time = None - for response, delta, prompt_tokens, completion_tokens in model.generate( - prompt, - stop=stop, - ): - if not completion_start_time: - completion_start_time = time.time() - if await request.is_disconnected(): - break - if stream: - yield json.dumps( - { - "object": ( - "chat.completion.chunk" - if chat_mode - else "text_completion" - ), - # "response": response, - "model": model.name, - "choices": [ - ( - { - "delta": {"content": delta}, - "index": 0, - "finish_reason": None, - } + try: + for response, delta, prompt_tokens, completion_tokens in model.generate( + prompt, + stop=stop, + ): + if not completion_start_time: + completion_start_time = time.time() + if await request.is_disconnected(): + break + if stream: + yield json.dumps( + { + "object": ( + "chat.completion.chunk" if chat_mode - else { - "text": delta, - "index": 0, - "finish_reason": None, - } - ) - ], - } - ) + else "text_completion" + ), + # "response": response, + "model": model.name, + "choices": [ + ( + { + "delta": {"content": delta}, + "index": 0, + "finish_reason": None, + } + if chat_mode + else { + "text": delta, + "index": 0, + "finish_reason": None, + } + ) + ], + } + ) + except: + pass # torch_gc() requests_num = requests_num - 1 completion_end_time = time.time() @@ -245,9 +248,7 @@ async def eval_rwkv( yield "[DONE]" else: # !stream yield { - "id": "", "object": "chat.completion" if chat_mode else "text_completion", - "created": int(time.time()), "model": model.name, "choices": [ ( @@ -360,32 +361,31 @@ def chat_template( ) system = "System" if body.system_name is None else body.system_name - tool = "Obersavtion" + tool = "Observation" for message in body.messages: append_message: str = "" - match message.role: - case Role.User.value: - append_message = f"{user}{interface} " + message.content - case Role.Assistant.value: - if message.content is None: - if message.tool_calls and len(message.tool_calls) > 0: - name = message.tool_calls[0].function.name - arguments = json.loads(message.tool_calls[0].function.arguments) - arguments = ", ".join( - [f'"{k}"="{v}"' for k, v in arguments.items()] - ) - append_message = ( - f"{bot}{interface} " - + f"{name}\n```python\ntool_call({arguments})\n```" - ) - else: - continue + if message.role == Role.User.value: + append_message = f"{user}{interface} " + message.content + elif message.role == Role.Assistant.value: + if message.content is None: + if message.tool_calls and len(message.tool_calls) > 0: + name = message.tool_calls[0].function.name + arguments = json.loads(message.tool_calls[0].function.arguments) + arguments = ", ".join( + [f'"{k}"="{v}"' for k, v in arguments.items()] + ) + append_message = ( + f"{bot}{interface} " + + f"{name}\n```python\ntool_call({arguments})\n```" + ) else: - append_message = f"{bot}{interface} " + message.content - case Role.System.value: - append_message = f"{system}{interface} " + message.content - case Role.Tool.value: - append_message = f"{tool}{interface} " + message.content + continue + else: + append_message = f"{bot}{interface} " + message.content + elif message.role == Role.System.value: + append_message = f"{system}{interface} " + message.content + elif message.role == Role.Tool.value: + append_message = f"{tool}{interface} " + message.content completion_text += append_message + "\n\n" completion_text += f"{bot}{interface}" return completion_text @@ -422,19 +422,33 @@ async def chat_completions(body: ChatCompletionBody, request: Request): # if not body.presystem: # body.stop.append("\n\n") - if body.tool_choice != "none" and body.tools is not None and len(body.tools) > 0: + if ( + body.tool_choice != "none" and body.tools is not None and len(body.tools) > 0 + ) or body.messages[-1].role == Role.Tool.value: return await chat_with_tools(model, body, request, completion_text) else: return await chat(model, body, request, completion_text) +tool_call_id_timestamps = {} + + async def chat_with_tools( model: TextRWKV, body: ChatCompletionBody, request: Request, completion_text: str ): - system = "System" if body.system_name is None else body.system_name + system = "System" interface = model.interface - tools = [tool.function for tool in body.tools] - tools_text = json.dumps(jsonable_encoder(tools), indent=2) + is_with_tool_call_id = body.messages[-1].role == Role.Tool.value + if is_with_tool_call_id: + tool_call_id = body.messages[-1].tool_call_id + tools_text = tool_call_id_timestamps.get(tool_call_id) + else: + tools = [tool.function for tool in body.tools] + tools_text = json.dumps(jsonable_encoder(tools), indent=2) + tool_call_id = generate_tool_call_id() + tool_call_id_timestamps[tool_call_id] = tools_text + if len(tool_call_id_timestamps) > 1000: + tool_call_id_timestamps.pop(next(iter(tool_call_id_timestamps))) # Function Call Prompts tools_text = f"""\ @@ -443,30 +457,41 @@ async def chat_with_tools( completion_text = tools_text + "\n" + completion_text + if is_with_tool_call_id: + return await chat(model, body, request, completion_text) if body.stream: - response = async_generator_stream_respose(model, body, request, completion_text) + response = async_generator_stream_response_tool_call( + model, body, request, completion_text, tool_call_id + ) return EventSourceResponse(response) else: response = await chat(model, body, request, completion_text) - response = postprocess_response(response) + if response is not None: + response = postprocess_response(response, tool_call_id) return response -async def async_generator_stream_respose( - model: TextRWKV, body: ChatCompletionBody, request: Request, completion_text: str +def generate_tool_call_id(): + return "call_" + "".join(random.sample(string.ascii_letters + string.digits, 24)) + + +async def async_generator_stream_response_tool_call( + model: TextRWKV, + body: ChatCompletionBody, + request: Request, + completion_text: str, + tool_call_id: str, ): # NOTE: There is none of existing failure analysis. # Initialization gen = eval_rwkv( model, request, body, completion_text, body.stream, body.stop, True - ) # Get an asnyc generator handle + ) # Get an async generator handle content: str = "" - function_id: str = "call_" + "".join( - random.sample(string.ascii_letters + string.digits, 24) - ) flag_is_function_call_confirmed = False flag_is_common_confirmed = False + convert_equal_to_colon = False # Loop, there is only one existing endpoint. done = False @@ -483,18 +508,15 @@ async def async_generator_stream_respose( } ) yield "[DONE]" + break try: - response = await anext(gen) # Generate a delta response + response = await gen.__anext__() # Generate a delta response if response == "[DONE]": done = True continue except StopAsyncIteration: - # Too few inference result - if not flag_is_function_call_confirmed and not flag_is_common_confirmed: - response_decoded["choices"][0]["delta"]["content"] = content - yield json.dumps(response_decoded) - break # The EXPECTED endpoint of the loop and the function + break if flag_is_common_confirmed: yield response @@ -502,9 +524,10 @@ async def async_generator_stream_respose( # Post process response response_decoded = json.loads(response) # Decode string - if response_decoded["choices"][0]["delta"] == {}: + delta = response_decoded["choices"][0]["delta"] + if delta == {}: continue - delta_content = response_decoded["choices"][0]["delta"]["content"] + delta_content: str = delta["content"] content += delta_content if flag_is_function_call_confirmed: @@ -521,10 +544,18 @@ async def async_generator_stream_respose( if ( pair[0] in stack and pair[1] in stack - and stack.index(pair[0]) < stack.index(pair[1]) + and ( + ( + pair[0] != pair[1] + and stack.index(pair[0]) < stack.index(pair[1]) + ) + or (pair[0] == pair[1] and stack.count(pair[0]) >= 2) + ) ): stack.remove(pair[0]) stack.remove(pair[1]) + if pair[0] == '"' or pair[0] == "'": + convert_equal_to_colon = True if "(" not in stack and ")" not in stack: done = True response_decoded["choices"][0]["delta"] = { @@ -534,8 +565,16 @@ async def async_generator_stream_respose( "function": { "arguments": ( '"' - if delta_content.startswith('"') - else "" + if delta_content.strip().startswith( + '"' + ) + else ( + "'" + if delta_content.strip().startswith( + "'" + ) + else "" + ) ) + "}", }, @@ -547,8 +586,9 @@ async def async_generator_stream_respose( if done: continue - delta_content = delta_content.replace("=", ":") - # content = content.replace(r'"', r"\"") # XXX: Check whether to reserve. + if "=" in delta_content and convert_equal_to_colon: + delta_content = delta_content.replace("=", ":") + convert_equal_to_colon = False response_decoded["choices"][0]["delta"]["content"] = None response_decoded["choices"][0]["delta"] = { "tool_calls": [ @@ -568,11 +608,11 @@ async def async_generator_stream_respose( # Unconfirmed Response, check content field by the followings: # Up to 4 line feeds: Common Response. # Up to 60 characters: Common Response. - # Up to 44 charaters under markdown code block unclosed: Common Response. - # Feild "```Functionname\ntool_call(...)```" detected: Function Call Response. + # Up to 44 characters under markdown code block unclosed: Common Response. + # Field "```FunctionName\ntool_call(...)```" detected: Function Call Response. # - There will be 2 responses generated. # Default: Unsure Response. - # - Recheck with the next delta.content feild added. + # - Recheck with the next delta.content field added. """ # Constant LIMIT_LINE_FEEDS = 4 @@ -581,7 +621,7 @@ async def async_generator_stream_respose( REGEX_BLOCKS_HEADERS = r"([\w]+)[\s]*```[\w\s]*tool_call\(" # Regex - regex_match_function_call_head: re.Match | None = re.search( + regex_match_function_call_head: Union[re.Match, None] = re.search( REGEX_BLOCKS_HEADERS, content ) @@ -595,10 +635,7 @@ async def async_generator_stream_respose( ) ): flag_is_common_confirmed = True - response_decoded["choices"][0]["delta"]["content"] = content - yield json.dumps(response_decoded) - del response_decoded - del content + yield response continue # Confirm Function call Response @@ -613,13 +650,11 @@ async def async_generator_stream_respose( # Generate a function call details response name = regex_match_function_call_head.group(1) - del response_decoded["choices"][0]["delta"]["role"] - del response_decoded["choices"][0]["delta"]["content"] response_decoded["choices"][0]["delta"] = { "tool_calls": [ { "index": 0, - "id": function_id, + "id": tool_call_id, "type": "function", "function": { "name": name, @@ -635,15 +670,20 @@ async def async_generator_stream_respose( "index": 0, "function": { "arguments": "{" - + ('"' if delta_content.endswith('"') else ""), + + ( + '"' + if delta_content.strip().endswith('"') + else ( + "'" + if delta_content.strip().endswith("'") + else "" + ) + ), }, } ] } yield json.dumps(response_decoded) - - # Reset content buffer - # content = feild_function_call_block.group(2) continue # Default: Unsure Response @@ -651,7 +691,7 @@ async def async_generator_stream_respose( # End of loop body -def postprocess_response(response: dict): +def postprocess_response(response: dict, tool_call_id: str): # NOTE: There is none of existing failure analysis. REGEX_BLOCKS = r"([\w]+)[\s]*```[\w\s]*tool_call(.*?)\n*```" REGEX_ARGS = r'[\'"]([^\'"]+)[\'"]\s*=\s*[\'"]([^\'"]+)[\'"]' @@ -664,12 +704,14 @@ def postprocess_response(response: dict): name = regex_match.group(1) function = regex_match.group(2).strip() - arguments = json.dumps(dict(re.findall(REGEX_ARGS, function))) + try: + arguments = json.dumps(dict(re.findall(REGEX_ARGS, function))) + except: + return response tool_calls = [ { - "id": "call_" - + "".join(random.sample(string.ascii_letters + string.digits, 24)), + "id": tool_call_id, "type": "function", "function": { "name": name, @@ -680,7 +722,6 @@ def postprocess_response(response: dict): response["choices"][0]["message"]["tool_calls"] = tool_calls response["choices"][0]["message"]["content"] = None - response["choices"][0]["logprobs"] = None response["choices"][0]["finish_reason"] = "tool_calls" return response @@ -703,7 +744,7 @@ def postprocess_response(response: dict): # arguments = json.dumps(dict(re.findall(REGEX_ARGS, function))) # tool_calls.append( # { -# "id": "call_" + "".join(random.sample(string.ascii_letters + string.digits, 24)), +# "id": tool_call_id, # "type": "function", # "function": { # "name": name, @@ -714,7 +755,6 @@ def postprocess_response(response: dict): # response["choices"][0]["message"]["tool_calls"] = tool_calls # response["choices"][0]["message"]["content"] = None -# response["choices"][0]["logprobs"] = None # response["choices"][0]["finish_reason"] = "tool_calls" # return response