Skip to content

Commit

Permalink
Merge branch 'feature/tools-def-exec' into agallardol/fix-missing-v2-…
Browse files Browse the repository at this point in the history
…search-tool
  • Loading branch information
agallardol committed Nov 13, 2024
2 parents dacc102 + 6a382f5 commit 74fc7b3
Show file tree
Hide file tree
Showing 27 changed files with 520 additions and 485 deletions.
26 changes: 11 additions & 15 deletions shinkai-bin/shinkai-node/src/managers/tool_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use shinkai_message_primitives::schemas::shinkai_tool_offering::{
use shinkai_message_primitives::schemas::wallet_mixed::{Asset, NetworkIdentifier};
use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::WSTopic;
use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption};
use shinkai_tools_primitives::tools::argument::ToolArgument;
use shinkai_tools_primitives::tools::argument::{ToolArgument, ToolOutputArg};
use shinkai_tools_primitives::tools::error::ToolError;
use shinkai_tools_primitives::tools::js_toolkit::JSToolkit;
use shinkai_tools_primitives::tools::network_tool::NetworkTool;
Expand Down Expand Up @@ -172,7 +172,7 @@ impl ToolRouter {

let toolkit = JSToolkit::new(&name, vec![definition.clone()]);
for tool in toolkit.tools {
let shinkai_tool = ShinkaiTool::JS(tool.clone(), true);
let shinkai_tool = ShinkaiTool::Deno(tool.clone(), true);
let lance_db = self.lance_db.write().await;
lance_db.set_tool(&shinkai_tool).await?;
}
Expand Down Expand Up @@ -206,6 +206,7 @@ impl ToolRouter {
description: "".to_string(),
is_required: true,
}],
output_arg: ToolOutputArg { json: "".to_string() },
embedding: None,
restrictions: None,
};
Expand All @@ -232,6 +233,7 @@ impl ToolRouter {
description: "The URL of the YouTube video".to_string(),
is_required: true,
}],
output_arg: ToolOutputArg { json: "".to_string() },
embedding: None,
restrictions: None,
};
Expand All @@ -245,9 +247,9 @@ impl ToolRouter {
if std::env::var("ADD_TESTING_NETWORK_ECHO").unwrap_or_else(|_| "false".to_string()) == "true" {
let lance_db = self.lance_db.write().await;
if let Some(shinkai_tool) = lance_db.get_tool("local:::shinkai-tool-echo:::shinkai__echo").await? {
if let ShinkaiTool::JS(mut js_tool, _) = shinkai_tool {
if let ShinkaiTool::Deno(mut js_tool, _) = shinkai_tool {
js_tool.name = "network__echo".to_string();
let modified_tool = ShinkaiTool::JS(js_tool, true);
let modified_tool = ShinkaiTool::Deno(js_tool, true);
lance_db.set_tool(&modified_tool).await?;
}
}
Expand All @@ -256,9 +258,9 @@ impl ToolRouter {
.get_tool("local:::shinkai-tool-youtube-transcript:::shinkai__youtube_transcript")
.await?
{
if let ShinkaiTool::JS(mut js_tool, _) = shinkai_tool {
if let ShinkaiTool::Deno(mut js_tool, _) = shinkai_tool {
js_tool.name = "youtube_transcript_with_timestamps".to_string();
let modified_tool = ShinkaiTool::JS(js_tool, true);
let modified_tool = ShinkaiTool::Deno(js_tool, true);
lance_db.set_tool(&modified_tool).await?;
}
}
Expand Down Expand Up @@ -378,12 +380,6 @@ impl ToolRouter {
let function_args = function_call.arguments.clone();

match shinkai_tool {
ShinkaiTool::Deno(_, _) => {
return Ok(ToolCallFunctionResponse {
response: "Deno!".to_string(),
function_call,
});
}
ShinkaiTool::Python(_, _) => {
return Ok(ToolCallFunctionResponse {
response: "Deno!".to_string(),
Expand Down Expand Up @@ -414,9 +410,9 @@ impl ToolRouter {
// });
// }
}
ShinkaiTool::JS(js_tool, _) => {
ShinkaiTool::Deno(deno_tool, _) => {
let function_config = shinkai_tool.get_config_from_env();
let result = js_tool
let result = deno_tool
.run(function_args, function_config)
.map_err(|e| LLMProviderError::FunctionExecutionError(e.to_string()))?;
let result_str = serde_json::to_string(&result)
Expand Down Expand Up @@ -714,7 +710,7 @@ impl ToolRouter {
let function_config = shinkai_tool.get_config_from_env();

let js_tool = match shinkai_tool {
ShinkaiTool::JS(js_tool, _) => js_tool,
ShinkaiTool::Deno(js_tool, _) => js_tool,
_ => return Err(LLMProviderError::FunctionNotFound(js_tool_name.to_string())),
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use ed25519_dalek::SigningKey;
use futures::Future;
use shinkai_db::db::{ShinkaiDB, Topic};
use shinkai_job_queue_manager::job_queue_manager::JobQueueManager;
use shinkai_message_primitives::schemas::invoices::{Invoice, InvoiceError, InvoiceRequest, InvoiceRequestNetworkError, InvoiceStatusEnum};
use shinkai_message_primitives::schemas::invoices::{
Invoice, InvoiceError, InvoiceRequest, InvoiceRequestNetworkError, InvoiceStatusEnum,
};
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
use shinkai_message_primitives::schemas::shinkai_tool_offering::{ShinkaiToolOffering, UsageType, UsageTypeInquiry};
use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::MessageSchemaType;
Expand Down Expand Up @@ -1064,7 +1066,7 @@ mod tests {
for (name, definition) in tools {
let toolkit = JSToolkit::new(&name, vec![definition.clone()]);
for tool in toolkit.tools {
let mut shinkai_tool = ShinkaiTool::JS(tool.clone(), true);
let mut shinkai_tool = ShinkaiTool::Deno(tool.clone(), true);
eprintln!("shinkai_tool name: {:?}", shinkai_tool.name());
let embedding = generator
.generate_embedding_default(&shinkai_tool.format_embedding_string())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use shinkai_message_primitives::{
signatures::clone_signature_secret_key,
},
};
use shinkai_tools_primitives::tools::{network_tool::NetworkTool, shinkai_tool::ShinkaiToolHeader};
use shinkai_tools_primitives::tools::{
argument::ToolOutputArg, network_tool::NetworkTool, shinkai_tool::ShinkaiToolHeader,
};
use shinkai_vector_fs::vector_fs::vector_fs::VectorFS;
use tokio::sync::Mutex;
use x25519_dalek::StaticSecret as EncryptionStaticKey;
Expand Down Expand Up @@ -570,6 +572,7 @@ impl MyAgentOfferingsManager {
true, // Assuming the tool is activated by default
tool_header.config.expect("Config is required"),
vec![], // TODO: Fix input_args
ToolOutputArg { json: "".to_string() },
None,
None,
);
Expand Down
15 changes: 12 additions & 3 deletions shinkai-bin/shinkai-node/src/network/handle_commands_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2930,9 +2930,12 @@ impl Node {
} => {
let db_clone = Arc::clone(&self.db);
let lance_db_clone = self.lance_db.clone();
let vector_fs_clone = self.vector_fs.clone();
let identity_manager_clone = self.identity_manager.clone();
let sheet_manager_clone = self.sheet_manager.clone();
let job_manager = self.job_manager.clone().unwrap();
let node_name = self.node_name.clone();
let identity_manager = self.identity_manager.clone();
let encryption_secret_key = self.encryption_secret_key.clone();
let encryption_public_key = self.encryption_public_key;
let signing_secret_key = self.identity_secret_key.clone();

tokio::spawn(async move {
let _ = Node::execute_command(
Expand All @@ -2942,6 +2945,12 @@ impl Node {
tool_router_key,
tool_type,
parameters,
node_name,
identity_manager,
job_manager,
encryption_secret_key,
encryption_public_key,
signing_secret_key,
res,
)
.await;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ use shinkai_message_primitives::{
shinkai_message::{MessageBody, MessageData, ShinkaiMessage},
shinkai_message_schemas::{
APIAddAgentRequest, APIAddOllamaModels, APIChangeJobAgentRequest, APIGetMessagesFromInboxRequest,
APIReadUpToTimeRequest, IdentityPermissions, MessageSchemaType,
RegistrationCodeRequest, RegistrationCodeType,
APIReadUpToTimeRequest, IdentityPermissions, MessageSchemaType, RegistrationCodeRequest,
RegistrationCodeType,
},
},
shinkai_utils::{
Expand Down Expand Up @@ -1675,7 +1675,7 @@ impl Node {
// Add the toolkit using LanceShinkaiDb
let lance_db = lance_db.write().await;
for tool in toolkit.tools {
let shinkai_tool = ShinkaiTool::JS(tool.clone(), true);
let shinkai_tool = ShinkaiTool::Deno(tool.clone(), true);
if let Err(err) = lance_db.set_tool(&shinkai_tool).await {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ impl Node {
res: Sender<Result<Value, APIError>>,
) -> Result<(), NodeError> {
// Convert the String output to a Value
let definitions = generate_tool_definitions(language, lance_db).await;
let value = Value::String(definitions);

// Send the result
res.send(Ok(value)).await.map_err(|e| NodeError {
message: format!("Failed to send response: {}", e),
})?;

let definitions = generate_tool_definitions(language, lance_db, false).await;
match definitions {
Ok(definitions) => {
let _ = res.send(Ok(Value::String(definitions))).await;
}
Err(e) => {
let _ = res.send(Err(e)).await;
}
}
Ok(())
}

Expand All @@ -48,6 +49,12 @@ impl Node {
tool_router_key: String,
tool_type: ToolType,
parameters: Map<String, Value>,
node_name: ShinkaiName,
identity_manager: Arc<Mutex<IdentityManager>>,
job_manager: Arc<Mutex<JobManager>>,
encryption_secret_key: EncryptionStaticKey,
encryption_public_key: EncryptionPublicKey,
signing_secret_key: SigningKey,
res: Sender<Result<Value, APIError>>,
) -> Result<(), NodeError> {
// Execute the tool directly
Expand All @@ -59,26 +66,28 @@ impl Node {
db,
lance_db,
bearer,
node_name,
identity_manager,
job_manager,
encryption_secret_key,
encryption_public_key,
signing_secret_key,
)
.await;

match result {
Ok(result) => {
println!("[execute_command] Tool execution successful: {}", tool_router_key);
if let Err(e) = res.send(Ok(result)).await {
eprintln!("[execute_command] Failed to send success response: {}", e);
return Err(NodeError {
message: format!("Failed to send response: {}", e),
});
}
let _ = res.send(Ok(result)).await;
}
Err(e) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Error executing tool: {}", e),
};
let _ = res.send(Err(api_error)).await;
let _ = res
.send(Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Error executing tool: {}", e),
}))
.await;
}
}

Expand Down Expand Up @@ -131,16 +140,10 @@ impl Node {

match implementation {
Ok(implementation_) => {
// Send response
res.send(Ok(implementation_)).await.map_err(|e| NodeError {
message: format!("Failed to send response: {}", e),
})?;
let _ = res.send(Ok(implementation_)).await;
}
Err(e) => {
let api_error = APIError::from(e);
res.send(Err(api_error)).await.map_err(|e| NodeError {
message: format!("Failed to send response: {}", e),
})?;
let _ = res.send(Err(e)).await;
}
}

Expand Down Expand Up @@ -188,16 +191,10 @@ impl Node {

match metadata {
Ok(metadata_) => {
// Send response
res.send(Ok(metadata_)).await.map_err(|e| NodeError {
message: format!("Failed to send response: {}", e),
})?;
let _ = res.send(Ok(metadata_)).await;
}
Err(e) => {
let api_error = APIError::from(e);
res.send(Err(api_error)).await.map_err(|e| NodeError {
message: format!("Failed to send response: {}", e),
})?;
let _ = res.send(Err(e)).await;
}
}

Expand Down
Loading

0 comments on commit 74fc7b3

Please sign in to comment.