Skip to content

Commit

Permalink
Merge branch 'mez'
Browse files Browse the repository at this point in the history
  • Loading branch information
massimoalbarello committed Aug 3, 2023
2 parents 5bd3c1a + 4296620 commit bd576b5
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 124 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/ic-websocket-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ serde_cbor = "0.11.2"
tokio = { version = "1.29.1", features = ["full"] }
tokio-native-tls = "0.3.1"
native-tls = "0.2.11"
tokio-util = "0.7.8"
serde_bytes = "0.11.12"
tokio-tungstenite = "0.20.0"
ed25519-compact = "2.0.4"
Expand Down
31 changes: 22 additions & 9 deletions src/ic-websocket-gateway/src/canister_poller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ pub struct CertifiedMessage {
#[derive(Debug)]
pub struct PollerChannelsPollerEnds {
main_to_poller: Receiver<PollerToClientChannelData>,
poller_to_main: Sender<Principal>,
poller_to_main: Sender<TerminationInfo>,
}

impl PollerChannelsPollerEnds {
pub fn new(
main_to_poller: Receiver<PollerToClientChannelData>,
poller_to_main: Sender<Principal>,
poller_to_main: Sender<TerminationInfo>,
) -> Self {
Self {
main_to_poller,
Expand All @@ -43,7 +43,7 @@ impl PollerChannelsPollerEnds {
}
}

/// contains the information that the main sends to the poller task
/// contains the information that the main sends to the poller task:
/// - NewClientChannel: sending side of the channel use by the poller to send messages to the client
/// - ClientDisconnected: signals the poller which cllient disconnected
#[derive(Debug, Clone)]
Expand All @@ -52,6 +52,14 @@ pub enum PollerToClientChannelData {
ClientDisconnected(ClientPublicKey),
}

/// determines the reason of the poller task termination:
/// - LastClientDisconnected: last client disconnected and therefore there is no need to continue polling
/// - CdkError: error while polling the canister
pub enum TerminationInfo {
LastClientDisconnected(Principal),
CdkError(Principal),
}

pub struct CanisterPoller {
canister_id: Principal,
agent: Arc<Agent>,
Expand Down Expand Up @@ -87,7 +95,7 @@ impl CanisterPoller {
// instead of issuing a new call to get_canister_updates
tokio::pin!(get_messages_operation);

loop {
'poller_loop: loop {
select! {
// receive channel used to send canister updates to new client's task
Some(channel_data) = poller_channels.main_to_poller.recv() => {
Expand All @@ -102,7 +110,7 @@ impl CanisterPoller {
info!("{} clients connected to poller", client_channels.len());
// exit task if last client disconnected
if client_channels.is_empty() {
signal_poller_task_termination(&mut poller_channels.poller_to_main, self.canister_id).await;
signal_poller_task_termination(&mut poller_channels.poller_to_main, TerminationInfo::LastClientDisconnected(self.canister_id)).await;
info!("Terminating poller task as no clients are connected");
break;
}
Expand Down Expand Up @@ -133,8 +141,10 @@ impl CanisterPoller {

match get_nonce_from_message(encoded_message.key) {
Ok(last_nonce) => nonce = last_nonce + 1,
Err(_e) => {
panic!("TODO: graceful shutdown of poller task and related clients disconnection");
Err(e) => {
signal_poller_task_termination(&mut poller_channels.poller_to_main, TerminationInfo::CdkError(self.canister_id)).await;
error!("Terminating poller task due to CDK error: {}", e);
break 'poller_loop;
}
}
}
Expand Down Expand Up @@ -171,8 +181,11 @@ fn get_nonce_from_message(key: String) -> Result<u64, String> {
))
}

async fn signal_poller_task_termination(channel: &mut Sender<Principal>, canister_id: Principal) {
if let Err(e) = channel.send(canister_id).await {
async fn signal_poller_task_termination(
channel: &mut Sender<TerminationInfo>,
info: TerminationInfo,
) {
if let Err(e) = channel.send(info).await {
error!(
"Receiver has been dropped on the pollers connection manager's side. Error: {:?}",
e
Expand Down
176 changes: 111 additions & 65 deletions src/ic-websocket-gateway/src/client_connection_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use tokio_tungstenite::{
tungstenite::{Error, Message},
WebSocketStream,
};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, span, warn, Instrument, Level};

use crate::{
Expand Down Expand Up @@ -104,54 +105,70 @@ impl WsConnectionsHandler {
}
}

pub async fn listen_for_incoming_requests(&mut self) {
while let Ok((stream, client_addr)) = self.listener.accept().await {
let stream = match self.tls_acceptor {
Some(ref acceptor) => {
let tls_stream = acceptor.accept(stream).await;
match tls_stream {
Ok(tls_stream) => {
debug!("TLS handshake successful");
CustomTcpStream::TcpWithTls(tls_stream)
},
Err(e) => {
error!("TLS handshake failed: {:?}", e);
continue;
pub async fn listen_for_incoming_requests(&mut self, parent_token: CancellationToken) {
// needed to ensure that we stop listening for incoming requests before we start shutting down the connections
let child_token = CancellationToken::new();
loop {
select! {
Ok((stream, client_addr)) = self.listener.accept(), if !parent_token.is_cancelled() => {
let stream = match self.tls_acceptor {
Some(ref acceptor) => {
let tls_stream = acceptor.accept(stream).await;
match tls_stream {
Ok(tls_stream) => {
debug!("TLS handshake successful");
CustomTcpStream::TcpWithTls(tls_stream)
},
Err(e) => {
error!("TLS handshake failed: {:?}", e);
continue;
},
}
},
}
},
None => CustomTcpStream::Tcp(stream),
};
let agent_cl = Arc::clone(&self.agent);
let client_connection_handler_tx_cl = self.client_connection_handler_tx.clone();
// spawn a connection handler task for each incoming client connection
let current_client_id = self.next_client_id;
let span = span!(
Level::INFO,
"handle_client_connection",
client_addr = ?client_addr,
client_id = current_client_id
);
tokio::spawn(
async move {
let client_connection_handler = ClientConnectionHandler::new(
current_client_id,
agent_cl,
client_connection_handler_tx_cl,
None => CustomTcpStream::Tcp(stream),
};
let agent_cl = Arc::clone(&self.agent);
let client_connection_handler_tx_cl = self.client_connection_handler_tx.clone();
// spawn a connection handler task for each incoming client connection
let current_client_id = self.next_client_id;
let span = span!(
Level::INFO,
"handle_client_connection",
client_addr = ?client_addr,
client_id = current_client_id
);
info!("Spawned new connection handler");
match stream {
CustomTcpStream::Tcp(stream) => {
client_connection_handler.handle_stream(stream).await
},
CustomTcpStream::TcpWithTls(stream) => {
client_connection_handler.handle_stream(stream).await
},
}

let child_token_cl = child_token.clone();
tokio::spawn(
async move {
let client_connection_handler = ClientConnectionHandler::new(
current_client_id,
agent_cl,
client_connection_handler_tx_cl,
child_token_cl
);
info!("Spawned new connection handler");
match stream {
CustomTcpStream::Tcp(stream) => {
client_connection_handler.handle_stream(stream).await
},
CustomTcpStream::TcpWithTls(stream) => {
client_connection_handler.handle_stream(stream).await
},
}
info!("Terminated client connection handler task");
}
.instrument(span),
);
self.next_client_id += 1;
},
_ = parent_token.cancelled() => {
child_token.cancel();
warn!("Stopped listening for incoming requests");
break;
}
.instrument(span),
);
self.next_client_id += 1;

}
}
}
}
Expand All @@ -160,17 +177,20 @@ struct ClientConnectionHandler {
id: u64,
agent: Arc<Agent>,
client_connection_handler_tx: Sender<WsConnectionState>,
token: CancellationToken,
}
impl ClientConnectionHandler {
pub fn new(
id: u64,
agent: Arc<Agent>,
client_connection_handler_tx: Sender<WsConnectionState>,
token: CancellationToken,
) -> Self {
Self {
id,
agent,
client_connection_handler_tx,
token,
}
}
pub async fn handle_stream<S: AsyncRead + AsyncWrite + Unpin>(&self, stream: S) {
Expand All @@ -181,10 +201,12 @@ impl ClientConnectionHandler {
let mut is_first_message = true;
// create channel which will be used to send messages from the canister poller directly to this client
let (message_for_client_tx, mut message_for_client_rx) = mpsc::channel(100);
let (terminate_client_handler_tx, mut terminate_client_handler_rx) =
mpsc::channel(1);
loop {
select! {
// wait for incoming message from client
msg_res = ws_read.try_next() => {
msg_res = ws_read.try_next(), if !self.token.is_cancelled() => {
match msg_res {
Ok(Some(message)) => {
// check if the WebSocket connection is closed
Expand All @@ -206,25 +228,34 @@ impl ClientConnectionHandler {
// the nonce is obtained from the canister every time a client connects and the ws_open is called by the WS Gateway
nonce,
}) => {
info!("Client established IC WebSocket connection");
// let the client know that the IC WS connection is setup correctly
send_ws_message_to_client(&mut ws_write, Message::Text("1".to_string())).await;
// prevent adding a new client to the gateway state while shutting down
if !self.token.is_cancelled() {
info!("Client established IC WebSocket connection");
// let the client know that the IC WS connection is setup correctly
send_ws_message_to_client(&mut ws_write, Message::Text("1".to_string())).await;

// create a new sender side of the channel which will be used to send canister messages
// from the poller task directly to the client's connection handler task
let message_for_client_tx_cl = message_for_client_tx.clone();
// instantiate a new GatewaySession and send it to the main thread
self.send_connection_state_to_clients_manager(
WsConnectionState::ConnectionEstablished(
GatewaySession::new(
self.id,
client_key,
canister_id,
message_for_client_tx_cl,
nonce,
),
)
).await;
// create a new sender side of the channel which will be used to send canister messages
// from the poller task directly to the client's connection handler task
let message_for_client_tx_cl = message_for_client_tx.clone();
let terminate_client_handler_tx_cl = terminate_client_handler_tx.clone();
// instantiate a new GatewaySession and send it to the main thread
self.send_connection_state_to_clients_manager(
WsConnectionState::ConnectionEstablished(
GatewaySession::new(
self.id,
client_key,
canister_id,
message_for_client_tx_cl,
terminate_client_handler_tx_cl,
nonce,
),
)
).await;
}
else {
warn!("Preventing client connection handler task to establish new WS connection");
break;
}
},
Err(e) => {
info!("Client did not follow IC WebSocket establishment protocol: {:?}", e);
Expand Down Expand Up @@ -283,10 +314,25 @@ impl ClientConnectionHandler {
},
Err(e) => error!("Could not serialize canister message. Error: {:?}", e)
}
},
_ = self.token.cancelled() => {
self.send_connection_state_to_clients_manager(
WsConnectionState::ConnectionClosed(self.id)
)
.await;
// close the WebSocket connection
ws_write.close().await.unwrap();
warn!("Terminating client connection handler task");
break;
},
_ = terminate_client_handler_rx.recv() => {
// close the WebSocket connection
ws_write.close().await.unwrap();
error!("Terminating client connection handler task due to CDK error");
break;
}
}
}
info!("Terminating client connection handler task");
},
// no cleanup needed on the WS Gateway has the client's session has never been created
Err(e) => {
Expand Down
Loading

0 comments on commit bd576b5

Please sign in to comment.