Skip to content

Commit

Permalink
improve handling of stream openai
Browse files Browse the repository at this point in the history
  • Loading branch information
nicarq committed Jan 8, 2025
1 parent 46ce2d2 commit 492cb2a
Showing 1 changed file with 125 additions and 94 deletions.
219 changes: 125 additions & 94 deletions shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ async fn handle_streaming_response(

let mut stream = res.bytes_stream();
let mut response_text = String::new();
let mut previous_json_chunk: String = String::new();
let mut buffer = String::new();
let mut function_calls: Vec<FunctionCall> = Vec::new();
let mut error_message: Option<String> = None;

Expand All @@ -200,110 +200,121 @@ async fn handle_streaming_response(
match item {
Ok(chunk) => {
let chunk_str = String::from_utf8_lossy(&chunk).to_string();
previous_json_chunk += chunk_str.as_str();
let trimmed_chunk_str = previous_json_chunk.trim().to_string();
let data_resp: Result<JsonValue, _> = serde_json::from_str(&trimmed_chunk_str);
match data_resp {
Ok(data) => {
previous_json_chunk = "".to_string();
buffer.push_str(&chunk_str);

// Check for error in the data
if let Some(error) = data.get("error") {
let code = error.get("code").and_then(|c| c.as_str());
let message = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error");
let formatted_error = format!("{}: {}", code.unwrap_or("Unknown code"), message);
// Process complete messages in the buffer
while let Some(message) = extract_next_complete_message(&mut buffer) {
if message.trim().is_empty() {
continue;
}

return Err(LLMProviderError::LLMServiceUnexpectedError(formatted_error));
}
// Handle [DONE] message
if message.trim() == "[DONE]" {
continue;
}

if let Some(choices) = data.get("choices") {
for choice in choices.as_array().unwrap_or(&vec![]) {
if let Some(message) = choice.get("message") {
if let Some(content) = message.get("content") {
response_text.push_str(content.as_str().unwrap_or(""));
}
if let Some(fc) = message.get("function_call") {
if let Some(name) = fc.get("name") {
let fc_arguments = fc
.get("arguments")
.and_then(|args| args.as_str())
.and_then(|args_str| serde_json::from_str(args_str).ok())
.and_then(|args_value: serde_json::Value| {
args_value.as_object().cloned()
})
.unwrap_or_else(|| serde_json::Map::new());

// Extract tool_router_key
let tool_router_key = tools.as_ref().and_then(|tools_array| {
tools_array.iter().find_map(|tool| {
if tool.get("name")?.as_str()? == name.as_str().unwrap_or("") {
tool.get("tool_router_key")
.and_then(|key| key.as_str().map(|s| s.to_string()))
} else {
None
match serde_json::from_str::<JsonValue>(&message) {
Ok(data) => {
// Check for error in the data
if let Some(error) = data.get("error") {
let code = error.get("code").and_then(|c| c.as_str());
let message = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error");
let formatted_error = format!("{}: {}", code.unwrap_or("Unknown code"), message);
error_message = Some(formatted_error);
continue;
}

if let Some(choices) = data.get("choices") {
for choice in choices.as_array().unwrap_or(&vec![]) {
if let Some(delta) = choice.get("delta") {
// Handle content updates
if let Some(content) = delta.get("content") {
if let Some(content_str) = content.as_str() {
response_text.push_str(content_str);
}
}

// Handle function calls
if let Some(fc) = delta.get("function_call") {
if let Some(name) = fc.get("name") {
let fc_arguments = fc
.get("arguments")
.and_then(|args| args.as_str())
.and_then(|args_str| serde_json::from_str(args_str).ok())
.and_then(|args_value: serde_json::Value| {
args_value.as_object().cloned()
})
.unwrap_or_else(|| serde_json::Map::new());

let tool_router_key = tools.as_ref().and_then(|tools_array| {
tools_array.iter().find_map(|tool| {
if tool.get("name")?.as_str()? == name.as_str().unwrap_or("") {
tool.get("tool_router_key")
.and_then(|key| key.as_str().map(|s| s.to_string()))
} else {
None
}
})
});

function_calls.push(FunctionCall {
name: name.as_str().unwrap_or("").to_string(),
arguments: fc_arguments.clone(),
tool_router_key,
response: None,
});

// Handle WebSocket updates for function calls
if let Some(ref manager) = ws_manager_trait {
if let Some(ref inbox_name) = inbox_name {
if let Some(last_function_call) = function_calls.last() {
let m = manager.lock().await;
let inbox_name_string = inbox_name.to_string();

let function_call_json = serde_json::to_value(last_function_call)
.unwrap_or_else(|_| serde_json::json!({}));

let tool_metadata = ToolMetadata {
tool_name: last_function_call.name.clone(),
tool_router_key: last_function_call.tool_router_key.clone(),
args: function_call_json.as_object().cloned().unwrap_or_default(),
result: None,
status: ToolStatus {
type_: ToolStatusType::Running,
reason: None,
},
};

let ws_message_type = WSMessageType::Widget(WidgetMetadata::ToolRequest(tool_metadata));

let _ = m
.queue_message(
WSTopic::Inbox,
inbox_name_string,
serde_json::to_string(last_function_call)
.unwrap_or_else(|_| "{}".to_string()),
ws_message_type,
true,
)
.await;
}
}
})
});

function_calls.push(FunctionCall {
name: name.as_str().unwrap_or("").to_string(),
arguments: fc_arguments.clone(),
tool_router_key,
response: None,
});
}
}
}
}
}
}
}

// Updated WS message handling for tooling
if let Some(ref manager) = ws_manager_trait {
if let Some(ref inbox_name) = inbox_name {
if let Some(last_function_call) = function_calls.last() {
let m = manager.lock().await;
let inbox_name_string = inbox_name.to_string();

// Serialize FunctionCall to JSON value
let function_call_json =
serde_json::to_value(last_function_call).unwrap_or_else(|_| serde_json::json!({}));

// Prepare ToolMetadata
let tool_metadata = ToolMetadata {
tool_name: last_function_call.name.clone(),
tool_router_key: last_function_call.tool_router_key.clone(),
args: function_call_json.as_object().cloned().unwrap_or_default(),
result: None,
status: ToolStatus {
type_: ToolStatusType::Running,
reason: None,
},
};

let ws_message_type =
WSMessageType::Widget(WidgetMetadata::ToolRequest(tool_metadata));

let _ = m
.queue_message(
WSTopic::Inbox,
inbox_name_string,
serde_json::to_string(last_function_call).unwrap_or_else(|_| "{}".to_string()),
ws_message_type,
true,
)
.await;
}
}
Err(e) => {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Debug,
format!("Failed to parse message chunk (this may be normal for partial chunks): {:?}", e).as_str(),
);
// Don't set error_message here as this might just be a partial chunk
}
}
Err(_e) => {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
format!("Error while receiving chunk: {:?} with chunk: {:?}", _e, trimmed_chunk_str).as_str(),
);
error_message = Some(format!("Error while receiving chunk: {:?} with chunk: {:?}", _e, trimmed_chunk_str));
}
}
}
Err(e) => {
Expand All @@ -326,6 +337,26 @@ async fn handle_streaming_response(
Ok(LLMInferenceResponse::new(response_text, json!({}), function_calls, None))
}

/// Helper function to extract the next complete message from a buffer
fn extract_next_complete_message(buffer: &mut String) -> Option<String> {
// Look for "data: " prefix and newline
if let Some(start) = buffer.find("data: ") {
if let Some(end) = buffer[start..].find('\n') {
let message = buffer[start + 6..start + end].to_string();
buffer.drain(..=start + end);
Some(message)
} else {
None // Incomplete message
}
} else {
// No "data: " prefix found, clear any leading incomplete data
if let Some(newline) = buffer.find('\n') {
buffer.drain(..=newline);
}
None
}
}

async fn handle_non_streaming_response(
client: &Client,
url: String,
Expand Down

0 comments on commit 492cb2a

Please sign in to comment.