Skip to content

Commit

Permalink
Merge pull request #780 from dcSpark/feature/vf-mounts-prompts
Browse files Browse the repository at this point in the history
Feature/vf mounts prompts
  • Loading branch information
acedward authored Jan 15, 2025
2 parents 96942ea + 58efcf6 commit 25b0472
Show file tree
Hide file tree
Showing 13 changed files with 398 additions and 897 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use crate::network::agent_payments_manager::my_agent_offerings_manager::MyAgentO
use crate::utils::environment::{fetch_node_environment, NodeEnvironment};
use async_trait::async_trait;
use shinkai_embedding::embedding_generator::RemoteEmbeddingGenerator;
use shinkai_fs::shinkai_file_manager::ShinkaiFileManager;
use shinkai_fs::shinkai_fs_error::ShinkaiFsError;
use shinkai_message_primitives::schemas::inbox_name::InboxName;
use shinkai_message_primitives::schemas::job::{Job, JobLike};
use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent;
Expand Down Expand Up @@ -175,8 +177,8 @@ impl GenericInferenceChain {
|| !job_filenames.is_empty()
{
let ret = JobManager::search_for_chunks_in_resources(
merged_fs_files_paths,
merged_fs_folder_paths,
merged_fs_files_paths.clone(),
merged_fs_folder_paths.clone(),
job_filenames.clone(),
full_job.job_id.clone(),
full_job.scope(),
Expand Down Expand Up @@ -344,7 +346,16 @@ impl GenericInferenceChain {
}
});

let additional_files = Self::get_additional_files(
&db,
&full_job,
job_filenames.clone(),
merged_fs_files_paths.clone(),
merged_fs_folder_paths.clone(),
)?;

let mut filled_prompt = JobPromptGenerator::generic_inference_prompt(
db.clone(),
custom_system_prompt.clone(),
custom_prompt.clone(),
user_message.clone(),
Expand All @@ -355,7 +366,9 @@ impl GenericInferenceChain {
tools.clone(),
None,
full_job.job_id.clone(),
additional_files.clone(),
node_env.clone(),
db.clone(),
);

let mut iteration_count = 0;
Expand Down Expand Up @@ -461,6 +474,7 @@ impl GenericInferenceChain {

// Update prompt with error information for retry
filled_prompt = JobPromptGenerator::generic_inference_prompt(
db.clone(),
custom_system_prompt.clone(),
custom_prompt.clone(),
user_message.clone(),
Expand All @@ -474,7 +488,9 @@ impl GenericInferenceChain {
response: error_msg.clone(),
}),
full_job.job_id.clone(),
additional_files.clone(),
node_env.clone(),
db.clone(),
);

// Set flag to retry and break out of the function calls loop
Expand Down Expand Up @@ -517,13 +533,22 @@ impl GenericInferenceChain {
last_function_response = Some(function_response);
}

let additional_files = Self::get_additional_files(
&db,
&full_job,
job_filenames.clone(),
merged_fs_files_paths.clone(),
merged_fs_folder_paths.clone(),
)?;

// If we need to retry, continue the outer loop
if should_retry {
continue;
}

// 7) Call LLM again with the response (for formatting)
filled_prompt = JobPromptGenerator::generic_inference_prompt(
db.clone(),
custom_system_prompt.clone(),
custom_prompt.clone(),
user_message.clone(),
Expand All @@ -534,7 +559,9 @@ impl GenericInferenceChain {
tools.clone(),
last_function_response,
full_job.job_id.clone(),
additional_files,
node_env.clone(),
db.clone(),
);
} else {
// No more function calls required, return the final response
Expand Down Expand Up @@ -613,4 +640,32 @@ impl GenericInferenceChain {
}
}
}

pub fn get_additional_files(
db: &SqliteManager,
full_job: &Job,
job_filenames: Vec<String>,
merged_fs_files_paths: Vec<ShinkaiPath>,
merged_fs_folder_paths: Vec<ShinkaiPath>,
) -> Result<Vec<String>, ShinkaiFsError> {
let mut additional_files: Vec<String> = vec![];
// Get agent/context files
let f = ShinkaiFileManager::get_absolute_path_for_additional_files(
merged_fs_files_paths.clone(),
merged_fs_folder_paths.clone(),
)?;
additional_files.extend(f);

// Get Job files
let folder_path: Result<ShinkaiPath, shinkai_sqlite::errors::SqliteManagerError> =
db.get_job_folder_name(&full_job.job_id.clone());

if let Ok(folder_path) = folder_path {
additional_files.extend(ShinkaiFileManager::get_absolute_paths_with_folder(
job_filenames.clone(),
folder_path.path.clone(),
));
}
Ok(additional_files)
}
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
use serde_json::json;
use std::collections::HashMap;
use shinkai_fs::shinkai_file_manager::ShinkaiFileManager;
use shinkai_sqlite::SqliteManager;
use std::{collections::HashMap, fs};

use crate::llm_provider::execution::prompts::general_prompts::JobPromptGenerator;
use crate::managers::tool_router::ToolCallFunctionResponse;

use crate::network::v2_api::api_v2_commands_app_files::get_app_folder_path;
use crate::network::Node;
use crate::tools::tool_implementation::native_tools::sql_processor::get_current_tables;
use crate::utils::environment::NodeEnvironment;
use shinkai_message_primitives::schemas::prompts::Prompt;
use shinkai_message_primitives::schemas::shinkai_fs::ShinkaiFileChunkCollection;
use shinkai_message_primitives::schemas::subprompts::SubPromptType;
use shinkai_message_primitives::shinkai_message::shinkai_message::ShinkaiMessage;
use shinkai_message_primitives::{schemas::prompts::Prompt, shinkai_utils::job_scope::MinimalJobScope};
use shinkai_tools_primitives::tools::shinkai_tool::ShinkaiTool;
use std::sync::mpsc;
use std::sync::{mpsc, Arc};
use tokio::runtime::Runtime;

impl JobPromptGenerator {
/// A basic generic prompt generator
/// summary_text is the content generated by an LLM on parsing (if exist)
#[allow(clippy::too_many_arguments)]
pub fn generic_inference_prompt(
db: Arc<SqliteManager>,
custom_system_prompt: Option<String>,
custom_user_prompt: Option<String>,
user_message: String,
Expand All @@ -31,7 +32,9 @@ impl JobPromptGenerator {
tools: Vec<ShinkaiTool>,
function_call: Option<ToolCallFunctionResponse>,
job_id: String,
node_env: NodeEnvironment,
additional_files: Vec<String>,
_node_env: NodeEnvironment,
_db: Arc<SqliteManager>,
) -> Prompt {
let mut prompt = Prompt::new();

Expand Down Expand Up @@ -69,14 +72,16 @@ impl JobPromptGenerator {
// Wait for the result
let current_tables = rx.recv().unwrap();
if let Ok(current_tables) = current_tables {
prompt.add_content(
format!(
"<current_tables>\n{}\n</current_tables>\n",
current_tables.join("; \n")
),
SubPromptType::ExtraContext,
97,
);
if !current_tables.is_empty() {
prompt.add_content(
format!(
"<current_tables>\n{}\n</current_tables>\n",
current_tables.join("; \n")
),
SubPromptType::ExtraContext,
97,
);
}
}
}
}
Expand All @@ -89,16 +94,21 @@ impl JobPromptGenerator {
priority = priority.saturating_sub(1);
}
}
let folder = get_app_folder_path(node_env, job_id.clone());
let current_files = Node::v2_api_list_app_files_internal(folder.clone(), true);
if let Ok(current_files) = current_files {
if !current_files.is_empty() {
prompt.add_content(
format!("<current_files>\n{}\n</current_files>\n", current_files.join("\n")),
SubPromptType::ExtraContext,
97,
);
}
let mut all_files = vec![];
// Add job scope files
let job_scope = ShinkaiFileManager::get_absolute_path_for_job_scope(&db, &job_id);
if let Ok(job_scope) = job_scope {
all_files.extend(job_scope);
}
// Add fs files and Agent files
all_files.extend(additional_files);

if !all_files.is_empty() {
prompt.add_content(
format!("<current_files>\n{}\n</current_files>\n", all_files.join("\n")),
SubPromptType::ExtraContext,
97,
);
}
}

Expand All @@ -108,7 +118,7 @@ impl JobPromptGenerator {
if has_ret_nodes && !user_message.is_empty() {
prompt.add_content("--- start --- \n".to_string(), SubPromptType::ExtraContext, 97);
}

prompt.add_ret_node_content(ret_nodes, SubPromptType::ExtraContext, 96);

if has_ret_nodes && !user_message.is_empty() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use serde_json::Value as JsonValue;
use shinkai_embedding::embedding_generator::RemoteEmbeddingGenerator;
use shinkai_message_primitives::schemas::job::Job;
use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent;
use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::SerializedLLMProvider;
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
use shinkai_message_primitives::schemas::ws_types::WSUpdateHandler;
use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::FunctionCallMetadata;
Expand Down Expand Up @@ -79,7 +80,9 @@ pub trait InferenceChainContextTrait: Send + Sync {
fn job_callback_manager(&self) -> Option<Arc<Mutex<JobCallbackManager>>>;
// fn sqlite_logger(&self) -> Option<Arc<SqliteLogger>>;
fn llm_stopper(&self) -> Arc<LLMStopper>;

fn fs_files_paths(&self) -> &Vec<ShinkaiPath>;
fn llm_provider(&self) -> &ProviderOrAgent;
fn job_filenames(&self) -> &Vec<String>;
fn clone_box(&self) -> Box<dyn InferenceChainContextTrait>;
}

Expand Down Expand Up @@ -109,7 +112,7 @@ impl InferenceChainContextTrait for InferenceChainContext {
fn db(&self) -> Arc<SqliteManager> {
Arc::clone(&self.db)
}

fn full_job(&self) -> &Job {
&self.full_job
}
Expand Down Expand Up @@ -189,6 +192,18 @@ impl InferenceChainContextTrait for InferenceChainContext {
fn clone_box(&self) -> Box<dyn InferenceChainContextTrait> {
Box::new(self.clone())
}

fn fs_files_paths(&self) -> &Vec<ShinkaiPath> {
&self.fs_files_paths
}

fn llm_provider(&self) -> &ProviderOrAgent {
&self.llm_provider
}

fn job_filenames(&self) -> &Vec<String> {
&self.job_filenames
}
}

/// Struct that represents the generalized context available to all chains as input. Note not all chains require
Expand All @@ -200,7 +215,7 @@ pub struct InferenceChainContext {
pub user_message: ParsedUserMessage,
pub user_tool_selected: Option<String>,
pub fs_files_paths: Vec<ShinkaiPath>,
pub job_filenames: Vec<String>,
pub job_filenames: Vec<String>,
pub message_hash_id: Option<String>,
pub image_files: HashMap<String, String>,
pub llm_provider: ProviderOrAgent,
Expand Down Expand Up @@ -229,7 +244,7 @@ impl InferenceChainContext {
user_message: ParsedUserMessage,
user_tool_selected: Option<String>,
fs_files_paths: Vec<ShinkaiPath>,
job_filenames: Vec<String>,
job_filenames: Vec<String>,
message_hash_id: Option<String>,
image_files: HashMap<String, String>,
llm_provider: ProviderOrAgent,
Expand Down Expand Up @@ -496,6 +511,18 @@ impl InferenceChainContextTrait for Box<dyn InferenceChainContextTrait> {
fn clone_box(&self) -> Box<dyn InferenceChainContextTrait> {
(**self).clone_box()
}

fn fs_files_paths(&self) -> &Vec<ShinkaiPath> {
(**self).fs_files_paths()
}

fn llm_provider(&self) -> &ProviderOrAgent {
(**self).llm_provider()
}

fn job_filenames(&self) -> &Vec<String> {
(**self).job_filenames()
}
}

/// A Mock implementation of the InferenceChainContextTrait for testing purposes.
Expand All @@ -511,6 +538,9 @@ pub struct MockInferenceChainContext {
pub my_agent_payments_manager: Option<Arc<Mutex<MyAgentOfferingsManager>>>,
pub ext_agent_payments_manager: Option<Arc<Mutex<ExtAgentOfferingsManager>>>,
pub llm_stopper: Arc<LLMStopper>,
pub fs_files_paths: Vec<ShinkaiPath>,
pub job_filenames: Vec<String>,
pub llm_provider: ProviderOrAgent,
}

impl MockInferenceChainContext {
Expand All @@ -527,6 +557,9 @@ impl MockInferenceChainContext {
my_agent_payments_manager: Option<Arc<Mutex<MyAgentOfferingsManager>>>,
ext_agent_payments_manager: Option<Arc<Mutex<ExtAgentOfferingsManager>>>,
llm_stopper: Arc<LLMStopper>,
fs_files_paths: Vec<ShinkaiPath>,
job_filenames: Vec<String>,
llm_provider: ProviderOrAgent,
) -> Self {
Self {
user_message,
Expand All @@ -540,6 +573,9 @@ impl MockInferenceChainContext {
my_agent_payments_manager,
ext_agent_payments_manager,
llm_stopper,
fs_files_paths,
job_filenames,
llm_provider,
}
}
}
Expand All @@ -563,6 +599,9 @@ impl Default for MockInferenceChainContext {
my_agent_payments_manager: None,
ext_agent_payments_manager: None,
llm_stopper: Arc::new(LLMStopper::new()),
fs_files_paths: vec![],
job_filenames: vec![],
llm_provider: ProviderOrAgent::LLMProvider(SerializedLLMProvider::mock_provider()),
}
}
}
Expand Down Expand Up @@ -667,6 +706,18 @@ impl InferenceChainContextTrait for MockInferenceChainContext {
fn clone_box(&self) -> Box<dyn InferenceChainContextTrait> {
Box::new(self.clone())
}

fn fs_files_paths(&self) -> &Vec<ShinkaiPath> {
&self.fs_files_paths
}

fn llm_provider(&self) -> &ProviderOrAgent {
&self.llm_provider
}

fn job_filenames(&self) -> &Vec<String> {
&self.job_filenames
}
}

impl Clone for MockInferenceChainContext {
Expand All @@ -683,6 +734,9 @@ impl Clone for MockInferenceChainContext {
my_agent_payments_manager: self.my_agent_payments_manager.clone(),
ext_agent_payments_manager: self.ext_agent_payments_manager.clone(),
llm_stopper: self.llm_stopper.clone(),
fs_files_paths: self.fs_files_paths.clone(),
job_filenames: self.job_filenames.clone(),
llm_provider: self.llm_provider.clone(),
}
}
}
Loading

0 comments on commit 25b0472

Please sign in to comment.