diff --git a/core/handshake/src/handshake.rs b/core/handshake/src/handshake.rs index 8497ee6fd..130a69571 100644 --- a/core/handshake/src/handshake.rs +++ b/core/handshake/src/handshake.rs @@ -178,7 +178,7 @@ impl Context

{ // Attempt to connect to the service, getting the unix socket. let Some(mut socket) = self.provider.connect(service).await else { - sender.terminate(TerminationReason::InvalidService); + sender.terminate(TerminationReason::InvalidService).await; warn!("failed to connect to service {service}"); return; }; @@ -189,7 +189,7 @@ impl Context

{ }; if let Err(e) = write_header(&header, &mut socket).await { - sender.terminate(TerminationReason::ServiceTerminated); + sender.terminate(TerminationReason::ServiceTerminated).await; warn!("failed to write connection header to service {service}: {e}"); return; } @@ -226,12 +226,12 @@ impl Context

{ let connection_id = u64::from_be_bytes(*arrayref::array_ref![access_token, 0, 8]); let Some(connection) = self.connections.get(&connection_id) else { - sender.terminate(TerminationReason::InvalidToken); + sender.terminate(TerminationReason::InvalidToken).await; return; }; if connection.access_token != access_token { - sender.terminate(TerminationReason::InvalidToken); + sender.terminate(TerminationReason::InvalidToken).await; return; } @@ -241,7 +241,7 @@ impl Context

{ .expect("Failed to get current time") .as_millis() { - sender.terminate(TerminationReason::InvalidToken); + sender.terminate(TerminationReason::InvalidToken).await; return; } @@ -255,7 +255,7 @@ impl Context

{ retry: Some(id), .. } => { let Some(connection) = self.connections.get(&id) else { - sender.terminate(TerminationReason::InvalidToken); + sender.terminate(TerminationReason::InvalidToken).await; return; }; diff --git a/core/handshake/src/proxy.rs b/core/handshake/src/proxy.rs index 1561035ba..19da76717 100644 --- a/core/handshake/src/proxy.rs +++ b/core/handshake/src/proxy.rs @@ -179,7 +179,7 @@ impl Proxy

{ res = receiver.recv() => { match async_map(res, |r| self.handle_incoming(is_primary, r)).await { Some(HandleRequestResult::Ok) if is_primary => { - self.maybe_flush_primary_queue(true, &mut sender); + self.maybe_flush_primary_queue(true, &mut sender).await; }, Some(HandleRequestResult::Ok) => {}, Some(HandleRequestResult::TerminateConnection) => { @@ -201,13 +201,13 @@ impl Proxy

{ // the same connection mode. The client must have lost its connection. // We terminate the current (old) connection with a `ConnectionInUse` error. (true, true) => { - sender.terminate(TerminationReason::ConnectionInUse); + sender.terminate(TerminationReason::ConnectionInUse).await; self.discard_bytes = true; self.queued_primary_response.clear(); State::OnlyPrimaryConnection(pair) }, (false, false) => { - sender.terminate(TerminationReason::ConnectionInUse); + sender.terminate(TerminationReason::ConnectionInUse).await; self.discard_bytes = true; State::OnlySecondaryConnection(pair) } @@ -255,12 +255,12 @@ impl Proxy

{ // This might be a window to flush some pending responses to the // primary. if is_primary { - self.maybe_flush_primary_queue(true, &mut sender); + self.maybe_flush_primary_queue(true, &mut sender).await; } let bytes = self.buffer.split_to(4); let len = u32::from_be_bytes(*array_ref![bytes, 0, 4]) as usize; - sender.start_write(len); + sender.start_write(len).await; self.current_write = len; continue 'inner; // to handle `len` == 0. } @@ -274,7 +274,7 @@ impl Proxy

{ continue 'inner; } - if sender.write(bytes.freeze()).is_err() { + if sender.write(bytes.freeze()).await.is_err() { self.discard_bytes = true; self.queued_primary_response.clear(); return State::NoConnection; @@ -285,7 +285,7 @@ impl Proxy

{ } }; - sender.terminate(reason); + sender.terminate(reason).await; State::Terminated } @@ -312,7 +312,7 @@ impl Proxy

{ res = p_receiver.recv() => { match async_map(res, |r| self.handle_incoming(true, r)).await { Some(HandleRequestResult::Ok) => { - self.maybe_flush_primary_queue(false, &mut p_sender); + self.maybe_flush_primary_queue(false, &mut p_sender).await; }, Some(HandleRequestResult::TerminateConnection) => { break 'outer TerminationReason::InternalError; @@ -393,11 +393,11 @@ impl Proxy

{ // This might be a window to flush some pending responses to the // primary. - self.maybe_flush_primary_queue(false, &mut p_sender); + self.maybe_flush_primary_queue(false, &mut p_sender).await; let bytes = self.buffer.split_to(4); let len = u32::from_be_bytes(*array_ref![bytes, 0, 4]) as usize; - s_sender.start_write(len); + s_sender.start_write(len).await; self.current_write = len; continue 'inner; // to handle `len` == 0. } @@ -412,14 +412,14 @@ impl Proxy

{ } return if self.is_primary_the_current_sender { - if p_sender.write(bytes.freeze()).is_ok() { + if p_sender.write(bytes.freeze()).await.is_ok() { continue 'inner; } self.discard_bytes = true; self.queued_primary_response.clear(); State::OnlySecondaryConnection((s_sender, s_receiver).into()) } else { - if s_sender.write(bytes.freeze()).is_ok() { + if s_sender.write(bytes.freeze()).await.is_ok() { continue 'inner; } self.discard_bytes = true; @@ -431,8 +431,8 @@ impl Proxy

{ } }; - s_sender.terminate(reason); - p_sender.terminate(reason); + s_sender.terminate(reason).await; + p_sender.terminate(reason).await; State::Terminated } @@ -493,7 +493,7 @@ impl Proxy

{ } #[inline(always)] - fn maybe_flush_primary_queue( + async fn maybe_flush_primary_queue( &mut self, only_primary: bool, sender: &mut S, @@ -505,7 +505,7 @@ impl Proxy

{ let is_primary_current_writer = only_primary || self.is_primary_the_current_sender; if !is_primary_current_writer || self.current_write == 0 || self.discard_bytes { while let Some(res) = self.queued_primary_response.pop_back() { - sender.send(res); + sender.send(res).await; } } } diff --git a/core/handshake/src/transports/http/handler.rs b/core/handshake/src/transports/http/handler.rs index a4b62d1f4..aca4e026b 100644 --- a/core/handshake/src/transports/http/handler.rs +++ b/core/handshake/src/transports/http/handler.rs @@ -52,11 +52,7 @@ pub async fn handler( }; let (frame_tx, frame_rx) = async_channel::bounded(8); - // Todo: Fix synchronization between data transfer and connection proxy. - // Notes: Reducing the size of this queue produces enough backpressure to slow down the - // transfer of data and if the socket is dropped before all the data is transferred, - // the proxy drops the connection and the client receives incomplete data. - let (body_tx, body_rx) = async_channel::bounded(2_000_000); + let (body_tx, body_rx) = async_channel::bounded(16); let (termination_tx, termination_rx) = oneshot::channel(); let sender = HttpSender::new(service_id, frame_tx, body_tx, termination_tx); diff --git a/core/handshake/src/transports/http/mod.rs b/core/handshake/src/transports/http/mod.rs index 9b3ff3b09..44a1e2c3c 100644 --- a/core/handshake/src/transports/http/mod.rs +++ b/core/handshake/src/transports/http/mod.rs @@ -95,25 +95,22 @@ impl HttpSender { } #[inline(always)] - fn inner_send(&mut self, bytes: Bytes) { - if let Err(e) = self.body_tx.try_send(Ok(bytes)) { + async fn inner_send(&mut self, bytes: Bytes) { + if let Err(e) = self.body_tx.send(Ok(bytes)).await { warn!("payload dropped, failed to send to write loop: {e}"); } } - fn send_with_http_override(&mut self, bytes: Bytes) { + async fn send_with_http_override(&mut self, bytes: Bytes) { if let Some(header_buffer) = self.header_buffer.as_mut() { header_buffer.extend(bytes); if self.current_write == 0 { // always Some due to previous check let payload = self.header_buffer.take().unwrap_or_default(); - - if let Err(e) = self.body_tx.try_send(Ok(payload.into())) { - warn!("payload dropped, failed to send to write loop: {e}"); - } + self.inner_send(payload.freeze()).await; } - } else if let Err(e) = self.body_tx.try_send(Ok(bytes)) { - warn!("payload dropped, failed to send to write loop: {e}"); + } else { + self.inner_send(bytes).await } } @@ -126,21 +123,21 @@ impl HttpSender { } impl TransportSender for HttpSender { - fn send_handshake_response(&mut self, _: HandshakeResponse) { + async fn send_handshake_response(&mut self, _: HandshakeResponse) { unimplemented!() } - fn send(&mut self, _: ResponseFrame) { + async fn send(&mut self, _: ResponseFrame) { unimplemented!() } - fn terminate(mut self, reason: TerminationReason) { + async fn terminate(mut self, reason: TerminationReason) { if let Some(reason_sender) = self.termination_tx.take() { let _ = reason_sender.send(reason); } } - fn start_write(&mut self, len: usize) { + async fn start_write(&mut self, len: usize) { // if the header buffer is gone it means we sent the headers already and are ready to stream // the body if self.header_buffer.is_none() || !self.service.supports_http_overrides() { @@ -155,7 +152,7 @@ impl TransportSender for HttpSender { self.current_write = len; } - fn write(&mut self, buf: Bytes) -> anyhow::Result { + async fn write(&mut self, buf: Bytes) -> anyhow::Result { let len = buf.len(); debug_assert!(self.current_write != 0); @@ -164,9 +161,9 @@ impl TransportSender for HttpSender { self.current_write -= len; if self.service.supports_http_overrides() { - self.send_with_http_override(buf); + self.send_with_http_override(buf).await; } else { - self.inner_send(buf); + self.inner_send(buf).await; } Ok(len) @@ -187,7 +184,6 @@ impl HttpReceiver { } } -#[async_trait] impl TransportReceiver for HttpReceiver { fn detail(&mut self) -> TransportDetail { self.detail diff --git a/core/handshake/src/transports/mock.rs b/core/handshake/src/transports/mock.rs index e4b0c2275..cfcf6acf7 100644 --- a/core/handshake/src/transports/mock.rs +++ b/core/handshake/src/transports/mock.rs @@ -115,7 +115,7 @@ pub struct MockTransportSender { } impl MockTransportSender { - fn send_inner(&mut self, bytes: Bytes) { + async fn send_inner(&mut self, bytes: Bytes) { self.tx .try_send(bytes) .expect("failed to send bytes over the mock connection") @@ -123,28 +123,29 @@ impl MockTransportSender { } impl TransportSender for MockTransportSender { - fn send_handshake_response(&mut self, response: schema::HandshakeResponse) { - self.send_inner(response.encode()); + async fn send_handshake_response(&mut self, response: schema::HandshakeResponse) { + self.send_inner(response.encode()).await; } - fn send(&mut self, frame: schema::ResponseFrame) { - self.send_inner(frame.encode()); + async fn send(&mut self, frame: schema::ResponseFrame) { + self.send_inner(frame.encode()).await; } - fn start_write(&mut self, len: usize) { + async fn start_write(&mut self, len: usize) { debug_assert!(self.buffer.is_empty()); self.buffer.reserve(len); self.current_write = len; } - fn write(&mut self, buf: Bytes) -> anyhow::Result { + async fn write(&mut self, buf: Bytes) -> anyhow::Result { let len = buf.len(); debug_assert!(len <= self.current_write); self.buffer.put(buf); if self.buffer.len() >= self.current_write { let bytes = self.buffer.split_to(self.current_write).into(); self.current_write = 0; - self.send(schema::ResponseFrame::ServicePayload { bytes }); + self.send(schema::ResponseFrame::ServicePayload { bytes }) + .await; } Ok(len) } @@ -155,7 +156,6 @@ pub struct MockTransportReceiver { rx: async_channel::Receiver, } -#[async_trait] impl TransportReceiver for MockTransportReceiver { async fn recv(&mut self) -> Option { let bytes = self.rx.recv().await.ok()?; diff --git a/core/handshake/src/transports/mod.rs b/core/handshake/src/transports/mod.rs index cfe7c2405..8197169d5 100644 --- a/core/handshake/src/transports/mod.rs +++ b/core/handshake/src/transports/mod.rs @@ -2,6 +2,7 @@ use async_trait::async_trait; use axum::Router; use bytes::{BufMut, Bytes, BytesMut}; use fn_sdk::header::TransportDetail; +use futures::Future; use lightning_interfaces::prelude::*; use serde::de::DeserializeOwned; use serde::Serialize; @@ -101,25 +102,30 @@ pub trait Transport: Sized + Send + Sync + 'static { // for secondary connections are added pub trait TransportSender: Sized + Send + Sync + 'static { /// Send the initial handshake response to the client. - fn send_handshake_response(&mut self, response: schema::HandshakeResponse); + fn send_handshake_response( + &mut self, + response: schema::HandshakeResponse, + ) -> impl Future + Send; /// Send a frame to the client. - fn send(&mut self, frame: schema::ResponseFrame); + fn send(&mut self, frame: schema::ResponseFrame) -> impl Future + Send; /// Terminate the connection - fn terminate(mut self, reason: schema::TerminationReason) { - self.send(schema::ResponseFrame::Termination { reason }) + fn terminate(mut self, reason: schema::TerminationReason) -> impl Future + Send { + async move { + self.send(schema::ResponseFrame::Termination { reason }) + .await + } } /// Declare a number of bytes to write as service payloads. - fn start_write(&mut self, len: usize); + fn start_write(&mut self, len: usize) -> impl Future + Send; /// Write some bytes as service payloads. Must ALWAYS be called after /// [`TransportSender::start_write`]. - fn write(&mut self, buf: Bytes) -> anyhow::Result; + fn write(&mut self, buf: Bytes) -> impl Future> + Send; } -#[async_trait] pub trait TransportReceiver: Send + Sync + 'static { /// Returns the transport detail from this connection which is then sent to the service on /// the hello frame. @@ -129,7 +135,7 @@ pub trait TransportReceiver: Send + Sync + 'static { /// Receive a frame from the connection. Returns `None` when the connection /// is closed. - async fn recv(&mut self) -> Option; + fn recv(&mut self) -> impl Future> + Send; } pub async fn spawn_transport_by_config( diff --git a/core/handshake/src/transports/tcp.rs b/core/handshake/src/transports/tcp.rs index 50fe8a5ff..1c091970a 100644 --- a/core/handshake/src/transports/tcp.rs +++ b/core/handshake/src/transports/tcp.rs @@ -127,7 +127,7 @@ fn spawn_handshake_task( let (reader, writer) = stream.into_split(); // Send the frame and the new connection over the channel - tx.send((frame, TcpSender::spawn(writer), TcpReceiver::new(reader))) + tx.send((frame, TcpSender::new(writer), TcpReceiver::new(reader))) .await .ok(); }); @@ -143,7 +143,7 @@ async fn spawn_write_driver(mut writer: OwnedWriteHalf, rx: async_channel::Recei } pub struct TcpSender { - sender: async_channel::Sender, + writer: OwnedWriteHalf, current_write: u32, } @@ -151,30 +151,27 @@ impl TcpSender { /// Create the [`TcpSender`], additionally spawning a task to handle writing bytes to the /// stream. #[inline(always)] - pub fn spawn(writer: OwnedWriteHalf) -> Self { - let (sender, receiver) = async_channel::unbounded(); - tokio::spawn(spawn_write_driver(writer, receiver)); - + pub fn new(writer: OwnedWriteHalf) -> Self { Self { - sender, + writer, current_write: 0, } } #[inline(always)] - fn send_inner(&mut self, bytes: Bytes) { - if let Err(e) = self.sender.try_send(bytes) { - warn!("payload dropped, failed to send to write loop: {e}"); + async fn send_inner(&mut self, buf: &[u8]) { + if let Err(e) = self.writer.write_all(buf).await { + warn!("Dropping payload, failed to write to stream: {e}"); } } } impl TransportSender for TcpSender { - fn send_handshake_response(&mut self, response: schema::HandshakeResponse) { - self.send_inner(delimit_frame(response.encode())); + async fn send_handshake_response(&mut self, response: schema::HandshakeResponse) { + self.send_inner(&delimit_frame(response.encode())).await; } - fn send(&mut self, frame: schema::ResponseFrame) { + async fn send(&mut self, frame: schema::ResponseFrame) { debug_assert!( !matches!( frame, @@ -185,10 +182,10 @@ impl TransportSender for TcpSender { ); let bytes = delimit_frame(frame.encode()); - self.send_inner(bytes); + self.send_inner(&bytes).await; } - fn start_write(&mut self, len: usize) { + async fn start_write(&mut self, len: usize) { let len = len as u32; debug_assert!( self.current_write == 0, @@ -202,16 +199,16 @@ impl TransportSender for TcpSender { buffer.put_u32(len + 1); buffer.put_u8(RES_SERVICE_PAYLOAD_TAG); // write the delimiter and payload tag to the stream - self.send_inner(buffer.into()); + self.send_inner(&buffer).await; } - fn write(&mut self, buf: Bytes) -> anyhow::Result { + async fn write(&mut self, buf: Bytes) -> anyhow::Result { let len = u32::try_from(buf.len())?; debug_assert!(self.current_write != 0); debug_assert!(self.current_write >= len); self.current_write -= len; - self.send_inner(buf); + self.send_inner(&buf).await; Ok(len as usize) } } @@ -230,7 +227,6 @@ impl TcpReceiver { } } -#[async_trait] impl TransportReceiver for TcpReceiver { /// Cancel Safety: /// This method is cancel safe, but could potentially allocate multiple times for the delimiter @@ -324,7 +320,7 @@ mod tests { assert_eq!(REQ_FRAME, frame, "received incorrect request frame"); // Send the response frame - sender.send_handshake_response(RES_FRAME); + sender.send_handshake_response(RES_FRAME).await; // Drop the connection } diff --git a/core/handshake/src/transports/webrtc/mod.rs b/core/handshake/src/transports/webrtc/mod.rs index 9ba03a1c2..70d426014 100644 --- a/core/handshake/src/transports/webrtc/mod.rs +++ b/core/handshake/src/transports/webrtc/mod.rs @@ -144,11 +144,11 @@ impl WebRtcSender { } impl TransportSender for WebRtcSender { - fn send_handshake_response(&mut self, frame: schema::HandshakeResponse) { + async fn send_handshake_response(&mut self, frame: schema::HandshakeResponse) { self.send_inner(&frame.encode()); } - fn send(&mut self, frame: schema::ResponseFrame) { + async fn send(&mut self, frame: schema::ResponseFrame) { debug_assert!( !matches!( frame, @@ -162,7 +162,7 @@ impl TransportSender for WebRtcSender { self.send_inner(&frame.encode()); } - fn start_write(&mut self, len: usize) { + async fn start_write(&mut self, len: usize) { debug_assert!( self.current_write == 0, "all bytes should be written before another call to start_write" @@ -171,7 +171,7 @@ impl TransportSender for WebRtcSender { } // TODO: consider buffering up to the max payload size to send less chunks/extra bytes - fn write(&mut self, mut buf: Bytes) -> anyhow::Result { + async fn write(&mut self, mut buf: Bytes) -> anyhow::Result { debug_assert!(self.current_write >= buf.len()); while !buf.is_empty() { @@ -196,7 +196,6 @@ impl TransportSender for WebRtcSender { /// Receiver for a webrtc connection. pub struct WebRtcReceiver(Receiver); -#[async_trait] impl TransportReceiver for WebRtcReceiver { #[inline(always)] async fn recv(&mut self) -> Option { diff --git a/core/handshake/src/transports/webtransport/connection.rs b/core/handshake/src/transports/webtransport/connection.rs index 7b2693b23..b0bb47657 100644 --- a/core/handshake/src/transports/webtransport/connection.rs +++ b/core/handshake/src/transports/webtransport/connection.rs @@ -2,7 +2,6 @@ use std::sync::{Arc, RwLock}; use std::time::Duration; use anyhow::Result; -use bytes::Bytes; use fleek_crypto::{NodeSecretKey, SecretKey}; use futures::StreamExt; use lightning_interfaces::ShutdownWaiter; @@ -121,11 +120,3 @@ pub async fn handle_incoming_session( } } } - -pub async fn sender_loop(data_rx: async_channel::Receiver, mut network_tx: SendStream) { - while let Ok(data) = data_rx.recv().await { - if let Err(e) = network_tx.write_all(&data).await { - error!("failed to send data: {e:?}"); - } - } -} diff --git a/core/handshake/src/transports/webtransport/mod.rs b/core/handshake/src/transports/webtransport/mod.rs index 247152d54..fefd8b7b6 100644 --- a/core/handshake/src/transports/webtransport/mod.rs +++ b/core/handshake/src/transports/webtransport/mod.rs @@ -13,7 +13,7 @@ use futures::StreamExt; use lightning_interfaces::prelude::*; use lightning_metrics::increment_counter; use tokio::sync::mpsc::{self, Receiver}; -use tracing::{error, info, warn}; +use tracing::{error, info}; use wtransport::tls::Certificate; use wtransport::{Endpoint, SendStream, ServerConfig}; @@ -82,8 +82,6 @@ impl Transport for WebTransport { async fn accept(&mut self) -> Option<(HandshakeRequestFrame, Self::Sender, Self::Receiver)> { let (frame, (frame_writer, frame_reader)) = self.conn_rx.recv().await?; - let (data_tx, data_rx) = async_channel::unbounded(); - tokio::spawn(connection::sender_loop(data_rx, frame_writer)); increment_counter!( "handshake_webtransport_sessions", @@ -93,7 +91,7 @@ impl Transport for WebTransport { Some(( frame, WebTransportSender { - tx: data_tx, + writer: frame_writer, current_write: 0, }, WebTransportReceiver { rx: frame_reader }, @@ -102,25 +100,25 @@ impl Transport for WebTransport { } pub struct WebTransportSender { - tx: async_channel::Sender, + writer: SendStream, current_write: u32, } impl WebTransportSender { #[inline(always)] - fn send_inner(&mut self, bytes: Bytes) { - if let Err(e) = self.tx.try_send(bytes) { - warn!("payload dropped, failed to send to write loop: {e}"); + async fn send_inner(&mut self, buf: &[u8]) { + if let Err(e) = self.writer.write_all(buf).await { + error!("failed to send data: {e:?}"); } } } impl TransportSender for WebTransportSender { - fn send_handshake_response(&mut self, response: HandshakeResponse) { - self.send_inner(delimit_frame(response.encode())); + async fn send_handshake_response(&mut self, response: HandshakeResponse) { + self.send_inner(&delimit_frame(response.encode())).await; } - fn send(&mut self, frame: ResponseFrame) { + async fn send(&mut self, frame: ResponseFrame) { debug_assert!( !matches!( frame, @@ -129,10 +127,10 @@ impl TransportSender for WebTransportSender { "payloads should only be sent via start_write and write" ); - self.send_inner(delimit_frame(frame.encode())); + self.send_inner(&delimit_frame(frame.encode())).await; } - fn start_write(&mut self, len: usize) { + async fn start_write(&mut self, len: usize) { let len = len as u32; debug_assert!( self.current_write == 0, @@ -146,16 +144,16 @@ impl TransportSender for WebTransportSender { buffer.put_u32(len + 1); buffer.put_u8(RES_SERVICE_PAYLOAD_TAG); // write the delimiter and payload tag to the stream - self.send_inner(buffer.into()); + self.send_inner(&buffer).await; } - fn write(&mut self, buf: Bytes) -> anyhow::Result { + async fn write(&mut self, buf: Bytes) -> anyhow::Result { let len = u32::try_from(buf.len())?; debug_assert!(self.current_write != 0); debug_assert!(self.current_write >= len); self.current_write -= len; - self.send_inner(buf); + self.send_inner(&buf).await; Ok(len as usize) } } @@ -164,7 +162,6 @@ pub struct WebTransportReceiver { rx: FramedStreamRx, } -#[async_trait] impl TransportReceiver for WebTransportReceiver { async fn recv(&mut self) -> Option { let data = match self.rx.next().await? {