Skip to content

Commit

Permalink
Merge pull request #356 from dcSpark/nico/update-node-stream-logic
Browse files Browse the repository at this point in the history
update comms
  • Loading branch information
nicarq authored Apr 23, 2024
2 parents ffd1cf9 + c1028d9 commit 31fb4b1
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 48 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ storage_streamer.zip
storage_rob
storage_requester
shinkai-libs/shinkai-fs-mirror/top-10-hn
shinkai_node.svg
old_streamer_shinkai_9850.key
126 changes: 79 additions & 47 deletions src/network/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use super::network_manager::network_job_manager::{
};
use super::node_api::{APIError, APIUseRegistrationCodeSuccessResponse, SendResponseBodyData};
use super::node_error::NodeError;
use super::subscription_manager::external_subscriber_manager::{ExternalSubscriberManager};
use super::subscription_manager::external_subscriber_manager::ExternalSubscriberManager;
use super::subscription_manager::my_subscription_manager::MySubscriptionsManager;
use crate::agent::job_manager::JobManager;
use crate::cron_tasks::cron_manager::CronManager;
Expand Down Expand Up @@ -467,11 +467,7 @@ impl Node {
let encryption_public_key = EncryptionPublicKey::from(&encryption_secret_key);
let node_name = ShinkaiName::new(node_name).unwrap();
{
match db_arc.update_local_node_keys(
node_name.clone(),
encryption_public_key,
identity_public_key,
) {
match db_arc.update_local_node_keys(node_name.clone(), encryption_public_key, identity_public_key) {
Ok(_) => (),
Err(e) => panic!("Failed to update local node keys: {}", e),
}
Expand Down Expand Up @@ -1914,59 +1910,71 @@ impl Node {

tokio::spawn(async move {
let mut socket = socket.lock().await;
let mut header = [0u8; 1]; // Buffer for the message type identifier
if socket.read_exact(&mut header).await.is_ok() {
let mut buffer = Vec::new();
if socket.read_to_end(&mut buffer).await.is_ok() {
let message_type = match header[0] {
let mut length_bytes = [0u8; 4];
if socket.read_exact(&mut length_bytes).await.is_ok() {
let total_length = u32::from_be_bytes(length_bytes) as usize;

// Read the identity length
let mut identity_length_bytes = [0u8; 4];
if socket.read_exact(&mut identity_length_bytes).await.is_err() {
return; // Exit if we fail to read identity length
}
let identity_length = u32::from_be_bytes(identity_length_bytes) as usize;

// Read the identity bytes
let mut identity_bytes = vec![0u8; identity_length];
if socket.read_exact(&mut identity_bytes).await.is_err() {
return; // Exit if we fail to read identity
}

// Calculate the message length excluding the identity length and the identity itself
let msg_length = total_length - 1 - 4 - identity_length; // Subtract 1 for the header and 4 for the identity length bytes

// Initialize buffer to fit the message
let mut buffer = vec![0u8; msg_length];

// Read the header byte to determine the message type
let mut header_byte = [0u8; 1];
if socket.read_exact(&mut header_byte).await.is_ok() {
let message_type = match header_byte[0] {
0x01 => NetworkMessageType::ShinkaiMessage,
0x02 => NetworkMessageType::VRKaiPathPair,
// Add cases for other message types as needed
_ => {
shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Error,
"Received message with unknown type identifier",
);
return; // Skip processing for unknown message types
return; // Exit the task if the message type is unknown
}
};
shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Info,
&format!("Received message of type {:?} from: {:?}", message_type, addr),
);

let destination_socket = socket.peer_addr().expect("Failed to get peer address");
let network_job = NetworkJobQueue {
receiver_address: addr,
unsafe_sender_address: destination_socket,
message_type,
content: buffer.clone(),
date_created: Utc::now(),
};

let mut network_job_manager = network_job_manager.lock().await;
if let Err(e) = network_job_manager.add_network_job_to_queue(&network_job).await {
shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Error,
&format!("Failed to add network job to queue: {}", e),
);
}
if let Err(e) = socket.flush().await {
// Read the rest of the message into the buffer
if socket.read_exact(&mut buffer).await.is_ok() {
shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Error,
&format!("Failed to flush the socket: {}", e),
ShinkaiLogLevel::Info,
&format!("Received message of type {:?} from: {:?}", message_type, addr),
);

let destination_socket = socket.peer_addr().expect("Failed to get peer address");
let network_job = NetworkJobQueue {
receiver_address: addr,
unsafe_sender_address: destination_socket,
message_type,
content: buffer.clone(), // Now buffer does not include the header
date_created: Utc::now(),
};

let mut network_job_manager = network_job_manager.lock().await;
if let Err(e) = network_job_manager.add_network_job_to_queue(&network_job).await {
shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Error,
&format!("Failed to add network job to queue: {}", e),
);
}
}
} else {
shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Error,
"Failed to read the message type identifier",
);
}
}
conn_limiter_clone.decrement_connection(&ip).await;
Expand Down Expand Up @@ -2049,7 +2057,19 @@ impl Node {
match stream {
Ok(mut stream) => {
let encoded_msg = message.encode_message().unwrap();
let mut data_to_send = vec![0x01]; // Message type identifier for ShinkaiMessage
let identity = &message.external_metadata.recipient;
let identity_bytes = identity.as_bytes();
let identity_length = (identity_bytes.len() as u32).to_be_bytes();

// Prepare the message with a length prefix and identity length
let total_length = (encoded_msg.len() as u32 + 1 + identity_bytes.len() as u32 + 4).to_be_bytes(); // Convert the total length to bytes, adding 1 for the header and 4 for the identity length

let mut data_to_send = Vec::new();
let header_data_to_send = vec![0x01]; // Message type identifier for ShinkaiMessage
data_to_send.extend_from_slice(&total_length);
data_to_send.extend_from_slice(&identity_length);
data_to_send.extend(identity_bytes);
data_to_send.extend(header_data_to_send);
data_to_send.extend_from_slice(&encoded_msg);
let _ = stream.write_all(&data_to_send).await;
let _ = stream.flush().await;
Expand Down Expand Up @@ -2098,6 +2118,7 @@ impl Node {
subscription_id: SubscriptionId,
encryption_key_hex: String,
peer: SocketAddr,
recipient: ShinkaiName,
) {
tokio::spawn(async move {
// Serialize only the VRKaiPath pairs
Expand Down Expand Up @@ -2131,8 +2152,19 @@ impl Node {
};
let vr_kai_serialized = bincode::serialize(&vr_kai).unwrap();

// Prepend nonce to the encrypted data to use it during decryption
let mut data_to_send = vec![0x02]; // Network Message type identifier for VRKaiPathPair
let identity = recipient.get_node_name_string();
let identity_bytes = identity.as_bytes();
let identity_length = (identity_bytes.len() as u32).to_be_bytes();

// Prepare the message with a length prefix, identity length, and identity
let total_length = (vr_kai_serialized.len() as u32 + 1 + identity_bytes.len() as u32 + 4).to_be_bytes(); // Convert the total length to bytes, adding 1 for the header and 4 for the identity length

let mut data_to_send = Vec::new();
let header_data_to_send = vec![0x02]; // Network Message type identifier for VRKaiPathPair
data_to_send.extend_from_slice(&total_length);
data_to_send.extend_from_slice(&identity_length);
data_to_send.extend(identity_bytes);
data_to_send.extend(header_data_to_send);
data_to_send.extend_from_slice(&vr_kai_serialized);

// Convert to Vec<u8> to send over TCP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ impl ExternalSubscriberManager {
.to_string(),
)
})?;
let receiver_name = receiver_identity.full_identity_name;

shinkai_log(
ShinkaiLogOption::ExtSubscriptions,
Expand All @@ -605,7 +606,7 @@ impl ExternalSubscriberManager {
);

// Call the send_encrypted_vrkaipath_pairs function
Node::send_encrypted_vrpack(vr_pack, subscription_id, symmetric_key, receiver_socket_addr).await;
Node::send_encrypted_vrpack(vr_pack, subscription_id, symmetric_key, receiver_socket_addr, receiver_name).await;
Ok(())
}

Expand Down

0 comments on commit 31fb4b1

Please sign in to comment.