Skip to content

Commit

Permalink
Merge pull request #726 from dcSpark/feature/mounts
Browse files Browse the repository at this point in the history
Mount files from local HD
  • Loading branch information
acedward authored Dec 24, 2024
2 parents d57db25 + d10439a commit 1159105
Show file tree
Hide file tree
Showing 20 changed files with 936 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::managers::sheet_manager::SheetManager;
use crate::managers::tool_router::{ToolCallFunctionResponse, ToolRouter};
use crate::network::agent_payments_manager::external_agent_offerings_manager::ExtAgentOfferingsManager;
use crate::network::agent_payments_manager::my_agent_offerings_manager::MyAgentOfferingsManager;
use crate::utils::environment::{fetch_node_environment, NodeEnvironment};
use async_trait::async_trait;
use shinkai_message_primitives::schemas::inbox_name::InboxName;
use shinkai_message_primitives::schemas::job::{Job, JobLike};
Expand Down Expand Up @@ -83,6 +84,7 @@ impl InferenceChain for GenericInferenceChain {
self.context.ext_agent_payments_manager.clone(),
// self.context.sqlite_logger.clone(),
self.context.llm_stopper.clone(),
fetch_node_environment(),
)
.await?;
Ok(response)
Expand Down Expand Up @@ -122,6 +124,7 @@ impl GenericInferenceChain {
ext_agent_payments_manager: Option<Arc<Mutex<ExtAgentOfferingsManager>>>,
// sqlite_logger: Option<Arc<SqliteLogger>>,
llm_stopper: Arc<LLMStopper>,
node_env: NodeEnvironment,
) -> Result<InferenceChainResult, LLMProviderError> {
shinkai_log(
ShinkaiLogOption::JobExecution,
Expand Down Expand Up @@ -327,6 +330,8 @@ impl GenericInferenceChain {
Some(full_job.step_history.clone()),
tools.clone(),
None,
full_job.job_id.clone(),
node_env.clone(),
);

let mut iteration_count = 0;
Expand Down Expand Up @@ -394,7 +399,10 @@ impl GenericInferenceChain {

// 6) Call workflow or tooling
// Find the ShinkaiTool that has a tool with the function name
let shinkai_tool = tools.iter().find(|tool| tool.name() == function_call.name || tool.tool_router_key() == function_call.tool_router_key.clone().unwrap_or_default());
let shinkai_tool = tools.iter().find(|tool| {
tool.name() == function_call.name
|| tool.tool_router_key() == function_call.tool_router_key.clone().unwrap_or_default()
});
if shinkai_tool.is_none() {
eprintln!("Function not found: {}", function_call.name);
return Err(LLMProviderError::FunctionNotFound(function_call.name.clone()));
Expand Down Expand Up @@ -443,6 +451,8 @@ impl GenericInferenceChain {
Some(full_job.step_history.clone()),
tools.clone(),
Some(function_response),
full_job.job_id.clone(),
node_env.clone(),
);
} else {
// No more function calls required, return the final response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ use std::collections::HashMap;

use crate::llm_provider::execution::prompts::general_prompts::JobPromptGenerator;
use crate::managers::tool_router::ToolCallFunctionResponse;
use crate::network::v2_api::api_v2_commands_app_files::get_app_folder_path;
use crate::network::Node;
use crate::utils::environment::NodeEnvironment;
use serde_json::json;
use shinkai_message_primitives::schemas::job::JobStepResult;
use shinkai_message_primitives::schemas::prompts::Prompt;
Expand All @@ -23,6 +26,8 @@ impl JobPromptGenerator {
job_step_history: Option<Vec<JobStepResult>>,
tools: Vec<ShinkaiTool>,
function_call: Option<ToolCallFunctionResponse>,
job_id: String,
node_env: NodeEnvironment,
) -> Prompt {
let mut prompt = Prompt::new();

Expand Down Expand Up @@ -52,6 +57,16 @@ impl JobPromptGenerator {
priority = priority.saturating_sub(1);
}
}

let folder = get_app_folder_path(node_env, job_id);
let current_files = Node::v2_api_list_app_files_internal(folder.clone(), true);
if let Ok(current_files) = current_files {
prompt.add_content(
format!("Current files: {}", current_files.join(", ")),
SubPromptType::ExtraContext,
97,
);
}
}

// Parses the retrieved nodes as individual sub-prompts, to support priority pruning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::managers::sheet_manager::SheetManager;
use crate::managers::tool_router::{ToolCallFunctionResponse, ToolRouter};
use crate::network::agent_payments_manager::external_agent_offerings_manager::ExtAgentOfferingsManager;
use crate::network::agent_payments_manager::my_agent_offerings_manager::MyAgentOfferingsManager;
use crate::utils::environment::{fetch_node_environment, NodeEnvironment};
use async_trait::async_trait;
use shinkai_message_primitives::schemas::inbox_name::InboxName;
use shinkai_message_primitives::schemas::job::{Job, JobLike};
Expand Down Expand Up @@ -79,6 +80,7 @@ impl InferenceChain for SheetUIInferenceChain {
self.context.ext_agent_payments_manager.clone(),
// self.context.sqlite_logger.clone(),
self.context.llm_stopper.clone(),
fetch_node_environment(),
)
.await?;
let job_execution_context = self.context.execution_context.clone();
Expand Down Expand Up @@ -123,6 +125,7 @@ impl SheetUIInferenceChain {
ext_agent_payments_manager: Option<Arc<Mutex<ExtAgentOfferingsManager>>>,
// sqlite_logger: Option<Arc<SqliteLogger>>,
llm_stopper: Arc<LLMStopper>,
node_env: NodeEnvironment,
) -> Result<String, LLMProviderError> {
shinkai_log(
ShinkaiLogOption::JobExecution,
Expand Down Expand Up @@ -283,6 +286,8 @@ impl SheetUIInferenceChain {
Some(full_job.step_history.clone()),
tools.clone(),
None,
full_job.job_id.clone(),
node_env.clone(),
);

let mut iteration_count = 0;
Expand Down Expand Up @@ -419,6 +424,8 @@ impl SheetUIInferenceChain {
Some(full_job.step_history.clone()),
tools.clone(),
Some(function_response),
full_job.job_id.clone(),
node_env.clone(),
);
} else {
// No more function calls required, return the final response
Expand Down
21 changes: 21 additions & 0 deletions shinkai-bin/shinkai-node/src/managers/tool_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::time::Instant;

use crate::llm_provider::error::LLMProviderError;
use crate::llm_provider::execution::chains::inference_chain_trait::{FunctionCall, InferenceChainContextTrait};
use crate::network::v2_api::api_v2_commands_app_files::get_app_folder_path;
use crate::network::Node;
use crate::tools::tool_definitions::definition_generation::{generate_tool_definitions, get_rust_tools};
use crate::tools::tool_execution::execution_header_generator::generate_execution_environment;
Expand Down Expand Up @@ -621,6 +622,14 @@ async def run(c: CONFIG, p: INPUTS) -> OUTPUT:
.await
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;

let folder = get_app_folder_path(node_env.clone(), context.full_job().job_id().to_string());
let mounts = Node::v2_api_list_app_files_internal(folder.clone(), true);
if let Err(e) = mounts {
eprintln!("Failed to list app files: {:?}", e);
return Err(LLMProviderError::FunctionExecutionError(format!("{:?}", e)));
}
let mounts = Some(mounts.unwrap_or_default());

let result = python_tool
.run(
envs,
Expand All @@ -635,6 +644,7 @@ async def run(c: CONFIG, p: INPUTS) -> OUTPUT:
node_name,
false,
None,
mounts,
)
.map_err(|e| LLMProviderError::FunctionExecutionError(e.to_string()))?;
let result_str = serde_json::to_string(&result)
Expand Down Expand Up @@ -678,6 +688,7 @@ async def run(c: CONFIG, p: INPUTS) -> OUTPUT:
generate_tool_definitions(tools, CodeLanguage::Typescript, self.sqlite_manager.clone(), false)
.await
.map_err(|_| ToolError::ExecutionError("Failed to generate tool definitions".to_string()))?;

let envs = generate_execution_environment(
context.db(),
context.agent().clone().get_id().to_string(),
Expand All @@ -690,6 +701,14 @@ async def run(c: CONFIG, p: INPUTS) -> OUTPUT:
.await
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;

let folder = get_app_folder_path(node_env.clone(), context.full_job().job_id().to_string());
let mounts = Node::v2_api_list_app_files_internal(folder.clone(), true);
if let Err(e) = mounts {
eprintln!("Failed to list app files: {:?}", e);
return Err(LLMProviderError::FunctionExecutionError(format!("{:?}", e)));
}
let mounts = Some(mounts.unwrap_or_default());

let result = deno_tool
.run(
envs,
Expand All @@ -704,6 +723,7 @@ async def run(c: CONFIG, p: INPUTS) -> OUTPUT:
node_name,
false,
Some(tool_id),
mounts,
)
.map_err(|e| LLMProviderError::FunctionExecutionError(e.to_string()))?;
let result_str = serde_json::to_string(&result)
Expand Down Expand Up @@ -1052,6 +1072,7 @@ async def run(c: CONFIG, p: INPUTS) -> OUTPUT:
requester_node_name,
true,
Some(tool_id),
None,
)
.map_err(|e| LLMProviderError::FunctionExecutionError(e.to_string()))?;
let result_str =
Expand Down
98 changes: 91 additions & 7 deletions shinkai-bin/shinkai-node/src/network/handle_commands_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2416,6 +2416,7 @@ impl Node {
app_id,
llm_provider,
extra_config,
mounts,
res,
} => {
let db_clone = Arc::clone(&self.db);
Expand Down Expand Up @@ -2444,6 +2445,7 @@ impl Node {
encryption_secret_key,
encryption_public_key,
signing_secret_key,
mounts,
res,
)
.await;
Expand All @@ -2460,6 +2462,7 @@ impl Node {
tool_id,
app_id,
llm_provider,
mounts,
res,
} => {
let db_clone = Arc::clone(&self.db);
Expand All @@ -2478,6 +2481,7 @@ impl Node {
app_id,
llm_provider,
node_name,
mounts,
res,
)
.await;
Expand Down Expand Up @@ -2603,11 +2607,7 @@ impl Node {
let _ = Node::v2_api_import_tool(db_clone, bearer, node_env, url, res).await;
});
}
NodeCommand::V2ApiRemoveTool {
bearer,
tool_key,
res,
} => {
NodeCommand::V2ApiRemoveTool { bearer, tool_key, res } => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_remove_tool(db_clone, bearer, tool_key, res).await;
Expand Down Expand Up @@ -2680,7 +2680,9 @@ impl Node {
let db_clone = Arc::clone(&self.db);
let cron_manager_clone = self.cron_manager.clone().unwrap();
tokio::spawn(async move {
let _ = Node::v2_api_force_execute_cron_task(db_clone, cron_manager_clone, bearer, cron_task_id, res).await;
let _ =
Node::v2_api_force_execute_cron_task(db_clone, cron_manager_clone, bearer, cron_task_id, res)
.await;
});
}
NodeCommand::V2ApiGetCronSchedule { bearer, res } => {
Expand Down Expand Up @@ -2770,7 +2772,12 @@ impl Node {
let _ = Node::v2_export_messages_from_inbox(db_clone, bearer, inbox_name, format, res).await;
});
}
NodeCommand::V2ApiSearchShinkaiTool { bearer, query, agent_or_llm, res } => {
NodeCommand::V2ApiSearchShinkaiTool {
bearer,
query,
agent_or_llm,
res,
} => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_search_shinkai_tool(db_clone, bearer, query, agent_or_llm, res).await;
Expand Down Expand Up @@ -2873,6 +2880,83 @@ impl Node {
.await;
});
}

NodeCommand::V2ApiUploadAppFile {
bearer,
tool_id,
app_id,
file_name,
file_data,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_env = fetch_node_environment();
tokio::spawn(async move {
let _ = Node::v2_api_upload_app_file(
db_clone, bearer, tool_id, app_id, file_name, file_data, node_env, res,
)
.await;
});
}
NodeCommand::V2ApiGetAppFile {
bearer,
tool_id,
app_id,
file_name,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_env = fetch_node_environment();
tokio::spawn(async move {
let _ =
Node::v2_api_get_app_file(db_clone, bearer, tool_id, app_id, file_name, node_env, res).await;
});
}
NodeCommand::V2ApiUpdateAppFile {
bearer,
tool_id,
app_id,
file_name,
new_name,
file_data,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_env = fetch_node_environment();
tokio::spawn(async move {
let _ = Node::v2_api_update_app_file(
db_clone, bearer, tool_id, app_id, file_name, new_name, file_data, node_env, res,
)
.await;
});
}
NodeCommand::V2ApiListAppFiles {
bearer,
tool_id,
app_id,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_env = fetch_node_environment();
tokio::spawn(async move {
let _ = Node::v2_api_list_app_files(db_clone, bearer, tool_id, app_id, node_env, res).await;
});
}
NodeCommand::V2ApiDeleteAppFile {
bearer,
tool_id,
app_id,
file_name,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_env = fetch_node_environment();
tokio::spawn(async move {
let _ =
Node::v2_api_delete_app_file(db_clone, bearer, tool_id, app_id, file_name, node_env, res).await;
});
}

_ => (),
}
}
Expand Down
Loading

0 comments on commit 1159105

Please sign in to comment.