Skip to content

Commit

Permalink
Merge pull request #544 from dcSpark/nico/node_bump_080
Browse files Browse the repository at this point in the history
Nico/node bump 080
  • Loading branch information
nicarq authored Sep 9, 2024
2 parents ecb9ffe + b6a1d49 commit 10ee938
Show file tree
Hide file tree
Showing 76 changed files with 2,407 additions and 4,002 deletions.
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion scripts/run_local_ai_with_proxy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export STATIC_SERVER_FOLDER="./static_server_example"
export INITIAL_AGENT_NAMES="my_gpt,llama3_8b"
export INITIAL_AGENT_URLS="https://api.openai.com,http://localhost:11434"
export INITIAL_AGENT_MODELS="openai:gpt-4o,ollama:llama3:8b-instruct-q4_1"
export RPC_URL="https://public.stackup.sh/api/v1/node/arbitrum-sepolia"
export RPC_URL="https://arbitrum-sepolia.blockpi.network/v1/rpc/public"
export CONTRACT_ADDRESS="0x1d2D57F78Bc3B878aF68c411a03AcF327c85e0D6"
export SUBSCRIPTION_HTTP_UPLOAD_INTERVAL_MINUTES="1"
export SUBSCRIPTION_UPDATE_CACHE_INTERVAL_MINUTES="1"
Expand Down
6 changes: 3 additions & 3 deletions shinkai-bin/shinkai-node/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "shinkai_node"
version = "0.7.34"
version = "0.8.0"
edition = "2021"
authors.workspace = true
# this causes `cargo run` in the workspace root to run this package
Expand All @@ -17,7 +17,7 @@ doctest = false

[build-dependencies]
reqwest = { version = "0.11.26", features = ["json", "tokio-native-tls", "blocking", "stream"] }
shinkai_tools_runner = { version = "0.7.10" } # change to a crate later on
shinkai_tools_runner = { version = "0.7.12" }

[dependencies]
async-trait = "0.1.74"
Expand Down Expand Up @@ -78,7 +78,7 @@ console-subscriber = { version = "0.1", optional = true }
rust_decimal = "1.17.0"
aws-types = "1.2.0"
aws-config = { version = "1.2.1", features = ["behavior-version-latest"] }
shinkai_tools_runner = { version = "0.7.10", features = ["built-in-tools"] } # change to a crate later on
shinkai_tools_runner = { version = "0.7.12", features = ["built-in-tools"] }
lancedb = "0.8.0"
arrow = "52.1"
arrow-array = "52.1"
Expand Down
37 changes: 36 additions & 1 deletion shinkai-bin/shinkai-node/src/db/db_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::time::Instant;
use super::{db_errors::ShinkaiDBError, db_main::Topic, ShinkaiDB};
use crate::llm_provider::execution::prompts::prompts::Prompt;
use crate::llm_provider::execution::prompts::subprompts::SubPromptType;
use crate::llm_provider::job::{Job, JobLike, JobStepResult};
use crate::llm_provider::job::{Job, JobConfig, JobLike, JobStepResult};
use crate::network::ws_manager::WSUpdateHandler;

use rocksdb::WriteBatch;
Expand All @@ -24,6 +24,7 @@ impl ShinkaiDB {
scope: JobScope,
is_hidden: bool,
associated_ui: Option<AssociatedUI>,
config: Option<JobConfig>,
) -> Result<(), ShinkaiDBError> {
let start = std::time::Instant::now();

Expand Down Expand Up @@ -98,6 +99,13 @@ impl ShinkaiDB {
batch.put_cf(cf_inbox, associated_ui_key.as_bytes(), &associated_ui_value);
}

// Serialize and put config if it exists
if let Some(cfg) = &config {
let config_key = format!("jobinbox_{}_config", job_id);
let config_value = serde_json::to_vec(cfg)?;
batch.put_cf(cf_inbox, config_key.as_bytes(), &config_value);
}

self.db.write(batch)?;

let batch_write_duration = batch_write_start.elapsed();
Expand All @@ -115,6 +123,20 @@ impl ShinkaiDB {
Ok(())
}

/// Updates the config of a job
pub fn update_job_config(&self, job_id: &str, config: JobConfig) -> Result<(), ShinkaiDBError> {
let cf_inbox = self.get_cf_handle(Topic::Inbox).unwrap();
let config_key = format!("jobinbox_{}_config", job_id);

// Serialize the config
let config_value = serde_json::to_vec(&config)?;

// Update the config in the database
self.db.put_cf(cf_inbox, config_key.as_bytes(), &config_value)?;

Ok(())
}

/// Changes the llm provider of a specific job
pub fn change_job_llm_provider(&self, job_id: &str, new_agent_id: &str) -> Result<(), ShinkaiDBError> {
let cf_inbox = self.get_cf_handle(Topic::Inbox).unwrap();
Expand Down Expand Up @@ -181,6 +203,7 @@ impl ShinkaiDB {
unprocessed_messages,
execution_context,
associated_ui,
config,
) = self.get_job_data(job_id, true)?;

// Construct the job
Expand All @@ -196,6 +219,7 @@ impl ShinkaiDB {
unprocessed_messages,
execution_context,
associated_ui,
config,
};

let duration = start.elapsed();
Expand Down Expand Up @@ -224,6 +248,7 @@ impl ShinkaiDB {
unprocessed_messages,
execution_context,
associated_ui,
config,
) = self.get_job_data(job_id, false)?;

// Construct the job
Expand All @@ -239,6 +264,7 @@ impl ShinkaiDB {
unprocessed_messages,
execution_context,
associated_ui,
config,
};

let duration = start.elapsed();
Expand Down Expand Up @@ -271,6 +297,7 @@ impl ShinkaiDB {
Vec<String>,
HashMap<String, String>,
Option<AssociatedUI>,
Option<JobConfig>,
),
ShinkaiDBError,
> {
Expand Down Expand Up @@ -329,6 +356,13 @@ impl ShinkaiDB {
.flatten()
.and_then(|value| serde_json::from_slice(&value).ok());

let config_value = self
.db
.get_cf(cf_jobs, format!("jobinbox_{}_config", job_id).as_bytes())
.ok()
.flatten()
.and_then(|value| serde_json::from_slice(&value).ok());

Ok((
scope,
is_finished,
Expand All @@ -340,6 +374,7 @@ impl ShinkaiDB {
unprocessed_messages,
self.get_job_execution_context(job_id)?,
associated_ui_value,
config_value,
))
}

Expand Down
4 changes: 3 additions & 1 deletion shinkai-bin/shinkai-node/src/lance_db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ pub mod ollama_embedding_fn;
pub mod shinkai_lance_db;
pub mod shinkai_tool_schema;
pub mod shinkai_lancedb_error;
pub mod shinkai_lance_version;
pub mod shinkai_lance_version;
pub mod shinkai_prompt_schema;
pub mod shinkai_prompt_db;
49 changes: 17 additions & 32 deletions shinkai-bin/shinkai-node/src/lance_db/shinkai_lance_db.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::tools::error::ToolError;
use crate::tools::js_toolkit_headers::{BasicConfig, ToolConfig};
use crate::tools::js_toolkit_headers::ToolConfig;
use crate::tools::shinkai_tool::{ShinkaiTool, ShinkaiToolHeader};
use arrow_array::{Array, BinaryArray, BooleanArray};
use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray};
Expand Down Expand Up @@ -32,8 +32,9 @@ pub struct LanceShinkaiDb {
connection: Connection,
pub tool_table: Table,
pub version_table: Table,
embedding_model: EmbeddingModelType,
embedding_function: OllamaEmbeddingFunction,
pub prompt_table: Table,
pub embedding_model: EmbeddingModelType,
pub embedding_function: OllamaEmbeddingFunction,
}

impl LanceShinkaiDb {
Expand All @@ -52,13 +53,15 @@ impl LanceShinkaiDb {
let connection = connect(&db_path).execute().await?;
let version_table = Self::create_version_table(&connection).await?;
let tool_table = Self::create_tool_router_table(&connection, &embedding_model).await?;
let prompt_table = Self::create_prompt_table(&connection, &embedding_model).await?;
let api_url = generator.api_url;
let embedding_function = OllamaEmbeddingFunction::new(&api_url, embedding_model.clone());

Ok(LanceShinkaiDb {
connection,
tool_table,
version_table,
prompt_table,
embedding_model,
embedding_function,
})
Expand Down Expand Up @@ -364,13 +367,10 @@ impl LanceShinkaiDb {
.await
.map_err(|e| ShinkaiLanceDBError::ToolError(e.to_string()))?;

let results = query
.try_collect::<Vec<_>>()
.await
.map_err(|e| ShinkaiLanceDBError::ToolError(e.to_string()))?;

let mut res = query;
let mut workflows = Vec::new();
for batch in results {

while let Some(Ok(batch)) = res.next().await {
let tool_header_array = batch
.column_by_name(ShinkaiToolSchema::tool_header_field())
.unwrap()
Expand Down Expand Up @@ -407,13 +407,10 @@ impl LanceShinkaiDb {
.await
.map_err(|e| ShinkaiLanceDBError::ToolError(e.to_string()))?;

let results = query
.try_collect::<Vec<_>>()
.await
.map_err(|e| ShinkaiLanceDBError::ToolError(e.to_string()))?;

let mut res = query;
let mut tools = Vec::new();
for batch in results {

while let Some(Ok(batch)) = res.next().await {
let tool_header_array = batch
.column_by_name(ShinkaiToolSchema::tool_header_field())
.unwrap()
Expand Down Expand Up @@ -498,18 +495,15 @@ impl LanceShinkaiDb {
query_builder = query_builder.only_if(filter.to_string());
}

let results = query_builder
let query = query_builder
.execute()
.await
.map_err(|e| ToolError::DatabaseError(e.to_string()))?;

let mut res = query;
let mut tool_headers = Vec::new();
let batches = results
.try_collect::<Vec<_>>()
.await
.map_err(|e| ToolError::DatabaseError(e.to_string()))?;

for batch in batches {
while let Some(Ok(batch)) = res.next().await {
let tool_header_array = batch
.column_by_name(ShinkaiToolSchema::tool_header_field())
.unwrap()
Expand Down Expand Up @@ -561,18 +555,14 @@ impl LanceShinkaiDb {
.nearest_to(embedding)
.map_err(|e| ToolError::DatabaseError(e.to_string()))?;

let results = query
let mut res = query
.execute()
.await
.map_err(|e| ToolError::DatabaseError(e.to_string()))?;

let mut tool_headers = Vec::new();
let batches = results
.try_collect::<Vec<_>>()
.await
.map_err(|e| ToolError::DatabaseError(e.to_string()))?;

for batch in batches {
while let Some(Ok(batch)) = res.next().await {
let tool_header_array = batch
.column_by_name(ShinkaiToolSchema::tool_header_field())
.unwrap()
Expand Down Expand Up @@ -654,7 +644,6 @@ mod tests {

#[tokio::test]
async fn test_vector_search_and_basics() -> Result<(), ShinkaiLanceDBError> {

setup();

let generator = RemoteEmbeddingGenerator::new_default();
Expand Down Expand Up @@ -759,7 +748,6 @@ mod tests {

#[tokio::test]
async fn test_add_tools_and_workflows() -> Result<(), ShinkaiLanceDBError> {

setup();

// Set the environment variable to enable testing workflows
Expand Down Expand Up @@ -849,7 +837,6 @@ mod tests {

#[tokio::test]
async fn test_add_tool_and_update_config() -> Result<(), ShinkaiLanceDBError> {

setup();

let generator = RemoteEmbeddingGenerator::new_default();
Expand Down Expand Up @@ -926,7 +913,6 @@ mod tests {

#[tokio::test]
async fn test_add_workflow_and_js_tool() -> Result<(), ShinkaiLanceDBError> {

setup();

let generator = RemoteEmbeddingGenerator::new_default();
Expand Down Expand Up @@ -994,7 +980,6 @@ mod tests {

#[tokio::test]
async fn test_has_any_js_tools() -> Result<(), ShinkaiLanceDBError> {

setup();

let generator = RemoteEmbeddingGenerator::new_default();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ impl LanceShinkaiDb {
#[cfg(test)]
mod tests {
use super::*;
use shinkai_message_primitives::shinkai_utils::shinkai_logging::init_default_tracing;
use shinkai_vector_resources::embedding_generator::EmbeddingGenerator;
use shinkai_vector_resources::embedding_generator::RemoteEmbeddingGenerator;
use std::fs;
Expand Down
12 changes: 12 additions & 0 deletions shinkai-bin/shinkai-node/src/lance_db/shinkai_lancedb_error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::tools::error::ToolError;
use lancedb::Error as LanceDbError;
use rocksdb::Error as RocksDbError;
use shinkai_vector_resources::resource_errors::VRError;
use std::fmt;

#[derive(Debug)]
Expand All @@ -12,6 +13,8 @@ pub enum ShinkaiLanceDBError {
InvalidPath(String),
ShinkaiDBError(String),
RocksDBError(String),
DatabaseError(String),
EmbeddingGenerationError(String),
}

impl fmt::Display for ShinkaiLanceDBError {
Expand All @@ -24,6 +27,8 @@ impl fmt::Display for ShinkaiLanceDBError {
ShinkaiLanceDBError::InvalidPath(err) => write!(f, "Invalid path error: {}", err),
ShinkaiLanceDBError::ShinkaiDBError(err) => write!(f, "ShinkaiDB error: {}", err),
ShinkaiLanceDBError::RocksDBError(err) => write!(f, "RocksDB error: {}", err),
ShinkaiLanceDBError::DatabaseError(err) => write!(f, "Database error: {}", err),
ShinkaiLanceDBError::EmbeddingGenerationError(err) => write!(f, "Embedding generation error: {}", err),
}
}
}
Expand Down Expand Up @@ -54,3 +59,10 @@ impl From<RocksDbError> for ShinkaiLanceDBError {
ShinkaiLanceDBError::RocksDBError(err.to_string())
}
}

// Add this implementation
impl From<VRError> for ShinkaiLanceDBError {
fn from(err: VRError) -> Self {
ShinkaiLanceDBError::Schema(err.to_string())
}
}
Loading

0 comments on commit 10ee938

Please sign in to comment.