Skip to content

Commit

Permalink
Merge pull request #794 from dcSpark/feature/config-tool
Browse files Browse the repository at this point in the history
Feature/config tool
  • Loading branch information
nicarq authored Jan 15, 2025
2 parents e9c3994 + b2998b9 commit 96942ea
Show file tree
Hide file tree
Showing 9 changed files with 540 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,6 @@ impl GenericInferenceChain {
let shinkai_tool = shinkai_tool.unwrap();

// Note: here we can add logic to handle the case that we have network tools

// TODO: if shinkai_tool is None we need to retry with the LLM (hallucination)
let function_response = match tool_router
.as_ref()
Expand All @@ -451,13 +450,15 @@ impl GenericInferenceChain {
Ok(response) => response,
Err(e) => {
match &e {
LLMProviderError::ToolRouterError(ref error_msg) if error_msg.contains("Invalid function arguments") => {
LLMProviderError::ToolRouterError(ref error_msg)
if error_msg.contains("Invalid function arguments") =>
{
// For invalid arguments, we'll retry with the LLM by including the error message
// in the next prompt to help it fix the parameters
let mut function_call_with_error = function_call.clone();
function_call_with_error.response = Some(error_msg.clone());
tool_calls_history.push(function_call_with_error);

// Update prompt with error information for retry
filled_prompt = JobPromptGenerator::generic_inference_prompt(
custom_system_prompt.clone(),
Expand All @@ -475,18 +476,20 @@ impl GenericInferenceChain {
full_job.job_id.clone(),
node_env.clone(),
);

// Set flag to retry and break out of the function calls loop
iteration_count += 1;
should_retry = true;
break;
},
LLMProviderError::ToolRouterError(ref error_msg) if error_msg.contains("MissingConfigError") => {
}
LLMProviderError::ToolRouterError(ref error_msg)
if error_msg.contains("MissingConfigError") =>
{
// For missing config, we'll pass through the error directly
// This will show up in the UI prompting the user to update their config
eprintln!("Missing config error: {:?}", error_msg);
return Err(e);
},
}
_ => {
eprintln!("Error calling function: {:?}", e);
return Err(e);
Expand Down
15 changes: 10 additions & 5 deletions shinkai-bin/shinkai-node/src/managers/tool_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ impl ToolRouter {
function_call,
});
}
ShinkaiTool::Rust(_rust_tool, _is_enabled) => {
ShinkaiTool::Rust(rust_tool, _is_enabled) => {
// Get app_id from Cron UI if present, otherwise use job_id
let app_id = match context.full_job().associated_ui().as_ref() {
Some(AssociatedUI::Cron(cron_id)) => cron_id.clone(),
Expand Down Expand Up @@ -582,6 +582,13 @@ impl ToolRouter {
));
}

check_tool(
shinkai_tool.tool_router_key().to_string_without_version().clone(),
vec![],
function_args.clone(),
rust_tool.input_args.clone(),
)?;

let result = execute_custom_tool(
&shinkai_tool.tool_router_key().to_string_without_version().clone(),
function_args,
Expand Down Expand Up @@ -1072,10 +1079,8 @@ impl ToolRouter {

// Check if the top vector search result has a score under 0.2
if let Some((tool, score)) = vector_tools.first() {
if *score < 0.2 {
if seen_ids.insert(tool.tool_router_key.clone()) {
combined_tools.push(tool.clone());
}
if seen_ids.insert(tool.tool_router_key.clone()) {
combined_tools.push(tool.clone());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ use crate::tools::tool_implementation;
// TODO keep in sync with execution_custom.rs
pub fn get_rust_tools() -> Vec<ShinkaiToolHeader> {
let mut custom_tools = Vec::new();
custom_tools.push(tool_implementation::native_tools::llm_prompt_processor::LmPromptProcessorTool::new().tool);
custom_tools.push(tool_implementation::native_tools::llm_prompt_processor::LlmPromptProcessorTool::new().tool);
custom_tools.push(tool_implementation::native_tools::sql_processor::SQLProcessorTool::new().tool);
custom_tools.push(tool_implementation::native_tools::tool_knowledge::KnowledgeTool::new().tool);
custom_tools.push(tool_implementation::native_tools::config_setup::ConfigSetupTool::new().tool);
custom_tools
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,25 @@ pub async fn execute_custom_tool(
)
.await
}
s if s == "local:::rust_toolkit:::shinkai_tool_config_updater" => {
tool_implementation::native_tools::config_setup::ConfigSetupTool::execute(
bearer,
tool_id,
app_id,
db,
node_name,
identity_manager,
job_manager,
encryption_secret_key,
encryption_public_key,
signing_secret_key,
&parameters,
llm_provider,
)
.await
}
s if s == "local:::rust_toolkit:::shinkai_llm_prompt_processor" => {
tool_implementation::native_tools::llm_prompt_processor::LmPromptProcessorTool::execute(
tool_implementation::native_tools::llm_prompt_processor::LlmPromptProcessorTool::execute(
bearer,
tool_id,
app_id,
Expand Down
Loading

0 comments on commit 96942ea

Please sign in to comment.