From beb273c52cf127673d172ceebbb28d4a1555d58a Mon Sep 17 00:00:00 2001 From: Nico <1622112+nicarq@users.noreply.github.com> Date: Thu, 9 Jan 2025 06:22:20 -0600 Subject: [PATCH 01/18] fix (#779) --- .../shinkai-node/src/managers/tool_router.rs | 6 +- .../network/v2_api/api_v2_commands_tools.rs | 4 +- .../tool_execution/execution_coordinator.rs | 4 +- .../tool_execution/execution_deno_dynamic.rs | 4 +- .../execution_python_dynamic.rs | 2 +- .../src/schemas/llm_providers/agent.rs | 14 +--- .../src/schemas/tool_router_key.rs | 78 ++++++++----------- .../shinkai-sqlite/src/agent_manager.rs | 50 +++++++++++- .../src/shinkai_tool_manager.rs | 40 +++++----- .../shinkai-sqlite/src/tool_playground.rs | 6 +- .../src/tools/deno_tools.rs | 2 +- .../src/tools/js_toolkit.rs | 2 +- .../src/tools/python_tools.rs | 2 +- .../src/tools/shinkai_tool.rs | 4 +- 14 files changed, 118 insertions(+), 100 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/managers/tool_router.rs b/shinkai-bin/shinkai-node/src/managers/tool_router.rs index 30c9d45c3..ecb643e35 100644 --- a/shinkai-bin/shinkai-node/src/managers/tool_router.rs +++ b/shinkai-bin/shinkai-node/src/managers/tool_router.rs @@ -498,7 +498,7 @@ impl ToolRouter { .ok_or_else(|| ToolError::ExecutionError("Node storage path is not set".to_string()))?; let app_id = context.full_job().job_id().to_string(); let tool_id = shinkai_tool.tool_router_key().to_string_without_version().clone(); - let tools = python_tool.tools.clone().unwrap_or_default(); + let tools = python_tool.tools.clone(); let support_files = generate_tool_definitions(tools, CodeLanguage::Typescript, self.sqlite_manager.clone(), false) .await @@ -612,7 +612,7 @@ impl ToolRouter { .ok_or_else(|| ToolError::ExecutionError("Node storage path is not set".to_string()))?; let app_id = context.full_job().job_id().to_string(); let tool_id = shinkai_tool.tool_router_key().to_string_without_version().clone(); - let tools = deno_tool.tools.clone().unwrap_or_default(); + let tools = deno_tool.tools.clone(); let support_files = generate_tool_definitions(tools, CodeLanguage::Typescript, self.sqlite_manager.clone(), false) .await @@ -965,7 +965,7 @@ impl ToolRouter { .node_storage_path .clone() .ok_or_else(|| ToolError::ExecutionError("Node storage path is not set".to_string()))?; - let tools = js_tool.clone().tools.unwrap_or_default(); + let tools = js_tool.clone().tools.clone(); let app_id = format!("external_{}", uuid::Uuid::new_v4()); let tool_id = shinkai_tool.tool_router_key().clone().to_string_without_version(); let support_files = diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs index 63dffba0d..590167d2d 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs @@ -453,7 +453,7 @@ impl Node { author: payload.metadata.author.clone(), version: payload.metadata.version.clone(), js_code: payload.code.clone(), - tools: payload.metadata.tools.clone(), + tools: payload.metadata.tools.clone().unwrap_or_default(), config: payload.metadata.configurations.clone(), oauth: payload.metadata.oauth.clone(), description: payload.metadata.description.clone(), @@ -477,7 +477,7 @@ impl Node { version: payload.metadata.version.clone(), author: payload.metadata.author.clone(), py_code: payload.code.clone(), - tools: payload.metadata.tools.clone(), + tools: payload.metadata.tools.clone().unwrap_or_default(), config: payload.metadata.configurations.clone(), oauth: payload.metadata.oauth.clone(), description: payload.metadata.description.clone(), diff --git a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_coordinator.rs b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_coordinator.rs index 67e261e75..e1b2dd98f 100644 --- a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_coordinator.rs +++ b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_coordinator.rs @@ -179,7 +179,7 @@ pub async fn execute_tool_cmd( .clone() .ok_or_else(|| ToolError::ExecutionError("Node storage path is not set".to_string()))?; let support_files = generate_tool_definitions( - python_tool.tools.clone().unwrap_or_default(), + python_tool.tools.clone(), CodeLanguage::Python, db, false, @@ -224,7 +224,7 @@ pub async fn execute_tool_cmd( .clone() .ok_or_else(|| ToolError::ExecutionError("Node storage path is not set".to_string()))?; let support_files = generate_tool_definitions( - deno_tool.tools.clone().unwrap_or_default(), + deno_tool.tools.clone(), CodeLanguage::Typescript, db, false, diff --git a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_deno_dynamic.rs b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_deno_dynamic.rs index 69a6f1b29..f9f1ca0d2 100644 --- a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_deno_dynamic.rs +++ b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_deno_dynamic.rs @@ -35,7 +35,7 @@ pub async fn execute_deno_tool( author: "system".to_string(), version: "1.0.0".to_string(), js_code: code, - tools: None, + tools: vec![], config: extra_config.clone(), oauth: oauth.clone(), description: "Deno runtime execution".to_string(), @@ -125,7 +125,7 @@ pub fn check_deno_tool( author: "system".to_string(), version: "1.0".to_string(), js_code: code, - tools: None, + tools: vec![], config: vec![], oauth: None, description: "Deno runtime execution".to_string(), diff --git a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_python_dynamic.rs b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_python_dynamic.rs index bce489071..4659093c2 100644 --- a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_python_dynamic.rs +++ b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_python_dynamic.rs @@ -36,7 +36,7 @@ pub async fn execute_python_tool( version: "1.0.0".to_string(), author: "system".to_string(), py_code: code, - tools: None, + tools: vec![], config: extra_config.clone(), description: "Python runtime execution".to_string(), keywords: vec![], diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/agent.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/agent.rs index 87fc84a52..f4c6a88bd 100644 --- a/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/agent.rs +++ b/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/agent.rs @@ -13,7 +13,8 @@ pub struct Agent { pub ui_description: String, pub knowledge: Vec, pub storage_path: String, - #[serde(deserialize_with = "deserialize_tools")] + #[serde(serialize_with = "ToolRouterKey::serialize_tool_router_keys", + deserialize_with = "ToolRouterKey::deserialize_tool_router_keys")] pub tools: Vec, pub debug_mode: bool, pub config: Option, @@ -21,17 +22,6 @@ pub struct Agent { pub scope: MinimalJobScope, } -fn deserialize_tools<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - let tool_strings: Vec = Vec::deserialize(deserializer)?; - tool_strings - .into_iter() - .map(|s| ToolRouterKey::from_string(&s).map_err(serde::de::Error::custom)) - .collect() -} - #[cfg(test)] mod tests { use super::*; diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/tool_router_key.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/tool_router_key.rs index 0593f9ef2..1423795b5 100644 --- a/shinkai-libs/shinkai-message-primitives/src/schemas/tool_router_key.rs +++ b/shinkai-libs/shinkai-message-primitives/src/schemas/tool_router_key.rs @@ -30,49 +30,29 @@ impl ToolRouterKey { } } - pub fn deserialize_tool_router_keys<'de, D>(deserializer: D) -> Result>, D::Error> + pub fn deserialize_tool_router_keys<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, { - let string_vec: Option> = Option::deserialize(deserializer)?; - - match string_vec { - Some(vec) => { - let router_keys = vec - .into_iter() - .map(|s| Self::from_string(&s)) - .collect::, _>>() - .map_err(serde::de::Error::custom)?; - Ok(Some(router_keys)) - } - None => Ok(None), - } + let string_vec: Vec = Vec::deserialize(deserializer)?; + string_vec + .into_iter() + .map(|s| Self::from_string(&s).map_err(serde::de::Error::custom)) + .collect() } pub fn serialize_tool_router_keys( - keys: &Option>, + tools: &Vec, serializer: S ) -> Result where S: serde::Serializer, { - match keys { - Some(keys) => { - let strings: Vec = keys - .iter() - .map(|k| { - // If version is Some, use to_string_with_version() - if k.version.is_some() { - k.to_string_with_version() - } else { - k.to_string_without_version() - } - }) - .collect(); - strings.serialize(serializer) - } - None => serializer.serialize_none(), - } + let strings: Vec = tools + .iter() + .map(|k| k.to_string_with_version()) + .collect(); + strings.serialize(serializer) } fn sanitize(input: &str) -> String { @@ -91,11 +71,15 @@ impl ToolRouterKey { } pub fn to_string_with_version(&self) -> String { + if self.version.is_none() { + return self.to_string_without_version(); + } + let sanitized_source = Self::sanitize(&self.source); let sanitized_toolkit_name = Self::sanitize(&self.toolkit_name); let sanitized_name = Self::sanitize(&self.name); - let version_str = self.version.clone().unwrap_or_else(|| "none".to_string()); + let version_str = self.version.clone().unwrap(); let key = format!( "{}:::{}:::{}:::{}", @@ -174,20 +158,6 @@ mod tests { ); } - #[test] - fn test_tool_router_key_to_string_with_version_none() { - let key = ToolRouterKey::new( - "local".to_string(), - "rust_toolkit".to_string(), - "concat_strings".to_string(), - None, - ); - assert_eq!( - key.to_string_with_version(), - "local:::rust_toolkit:::concat_strings:::none" - ); - } - #[test] fn test_tool_router_key_from_string_without_version() { let key_str = "local:::rust_toolkit:::concat_strings"; @@ -257,4 +227,18 @@ mod tests { assert!(!key_string.contains(' '), "Key string should not contain spaces"); assert_eq!(key_string, "local:::deno_toolkit:::versioned_tool"); } + + #[test] + fn test_tool_router_key_to_string_with_version_returns_without_version_when_none() { + let key = ToolRouterKey::new( + "local".to_string(), + "rust_toolkit".to_string(), + "concat_strings".to_string(), + None, + ); + assert_eq!( + key.to_string_with_version(), + "local:::rust_toolkit:::concat_strings" + ); + } } diff --git a/shinkai-libs/shinkai-sqlite/src/agent_manager.rs b/shinkai-libs/shinkai-sqlite/src/agent_manager.rs index 5364eedc5..f41f53f82 100644 --- a/shinkai-libs/shinkai-sqlite/src/agent_manager.rs +++ b/shinkai-libs/shinkai-sqlite/src/agent_manager.rs @@ -1,5 +1,5 @@ use rusqlite::params; -use shinkai_message_primitives::schemas::{llm_providers::agent::Agent, shinkai_name::ShinkaiName}; +use shinkai_message_primitives::schemas::{llm_providers::agent::Agent, shinkai_name::ShinkaiName, tool_router_key::ToolRouterKey}; use crate::{SqliteManager, SqliteManagerError}; @@ -30,7 +30,10 @@ impl SqliteManager { let knowledge = serde_json::to_string(&agent.knowledge).unwrap(); let config = agent.config.map(|c| serde_json::to_string(&c).unwrap()); - let tools = serde_json::to_string(&agent.tools).unwrap(); + let tools: Vec = agent.tools.iter() + .map(|t| t.to_string_with_version()) + .collect(); + let tools = serde_json::to_string(&tools).unwrap(); let scope = serde_json::to_string(&agent.scope).unwrap(); tx.execute( @@ -201,7 +204,10 @@ impl SqliteManager { let knowledge = serde_json::to_string(&updated_agent.knowledge).unwrap(); let config = updated_agent.config.map(|c| serde_json::to_string(&c).unwrap()); - let tools = serde_json::to_string(&updated_agent.tools).unwrap(); + let tools: Vec = updated_agent.tools.iter() + .map(|t| t.to_string_with_version()) + .collect(); + let tools = serde_json::to_string(&tools).unwrap(); let scope = serde_json::to_string(&updated_agent.scope).unwrap(); // Serialize the scope tx.execute( @@ -410,4 +416,42 @@ mod tests { assert_eq!(agent.storage_path, updated_agent.storage_path); assert_eq!(agent.debug_mode, updated_agent.debug_mode); } + + #[test] + fn test_add_and_get_agent_with_tools() { + let db = setup_test_db(); + + // Create a proper ToolRouterKey + let tool = ToolRouterKey::new( + "local".to_string(), + "rust_toolkit".to_string(), + "test_tool".to_string(), + Some("1.0".to_string()) + ); + + let agent = Agent { + agent_id: "test_agent".to_string(), + name: "Test Agent".to_string(), + full_identity_name: ShinkaiName::new("@@test_user.shinkai/main".to_string()).unwrap(), + llm_provider_id: "test_llm_provider".to_string(), + ui_description: "Test description".to_string(), + knowledge: Default::default(), + storage_path: "test_storage_path".to_string(), + tools: vec![tool.clone()], // Add the ToolRouterKey + debug_mode: false, + config: None, + scope: Default::default(), + }; + let profile = ShinkaiName::new("@@test_user.shinkai/main".to_string()).unwrap(); + + // Add the agent + db.add_agent(agent.clone(), &profile).unwrap(); + + // Retrieve the agent + let retrieved_agent = db.get_agent(&agent.agent_id).unwrap().unwrap(); + + // Verify the tools were correctly stored and retrieved + assert_eq!(retrieved_agent.tools.len(), 1); + assert_eq!(retrieved_agent.tools[0], tool); + } } diff --git a/shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs b/shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs index f935df4bd..9974fff75 100644 --- a/shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs +++ b/shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs @@ -812,7 +812,7 @@ mod tests { author: "Deno Author".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Hello, Deno!');".to_string(), - tools: None, + tools: vec![], config: vec![], oauth: None, description: "A Deno tool for testing".to_string(), @@ -884,7 +884,7 @@ mod tests { author: "Deno Author 1".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Hello, Deno 1!');".to_string(), - tools: None, + tools: vec![], config: vec![], oauth: None, description: "A Deno tool for testing 1".to_string(), @@ -906,7 +906,7 @@ mod tests { author: "Deno Author 2".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Hello, Deno 2!');".to_string(), - tools: None, + tools: vec![], config: vec![], oauth: None, description: "A Deno tool for testing 2".to_string(), @@ -928,7 +928,7 @@ mod tests { author: "Deno Author 3".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Hello, Deno 3!');".to_string(), - tools: None, + tools: vec![], config: vec![], oauth: None, description: "A Deno tool for testing 3".to_string(), @@ -987,7 +987,7 @@ mod tests { author: "Author 1".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Tool 1');".to_string(), - tools: None, + tools: vec![], config: vec![], oauth: None, description: "First Deno tool".to_string(), @@ -1009,7 +1009,7 @@ mod tests { author: "Author 2".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Tool 2');".to_string(), - tools: None, + tools: vec![], config: vec![], oauth: None, description: "Second Deno tool".to_string(), @@ -1031,7 +1031,7 @@ mod tests { author: "Author 3".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Tool 3');".to_string(), - tools: None, + tools: vec![], config: vec![], oauth: None, description: "Third Deno tool".to_string(), @@ -1120,7 +1120,7 @@ mod tests { author: "Deno Author".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Hello, Deno!');".to_string(), - tools: None, + tools: vec![], config: vec![], oauth: None, description: "A Deno tool for testing duplicates".to_string(), @@ -1166,7 +1166,7 @@ mod tests { author: "Author 1".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Tool 1');".to_string(), - tools: None, + tools: vec![], config: vec![], oauth: None, description: "Process and manipulate images".to_string(), @@ -1187,7 +1187,7 @@ mod tests { author: "Author 2".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Tool 2');".to_string(), - tools: None, + tools: vec![], config: vec![], oauth: None, description: "Analyze text content".to_string(), @@ -1208,7 +1208,7 @@ mod tests { author: "Author 3".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Tool 3');".to_string(), - tools: None, + tools: vec![], config: vec![], oauth: None, description: "Visualize data sets".to_string(), @@ -1272,7 +1272,7 @@ mod tests { version: "1.0.0".to_string(), author: "Author 1".to_string(), js_code: "console.log('Enabled');".to_string(), - tools: None, + tools: vec![], config: vec![], description: "An enabled tool for testing".to_string(), keywords: vec!["enabled".to_string(), "test".to_string()], @@ -1294,7 +1294,7 @@ mod tests { author: "Author 2".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Disabled');".to_string(), - tools: None, + tools: vec![], config: vec![], description: "A disabled tool for testing".to_string(), keywords: vec!["disabled".to_string(), "test".to_string()], @@ -1383,7 +1383,7 @@ mod tests { author: "Author 1".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Enabled Non-Network');".to_string(), - tools: None, + tools: vec![], config: vec![], description: "An enabled non-network tool".to_string(), keywords: vec!["enabled".to_string(), "non-network".to_string()], @@ -1405,7 +1405,7 @@ mod tests { author: "Author 2".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Disabled Non-Network');".to_string(), - tools: None, + tools: vec![], config: vec![], description: "A disabled non-network tool".to_string(), keywords: vec!["disabled".to_string(), "non-network".to_string()], @@ -1535,7 +1535,7 @@ mod tests { author: "Author 1".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Tool 1');".to_string(), - tools: None, + tools: vec![], config: vec![], description: "First test tool".to_string(), keywords: vec!["test".to_string(), "one".to_string()], @@ -1557,7 +1557,7 @@ mod tests { author: "Author 2".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Tool 2');".to_string(), - tools: None, + tools: vec![], config: vec![], description: "Second test tool".to_string(), keywords: vec!["test".to_string(), "two".to_string()], @@ -1579,7 +1579,7 @@ mod tests { author: "Author 3".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Tool 3');".to_string(), - tools: None, + tools: vec![], config: vec![], description: "Third test tool".to_string(), keywords: vec!["test".to_string(), "three".to_string()], @@ -1661,7 +1661,7 @@ mod tests { author: "Version Author".to_string(), version: "1.0".to_string(), js_code: "console.log('Version 1');".to_string(), - tools: None, + tools: vec![], config: vec![], description: "A tool with version 1.0".to_string(), keywords: vec!["version".to_string(), "test".to_string()], @@ -1683,7 +1683,7 @@ mod tests { author: "Version Author".to_string(), version: "2.0".to_string(), js_code: "console.log('Version 2');".to_string(), - tools: None, + tools: vec![], config: vec![], description: "A tool with version 2.0".to_string(), keywords: vec!["version".to_string(), "test".to_string()], diff --git a/shinkai-libs/shinkai-sqlite/src/tool_playground.rs b/shinkai-libs/shinkai-sqlite/src/tool_playground.rs index 4c537acfc..aa8c2047a 100644 --- a/shinkai-libs/shinkai-sqlite/src/tool_playground.rs +++ b/shinkai-libs/shinkai-sqlite/src/tool_playground.rs @@ -372,7 +372,7 @@ mod tests { author: "Deno Author".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Hello, Deno!');".to_string(), - tools: None, + tools: vec![], config: vec![], description: "A Deno tool for testing".to_string(), keywords: vec!["deno".to_string(), "test".to_string()], @@ -499,7 +499,7 @@ mod tests { author: "Deno Author".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Hello, Deno!');".to_string(), - tools: None, + tools: vec![], config: vec![], description: "A Deno tool for testing".to_string(), keywords: vec!["deno".to_string(), "test".to_string()], @@ -572,7 +572,7 @@ mod tests { author: "Deno Author".to_string(), version: "1.0.0".to_string(), js_code: "console.log('Hello, Deno!');".to_string(), - tools: None, + tools: vec![], config: vec![], description: "A Deno tool for testing".to_string(), keywords: vec!["deno".to_string(), "test".to_string()], diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs index 90b363424..af71a2298 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs @@ -32,7 +32,7 @@ pub struct DenoTool { #[serde(default)] #[serde(deserialize_with = "ToolRouterKey::deserialize_tool_router_keys")] #[serde(serialize_with = "ToolRouterKey::serialize_tool_router_keys")] - pub tools: Option>, + pub tools: Vec, pub config: Vec, pub description: String, pub keywords: Vec, diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/js_toolkit.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/js_toolkit.rs index c7caf9b1e..70332f766 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/js_toolkit.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/js_toolkit.rs @@ -59,7 +59,7 @@ impl JSToolkit { config, oauth: None, js_code: definition.code.clone().unwrap_or_default(), - tools: None, + tools: vec![], description: definition.description.clone(), keywords: definition.keywords.clone(), input_args: input_args.clone(), diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/python_tools.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/python_tools.rs index 8043acfa5..4c19c6014 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/python_tools.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/python_tools.rs @@ -31,7 +31,7 @@ pub struct PythonTool { #[serde(default)] #[serde(deserialize_with = "ToolRouterKey::deserialize_tool_router_keys")] #[serde(serialize_with = "ToolRouterKey::serialize_tool_router_keys")] - pub tools: Option>, + pub tools: Vec, pub config: Vec, pub description: String, pub keywords: Vec, diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/shinkai_tool.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/shinkai_tool.rs index 96a05de79..be8ea0b1a 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/shinkai_tool.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/shinkai_tool.rs @@ -429,7 +429,7 @@ mod tests { author: "1.0".to_string(), version: "1.0.0".to_string(), js_code: "".to_string(), - tools: None, + tools: vec![], keywords: vec![], activated: false, embedding: None, @@ -510,7 +510,7 @@ mod tests { config: vec![], author: tool_definition.author.clone(), js_code: tool_definition.code.clone().unwrap_or_default(), - tools: None, + tools: vec![], keywords: tool_definition.keywords.clone(), activated: false, embedding: None, From 99cf570c16a5b99539c87e97beb3e005dac14e0a Mon Sep 17 00:00:00 2001 From: Nico <1622112+nicarq@users.noreply.github.com> Date: Thu, 9 Jan 2025 12:49:36 -0600 Subject: [PATCH 02/18] test (#778) --- .../src/tools/deno_tools.rs | 214 +++++++++++++++++- .../src/tools/js_toolkit.rs | 1 + .../src/tools/tool_config.rs | 15 +- .../src/tools/tool_playground.rs | 4 +- 4 files changed, 231 insertions(+), 3 deletions(-) diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs index af71a2298..13749ebcb 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs @@ -5,7 +5,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use std::{env, thread}; use super::parameters::Parameters; -use super::tool_config::{OAuth, ToolConfig}; +use super::tool_config::{BasicConfig, OAuth, ToolConfig}; use super::tool_output_arg::ToolOutputArg; use super::tool_playground::{SqlQuery, SqlTable}; use crate::tools::error::ToolError; @@ -431,6 +431,7 @@ impl ToolResult { #[cfg(test)] mod tests { use super::*; + use serde_json::json; #[test] fn test_deserialize_jstool_result_with_hashmap_properties() { @@ -473,4 +474,215 @@ mod tests { panic!("address property missing"); } } + + #[test] + fn test_deserialize_deno_tool() { + let json_data = r#"{ + "author": "Shinkai", + "config": [ + { + "BasicConfig": { + "description": "", + "key_name": "name", + "key_value": null, + "required": true + } + }, + { + "BasicConfig": { + "description": "", + "key_name": "privateKey", + "key_value": null, + "required": true + } + }, + { + "BasicConfig": { + "description": "", + "key_name": "useServerSigner", + "key_value": null, + "required": false + } + } + ], + "description": "Tool for creating a Coinbase wallet", + "input_args": { + "properties": {}, + "required": [], + "type": "object" + }, + "name": "Coinbase Wallet Creator", + "output_arg": { + "json": "" + }, + "toolkit_name": "deno-toolkit", + "version": "1.0.0", + "js_code": "", + "keywords": [], + "activated": false, + "result": { + "type": "object", + "properties": {}, + "required": [] + } + }"#; + + let deserialized: DenoTool = serde_json::from_str(json_data).expect("Failed to deserialize DenoTool"); + + assert_eq!(deserialized.author, "Shinkai"); + assert_eq!(deserialized.name, "Coinbase Wallet Creator"); + assert_eq!(deserialized.toolkit_name, "deno-toolkit"); + assert_eq!(deserialized.version, "1.0.0"); + assert_eq!(deserialized.description, "Tool for creating a Coinbase wallet"); + + // Verify config entries + assert_eq!(deserialized.config.len(), 3); + if let ToolConfig::BasicConfig(config) = &deserialized.config[0] { + assert_eq!(config.key_name, "name"); + assert!(config.required); + } + if let ToolConfig::BasicConfig(config) = &deserialized.config[1] { + assert_eq!(config.key_name, "privateKey"); + assert!(config.required); + } + if let ToolConfig::BasicConfig(config) = &deserialized.config[2] { + assert_eq!(config.key_name, "useServerSigner"); + assert!(!config.required); + } + } + + #[test] + fn test_email_fetcher_tool_config() { + let tool = DenoTool { + toolkit_name: "deno-toolkit".to_string(), + name: "Email Fetcher".to_string(), + author: "Shinkai".to_string(), + version: "1.0.0".to_string(), + description: "Fetches emails from an IMAP server".to_string(), + keywords: vec!["email".to_string(), "imap".to_string()], + js_code: "".to_string(), + tools: None, + config: vec![ + ToolConfig::BasicConfig(BasicConfig { + key_name: "imap_server".to_string(), + description: "The IMAP server address".to_string(), + required: true, + type_name: Some("string".to_string()), + key_value: None, + }), + ToolConfig::BasicConfig(BasicConfig { + key_name: "username".to_string(), + description: "The username for the IMAP account".to_string(), + required: true, + type_name: Some("string".to_string()), + key_value: None, + }), + ToolConfig::BasicConfig(BasicConfig { + key_name: "password".to_string(), + description: "The password for the IMAP account".to_string(), + required: true, + type_name: Some("string".to_string()), + key_value: None, + }), + ToolConfig::BasicConfig(BasicConfig { + key_name: "port".to_string(), + description: "The port number for the IMAP server (defaults to 993 for IMAPS)".to_string(), + required: false, + type_name: Some("integer".to_string()), + key_value: None, + }), + ToolConfig::BasicConfig(BasicConfig { + key_name: "ssl".to_string(), + description: "Whether to use SSL for the IMAP connection (defaults to true)".to_string(), + required: false, + type_name: Some("boolean".to_string()), + key_value: None, + }), + ], + input_args: Parameters::new(), + output_arg: ToolOutputArg { json: "".to_string() }, + activated: false, + embedding: None, + result: ToolResult::new("object".to_string(), json!({}), vec![]), + sql_tables: None, + sql_queries: None, + file_inbox: None, + oauth: None, + assets: None, + }; + + // Test check_required_config_fields with no values set + assert!(!tool.check_required_config_fields(), "Should fail when required fields have no values"); + + // Create a tool with values set for required fields + let mut tool_with_values = tool.clone(); + tool_with_values.config = vec![ + ToolConfig::BasicConfig(BasicConfig { + key_name: "imap_server".to_string(), + description: "The IMAP server address".to_string(), + required: true, + type_name: Some("string".to_string()), + key_value: Some("imap.example.com".to_string()), + }), + ToolConfig::BasicConfig(BasicConfig { + key_name: "username".to_string(), + description: "The username for the IMAP account".to_string(), + required: true, + type_name: Some("string".to_string()), + key_value: Some("user@example.com".to_string()), + }), + ToolConfig::BasicConfig(BasicConfig { + key_name: "password".to_string(), + description: "The password for the IMAP account".to_string(), + required: true, + type_name: Some("string".to_string()), + key_value: Some("password123".to_string()), + }), + ToolConfig::BasicConfig(BasicConfig { + key_name: "port".to_string(), + description: "The port number for the IMAP server (defaults to 993 for IMAPS)".to_string(), + required: false, + type_name: Some("integer".to_string()), + key_value: None, + }), + ToolConfig::BasicConfig(BasicConfig { + key_name: "ssl".to_string(), + description: "Whether to use SSL for the IMAP connection (defaults to true)".to_string(), + required: false, + type_name: Some("boolean".to_string()), + key_value: None, + }), + ]; + + assert!(tool_with_values.check_required_config_fields(), "Should pass when required fields have values"); + + // Test serialization/deserialization + let serialized = serde_json::to_string(&tool).expect("Failed to serialize DenoTool"); + let deserialized: DenoTool = serde_json::from_str(&serialized).expect("Failed to deserialize DenoTool"); + + assert_eq!(deserialized.config.len(), 5, "Should have 5 configuration items"); + + // Check specific configs + let imap_server_config = deserialized.config.iter().find(|c| match c { + ToolConfig::BasicConfig(bc) => bc.key_name == "imap_server", + _ => false, + }).unwrap(); + if let ToolConfig::BasicConfig(config) = imap_server_config { + assert_eq!(config.description, "The IMAP server address"); + assert_eq!(config.type_name, Some("string".to_string())); + assert!(config.required); + assert_eq!(config.key_value, None); + } + + let port_config = deserialized.config.iter().find(|c| match c { + ToolConfig::BasicConfig(bc) => bc.key_name == "port", + _ => false, + }).unwrap(); + if let ToolConfig::BasicConfig(config) = port_config { + assert_eq!(config.description, "The port number for the IMAP server (defaults to 993 for IMAPS)"); + assert_eq!(config.type_name, Some("integer".to_string())); + assert!(!config.required); + assert_eq!(config.key_value, None); + } + } } diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/js_toolkit.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/js_toolkit.rs index 70332f766..bf480320a 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/js_toolkit.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/js_toolkit.rs @@ -129,6 +129,7 @@ impl JSToolkit { required: definition.configurations["required"] .as_array() .map_or(false, |req| req.iter().any(|r| r == key)), + type_name: value["type"].as_str().map(String::from), key_value: None, }) }) diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/tool_config.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/tool_config.rs index 3d8514c1c..50e8737a8 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/tool_config.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/tool_config.rs @@ -41,6 +41,7 @@ impl ToolConfig { key_name: config.key_name.clone(), description: config.description.clone(), required: config.required, + type_name: config.type_name.clone(), key_value: None, }), } @@ -52,12 +53,17 @@ impl ToolConfig { if let Some(obj) = value.as_object() { for (key, val) in obj { - let key_value = val.as_str().map(String::from); + let (key_value, type_name) = if let Some(val_obj) = val.as_object() { + (None, val_obj.get("type").and_then(|v| v.as_str()).map(String::from)) + } else { + (val.as_str().map(String::from), None) + }; let basic_config = BasicConfig { key_name: key.clone(), description: format!("Description for {}", key), required: false, // Set default or determine from context + type_name, key_value, }; configs.push(ToolConfig::BasicConfig(basic_config)); @@ -75,11 +81,13 @@ impl ToolConfig { let description = obj.get("description").and_then(|v| v.as_str()).unwrap_or_default(); let required = obj.get("required").and_then(|v| v.as_bool()).unwrap_or(false); let key_value = obj.get("key_value").and_then(|v| v.as_str()).map(String::from); + let type_name = obj.get("type").and_then(|v| v.as_str()).map(String::from); let basic_config = BasicConfig { key_name: key_name.to_string(), description: description.to_string(), required, + type_name, key_value, }; return Some(ToolConfig::BasicConfig(basic_config)); @@ -219,6 +227,8 @@ pub struct BasicConfig { pub key_name: String, pub description: String, pub required: bool, + #[serde(rename = "type")] + pub type_name: Option, pub key_value: Option, } @@ -232,6 +242,7 @@ mod tests { "key_name": "apiKey", "description": "API Key for weather service", "required": true, + "type": "string", "key_value": "63d35ff6068c3103ccd1227546935111" }"#; @@ -243,6 +254,7 @@ mod tests { assert_eq!(config.key_name, "apiKey"); assert_eq!(config.description, "API Key for weather service"); assert!(config.required); + assert_eq!(config.type_name, Some("string".to_string())); assert_eq!(config.key_value, Some("63d35ff6068c3103ccd1227546935111".to_string())); } _ => panic!("Parsed ToolConfig is not a BasicConfig"), @@ -267,6 +279,7 @@ mod tests { assert_eq!(config.key_name, "apiKey"); assert_eq!(config.description, ""); assert!(!config.required); + assert_eq!(config.type_name, None); assert_eq!(config.key_value, Some("63d35ff6068c3103ccd1227546935111".to_string())); } _ => panic!("Parsed ToolConfig is not a BasicConfig"), diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/tool_playground.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/tool_playground.rs index 1ef639728..5ddd95d4e 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/tool_playground.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/tool_playground.rs @@ -85,12 +85,14 @@ where let configs = properties .iter() .map(|(key, val)| { - let description = val.get("type").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let description = val.get("description").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let type_name = val.get("type").and_then(|v| v.as_str()).map(String::from); let required = required_keys.contains(key); let basic_config = BasicConfig { key_name: key.clone(), description, required, + type_name, key_value: None, // or extract a default value if needed }; ToolConfig::BasicConfig(basic_config) From 31844aa993d17df6cd6dd3d37077da14d9107285 Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Thu, 9 Jan 2025 13:08:42 -0600 Subject: [PATCH 03/18] add phi4 --- .../shinkai-node/src/managers/model_capabilities_manager.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs b/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs index 155c2d330..2aa083c07 100644 --- a/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs +++ b/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs @@ -427,6 +427,7 @@ impl ModelCapabilitiesManager { model_type if model_type.starts_with("falcon2") => 8_000, model_type if model_type.starts_with("llama3-chatqa") => 8_000, model_type if model_type.starts_with("llava-phi3") => 4_000, + model_type if model_type.starts_with("phi4") => 16_000, model_type if model_type.contains("minicpm-v") => 8_000, model_type if model_type.starts_with("dolphin-llama3") => 8_000, model_type if model_type.starts_with("command-r-plus") => 128_000, From 538708958e3d3f887ef565b9d743e3119fb142ce Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Thu, 9 Jan 2025 13:48:19 -0600 Subject: [PATCH 04/18] extend api support --- .../src/network/v2_api/api_v2_commands_vecfs.rs | 4 +++- .../tests/it/utils/shinkai_testing_framework.rs | 5 ++++- shinkai-bin/shinkai-node/tests/it/utils/vecfs_test_utils.rs | 5 ++++- .../src/shinkai_message/shinkai_message_schemas.rs | 1 + .../shinkai_utils/shinkai_message_builder_bundled_vecfs.rs | 6 +++++- .../shinkai-tools-primitives/src/tools/deno_tools.rs | 2 +- 6 files changed, 18 insertions(+), 5 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_vecfs.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_vecfs.rs index 6e21ca473..513c51eb1 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_vecfs.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_vecfs.rs @@ -41,8 +41,10 @@ impl Node { let vr_path = ShinkaiPath::from_string(input_payload.path); + let depth = input_payload.depth.unwrap_or(1); + // Use list_directory_contents_with_depth to get directory contents with depth 1 - let directory_contents = ShinkaiFileManager::list_directory_contents_with_depth(vr_path, &db, 1); + let directory_contents = ShinkaiFileManager::list_directory_contents_with_depth(vr_path, &db, depth); if let Err(e) = directory_contents { let api_error = APIError { diff --git a/shinkai-bin/shinkai-node/tests/it/utils/shinkai_testing_framework.rs b/shinkai-bin/shinkai-node/tests/it/utils/shinkai_testing_framework.rs index 4a1c8eda4..212bcbc11 100644 --- a/shinkai-bin/shinkai-node/tests/it/utils/shinkai_testing_framework.rs +++ b/shinkai-bin/shinkai-node/tests/it/utils/shinkai_testing_framework.rs @@ -208,7 +208,10 @@ impl ShinkaiTestingFramework { /// Retrieves simplified path information and optionally prints it based on `should_print`. pub async fn retrieve_and_print_path_simplified(&self, path: &str, should_print: bool) -> serde_json::Value { - let payload = APIVecFsRetrievePathSimplifiedJson { path: path.to_string() }; + let payload = APIVecFsRetrievePathSimplifiedJson { + path: path.to_string(), + depth: Some(1), + }; let msg = generate_message_with_payload( serde_json::to_string(&payload).unwrap(), MessageSchemaType::VecFsRetrievePathSimplifiedJson, diff --git a/shinkai-bin/shinkai-node/tests/it/utils/vecfs_test_utils.rs b/shinkai-bin/shinkai-node/tests/it/utils/vecfs_test_utils.rs index e6a388663..5dd76ee27 100644 --- a/shinkai-bin/shinkai-node/tests/it/utils/vecfs_test_utils.rs +++ b/shinkai-bin/shinkai-node/tests/it/utils/vecfs_test_utils.rs @@ -134,7 +134,10 @@ pub async fn retrieve_file_info( path: &str, is_simple: bool, ) -> Value { - let payload = APIVecFsRetrievePathSimplifiedJson { path: path.to_string() }; + let payload = APIVecFsRetrievePathSimplifiedJson { + path: path.to_string(), + depth: Some(1), + }; let msg = generate_message_with_payload( serde_json::to_string(&payload).unwrap(), diff --git a/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs b/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs index 6fddcaf12..f38291f57 100644 --- a/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs +++ b/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs @@ -481,6 +481,7 @@ pub struct APIAddAgentRequest { #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, ToSchema)] pub struct APIVecFsRetrievePathSimplifiedJson { pub path: String, + pub depth: Option, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, ToSchema)] diff --git a/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_message_builder_bundled_vecfs.rs b/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_message_builder_bundled_vecfs.rs index 6a15be9a3..356a7e799 100644 --- a/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_message_builder_bundled_vecfs.rs +++ b/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_message_builder_bundled_vecfs.rs @@ -380,6 +380,7 @@ impl ShinkaiMessageBuilder { #[allow(dead_code)] pub fn vecfs_retrieve_path_simplified( path: &str, + depth: Option, my_encryption_secret_key: EncryptionStaticKey, my_signature_secret_key: SigningKey, receiver_public_key: EncryptionPublicKey, @@ -388,7 +389,10 @@ impl ShinkaiMessageBuilder { node_receiver: ShinkaiNameString, node_receiver_subidentity: ShinkaiNameString, ) -> Result { - let payload = APIVecFsRetrievePathSimplifiedJson { path: path.to_string() }; + let payload = APIVecFsRetrievePathSimplifiedJson { + path: path.to_string(), + depth: depth, + }; Self::create_vecfs_message( payload, diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs index 13749ebcb..0e7820a39 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs @@ -561,7 +561,7 @@ mod tests { description: "Fetches emails from an IMAP server".to_string(), keywords: vec!["email".to_string(), "imap".to_string()], js_code: "".to_string(), - tools: None, + tools: vec![], config: vec![ ToolConfig::BasicConfig(BasicConfig { key_name: "imap_server".to_string(), From deb00a2572dabc39acdfba16edb490c14858ad57 Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Thu, 9 Jan 2025 14:08:11 -0600 Subject: [PATCH 05/18] add error for ollama models eg not supporting tools --- .../shinkai-node/src/llm_provider/error.rs | 3 +++ .../src/llm_provider/providers/ollama.rs | 26 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/shinkai-bin/shinkai-node/src/llm_provider/error.rs b/shinkai-bin/shinkai-node/src/llm_provider/error.rs index 5456b4bcf..443febf95 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/error.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/error.rs @@ -84,6 +84,7 @@ pub enum LLMProviderError { AgentNotFound(String), MessageTooLargeForLLM { max_tokens: usize, used_tokens: usize }, SomeError(String), + APIError(String), } impl fmt::Display for LLMProviderError { @@ -176,6 +177,7 @@ impl fmt::Display for LLMProviderError { write!(f, "Message too large for LLM: Used {} tokens, but the maximum allowed is {}.", used_tokens, max_tokens) }, LLMProviderError::SomeError(s) => write!(f, "{}", s), + LLMProviderError::APIError(s) => write!(f, "{}", s), } } } @@ -256,6 +258,7 @@ impl LLMProviderError { LLMProviderError::AgentNotFound(_) => "AgentNotFound", LLMProviderError::MessageTooLargeForLLM { .. } => "MessageTooLargeForLLM", LLMProviderError::SomeError(_) => "SomeError", + LLMProviderError::APIError(_) => "APIError", }; format!("Error {} with message: {}", error_name, self) diff --git a/shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs b/shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs index c1fe35a68..64813bfd5 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs @@ -290,6 +290,19 @@ async fn process_stream( if !previous_json_chunk.is_empty() { chunk_str = previous_json_chunk.clone() + chunk_str.as_str(); } + + // First check if it's an error response + if let Ok(error_response) = serde_json::from_str::(&chunk_str) { + if let Some(error_msg) = error_response.get("error").and_then(|e| e.as_str()) { + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Error, + format!("Ollama API Error: {}", error_msg).as_str(), + ); + return Err(LLMProviderError::APIError(error_msg.to_string())); + } + } + let data_resp: Result = serde_json::from_str(&chunk_str); match data_resp { Ok(data) => { @@ -484,6 +497,19 @@ async fn handle_non_streaming_response( result = &mut response_future => { let res = result?; let response_body = res.text().await?; + + // First check if it's an error response + if let Ok(error_response) = serde_json::from_str::(&response_body) { + if let Some(error_msg) = error_response.get("error").and_then(|e| e.as_str()) { + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Error, + format!("Ollama API Error: {}", error_msg).as_str(), + ); + return Err(LLMProviderError::APIError(error_msg.to_string())); + } + } + let response_json: serde_json::Value = serde_json::from_str(&response_body)?; if let Some(message) = response_json.get("message") { From cefa0937ca19b81a530317db994dcda08f765cc1 Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 9 Jan 2025 18:26:55 -0300 Subject: [PATCH 06/18] chore: bump version to 0.9.5 in Cargo.toml (#782) Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Nicolas Arqueros Co-authored-by: Alfredo Gallardo --- Cargo.lock | 28 ++++++++++++++-------------- Cargo.toml | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0363a1260..633b67a9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6573,7 +6573,7 @@ checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" [[package]] name = "shinkai-spreadsheet-llm" -version = "0.9.4" +version = "0.9.5" dependencies = [ "async-trait", "chrono", @@ -6588,7 +6588,7 @@ dependencies = [ [[package]] name = "shinkai_crypto_identities" -version = "0.9.4" +version = "0.9.5" dependencies = [ "chrono", "dashmap", @@ -6604,7 +6604,7 @@ dependencies = [ [[package]] name = "shinkai_embedding" -version = "0.9.4" +version = "0.9.5" dependencies = [ "async-trait", "bincode", @@ -6625,7 +6625,7 @@ dependencies = [ [[package]] name = "shinkai_fs" -version = "0.9.4" +version = "0.9.5" dependencies = [ "async-trait", "bincode", @@ -6656,7 +6656,7 @@ dependencies = [ [[package]] name = "shinkai_http_api" -version = "0.9.4" +version = "0.9.5" dependencies = [ "async-channel 1.9.0", "bytes", @@ -6681,7 +6681,7 @@ dependencies = [ [[package]] name = "shinkai_job_queue_manager" -version = "0.9.4" +version = "0.9.5" dependencies = [ "chrono", "serde", @@ -6695,7 +6695,7 @@ dependencies = [ [[package]] name = "shinkai_message_primitives" -version = "0.9.4" +version = "0.9.5" dependencies = [ "aes-gcm", "async-trait", @@ -6722,14 +6722,14 @@ dependencies = [ [[package]] name = "shinkai_mini_libs" -version = "0.9.4" +version = "0.9.5" dependencies = [ "base64 0.22.1", ] [[package]] name = "shinkai_node" -version = "0.9.4" +version = "0.9.5" dependencies = [ "aes-gcm", "anyhow", @@ -6801,7 +6801,7 @@ dependencies = [ [[package]] name = "shinkai_ocr" -version = "0.9.4" +version = "0.9.5" dependencies = [ "anyhow", "image 0.25.5", @@ -6816,7 +6816,7 @@ dependencies = [ [[package]] name = "shinkai_sheet" -version = "0.9.4" +version = "0.9.5" dependencies = [ "async-channel 1.9.0", "chrono", @@ -6830,7 +6830,7 @@ dependencies = [ [[package]] name = "shinkai_sqlite" -version = "0.9.4" +version = "0.9.5" dependencies = [ "bincode", "blake3", @@ -6861,7 +6861,7 @@ dependencies = [ [[package]] name = "shinkai_tcp_relayer" -version = "0.9.4" +version = "0.9.5" dependencies = [ "chrono", "clap 3.2.25", @@ -6880,7 +6880,7 @@ dependencies = [ [[package]] name = "shinkai_tools_primitives" -version = "0.9.4" +version = "0.9.5" dependencies = [ "anyhow", "regex", diff --git a/Cargo.toml b/Cargo.toml index a2c8445e1..40b6ded89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ members = [ resolver = "2" [workspace.package] -version = "0.9.4" +version = "0.9.5" edition = "2021" authors = ["Nico Arqueros "] From 166017fb7f172550e13ead61c0e4e39b5a3a3a11 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Thu, 9 Jan 2025 16:28:45 -0500 Subject: [PATCH 07/18] Fixes env vars for run node scripts. Add Docker file to build from binary. --- .github/binary.Dockerfile | 19 +++++++++++++++++++ scripts/run_all_localhost.sh | 10 +++++++--- scripts/run_local_ai_with_proxy.sh | 1 + scripts/run_node_localhost.sh | 1 + 4 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 .github/binary.Dockerfile diff --git a/.github/binary.Dockerfile b/.github/binary.Dockerfile new file mode 100644 index 000000000..546660142 --- /dev/null +++ b/.github/binary.Dockerfile @@ -0,0 +1,19 @@ +FROM ubuntu:24.10 AS downloader + RUN apt-get update && apt-get install -y curl unzip + ARG SHINKAI_NODE_VERSION + RUN curl -L -o shinkai-node.zip https://download.shinkai.com/shinkai-node/binaries/production/x86_64-unknown-linux-gnu/${SHINKAI_NODE_VERSION:-latest}.zip + RUN FILE_SIZE=$(stat -c %s /shinkai-node.zip) && \ + if [ $FILE_SIZE -lt 26214400 ]; then \ + echo "Error: shinkai-node file is less than 25MB" && \ + exit 1; \ + fi + RUN unzip shinkai-node.zip + RUN chmod +x /shinkai-node + + FROM ubuntu:24.10 + RUN apt-get update && apt-get install -y openssl ca-certificates + WORKDIR /app + COPY --from=downloader /shinkai-node ./shinkai-node + + EXPOSE 9550 + ENTRYPOINT ["/bin/sh", "-c", "/app/shinkai-node"] \ No newline at end of file diff --git a/scripts/run_all_localhost.sh b/scripts/run_all_localhost.sh index 9a9cd0a9c..3acca999e 100755 --- a/scripts/run_all_localhost.sh +++ b/scripts/run_all_localhost.sh @@ -1,9 +1,11 @@ #!/bin/bash -export NODE_IP="0.0.0.0" -export NODE_PORT="9552" export NODE_API_IP="0.0.0.0" +export NODE_IP="0.0.0.0" export NODE_API_PORT="9550" +export NODE_WS_PORT="9551" +export NODE_PORT="9552" +export NODE_HTTPS_PORT="9553" export IDENTITY_SECRET_KEY="df3f619804a92fdb4057192dc43dd748ea778adc52bc498ce80524c014b81119" export ENCRYPTION_SECRET_KEY="d83f619804a92fdb4057192dc43dd748ea778adc52bc498ce80524c014b81159" export PING_INTERVAL_SECS="0" @@ -14,9 +16,11 @@ export STARTING_NUM_QR_DEVICES="1" export FIRST_DEVICE_NEEDS_REGISTRATION_CODE="false" export LOG_SIMPLE="true" export NO_SECRET_FILE="true" -export EMBEDDINGS_SERVER_URL="http://localhost:9081/" +export EMBEDDINGS_SERVER_URL="http://localhost:11434/" export PROXY_IDENTITY="@@relayer_pub_01.arb-sep-shinkai" export SHINKAI_TOOLS_RUNNER_DENO_BINARY_PATH="${workspaceFolder}/shinkai-bin/shinkai-node/shinkai-tools-runner-resources/deno" +export SHINKAI_TOOLS_RUNNER_UV_BINARY_PATH="${workspaceFolder}/shinkai-bin/shinkai-node/shinkai-tools-runner-resources/uv" +export SHINKAI_TOOLS_DIRECTORY_URL="https://download.shinkai.com/tools/directory.json" export INITIAL_AGENT_NAMES="o_mixtral" export INITIAL_AGENT_URLS="http://localhost:11434" diff --git a/scripts/run_local_ai_with_proxy.sh b/scripts/run_local_ai_with_proxy.sh index dc3cd6632..bcfa2aba7 100755 --- a/scripts/run_local_ai_with_proxy.sh +++ b/scripts/run_local_ai_with_proxy.sh @@ -27,5 +27,6 @@ export LOG_ALL="1" export DEBUG_VRKAI="1" # export PROXY_IDENTITY="@@kao_tcp_relayer.arb-sep-shinkai" export PROXY_IDENTITY="@@relayer_pub_01.arb-sep-shinkai" +export SHINKAI_TOOLS_DIRECTORY_URL="https://download.shinkai.com/tools/directory.json" cargo run --bin shinkai_node --package shinkai_node \ No newline at end of file diff --git a/scripts/run_node_localhost.sh b/scripts/run_node_localhost.sh index d080815cd..9661819fa 100755 --- a/scripts/run_node_localhost.sh +++ b/scripts/run_node_localhost.sh @@ -13,6 +13,7 @@ export FIRST_DEVICE_NEEDS_REGISTRATION_CODE="false" export LOG_SIMPLE="true" export EMBEDDINGS_SERVER_URL="http://localhost:11434" # assumes that you installed the embeddings server locally using ollama (shinkai-apps helps you handling all of this) # export EMBEDDINGS_SERVER_URL="https://public.shinkai.com/x-em" # if you prefer to use the public embeddings server +export SHINKAI_TOOLS_DIRECTORY_URL="https://download.shinkai.com/tools/directory.json" # Add these lines to enable all log options export LOG_ALL=1 From d666000279c5c094d6087f26ff372768bb6ac44f Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Thu, 9 Jan 2025 17:45:36 -0600 Subject: [PATCH 08/18] update --- shinkai-bin/shinkai-node/src/managers/tool_router.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shinkai-bin/shinkai-node/src/managers/tool_router.rs b/shinkai-bin/shinkai-node/src/managers/tool_router.rs index ecb643e35..81db4f399 100644 --- a/shinkai-bin/shinkai-node/src/managers/tool_router.rs +++ b/shinkai-bin/shinkai-node/src/managers/tool_router.rs @@ -148,7 +148,7 @@ impl ToolRouter { let start_time = Instant::now(); let url = env::var("SHINKAI_TOOLS_DIRECTORY_URL") - .map_err(|_| ToolError::MissingConfigError("SHINKAI_TOOLS_DIRECTORY_URL not set".to_string()))?; + .unwrap_or_else(|_| "https://download.shinkai.com/tools/directory.json".to_string()); let response = reqwest::get(url).await.map_err(|e| ToolError::RequestError(e))?; From 776c9beaf1ef01b235860c1d8d5da5fcd32e91e0 Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Fri, 10 Jan 2025 10:37:58 -0600 Subject: [PATCH 09/18] fix --- .../shinkai-node/src/managers/tool_router.rs | 25 ++++++++++++++++--- .../shinkai_message_schemas.rs | 1 + 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/managers/tool_router.rs b/shinkai-bin/shinkai-node/src/managers/tool_router.rs index 81db4f399..c681051a0 100644 --- a/shinkai-bin/shinkai-node/src/managers/tool_router.rs +++ b/shinkai-bin/shinkai-node/src/managers/tool_router.rs @@ -24,7 +24,7 @@ use shinkai_message_primitives::schemas::shinkai_tool_offering::{ use shinkai_message_primitives::schemas::shinkai_tools::CodeLanguage; use shinkai_message_primitives::schemas::wallet_mixed::{Asset, NetworkIdentifier}; use shinkai_message_primitives::schemas::ws_types::{PaymentMetadata, WSMessageType, WidgetMetadata}; -use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::WSTopic; +use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::{AssociatedUI, WSTopic}; use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}; use shinkai_sqlite::errors::SqliteManagerError; use shinkai_sqlite::files::prompts_data; @@ -496,7 +496,13 @@ impl ToolRouter { .node_storage_path .clone() .ok_or_else(|| ToolError::ExecutionError("Node storage path is not set".to_string()))?; - let app_id = context.full_job().job_id().to_string(); + + // Get app_id from Cron UI if present, otherwise use job_id + let app_id = match context.full_job().associated_ui().as_ref() { + Some(AssociatedUI::Cron(cron_id)) => cron_id.clone(), + _ => context.full_job().job_id().to_string() + }; + let tool_id = shinkai_tool.tool_router_key().to_string_without_version().clone(); let tools = python_tool.tools.clone(); let support_files = @@ -555,7 +561,12 @@ impl ToolRouter { }); } ShinkaiTool::Rust(_rust_tool, _is_enabled) => { - let app_id = context.full_job().job_id().to_string(); + // Get app_id from Cron UI if present, otherwise use job_id + let app_id = match context.full_job().associated_ui().as_ref() { + Some(AssociatedUI::Cron(cron_id)) => cron_id.clone(), + _ => context.full_job().job_id().to_string() + }; + let tool_id = shinkai_tool.tool_router_key().to_string_without_version().clone(); let function_config = shinkai_tool.get_config_from_env(); let function_config_vec: Vec = function_config.into_iter().collect(); @@ -610,7 +621,13 @@ impl ToolRouter { .node_storage_path .clone() .ok_or_else(|| ToolError::ExecutionError("Node storage path is not set".to_string()))?; - let app_id = context.full_job().job_id().to_string(); + + // Get app_id from Cron UI if present, otherwise use job_id + let app_id = match context.full_job().associated_ui().as_ref() { + Some(AssociatedUI::Cron(cron_id)) => cron_id.clone(), + _ => context.full_job().job_id().to_string() + }; + let tool_id = shinkai_tool.tool_router_key().to_string_without_version().clone(); let tools = deno_tool.tools.clone(); let support_files = diff --git a/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs b/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs index f38291f57..60f8164ff 100644 --- a/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs +++ b/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs @@ -296,6 +296,7 @@ pub struct SymmetricKeyExchange { pub enum AssociatedUI { Sheet(String), Playground, + Cron(String), // Add more variants as needed } From cdc907c678db68ae5125bd8d7d992ad1cf0fc017 Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Fri, 10 Jan 2025 10:41:38 -0600 Subject: [PATCH 10/18] add cron id to job --- shinkai-bin/shinkai-node/src/cron_tasks/cron_manager.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/shinkai-bin/shinkai-node/src/cron_tasks/cron_manager.rs b/shinkai-bin/shinkai-node/src/cron_tasks/cron_manager.rs index e8688522c..e3c7736a4 100644 --- a/shinkai-bin/shinkai-node/src/cron_tasks/cron_manager.rs +++ b/shinkai-bin/shinkai-node/src/cron_tasks/cron_manager.rs @@ -15,7 +15,7 @@ use shinkai_message_primitives::{ shinkai_name::ShinkaiName, ws_types::WSUpdateHandler, }, - shinkai_message::shinkai_message_schemas::JobMessage, + shinkai_message::shinkai_message_schemas::{AssociatedUI, JobMessage}, shinkai_utils::{ shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}, signatures::clone_signature_secret_key, @@ -331,6 +331,8 @@ impl CronManager { if job_creation_info_clone.is_hidden.is_none() { job_creation_info_clone.is_hidden = Some(true); } + // Set the associated UI to Cron with the task ID + job_creation_info_clone.associated_ui = Some(AssociatedUI::Cron(cron_job.task_id.to_string())); let job_id = job_manager .lock() From 1b3ca730302b397fb834303a5dd71130daee5d0a Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Fri, 10 Jan 2025 13:48:34 -0600 Subject: [PATCH 11/18] fix stream --- shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs b/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs index 7badc4d82..6c2d6831f 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs @@ -98,6 +98,7 @@ impl LLMService for OpenAI { "model": self.model_type, "messages": messages_json, "max_tokens": result.remaining_output_tokens, + "stream": is_stream, }); // Conditionally add functions to the payload if tools_json is not empty From 1258bd9c74839b9a83106832c935a378e9cfca49 Mon Sep 17 00:00:00 2001 From: Alfredo Gallardo Date: Fri, 10 Jan 2025 16:54:41 -0300 Subject: [PATCH 12/18] fix: os path for multiplatform compatibility --- Cargo.lock | 14 +++++++++++++- shinkai-libs/shinkai-fs/Cargo.toml | 2 ++ .../shinkai-fs/src/shinkai_file_manager.rs | 15 ++++++++++----- .../shinkai-message-primitives/Cargo.toml | 2 ++ .../src/shinkai_utils/shinkai_path.rs | 12 +++++++----- 5 files changed, 34 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 633b67a9e..71649811e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "Inflector" @@ -4542,6 +4542,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "os_path" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "360a6ecb129f544ba5ae18776ca8779cf3cf979c8133e9eefe9464ea74741f6b" +dependencies = [ + "regex", + "serde", +] + [[package]] name = "os_str_bytes" version = "6.6.1" @@ -6636,6 +6646,7 @@ dependencies = [ "futures", "keyphrases", "lazy_static", + "os_path", "rand 0.8.5", "regex", "reqwest 0.11.27", @@ -6705,6 +6716,7 @@ dependencies = [ "chrono", "ed25519-dalek", "hex", + "os_path", "rand 0.8.5", "regex", "rust_decimal", diff --git a/shinkai-libs/shinkai-fs/Cargo.toml b/shinkai-libs/shinkai-fs/Cargo.toml index 498985927..016e5258c 100644 --- a/shinkai-libs/shinkai-fs/Cargo.toml +++ b/shinkai-libs/shinkai-fs/Cargo.toml @@ -29,6 +29,8 @@ csv = "1.1.6" utoipa = "4.2.3" regex = { workspace = true } +os_path = "0.8.0" + [dependencies.serde] workspace = true features = ["derive"] diff --git a/shinkai-libs/shinkai-fs/src/shinkai_file_manager.rs b/shinkai-libs/shinkai-fs/src/shinkai_file_manager.rs index 621b4fa26..1d1987831 100644 --- a/shinkai-libs/shinkai-fs/src/shinkai_file_manager.rs +++ b/shinkai-libs/shinkai-fs/src/shinkai_file_manager.rs @@ -713,29 +713,34 @@ mod tests { let level2_contents = level1_info.children.as_ref().unwrap(); assert_eq!(level2_contents.len(), 2); // One directory and one file - let file1_info = level2_contents.iter().find(|info| info.path == "level1/file1.txt").unwrap(); + let file1_path = os_path::OsPath::from("level1/file1.txt").to_string(); + let file1_info = level2_contents.iter().find(|info| info.path == file1_path).unwrap(); assert!(!file1_info.is_directory); assert!(file1_info.has_embeddings, "File 'level1/file1.txt' should have embeddings."); - let level2_info = level2_contents.iter().find(|info| info.path == "level1/level2").unwrap(); + let level2_path = os_path::OsPath::from("level1/level2").to_string(); + let level2_info = level2_contents.iter().find(|info| info.path == level2_path).unwrap(); assert!(level2_info.is_directory); assert!(level2_info.children.is_some()); let level3_contents = level2_info.children.as_ref().unwrap(); assert_eq!(level3_contents.len(), 2); // One directory and one file - let file2_info = level3_contents.iter().find(|info| info.path == "level1/level2/file2.txt").unwrap(); + let file2_path = os_path::OsPath::from("level1/level2/file2.txt").to_string(); + let file2_info = level3_contents.iter().find(|info| info.path == file2_path).unwrap(); assert!(!file2_info.is_directory); assert!(file2_info.has_embeddings, "File 'level1/level2/file2.txt' should have embeddings."); - let level3_info = level3_contents.iter().find(|info| info.path == "level1/level2/level3").unwrap(); + let level3_path = os_path::OsPath::from("level1/level2/level3").to_string(); + let level3_info = level3_contents.iter().find(|info| info.path == level3_path).unwrap(); assert!(level3_info.is_directory); assert!(level3_info.children.is_some()); let level3_files = level3_info.children.as_ref().unwrap(); assert_eq!(level3_files.len(), 1); // Only one file - let file3_info = level3_files.iter().find(|info| info.path == "level1/level2/level3/file3.txt").unwrap(); + let file3_path = os_path::OsPath::from("level1/level2/level3/file3.txt").to_string(); + let file3_info = level3_files.iter().find(|info| info.path == file3_path).unwrap(); assert!(!file3_info.is_directory); assert!(!file3_info.has_embeddings, "File 'level1/level2/level3/file3.txt' should not have embeddings."); } diff --git a/shinkai-libs/shinkai-message-primitives/Cargo.toml b/shinkai-libs/shinkai-message-primitives/Cargo.toml index 5c06caaa3..3ca9d74e0 100644 --- a/shinkai-libs/shinkai-message-primitives/Cargo.toml +++ b/shinkai-libs/shinkai-message-primitives/Cargo.toml @@ -29,6 +29,8 @@ tracing = { version = "0.1.40", optional = true } tracing-subscriber = { version = "0.3", optional = true } +os_path = { version = "0.8.0" } + [lib] crate-type = ["cdylib", "rlib"] diff --git a/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_path.rs b/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_path.rs index fdf29932c..f6e0afc53 100644 --- a/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_path.rs +++ b/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_path.rs @@ -15,7 +15,7 @@ impl ShinkaiPath { /// Private helper method to create a ShinkaiPath from a &str. pub fn new(path: &str) -> Self { let base_path = Self::base_path(); - let path_buf = PathBuf::from(path); + let path_buf = os_path::OsPath::from(path).to_pathbuf(); // PathBuf::from(path); let final_path = if path_buf.is_absolute() { if path_buf.starts_with(&base_path) { @@ -227,7 +227,7 @@ mod tests { env::var("NODE_STORAGE_PATH").unwrap() )) ); - assert_eq!(path.relative_path(), "word_files/christmas.docx"); + assert_eq!(path.relative_path(), os_path::OsPath::from("word_files/christmas.docx").to_string()); } #[test] @@ -240,7 +240,7 @@ mod tests { path.as_path(), Path::new("storage/filesystem/word_files/christmas.docx") ); - assert_eq!(path.relative_path(), "word_files/christmas.docx"); + assert_eq!(path.relative_path(), os_path::OsPath::from("word_files/christmas.docx").to_string()); } #[test] @@ -248,7 +248,7 @@ mod tests { fn test_relative_path_outside_base() { let _dir = testing_create_tempdir_and_set_env_var(); let absolute_outside = ShinkaiPath::from_string("/some/other/path".to_string()); - assert_eq!(absolute_outside.relative_path(), "some/other/path"); + assert_eq!(absolute_outside.relative_path(), os_path::OsPath::from("some/other/path").to_string()); } #[test] @@ -349,6 +349,8 @@ mod tests { let serialized_path = serde_json::to_string(&path).unwrap(); // Check if the serialized output matches the expected relative path - assert_eq!(serialized_path, "\"word_files/christmas.docx\""); + let serialized_path_str = serde_json::to_string(&os_path::OsPath::from("word_files/christmas.docx").to_string()).unwrap(); + + assert_eq!(serialized_path, serialized_path_str); } } From 6ab607b1bc2dba593a82906a0b1e1a1af40f724f Mon Sep 17 00:00:00 2001 From: Eddie Date: Fri, 10 Jan 2025 17:37:43 -0300 Subject: [PATCH 13/18] Removed storage prefix --- .../shinkai-node/src/managers/tool_router.rs | 39 ++++++++----------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/managers/tool_router.rs b/shinkai-bin/shinkai-node/src/managers/tool_router.rs index c681051a0..05d354ac3 100644 --- a/shinkai-bin/shinkai-node/src/managers/tool_router.rs +++ b/shinkai-bin/shinkai-node/src/managers/tool_router.rs @@ -166,13 +166,7 @@ impl ToolRouter { let tool_urls = tools .iter() - .map(|tool| { - ( - tool["name"].as_str(), - tool["file"].as_str(), - tool["router_key"].as_str(), - ) - }) + .map(|tool| (tool["name"].as_str(), tool["file"].as_str(), tool["routerKey"].as_str())) .collect::>() .into_iter() .filter(|(name, url, router_key)| url.is_some() && name.is_some() && router_key.is_some()) @@ -496,11 +490,11 @@ impl ToolRouter { .node_storage_path .clone() .ok_or_else(|| ToolError::ExecutionError("Node storage path is not set".to_string()))?; - + // Get app_id from Cron UI if present, otherwise use job_id let app_id = match context.full_job().associated_ui().as_ref() { Some(AssociatedUI::Cron(cron_id)) => cron_id.clone(), - _ => context.full_job().job_id().to_string() + _ => context.full_job().job_id().to_string(), }; let tool_id = shinkai_tool.tool_router_key().to_string_without_version().clone(); @@ -513,10 +507,10 @@ impl ToolRouter { let envs = generate_execution_environment( context.db(), context.agent().clone().get_id().to_string(), - format!("jid-{}", tool_id), - format!("jid-{}", app_id), + tool_id.clone(), + app_id.clone(), shinkai_tool.tool_router_key().to_string_without_version().clone(), - format!("jid-{}", app_id), + app_id.clone(), &python_tool.oauth, ) .await @@ -545,8 +539,8 @@ impl ToolRouter { function_args, function_config_vec, node_storage_path, - app_id, - tool_id, + app_id.clone(), + tool_id.clone(), node_name, false, None, @@ -564,7 +558,7 @@ impl ToolRouter { // Get app_id from Cron UI if present, otherwise use job_id let app_id = match context.full_job().associated_ui().as_ref() { Some(AssociatedUI::Cron(cron_id)) => cron_id.clone(), - _ => context.full_job().job_id().to_string() + _ => context.full_job().job_id().to_string(), }; let tool_id = shinkai_tool.tool_router_key().to_string_without_version().clone(); @@ -575,7 +569,6 @@ impl ToolRouter { let llm_provider = context.agent().get_llm_provider_id().to_string(); let bearer = db.read_api_v2_key().unwrap_or_default().unwrap_or_default(); - let job_callback_manager = context.job_callback_manager(); let mut job_manager: Option>> = None; if let Some(job_callback_manager) = &job_callback_manager { @@ -584,7 +577,9 @@ impl ToolRouter { } if job_manager.is_none() { - return Err(LLMProviderError::FunctionExecutionError("Job manager is not available".to_string())); + return Err(LLMProviderError::FunctionExecutionError( + "Job manager is not available".to_string(), + )); } let result = execute_custom_tool( @@ -621,11 +616,11 @@ impl ToolRouter { .node_storage_path .clone() .ok_or_else(|| ToolError::ExecutionError("Node storage path is not set".to_string()))?; - + // Get app_id from Cron UI if present, otherwise use job_id let app_id = match context.full_job().associated_ui().as_ref() { Some(AssociatedUI::Cron(cron_id)) => cron_id.clone(), - _ => context.full_job().job_id().to_string() + _ => context.full_job().job_id().to_string(), }; let tool_id = shinkai_tool.tool_router_key().to_string_without_version().clone(); @@ -638,10 +633,10 @@ impl ToolRouter { let envs = generate_execution_environment( context.db(), context.agent().clone().get_id().to_string(), - format!("jid-{}", app_id), - format!("jid-{}", tool_id), + app_id.clone(), + tool_id.clone(), shinkai_tool.tool_router_key().to_string_without_version().clone(), - format!("jid-{}", app_id), + app_id.clone(), &deno_tool.oauth, ) .await From 5d83ce680bf2d2114ee5ff99c4a4789b303e77e6 Mon Sep 17 00:00:00 2001 From: Alfredo Gallardo Date: Fri, 10 Jan 2025 18:16:43 -0300 Subject: [PATCH 14/18] fix: normalize paths when write & read relative paths --- .../shinkai-fs/src/shinkai_file_manager.rs | 3 ++- .../src/shinkai_utils/job_scope.rs | 8 +++--- .../shinkai-sqlite/src/file_system.rs | 25 ++++++++++++++++--- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/shinkai-libs/shinkai-fs/src/shinkai_file_manager.rs b/shinkai-libs/shinkai-fs/src/shinkai_file_manager.rs index 1d1987831..56e041093 100644 --- a/shinkai-libs/shinkai-fs/src/shinkai_file_manager.rs +++ b/shinkai-libs/shinkai-fs/src/shinkai_file_manager.rs @@ -317,9 +317,10 @@ mod tests { } fn create_test_parsed_file(id: i64, relative_path: &str) -> ParsedFile { + let pf_relative_path = SqliteManager::normalize_path(relative_path); ParsedFile { id: Some(id), - relative_path: relative_path.to_string(), + relative_path: pf_relative_path.to_string(), original_extension: None, description: None, source: None, diff --git a/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/job_scope.rs b/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/job_scope.rs index 94a9e2aea..4ffe965ed 100644 --- a/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/job_scope.rs +++ b/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/job_scope.rs @@ -75,8 +75,8 @@ mod tests { let deserialized: MinimalJobScope = serde_json::from_value(json_data).expect("Failed to deserialize"); assert_eq!(deserialized.vector_fs_items.len(), 2); - assert_eq!(deserialized.vector_fs_items[0].relative_path(), "path/to/file1"); - assert_eq!(deserialized.vector_fs_items[1].relative_path(), "path/to/file2"); + assert_eq!(deserialized.vector_fs_items[0].relative_path(), os_path::OsPath::from("path/to/file1").to_string()); + assert_eq!(deserialized.vector_fs_items[1].relative_path(), os_path::OsPath::from("path/to/file2").to_string()); assert_eq!(deserialized.vector_fs_folders.len(), 1); assert_eq!(deserialized.vector_fs_folders[0].relative_path(), "My Files (Private)"); assert_eq!(deserialized.vector_search_mode, VectorSearchMode::FillUpTo25k); @@ -93,9 +93,9 @@ mod tests { let deserialized: MinimalJobScope = serde_json::from_value(json_data).expect("Failed to deserialize"); assert_eq!(deserialized.vector_fs_items.len(), 1); - assert_eq!(deserialized.vector_fs_items[0].relative_path(), "path/to/file1"); + assert_eq!(deserialized.vector_fs_items[0].relative_path(), os_path::OsPath::from("path/to/file1").to_string()); assert_eq!(deserialized.vector_fs_folders.len(), 1); - assert_eq!(deserialized.vector_fs_folders[0].relative_path(), "My Files (Private)"); + assert_eq!(deserialized.vector_fs_folders[0].relative_path(), os_path::OsPath::from("My Files (Private)").to_string()); assert_eq!(deserialized.vector_search_mode, VectorSearchMode::FillUpTo25k); // Check default } } diff --git a/shinkai-libs/shinkai-sqlite/src/file_system.rs b/shinkai-libs/shinkai-sqlite/src/file_system.rs index 22953bbc9..79967851b 100644 --- a/shinkai-libs/shinkai-sqlite/src/file_system.rs +++ b/shinkai-libs/shinkai-sqlite/src/file_system.rs @@ -6,6 +6,13 @@ use shinkai_message_primitives::{ }; impl SqliteManager { + // TODO: This is a temporary workaround for Windows paths. We should handle this more robustly. + pub fn normalize_path(path: &str) -> String { + let mut path = path.replace("\\\\", "/"); + path = path.replace("\\", "/"); + path + } + pub fn initialize_filesystem_tables(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> { // parsed_files table conn.execute( @@ -70,21 +77,24 @@ impl SqliteManager { let mut conn = self.get_connection()?; let tx = conn.transaction()?; + let pf_relative_path = Self::normalize_path(&pf.relative_path); + let exists: bool = tx.query_row( "SELECT EXISTS(SELECT 1 FROM parsed_files WHERE relative_path = ?)", - [&pf.relative_path], + [&pf_relative_path], |row| row.get(0), )?; if exists { return Err(SqliteManagerError::DataAlreadyExists); } + let relative_path = Self::normalize_path(&pf.relative_path); tx.execute( "INSERT INTO parsed_files (relative_path, original_extension, description, source, embedding_model_used, keywords, distribution_info, created_time, tags, total_tokens, total_characters) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)", params![ - pf.relative_path, + relative_path, pf.original_extension, pf.description, pf.source, @@ -104,6 +114,9 @@ impl SqliteManager { pub fn get_parsed_file_by_rel_path(&self, rel_path: &str) -> Result, SqliteManagerError> { let conn = self.get_connection()?; + + let rel_path = Self::normalize_path(rel_path); + let mut stmt = conn.prepare( " SELECT id, relative_path, original_extension, description, source, embedding_model_used, keywords, @@ -150,13 +163,14 @@ impl SqliteManager { return Err(SqliteManagerError::DataNotFound); } + let relative_path = Self::normalize_path(&pf.relative_path); tx.execute( "UPDATE parsed_files SET relative_path = ?1, original_extension = ?2, description = ?3, source = ?4, embedding_model_used = ?5, keywords = ?6, distribution_info = ?7, created_time = ?8, tags = ?9, total_tokens = ?10, total_characters = ?11 WHERE id = ?12", params![ - pf.relative_path, + relative_path, pf.original_extension, pf.description, pf.source, @@ -448,6 +462,9 @@ impl SqliteManager { // ------------------------- pub fn update_folder_paths(&self, old_prefix: &str, new_prefix: &str) -> Result<(), SqliteManagerError> { + let old_prefix = Self::normalize_path(old_prefix); + let new_prefix = Self::normalize_path(new_prefix); + let mut conn = self.get_connection()?; let tx = conn.transaction()?; @@ -469,6 +486,7 @@ impl SqliteManager { &self, directory_path: &str, ) -> Result, SqliteManagerError> { + let directory_path = Self::normalize_path(directory_path); let conn = self.get_connection()?; let mut stmt = conn.prepare( "SELECT id, relative_path, original_extension, description, source, embedding_model_used, keywords, @@ -547,6 +565,7 @@ impl SqliteManager { /// Retrieve all parsed files whose relative paths start with the given prefix. pub fn get_parsed_files_by_prefix(&self, prefix: &str) -> Result, SqliteManagerError> { + let prefix = Self::normalize_path(prefix); let conn = self.get_connection()?; let mut stmt = conn.prepare( "SELECT id, relative_path, original_extension, description, source, embedding_model_used, keywords, From 1a6447936286a450f81289b8f0a4906a832795d3 Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Fri, 10 Jan 2025 16:42:15 -0600 Subject: [PATCH 15/18] fix --- shinkai-bin/shinkai-node/src/utils/qr_code_setup.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/shinkai-bin/shinkai-node/src/utils/qr_code_setup.rs b/shinkai-bin/shinkai-node/src/utils/qr_code_setup.rs index bb9615647..b1a6d79e4 100644 --- a/shinkai-bin/shinkai-node/src/utils/qr_code_setup.rs +++ b/shinkai-bin/shinkai-node/src/utils/qr_code_setup.rs @@ -108,9 +108,13 @@ pub fn save_qr_data_to_local_image(qr_data: QRSetupData, name: String) { } pub fn print_qr_data_to_console(qr_data: QRSetupData, node_profile: &str) { + let version = env!("CARGO_PKG_VERSION"); + let api_key = std::env::var("API_V2_KEY").unwrap_or_else(|_| "Not set".to_string()); // Print qr_data to console in a beautiful way println!("Please scan the QR code below with your phone to register this device:"); println!("---------------------------------------------------------------"); + println!("Node version: {}", version); + println!("API v2 key (Bearer): {}", api_key); println!("Node registration code: {}", qr_data.registration_code); println!("Node profile: {}", node_profile); println!("Node identity type: {}", qr_data.identity_type); From 87b5b390294dd8ad72aa073cbf4e79a5b65a7f01 Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Fri, 10 Jan 2025 17:37:49 -0600 Subject: [PATCH 16/18] remove --- shinkai-libs/shinkai-sqlite/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shinkai-libs/shinkai-sqlite/src/lib.rs b/shinkai-libs/shinkai-sqlite/src/lib.rs index 6b4fa861b..f2247d65d 100644 --- a/shinkai-libs/shinkai-sqlite/src/lib.rs +++ b/shinkai-libs/shinkai-sqlite/src/lib.rs @@ -889,7 +889,7 @@ impl SqliteManager { // Method to set the version and determine if a global reset is needed pub fn set_version(&self, version: &str) -> Result<()> { // Note: add breaking versions here as needed - let breaking_versions = ["0.9.0", "0.9.1", "0.9.2", "0.9.3", "0.9.4"]; + let breaking_versions = ["0.9.0", "0.9.1", "0.9.2", "0.9.3", "0.9.4", "0.9.5"]; let needs_global_reset = self.get_version().map_or(false, |(current_version, _)| { breaking_versions From bd04eafaa4d8e4de5874f5e95fd48402813bf3ac Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Fri, 10 Jan 2025 20:18:58 -0500 Subject: [PATCH 17/18] Fix binary dockerfile. --- .github/binary.Dockerfile | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/binary.Dockerfile b/.github/binary.Dockerfile index 546660142..0ebc3b2f2 100644 --- a/.github/binary.Dockerfile +++ b/.github/binary.Dockerfile @@ -7,13 +7,13 @@ FROM ubuntu:24.10 AS downloader echo "Error: shinkai-node file is less than 25MB" && \ exit 1; \ fi - RUN unzip shinkai-node.zip - RUN chmod +x /shinkai-node + RUN unzip -o shinkai-node.zip -d ./node + RUN chmod +x /node/shinkai-node - FROM ubuntu:24.10 + FROM ubuntu:24.10 AS runner RUN apt-get update && apt-get install -y openssl ca-certificates WORKDIR /app - COPY --from=downloader /shinkai-node ./shinkai-node + COPY --from=downloader /node ./ EXPOSE 9550 ENTRYPOINT ["/bin/sh", "-c", "/app/shinkai-node"] \ No newline at end of file From 5d05dfc066891c4728b9b2545259d6d9964f3c90 Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Sun, 12 Jan 2025 21:39:12 -0600 Subject: [PATCH 18/18] groq update --- .../src/llm_provider/providers/groq.rs | 186 ++++++++++++------ .../llm_provider/providers/shared/groq_api.rs | 2 +- .../providers/shared/openai_api.rs | 126 +++++++++++- .../managers/model_capabilities_manager.rs | 31 ++- .../src/api_v2/api_v2_handlers_tools.rs | 3 +- 5 files changed, 274 insertions(+), 74 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/llm_provider/providers/groq.rs b/shinkai-bin/shinkai-node/src/llm_provider/providers/groq.rs index ba4dc310c..a08d1696f 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/providers/groq.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/providers/groq.rs @@ -56,6 +56,10 @@ impl LLMService for Groq { } }; + // Extract tools_json from the result and keep original for matching + let tools_json = result.functions.clone().unwrap_or_else(Vec::new); + let original_tools = tools_json.clone(); // Keep original for matching + let mut payload = json!({ "model": self.model_type, "messages": messages_json, @@ -63,6 +67,40 @@ impl LLMService for Groq { "stream": is_stream, }); + // Add tools to payload if they exist, but remove tool_router_key + if !tools_json.is_empty() { + let tools: Vec = tools_json.iter().map(|tool| { + let mut function = tool.clone(); + if let Some(obj) = function.as_object_mut() { + obj.remove("tool_router_key"); + } + json!({ + "type": "function", + "function": function + }) + }).collect(); + payload["tools"] = serde_json::Value::Array(tools); + payload["tool_choice"] = json!("auto"); + } + + // Clean up message content if needed + if let Some(messages) = payload.get_mut("messages") { + if let Some(messages_array) = messages.as_array_mut() { + for message in messages_array { + if let Some(content) = message.get_mut("content") { + if let Some(content_array) = content.as_array() { + // If content is an array with a single text element, simplify it + if content_array.len() == 1 { + if let Some(text_obj) = content_array[0].get("text") { + *content = text_obj.clone(); + } + } + } + } + } + } + } + // Add options to payload add_options_to_payload(&mut payload, config.as_ref()); @@ -82,10 +120,10 @@ impl LLMService for Groq { payload, key.clone(), inbox_name, - ws_manager_trait, + ws_manager_trait.clone(), llm_stopper, session_id, - result.functions, + Some(original_tools), // Pass original tools with router keys ) .await } else { @@ -96,7 +134,8 @@ impl LLMService for Groq { key.clone(), inbox_name, llm_stopper, - result.functions, + ws_manager_trait.clone(), + Some(original_tools), // Pass original tools with router keys ) .await } @@ -209,38 +248,46 @@ async fn handle_streaming_response( if let Some(content) = delta.get("content") { response_text.push_str(content.as_str().unwrap_or("")); } - if let Some(fc) = delta.get("function_call") { - if let Some(name) = fc.get("name") { - let fc_arguments = fc - .get("arguments") - .and_then(|args| args.as_str()) - .and_then(|args_str| serde_json::from_str(args_str).ok()) - .and_then(|args_value: serde_json::Value| { - args_value.as_object().cloned() - }) - .unwrap_or_else(|| serde_json::Map::new()); - - // Search for the tool_router_key in the tools array - let tool_router_key = tools.as_ref().and_then(|tools_array| { - tools_array.iter().find_map(|tool| { - if tool.get("name")?.as_str()? - == name.as_str().unwrap_or("") - { - tool.get("tool_router_key").and_then(|key| { - key.as_str().map(|s| s.to_string()) - }) - } else { - None + if let Some(fc) = delta.get("tool_calls") { + if let Some(tool_calls_array) = fc.as_array() { + for tool_call in tool_calls_array { + if let Some(function) = tool_call.get("function") { + if let Some(name) = function.get("name") { + let fc_arguments = function + .get("arguments") + .and_then(|args| args.as_str()) + .and_then(|args_str| serde_json::from_str(args_str).ok()) + .and_then(|args_value: serde_json::Value| { + args_value.as_object().cloned() + }) + .unwrap_or_else(|| serde_json::Map::new()); + + // Search for the tool_router_key in the tools array + let tool_router_key = tools.as_ref().and_then(|tools_array| { + tools_array.iter().find_map(|tool| { + if let Some(function) = tool.get("function") { + if function.get("name")?.as_str()? == name.as_str().unwrap_or("") { + function.get("tool_router_key").and_then(|key| { + key.as_str().map(|s| s.to_string()) + }) + } else { + None + } + } else { + None + } + }) + }); + + function_calls.push(FunctionCall { + name: name.as_str().unwrap_or("").to_string(), + arguments: fc_arguments.clone(), + tool_router_key, + response: None, + }); } - }) - }); - - function_calls.push(FunctionCall { - name: name.as_str().unwrap_or("").to_string(), - arguments: fc_arguments.clone(), - tool_router_key, - response: None, - }); + } + } } } } @@ -349,7 +396,8 @@ async fn handle_non_streaming_response( api_key: String, inbox_name: Option, llm_stopper: Arc, - tools: Option>, // Add tools parameter + ws_manager_trait: Option>>, + tools: Option>, ) -> Result { let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(500)); let response_fut = client @@ -418,38 +466,50 @@ async fn handle_non_streaming_response( .collect::>() .join(" "); - let function_call: Option = data.choices.iter().find_map(|choice| { - choice.message.function_call.clone().map(|fc| { - let arguments = serde_json::from_str::(&fc.arguments) - .ok() - .and_then(|args_value: serde_json::Value| args_value.as_object().cloned()) - .unwrap_or_else(|| serde_json::Map::new()); - - // Search for the tool_router_key in the tools array - let tool_router_key = tools.as_ref().and_then(|tools_array| { - tools_array.iter().find_map(|tool| { - if tool.get("name")?.as_str()? == fc.name { - tool.get("tool_router_key").and_then(|key| key.as_str().map(|s| s.to_string())) - } else { - None - } - }) - }); - - FunctionCall { - name: fc.name, - arguments, - tool_router_key, // Set the tool_router_key - response: None, + let function_calls: Vec = data.choices.iter().flat_map(|choice| { + let mut calls = Vec::new(); + + // Handle tool_calls + if let Some(tool_calls) = &choice.message.tool_calls { + for tool_call in tool_calls { + let arguments = serde_json::from_str::(&tool_call.function.arguments) + .ok() + .and_then(|args_value: serde_json::Value| args_value.as_object().cloned()) + .unwrap_or_else(|| serde_json::Map::new()); + + // Find matching tool and extract router key + let tool_router_key = tools.as_ref().and_then(|tools_array| { + tools_array.iter().find_map(|tool| { + if let Some(name) = tool.get("name").and_then(|n| n.as_str()) { + if name == tool_call.function.name { + tool.get("tool_router_key") + .and_then(|key| key.as_str()) + .map(|s| s.to_string()) + } else { + None + } + } else { + None + } + }) + }); + + calls.push(FunctionCall { + name: tool_call.function.name.clone(), + arguments, + tool_router_key, + response: None, + }); } - }) - }); - eprintln!("Function Call: {:?}", function_call); - eprintln!("Response String: {:?}", response_string); + } + + calls + }).collect(); + return Ok(LLMInferenceResponse::new( response_string, json!({}), - function_call.map_or_else(Vec::new, |fc| vec![fc]), + function_calls, None, )); } diff --git a/shinkai-bin/shinkai-node/src/llm_provider/providers/shared/groq_api.rs b/shinkai-bin/shinkai-node/src/llm_provider/providers/shared/groq_api.rs index ac7b08f96..a79f04dad 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/providers/shared/groq_api.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/providers/shared/groq_api.rs @@ -107,7 +107,7 @@ pub fn groq_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) -> Re Ok(PromptResult { messages: PromptResultEnum::Value(messages_json), - functions: None, + functions: result.functions, remaining_output_tokens: result.remaining_output_tokens, tokens_used: result.tokens_used, }) diff --git a/shinkai-bin/shinkai-node/src/llm_provider/providers/shared/openai_api.rs b/shinkai-bin/shinkai-node/src/llm_provider/providers/shared/openai_api.rs index d2344ecf3..0bfb251d9 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/providers/shared/openai_api.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/providers/shared/openai_api.rs @@ -17,12 +17,39 @@ pub struct OpenAIResponse { created: u64, pub choices: Vec, usage: Usage, + system_fingerprint: Option, + #[serde(rename = "x_groq", default)] + groq: Option, } #[derive(Debug, Deserialize)] pub struct Choice { pub index: i32, pub message: OpenAIApiMessage, + #[serde(rename = "finish_reason")] + pub finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +pub struct GroqInfo { + id: String, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub call_type: String, + pub function: FunctionCall, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct OpenAIApiMessage { + pub role: String, + pub content: Option, + pub function_call: Option, + #[serde(default)] + pub tool_calls: Option>, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -45,13 +72,6 @@ pub enum MessageContent { ImageUrl { url: String }, } -#[derive(Debug, Clone, Deserialize)] -pub struct OpenAIApiMessage { - pub role: String, - pub content: Option, - pub function_call: Option, -} - impl Serialize for OpenAIApiMessage { fn serialize(&self, serializer: S) -> Result where @@ -74,6 +94,14 @@ pub struct Usage { prompt_tokens: i32, completion_tokens: i32, total_tokens: i32, + #[serde(default)] + queue_time: Option, + #[serde(default)] + prompt_time: Option, + #[serde(default)] + completion_time: Option, + #[serde(default)] + total_time: Option, } pub fn openai_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) -> Result { @@ -367,4 +395,88 @@ mod tests { panic!("Expected text content"); } } + + #[test] + fn test_groq_response_with_tool_calls() { + let response_text = r#"{ + "id": "chatcmpl-0cae310a-2b36-470a-9261-0f24d77b01bc", + "object": "chat.completion", + "created": 1736736692, + "model": "llama-3.2-11b-vision-preview", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "tool_calls": [ + { + "id": "call_sa3n", + "type": "function", + "function": { + "name": "duckduckgo_search", + "arguments": "{\"message\": \"best movie 2024\"}" + } + } + ] + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "queue_time": 0.018144843999999993, + "prompt_tokens": 1185, + "prompt_time": 0.077966956, + "completion_tokens": 21, + "completion_time": 0.028, + "total_tokens": 1206, + "total_time": 0.105966956 + }, + "system_fingerprint": "fp_9cb648b966", + "x_groq": { + "id": "req_01jhes5nvkedsb8hcw0x912fa6" + } + }"#; + + let response: OpenAIResponse = serde_json::from_str(response_text).expect("Failed to deserialize"); + + // Verify basic response fields + assert_eq!(response.id, "chatcmpl-0cae310a-2b36-470a-9261-0f24d77b01bc"); + assert_eq!(response.object, "chat.completion"); + assert_eq!(response.created, 1736736692); + assert_eq!(response.system_fingerprint, Some("fp_9cb648b966".to_string())); + + // Verify choices + assert_eq!(response.choices.len(), 1); + let choice = &response.choices[0]; + assert_eq!(choice.index, 0); + assert_eq!(choice.finish_reason, Some("tool_calls".to_string())); + + // Verify tool calls + let message = &choice.message; + assert_eq!(message.role, "assistant"); + assert!(message.content.is_none()); + + let tool_calls = message.tool_calls.as_ref().expect("Should have tool_calls"); + assert_eq!(tool_calls.len(), 1); + + let tool_call = &tool_calls[0]; + assert_eq!(tool_call.id, "call_sa3n"); + assert_eq!(tool_call.call_type, "function"); + assert_eq!(tool_call.function.name, "duckduckgo_search"); + assert_eq!(tool_call.function.arguments, "{\"message\": \"best movie 2024\"}"); + + // Verify usage + assert_eq!(response.usage.prompt_tokens, 1185); + assert_eq!(response.usage.completion_tokens, 21); + assert_eq!(response.usage.total_tokens, 1206); + assert!(response.usage.queue_time.is_some()); + assert!(response.usage.prompt_time.is_some()); + assert!(response.usage.completion_time.is_some()); + assert!(response.usage.total_time.is_some()); + + // Verify Groq info + let groq = response.groq.expect("Should have Groq info"); + assert_eq!(groq.id, "req_01jhes5nvkedsb8hcw0x912fa6"); + } } diff --git a/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs b/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs index 2aa083c07..72c5b985d 100644 --- a/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs +++ b/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs @@ -457,6 +457,23 @@ impl ModelCapabilitiesManager { model_type if model_type.starts_with("llama3.1") => 128_000, model_type if model_type.starts_with("llama3") || model_type.starts_with("llava-llama3") => 8_000, model_type if model_type.starts_with("claude") => 200_000, + model_type if model_type.starts_with("llama-3.3-70b-versatile") => 128_000, + model_type if model_type.starts_with("llama-3.1-8b-instant") => 128_000, + model_type if model_type.starts_with("llama-guard-3-8b") => 8_192, + model_type if model_type.starts_with("llama3-70b-8192") => 8_192, + model_type if model_type.starts_with("llama3-8b-8192") => 8_192, + model_type if model_type.starts_with("mixtral-8x7b-32768") => 32_768, + model_type if model_type.starts_with("gemma2-9b-it") => 8_192, + model_type if model_type.starts_with("llama-3.3-70b-specdec") => 8_192, + model_type if model_type.starts_with("llama-3.2-1b-preview") => 128_000, + model_type if model_type.starts_with("llama-3.2-3b-preview") => 128_000, + model_type if model_type.starts_with("llama-3.2-11b-vision-preview") => 128_000, + model_type if model_type.starts_with("llama-3.2-90b-vision-preview") => 128_000, + model_type if model_type.starts_with("llama-3.2") => 128_000, + model_type if model_type.starts_with("llama3.3") => 128_000, + model_type if model_type.starts_with("llama3.4") => 128_000, + model_type if model_type.starts_with("llama-3.1") => 128_000, + model_type if model_type.starts_with("llama3.1") => 128_000, _ => 4096, // Default token count if no specific model type matches } } @@ -637,7 +654,19 @@ impl ModelCapabilitiesManager { || model.model_type.starts_with("qwq") } LLMProviderInterface::Groq(model) => { - model.model_type.starts_with("llama-3.2") + model.model_type.starts_with("llama-3.3-70b-versatile") + || model.model_type.starts_with("llama-3.1-8b-instant") + || model.model_type.starts_with("llama-guard-3-8b") + || model.model_type.starts_with("llama3-70b-8192") + || model.model_type.starts_with("llama3-8b-8192") + || model.model_type.starts_with("mixtral-8x7b-32768") + || model.model_type.starts_with("gemma2-9b-it") + || model.model_type.starts_with("llama-3.3-70b-specdec") + || model.model_type.starts_with("llama-3.2-1b-preview") + || model.model_type.starts_with("llama-3.2-3b-preview") + || model.model_type.starts_with("llama-3.2-11b-vision-preview") + || model.model_type.starts_with("llama-3.2-90b-vision-preview") + || model.model_type.starts_with("llama-3.2") || model.model_type.starts_with("llama3.2") || model.model_type.starts_with("llama-3.1") || model.model_type.starts_with("llama3.1") diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs index 330af2503..56950c2f1 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs @@ -140,6 +140,7 @@ pub fn tool_routes( .and(warp::body::json()) .and_then(tool_implementation_code_update_handler); + // Resolves shinkai://file URLs to actual file bytes, providing secure access to files in the node's storage let resolve_shinkai_file_protocol_route = warp::path("resolve_shinkai_file_protocol") .and(warp::get()) .and(with_sender(node_commands_sender.clone())) @@ -1028,8 +1029,6 @@ pub async fn code_execution_handler( ) -> Result { let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); - eprintln!("payload: {:?}", payload); - // Convert parameters to a Map if it isn't already let parameters = match payload.parameters { Value::Object(map) => map,