Skip to content

Commit

Permalink
Merge pull request #460 from dcSpark/nico/websocket_identity_not_foun…
Browse files Browse the repository at this point in the history
…d_fix
  • Loading branch information
nicarq authored Jun 30, 2024
2 parents 9e7075f + f40275f commit a873350
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 23 deletions.
2 changes: 0 additions & 2 deletions shinkai-bin/shinkai-node/src/managers/identity_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ impl IdentityManager {
}

pub fn get_all_subidentities(&self) -> Vec<Identity> {
// println!("identities_manager identities: {:?}", self.local_identities);
self.local_identities.clone()
}

Expand Down Expand Up @@ -325,7 +324,6 @@ impl IdentityManager {
#[async_trait]
impl IdentityManagerTrait for IdentityManager {
fn find_by_identity_name(&self, full_profile_name: ShinkaiName) -> Option<&Identity> {
// println!("identities_manager identities: {:?}", self.local_identities);
self.local_identities.iter().find(|identity| {
match identity {
Identity::Standard(identity) => identity.full_identity_name == full_profile_name,
Expand Down
8 changes: 3 additions & 5 deletions shinkai-bin/shinkai-node/src/network/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,11 +615,9 @@ impl Node {
})));
let proxy_connection_info_weak = Arc::downgrade(&proxy_connection_info);

let identity_manager_trait: Arc<Mutex<Box<dyn IdentityManagerTrait + Send + 'static>>> = {
let identity_manager_inner = identity_manager.lock().await;
let boxed_identity_manager =
Box::new(identity_manager_inner.clone()) as Box<dyn IdentityManagerTrait + Send + 'static>;
Arc::new(Mutex::new(boxed_identity_manager))
let identity_manager_trait: Arc<Mutex<dyn IdentityManagerTrait + Send + 'static>> = {
// Cast the Arc<Mutex<IdentityManager>> to Arc<Mutex<dyn IdentityManagerTrait + Send + 'static>>
identity_manager.clone() as Arc<Mutex<dyn IdentityManagerTrait + Send + 'static>>
};

let ws_manager = if ws_address.is_some() {
Expand Down
5 changes: 1 addition & 4 deletions shinkai-bin/shinkai-node/src/network/node_api_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,9 @@ impl Node {
potentially_encrypted_msg: ShinkaiMessage,
schema_type: Option<MessageSchemaType>,
) -> Result<(ShinkaiMessage, Identity), APIError> {
let identity_manager_trait: Box<dyn IdentityManagerTrait + Send> = identity_manager.lock().await.clone_box();
// Decrypt the message body if needed

validate_message_main_logic(
&encryption_secret_key,
Arc::new(Mutex::new(identity_manager_trait)),
identity_manager,
&node_name.clone(),
potentially_encrypted_msg,
schema_type,
Expand Down
3 changes: 1 addition & 2 deletions shinkai-bin/shinkai-node/src/network/node_shareable_logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ use x25519_dalek::StaticSecret as EncryptionStaticKey;

pub async fn validate_message_main_logic(
encryption_secret_key: &EncryptionStaticKey,
// identity_manager: Arc<Mutex<dyn IdentityManagerTrait + Send>>,
identity_manager: Arc<Mutex<Box<dyn IdentityManagerTrait + Send>>>,
identity_manager: Arc<Mutex<dyn IdentityManagerTrait + Send>>,
node_profile_name: &ShinkaiName,
potentially_encrypted_msg: ShinkaiMessage,
schema_type: Option<MessageSchemaType>,
Expand Down
7 changes: 3 additions & 4 deletions shinkai-bin/shinkai-node/src/network/ws_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ pub struct WebSocketManager {
shared_keys: HashMap<String, String>,
shinkai_db: Weak<ShinkaiDB>,
node_name: ShinkaiName,
identity_manager_trait: Arc<Mutex<Box<dyn IdentityManagerTrait + Send>>>,
identity_manager_trait: Arc<Mutex<dyn IdentityManagerTrait + Send>>,
encryption_secret_key: EncryptionStaticKey,
message_queue: MessageQueue,
}
Expand All @@ -134,7 +134,7 @@ impl WebSocketManager {
pub async fn new(
shinkai_db: Weak<ShinkaiDB>,
node_name: ShinkaiName,
identity_manager_trait: Arc<Mutex<Box<dyn IdentityManagerTrait + Send>>>,
identity_manager_trait: Arc<Mutex<dyn IdentityManagerTrait + Send>>,
encryption_secret_key: EncryptionStaticKey,
) -> Arc<Mutex<Self>> {
let manager = Arc::new(Mutex::new(Self {
Expand Down Expand Up @@ -185,7 +185,6 @@ impl WebSocketManager {
}
}

// TODO: shouldn't this be encrypted?
pub async fn user_validation(
&self,
shinkai_name: ShinkaiName,
Expand Down Expand Up @@ -308,7 +307,7 @@ impl WebSocketManager {
WebSocketManagerError::UserValidationFailed(format!("Failed to deserialize WSMessage: {}", e))
})?;

eprintln!("ws_message: {:?}", ws_message);
// eprintln!("ws_message: {:?}", ws_message);

// Validate shared_key if it exists
if let Some(shared_key) = &ws_message.shared_key {
Expand Down
10 changes: 4 additions & 6 deletions shinkai-bin/shinkai-node/tests/it/websocket_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,8 @@ async fn test_websocket() {
let (node1_encryption_sk, node1_encryption_pk) = unsafe_deterministic_encryption_keypair(0);

let node_name = ShinkaiName::new(node1_identity_name.to_string()).unwrap();
let identity_manager_trait = Arc::new(Mutex::new(
Box::new(MockIdentityManager::new()) as Box<dyn IdentityManagerTrait + Send>
));
let identity_manager_trait: Arc<Mutex<dyn IdentityManagerTrait + Send>> =
Arc::new(Mutex::new(MockIdentityManager::new()));

let inbox_name1 = InboxName::get_job_inbox_name_from_params(job_id1.to_string()).unwrap();
let inbox_name2 = InboxName::get_job_inbox_name_from_params(job_id2.to_string()).unwrap();
Expand Down Expand Up @@ -484,9 +483,8 @@ async fn test_websocket_smart_inbox() {
let (node1_encryption_sk, node1_encryption_pk) = unsafe_deterministic_encryption_keypair(0);

let node_name = ShinkaiName::new(node1_identity_name.to_string()).unwrap();
let identity_manager_trait = Arc::new(Mutex::new(
Box::new(MockIdentityManager::new()) as Box<dyn IdentityManagerTrait + Send>
));
let identity_manager_trait: Arc<Mutex<dyn IdentityManagerTrait + Send>> =
Arc::new(Mutex::new(MockIdentityManager::new()));

let inbox_name1 = InboxName::get_job_inbox_name_from_params(job_id1.to_string()).unwrap();
let inbox_name1_string = match inbox_name1 {
Expand Down

0 comments on commit a873350

Please sign in to comment.