Skip to content

Commit

Permalink
Merge pull request #326 from dcSpark/nico/update_main_db
Browse files Browse the repository at this point in the history
update shinkaidb and how we handle threads for commands at the node lvl
  • Loading branch information
rinor authored Apr 12, 2024
2 parents dbde5ae + 441294f commit 6e40794
Show file tree
Hide file tree
Showing 49 changed files with 2,462 additions and 1,268 deletions.
11 changes: 10 additions & 1 deletion src/agent/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{db::db_errors::ShinkaiDBError, managers::model_capabilities_manager::ModelCapabilitiesManagerError};
use crate::{db::db_errors::ShinkaiDBError, managers::model_capabilities_manager::ModelCapabilitiesManagerError, vector_fs::vector_fs_error::VectorFSError};
use anyhow::Error as AnyhowError;
use shinkai_message_primitives::{
schemas::{inbox_name::InboxNameError, shinkai_name::ShinkaiNameError},
Expand Down Expand Up @@ -27,6 +27,7 @@ pub enum AgentError {
MessageTypeParseFailed,
IO(String),
ShinkaiDB(ShinkaiDBError),
VectorFS(VectorFSError),
ShinkaiNameError(ShinkaiNameError),
AgentNotFound,
ContentParseFailed,
Expand Down Expand Up @@ -95,6 +96,7 @@ impl fmt::Display for AgentError {
AgentError::MessageTypeParseFailed => write!(f, "Could not parse message type"),
AgentError::IO(err) => write!(f, "IO error: {}", err),
AgentError::ShinkaiDB(err) => write!(f, "Shinkai DB error: {}", err),
AgentError::VectorFS(err) => write!(f, "VectorFS error: {}", err),
AgentError::AgentNotFound => write!(f, "Agent not found"),
AgentError::ContentParseFailed => write!(f, "Failed to parse content"),
AgentError::ShinkaiNameError(err) => write!(f, "ShinkaiName error: {}", err),
Expand Down Expand Up @@ -158,6 +160,7 @@ impl AgentError {
AgentError::MessageTypeParseFailed => "MessageTypeParseFailed",
AgentError::IO(_) => "IO",
AgentError::ShinkaiDB(_) => "ShinkaiDB",
AgentError::VectorFS(_) => "VectorFS",
AgentError::ShinkaiNameError(_) => "ShinkaiNameError",
AgentError::AgentNotFound => "AgentNotFound",
AgentError::ContentParseFailed => "ContentParseFailed",
Expand Down Expand Up @@ -286,3 +289,9 @@ impl From<ModelCapabilitiesManagerError> for AgentError {
AgentError::AgentsCapabilitiesManagerError(error)
}
}

impl From<VectorFSError> for AgentError {
fn from(err: VectorFSError) -> AgentError {
AgentError::VectorFS(err)
}
}
2 changes: 1 addition & 1 deletion src/agent/execution/chains/cron_creation_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ impl JobManager {
/// in the JobScope to find relevant content for the LLM to use at each step.
#[async_recursion]
pub async fn start_cron_creation_chain(
db: Arc<Mutex<ShinkaiDB>>,
db: Arc<ShinkaiDB>,
full_job: Job,
job_task: String,
agent: SerializedAgent,
Expand Down
4 changes: 2 additions & 2 deletions src/agent/execution/chains/cron_execution_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct CronExecutionState {
impl JobManager {
#[async_recursion]
pub async fn start_cron_execution_chain_for_subtask(
db: Arc<Mutex<ShinkaiDB>>,
db: Arc<ShinkaiDB>,
full_job: Job,
agent: SerializedAgent,
execution_context: HashMap<String, String>,
Expand Down Expand Up @@ -53,7 +53,7 @@ impl JobManager {
/// in the JobScope to find relevant content for the LLM to use at each step.
#[async_recursion]
pub async fn start_cron_execution_chain_for_main_task(
db: Arc<Mutex<ShinkaiDB>>,
db: Arc<ShinkaiDB>,
full_job: Job,
agent: SerializedAgent,
execution_context: HashMap<String, String>,
Expand Down
2 changes: 1 addition & 1 deletion src/agent/execution/chains/image_analysis_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct CronExecutionState {
impl JobManager {
#[async_recursion]
pub async fn image_analysis_chain(
db: Arc<Mutex<ShinkaiDB>>,
db: Arc<ShinkaiDB>,
full_job: Job,
agent_found: Option<SerializedAgent>,
execution_context: HashMap<String, String>,
Expand Down
6 changes: 3 additions & 3 deletions src/agent/execution/chains/inference_chain_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl JobManager {
/// Returns the final String result from the inferencing, and a new execution context.
#[instrument(skip(generator, vector_fs, db))]
pub async fn inference_chain_router(
db: Arc<Mutex<ShinkaiDB>>,
db: Arc<ShinkaiDB>,
vector_fs: Arc<VectorFS>,
agent_found: Option<SerializedAgent>,
full_job: Job,
Expand Down Expand Up @@ -89,7 +89,7 @@ impl JobManager {
// Could it be based on the first message of the Job?
#[instrument(skip(db))]
pub async fn alt_inference_chain_router(
db: Arc<Mutex<ShinkaiDB>>,
db: Arc<ShinkaiDB>,
agent_found: Option<SerializedAgent>,
full_job: Job,
job_message: JobMessage,
Expand Down Expand Up @@ -145,7 +145,7 @@ impl JobManager {

#[instrument(skip(db, chosen_chain))]
pub async fn cron_inference_chain_router_summary(
db: Arc<Mutex<ShinkaiDB>>,
db: Arc<ShinkaiDB>,
agent_found: Option<SerializedAgent>,
full_job: Job,
task_description: String,
Expand Down
4 changes: 2 additions & 2 deletions src/agent/execution/chains/qa_inference_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl JobManager {
#[async_recursion]
#[instrument(skip(generator, vector_fs, db))]
pub async fn start_qa_inference_chain(
db: Arc<Mutex<ShinkaiDB>>,
db: Arc<ShinkaiDB>,
vector_fs: Arc<VectorFS>,
full_job: Job,
job_task: String,
Expand Down Expand Up @@ -221,7 +221,7 @@ impl JobManager {

async fn no_json_object_retry_logic(
response: Result<JsonValue, AgentError>,
db: Arc<Mutex<ShinkaiDB>>,
db: Arc<ShinkaiDB>,
vector_fs: Arc<VectorFS>,
full_job: Job,
job_task: String,
Expand Down
52 changes: 26 additions & 26 deletions src/agent/execution/job_execution_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ impl JobManager {
#[instrument(skip(identity_secret_key, generator, unstructured_api, vector_fs, db))]
pub async fn process_job_message_queued(
job_message: JobForProcessing,
db: Weak<Mutex<ShinkaiDB>>,
db: Weak<ShinkaiDB>,
vector_fs: Weak<VectorFS>,
identity_secret_key: SigningKey,
generator: RemoteEmbeddingGenerator,
unstructured_api: UnstructuredAPI,
) -> Result<String, AgentError> {
let db = db.upgrade().ok_or("Failed to upgrade shinkai_db").unwrap();
let vector_fs = vector_fs.upgrade().ok_or("Failed to upgrade vector_db").unwrap();
let job_id = job_message.job_message.job_id.clone();
shinkai_log(
ShinkaiLogOption::JobExecution,
Expand All @@ -68,6 +69,7 @@ impl JobManager {
// If a .jobkai file is found, processing job message is taken over by this alternate logic
let jobkai_found_result = JobManager::should_process_job_files_for_tasks_take_over(
db.clone(),
vector_fs.clone(),
&job_message.job_message,
agent_found.clone(),
full_job.clone(),
Expand All @@ -88,6 +90,7 @@ impl JobManager {
// Processes any files which were sent with the job message
let process_files_result = JobManager::process_job_message_files_for_vector_resources(
db.clone(),
vector_fs.clone(),
&job_message.job_message,
agent_found.clone(),
&mut full_job,
Expand All @@ -101,7 +104,6 @@ impl JobManager {
return Self::handle_error(&db, Some(user_profile), &job_id, &identity_secret_key, e).await;
}

let vector_fs = vector_fs.upgrade().ok_or("Failed to upgrade vector_fs").unwrap();
let inference_chain_result = JobManager::process_inference_chain(
db.clone(),
vector_fs.clone(),
Expand All @@ -124,7 +126,7 @@ impl JobManager {

/// Handle errors by sending an error message to the job inbox
async fn handle_error(
db: &Arc<Mutex<ShinkaiDB>>,
db: &Arc<ShinkaiDB>,
user_profile: Option<ShinkaiName>,
job_id: &str,
identity_secret_key: &SigningKey,
Expand Down Expand Up @@ -153,9 +155,7 @@ impl JobManager {
)
.expect("Failed to build error message");

let mut shinkai_db = db.lock().await;
shinkai_db
.add_message_to_job_inbox(job_id, &shinkai_message, None)
db.add_message_to_job_inbox(job_id, &shinkai_message, None)
.await
.expect("Failed to add error message to job inbox");

Expand All @@ -166,7 +166,7 @@ impl JobManager {
/// and then parses + saves the output result to the DB.
#[instrument(skip(identity_secret_key, db, vector_fs, generator))]
pub async fn process_inference_chain(
db: Arc<Mutex<ShinkaiDB>>,
db: Arc<ShinkaiDB>,
vector_fs: Arc<VectorFS>,
identity_secret_key: SigningKey,
job_message: JobMessage,
Expand Down Expand Up @@ -230,24 +230,23 @@ impl JobManager {
);

// Save response data to DB
let mut shinkai_db = db.lock().await;
shinkai_db.add_step_history(
db.add_step_history(
job_message.job_id.clone(),
job_message.content,
inference_response_content.to_string(),
None,
)?;
shinkai_db
.add_message_to_job_inbox(&job_message.job_id.clone(), &shinkai_message, None)
db.add_message_to_job_inbox(&job_message.job_id.clone(), &shinkai_message, None)
.await?;
shinkai_db.set_job_execution_context(job_message.job_id.clone(), new_execution_context, None)?;
db.set_job_execution_context(job_message.job_id.clone(), new_execution_context, None)?;

Ok(())
}

/// Temporary function to process the files in the job message for tasks
pub async fn should_process_job_files_for_tasks_take_over(
db: Arc<Mutex<ShinkaiDB>>,
db: Arc<ShinkaiDB>,
vector_fs: Arc<VectorFS>,
job_message: &JobMessage,
agent_found: Option<SerializedAgent>,
full_job: Job,
Expand All @@ -268,12 +267,11 @@ impl JobManager {

// Get the files from the DB
let files = {
let shinkai_db = db.lock().await;
let files_result = shinkai_db.get_all_files_from_inbox(job_message.files_inbox.clone());
let files_result = vector_fs.db.get_all_files_from_inbox(job_message.files_inbox.clone());
// Check if there was an error getting the files
match files_result {
Ok(files) => files,
Err(e) => return Err(AgentError::ShinkaiDB(e)),
Err(e) => return Err(AgentError::VectorFS(e)),
}
};

Expand Down Expand Up @@ -321,6 +319,7 @@ impl JobManager {
// Handle CronJobRequest
JobManager::handle_cron_job_request(
db.clone(),
vector_fs.clone(),
agent_found.clone(),
full_job.clone(),
job_message.clone(),
Expand Down Expand Up @@ -415,7 +414,8 @@ impl JobManager {
/// Processes the files sent together with the current job_message into Vector Resources,
/// and saves them either into the local job scope, or the DB depending on `save_to_db_directly`.
pub async fn process_job_message_files_for_vector_resources(
db: Arc<Mutex<ShinkaiDB>>,
db: Arc<ShinkaiDB>,
vector_fs: Arc<VectorFS>,
job_message: &JobMessage,
agent_found: Option<SerializedAgent>,
full_job: &mut Job,
Expand All @@ -433,6 +433,7 @@ impl JobManager {
// TODO: later we should able to grab errors and return them to the user
let new_scope_entries_result = JobManager::process_files_inbox(
db.clone(),
vector_fs.clone(),
agent_found,
job_message.files_inbox.clone(),
profile,
Expand Down Expand Up @@ -503,8 +504,7 @@ impl JobManager {
}
}
}
let mut shinkai_db = db.lock().await;
shinkai_db.update_job_scope(full_job.job_id().to_string(), full_job.scope.clone())?;
db.update_job_scope(full_job.job_id().to_string(), full_job.scope.clone())?;
}
Err(e) => {
shinkai_log(
Expand All @@ -524,10 +524,11 @@ impl JobManager {
/// If save_to_vector_fs_folder == true, the files will save to the DB and be returned as `VectorFSScopeEntry`s.
/// Else, the files will be returned as LocalScopeEntries and thus held inside.
pub async fn process_files_inbox(
db: Arc<Mutex<ShinkaiDB>>,
_db: Arc<ShinkaiDB>,
vector_fs: Arc<VectorFS>,
agent: Option<SerializedAgent>,
files_inbox: String,
profile: ShinkaiName,
_profile: ShinkaiName,
save_to_vector_fs_folder: Option<VRPath>,
generator: RemoteEmbeddingGenerator,
unstructured_api: UnstructuredAPI,
Expand All @@ -537,12 +538,11 @@ impl JobManager {

// Get the files from the DB
let files = {
let shinkai_db = db.lock().await;
let files_result = shinkai_db.get_all_files_from_inbox(files_inbox.clone());
let files_result = vector_fs.db.get_all_files_from_inbox(files_inbox.clone());
// Check if there was an error getting the files
match files_result {
Ok(files) => files,
Err(e) => return Err(AgentError::ShinkaiDB(e)),
Err(e) => return Err(AgentError::VectorFS(e)),
}
};

Expand Down Expand Up @@ -577,7 +577,7 @@ impl JobManager {

files_map.insert(filename, ScopeEntry::VectorFSItem(fs_scope_entry));
} else {
let local_scope_entry = LocalScopeVRKaiEntry { vrkai: vrkai };
let local_scope_entry = LocalScopeVRKaiEntry { vrkai };
files_map.insert(filename, ScopeEntry::LocalScopeVRKai(local_scope_entry));
}
}
Expand All @@ -597,7 +597,7 @@ impl JobManager {

files_map.insert(filename, ScopeEntry::VectorFSFolder(fs_scope_entry));
} else {
let local_scope_entry = LocalScopeVRPackEntry { vrpack: vrpack };
let local_scope_entry = LocalScopeVRPackEntry { vrpack };
files_map.insert(filename, ScopeEntry::LocalScopeVRPack(local_scope_entry));
}
}
Expand Down
Loading

0 comments on commit 6e40794

Please sign in to comment.