diff --git a/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs b/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs index f6090ec88..7badc4d82 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs @@ -178,7 +178,7 @@ async fn handle_streaming_response( let mut stream = res.bytes_stream(); let mut response_text = String::new(); - let mut previous_json_chunk: String = String::new(); + let mut buffer = String::new(); let mut function_calls: Vec = Vec::new(); let mut error_message: Option = None; @@ -200,110 +200,121 @@ async fn handle_streaming_response( match item { Ok(chunk) => { let chunk_str = String::from_utf8_lossy(&chunk).to_string(); - previous_json_chunk += chunk_str.as_str(); - let trimmed_chunk_str = previous_json_chunk.trim().to_string(); - let data_resp: Result = serde_json::from_str(&trimmed_chunk_str); - match data_resp { - Ok(data) => { - previous_json_chunk = "".to_string(); + buffer.push_str(&chunk_str); - // Check for error in the data - if let Some(error) = data.get("error") { - let code = error.get("code").and_then(|c| c.as_str()); - let message = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error"); - let formatted_error = format!("{}: {}", code.unwrap_or("Unknown code"), message); + // Process complete messages in the buffer + while let Some(message) = extract_next_complete_message(&mut buffer) { + if message.trim().is_empty() { + continue; + } - return Err(LLMProviderError::LLMServiceUnexpectedError(formatted_error)); - } + // Handle [DONE] message + if message.trim() == "[DONE]" { + continue; + } - if let Some(choices) = data.get("choices") { - for choice in choices.as_array().unwrap_or(&vec![]) { - if let Some(message) = choice.get("message") { - if let Some(content) = message.get("content") { - response_text.push_str(content.as_str().unwrap_or("")); - } - if let Some(fc) = message.get("function_call") { - if let Some(name) = fc.get("name") { - let fc_arguments = fc - .get("arguments") - .and_then(|args| args.as_str()) - .and_then(|args_str| serde_json::from_str(args_str).ok()) - .and_then(|args_value: serde_json::Value| { - args_value.as_object().cloned() - }) - .unwrap_or_else(|| serde_json::Map::new()); - - // Extract tool_router_key - let tool_router_key = tools.as_ref().and_then(|tools_array| { - tools_array.iter().find_map(|tool| { - if tool.get("name")?.as_str()? == name.as_str().unwrap_or("") { - tool.get("tool_router_key") - .and_then(|key| key.as_str().map(|s| s.to_string())) - } else { - None + match serde_json::from_str::(&message) { + Ok(data) => { + // Check for error in the data + if let Some(error) = data.get("error") { + let code = error.get("code").and_then(|c| c.as_str()); + let message = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error"); + let formatted_error = format!("{}: {}", code.unwrap_or("Unknown code"), message); + error_message = Some(formatted_error); + continue; + } + + if let Some(choices) = data.get("choices") { + for choice in choices.as_array().unwrap_or(&vec![]) { + if let Some(delta) = choice.get("delta") { + // Handle content updates + if let Some(content) = delta.get("content") { + if let Some(content_str) = content.as_str() { + response_text.push_str(content_str); + } + } + + // Handle function calls + if let Some(fc) = delta.get("function_call") { + if let Some(name) = fc.get("name") { + let fc_arguments = fc + .get("arguments") + .and_then(|args| args.as_str()) + .and_then(|args_str| serde_json::from_str(args_str).ok()) + .and_then(|args_value: serde_json::Value| { + args_value.as_object().cloned() + }) + .unwrap_or_else(|| serde_json::Map::new()); + + let tool_router_key = tools.as_ref().and_then(|tools_array| { + tools_array.iter().find_map(|tool| { + if tool.get("name")?.as_str()? == name.as_str().unwrap_or("") { + tool.get("tool_router_key") + .and_then(|key| key.as_str().map(|s| s.to_string())) + } else { + None + } + }) + }); + + function_calls.push(FunctionCall { + name: name.as_str().unwrap_or("").to_string(), + arguments: fc_arguments.clone(), + tool_router_key, + response: None, + }); + + // Handle WebSocket updates for function calls + if let Some(ref manager) = ws_manager_trait { + if let Some(ref inbox_name) = inbox_name { + if let Some(last_function_call) = function_calls.last() { + let m = manager.lock().await; + let inbox_name_string = inbox_name.to_string(); + + let function_call_json = serde_json::to_value(last_function_call) + .unwrap_or_else(|_| serde_json::json!({})); + + let tool_metadata = ToolMetadata { + tool_name: last_function_call.name.clone(), + tool_router_key: last_function_call.tool_router_key.clone(), + args: function_call_json.as_object().cloned().unwrap_or_default(), + result: None, + status: ToolStatus { + type_: ToolStatusType::Running, + reason: None, + }, + }; + + let ws_message_type = WSMessageType::Widget(WidgetMetadata::ToolRequest(tool_metadata)); + + let _ = m + .queue_message( + WSTopic::Inbox, + inbox_name_string, + serde_json::to_string(last_function_call) + .unwrap_or_else(|_| "{}".to_string()), + ws_message_type, + true, + ) + .await; + } } - }) - }); - - function_calls.push(FunctionCall { - name: name.as_str().unwrap_or("").to_string(), - arguments: fc_arguments.clone(), - tool_router_key, - response: None, - }); + } + } } } } } } - - // Updated WS message handling for tooling - if let Some(ref manager) = ws_manager_trait { - if let Some(ref inbox_name) = inbox_name { - if let Some(last_function_call) = function_calls.last() { - let m = manager.lock().await; - let inbox_name_string = inbox_name.to_string(); - - // Serialize FunctionCall to JSON value - let function_call_json = - serde_json::to_value(last_function_call).unwrap_or_else(|_| serde_json::json!({})); - - // Prepare ToolMetadata - let tool_metadata = ToolMetadata { - tool_name: last_function_call.name.clone(), - tool_router_key: last_function_call.tool_router_key.clone(), - args: function_call_json.as_object().cloned().unwrap_or_default(), - result: None, - status: ToolStatus { - type_: ToolStatusType::Running, - reason: None, - }, - }; - - let ws_message_type = - WSMessageType::Widget(WidgetMetadata::ToolRequest(tool_metadata)); - - let _ = m - .queue_message( - WSTopic::Inbox, - inbox_name_string, - serde_json::to_string(last_function_call).unwrap_or_else(|_| "{}".to_string()), - ws_message_type, - true, - ) - .await; - } - } + Err(e) => { + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Debug, + format!("Failed to parse message chunk (this may be normal for partial chunks): {:?}", e).as_str(), + ); + // Don't set error_message here as this might just be a partial chunk } } - Err(_e) => { - shinkai_log( - ShinkaiLogOption::JobExecution, - ShinkaiLogLevel::Error, - format!("Error while receiving chunk: {:?} with chunk: {:?}", _e, trimmed_chunk_str).as_str(), - ); - error_message = Some(format!("Error while receiving chunk: {:?} with chunk: {:?}", _e, trimmed_chunk_str)); - } } } Err(e) => { @@ -326,6 +337,26 @@ async fn handle_streaming_response( Ok(LLMInferenceResponse::new(response_text, json!({}), function_calls, None)) } +/// Helper function to extract the next complete message from a buffer +fn extract_next_complete_message(buffer: &mut String) -> Option { + // Look for "data: " prefix and newline + if let Some(start) = buffer.find("data: ") { + if let Some(end) = buffer[start..].find('\n') { + let message = buffer[start + 6..start + end].to_string(); + buffer.drain(..=start + end); + Some(message) + } else { + None // Incomplete message + } + } else { + // No "data: " prefix found, clear any leading incomplete data + if let Some(newline) = buffer.find('\n') { + buffer.drain(..=newline); + } + None + } +} + async fn handle_non_streaming_response( client: &Client, url: String,