Skip to content

Commit

Permalink
chore(handshake): inline(always) transport send/recv, minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ozwaldorf committed May 22, 2024
1 parent 59637ff commit 7545c38
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 22 deletions.
2 changes: 2 additions & 0 deletions core/handshake/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,12 @@ impl<P: ExecutorProviderInterface> Proxy<P> {
}
}

#[inline(always)]
pub fn spawn(self, start: Option<State>) {
tokio::spawn(self.run(start));
}

#[inline(always)]
pub async fn run(mut self, start: Option<State>) {
// Run the main loop of the state each iteration of this loop switches between different
// cases and implementations depending on the connections and their existence. And does
Expand Down
5 changes: 5 additions & 0 deletions core/handshake/src/transports/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,14 @@ impl TransportSender for HttpSender {
unimplemented!()
}

#[inline(always)]
async fn terminate(mut self, reason: TerminationReason) {
if let Some(reason_sender) = self.termination_tx.take() {
let _ = reason_sender.send(reason);
}
}

#[inline(always)]
async fn start_write(&mut self, len: usize) {
// if the header buffer is gone it means we sent the headers already and are ready to stream
// the body
Expand All @@ -152,6 +154,7 @@ impl TransportSender for HttpSender {
self.current_write = len;
}

#[inline(always)]
async fn write(&mut self, buf: Bytes) -> anyhow::Result<usize> {
let len = buf.len();

Expand Down Expand Up @@ -185,12 +188,14 @@ impl HttpReceiver {
}

impl TransportReceiver for HttpReceiver {
#[inline(always)]
fn detail(&mut self) -> TransportDetail {
self.detail
.take()
.expect("HTTP Transport detail already taken.")
}

#[inline(always)]
async fn recv(&mut self) -> Option<RequestFrame> {
self.inner.recv().await.ok().flatten()
}
Expand Down
23 changes: 13 additions & 10 deletions core/handshake/src/transports/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,29 +114,31 @@ pub struct MockTransportSender {
buffer: BytesMut,
}

impl MockTransportSender {
async fn send_inner(&mut self, bytes: Bytes) {
self.tx
.try_send(bytes)
.expect("failed to send bytes over the mock connection")
}
}

impl TransportSender for MockTransportSender {
#[inline(always)]
async fn send_handshake_response(&mut self, response: schema::HandshakeResponse) {
self.send_inner(response.encode()).await;
self.tx
.send(response.encode())
.await
.expect("failed to send bytes over the mock connection")
}

#[inline(always)]
async fn send(&mut self, frame: schema::ResponseFrame) {
self.send_inner(frame.encode()).await;
self.tx
.send(frame.encode())
.await
.expect("failed to send bytes over the mock connection")
}

#[inline(always)]
async fn start_write(&mut self, len: usize) {
debug_assert!(self.buffer.is_empty());
self.buffer.reserve(len);
self.current_write = len;
}

#[inline(always)]
async fn write(&mut self, buf: Bytes) -> anyhow::Result<usize> {
let len = buf.len();
debug_assert!(len <= self.current_write);
Expand All @@ -157,6 +159,7 @@ pub struct MockTransportReceiver {
}

impl TransportReceiver for MockTransportReceiver {
#[inline(always)]
async fn recv(&mut self) -> Option<schema::RequestFrame> {
let bytes = self.rx.recv().await.ok()?;
Some(schema::RequestFrame::decode(&bytes).expect("failed to decode request frame"))
Expand Down
4 changes: 3 additions & 1 deletion core/handshake/src/transports/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ pub trait TransportSender: Sized + Send + Sync + 'static {
pub trait TransportReceiver: Send + Sync + 'static {
/// Returns the transport detail from this connection which is then sent to the service on
/// the hello frame.
#[inline(always)]
fn detail(&mut self) -> TransportDetail {
TransportDetail::Other
}
Expand Down Expand Up @@ -177,9 +178,10 @@ pub async fn spawn_transport_by_config<P: ExecutorProviderInterface>(
}

/// Delimit a complete frame with a u32 length.
#[inline(always)]
pub fn delimit_frame(bytes: Bytes) -> Bytes {
let mut buf = BytesMut::with_capacity(4 + bytes.len());
buf.put_u32(bytes.len() as u32);
buf.put(bytes);
buf.into()
buf.freeze()
}
16 changes: 7 additions & 9 deletions core/handshake/src/transports/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ impl Transport for TcpTransport {
}
}

#[inline(always)]
fn spawn_handshake_task(
mut stream: TcpStream,
tx: mpsc::Sender<(schema::HandshakeRequestFrame, TcpSender, TcpReceiver)>,
Expand Down Expand Up @@ -133,15 +134,6 @@ fn spawn_handshake_task(
});
}

/// Driver loop to write outgoing bytes directly to the stream.
async fn spawn_write_driver(mut writer: OwnedWriteHalf, rx: async_channel::Receiver<Bytes>) {
while let Ok(bytes) = rx.recv().await {
if let Err(e) = writer.write_all(&bytes).await {
warn!("Dropping payload, failed to write to stream: {e}");
};
}
}

pub struct TcpSender {
writer: OwnedWriteHalf,
current_write: u32,
Expand All @@ -167,10 +159,12 @@ impl TcpSender {
}

impl TransportSender for TcpSender {
#[inline(always)]
async fn send_handshake_response(&mut self, response: schema::HandshakeResponse) {
self.send_inner(&delimit_frame(response.encode())).await;
}

#[inline(always)]
async fn send(&mut self, frame: schema::ResponseFrame) {
debug_assert!(
!matches!(
Expand All @@ -185,6 +179,7 @@ impl TransportSender for TcpSender {
self.send_inner(&bytes).await;
}

#[inline(always)]
async fn start_write(&mut self, len: usize) {
let len = len as u32;
debug_assert!(
Expand All @@ -202,6 +197,7 @@ impl TransportSender for TcpSender {
self.send_inner(&buffer).await;
}

#[inline(always)]
async fn write(&mut self, buf: Bytes) -> anyhow::Result<usize> {
let len = u32::try_from(buf.len())?;
debug_assert!(self.current_write != 0);
Expand All @@ -219,6 +215,7 @@ pub struct TcpReceiver {
}

impl TcpReceiver {
#[inline(always)]
pub fn new(reader: OwnedReadHalf) -> Self {
Self {
reader,
Expand All @@ -231,6 +228,7 @@ impl TransportReceiver for TcpReceiver {
/// Cancel Safety:
/// This method is cancel safe, but could potentially allocate multiple times for the delimiter
/// if canceled.
#[inline(always)]
async fn recv(&mut self) -> Option<schema::RequestFrame> {
loop {
if self.buffer.len() < 4 {
Expand Down
4 changes: 4 additions & 0 deletions core/handshake/src/transports/webrtc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,12 @@ impl WebRtcSender {
}

impl TransportSender for WebRtcSender {
#[inline(always)]
async fn send_handshake_response(&mut self, frame: schema::HandshakeResponse) {
self.send_inner(&frame.encode());
}

#[inline(always)]
async fn send(&mut self, frame: schema::ResponseFrame) {
debug_assert!(
!matches!(
Expand All @@ -162,6 +164,7 @@ impl TransportSender for WebRtcSender {
self.send_inner(&frame.encode());
}

#[inline(always)]
async fn start_write(&mut self, len: usize) {
debug_assert!(
self.current_write == 0,
Expand All @@ -171,6 +174,7 @@ impl TransportSender for WebRtcSender {
}

// TODO: consider buffering up to the max payload size to send less chunks/extra bytes
#[inline(always)]
async fn write(&mut self, mut buf: Bytes) -> anyhow::Result<usize> {
debug_assert!(self.current_write >= buf.len());

Expand Down
9 changes: 7 additions & 2 deletions core/handshake/src/transports/webtransport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl Transport for WebTransport {
}

async fn accept(&mut self) -> Option<(HandshakeRequestFrame, Self::Sender, Self::Receiver)> {
let (frame, (frame_writer, frame_reader)) = self.conn_rx.recv().await?;
let (frame, (writer, frame_reader)) = self.conn_rx.recv().await?;

increment_counter!(
"handshake_webtransport_sessions",
Expand All @@ -91,7 +91,7 @@ impl Transport for WebTransport {
Some((
frame,
WebTransportSender {
writer: frame_writer,
writer,
current_write: 0,
},
WebTransportReceiver { rx: frame_reader },
Expand All @@ -114,10 +114,12 @@ impl WebTransportSender {
}

impl TransportSender for WebTransportSender {
#[inline(always)]
async fn send_handshake_response(&mut self, response: HandshakeResponse) {
self.send_inner(&delimit_frame(response.encode())).await;
}

#[inline(always)]
async fn send(&mut self, frame: ResponseFrame) {
debug_assert!(
!matches!(
Expand All @@ -130,6 +132,7 @@ impl TransportSender for WebTransportSender {
self.send_inner(&delimit_frame(frame.encode())).await;
}

#[inline(always)]
async fn start_write(&mut self, len: usize) {
let len = len as u32;
debug_assert!(
Expand All @@ -147,6 +150,7 @@ impl TransportSender for WebTransportSender {
self.send_inner(&buffer).await;
}

#[inline(always)]
async fn write(&mut self, buf: Bytes) -> anyhow::Result<usize> {
let len = u32::try_from(buf.len())?;
debug_assert!(self.current_write != 0);
Expand All @@ -163,6 +167,7 @@ pub struct WebTransportReceiver {
}

impl TransportReceiver for WebTransportReceiver {
#[inline(always)]
async fn recv(&mut self) -> Option<RequestFrame> {
let data = match self.rx.next().await? {
Ok(data) => data,
Expand Down

0 comments on commit 7545c38

Please sign in to comment.