Skip to content

Commit

Permalink
refactor(handshake): async TransportSender
Browse files Browse the repository at this point in the history
  • Loading branch information
ozwaldorf committed May 22, 2024
1 parent 0dfc8d0 commit 59637ff
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 112 deletions.
12 changes: 6 additions & 6 deletions core/handshake/src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ impl<P: ExecutorProviderInterface> Context<P> {

// Attempt to connect to the service, getting the unix socket.
let Some(mut socket) = self.provider.connect(service).await else {
sender.terminate(TerminationReason::InvalidService);
sender.terminate(TerminationReason::InvalidService).await;
warn!("failed to connect to service {service}");
return;
};
Expand All @@ -189,7 +189,7 @@ impl<P: ExecutorProviderInterface> Context<P> {
};

if let Err(e) = write_header(&header, &mut socket).await {
sender.terminate(TerminationReason::ServiceTerminated);
sender.terminate(TerminationReason::ServiceTerminated).await;
warn!("failed to write connection header to service {service}: {e}");
return;
}
Expand Down Expand Up @@ -226,12 +226,12 @@ impl<P: ExecutorProviderInterface> Context<P> {
let connection_id = u64::from_be_bytes(*arrayref::array_ref![access_token, 0, 8]);

let Some(connection) = self.connections.get(&connection_id) else {
sender.terminate(TerminationReason::InvalidToken);
sender.terminate(TerminationReason::InvalidToken).await;
return;
};

if connection.access_token != access_token {
sender.terminate(TerminationReason::InvalidToken);
sender.terminate(TerminationReason::InvalidToken).await;
return;
}

Expand All @@ -241,7 +241,7 @@ impl<P: ExecutorProviderInterface> Context<P> {
.expect("Failed to get current time")
.as_millis()
{
sender.terminate(TerminationReason::InvalidToken);
sender.terminate(TerminationReason::InvalidToken).await;
return;
}

Expand All @@ -255,7 +255,7 @@ impl<P: ExecutorProviderInterface> Context<P> {
retry: Some(id), ..
} => {
let Some(connection) = self.connections.get(&id) else {
sender.terminate(TerminationReason::InvalidToken);
sender.terminate(TerminationReason::InvalidToken).await;
return;
};

Expand Down
32 changes: 16 additions & 16 deletions core/handshake/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ impl<P: ExecutorProviderInterface> Proxy<P> {
res = receiver.recv() => {
match async_map(res, |r| self.handle_incoming(is_primary, r)).await {
Some(HandleRequestResult::Ok) if is_primary => {
self.maybe_flush_primary_queue(true, &mut sender);
self.maybe_flush_primary_queue(true, &mut sender).await;
},
Some(HandleRequestResult::Ok) => {},
Some(HandleRequestResult::TerminateConnection) => {
Expand All @@ -201,13 +201,13 @@ impl<P: ExecutorProviderInterface> Proxy<P> {
// the same connection mode. The client must have lost its connection.
// We terminate the current (old) connection with a `ConnectionInUse` error.
(true, true) => {
sender.terminate(TerminationReason::ConnectionInUse);
sender.terminate(TerminationReason::ConnectionInUse).await;
self.discard_bytes = true;
self.queued_primary_response.clear();
State::OnlyPrimaryConnection(pair)
},
(false, false) => {
sender.terminate(TerminationReason::ConnectionInUse);
sender.terminate(TerminationReason::ConnectionInUse).await;
self.discard_bytes = true;
State::OnlySecondaryConnection(pair)
}
Expand Down Expand Up @@ -255,12 +255,12 @@ impl<P: ExecutorProviderInterface> Proxy<P> {
// This might be a window to flush some pending responses to the
// primary.
if is_primary {
self.maybe_flush_primary_queue(true, &mut sender);
self.maybe_flush_primary_queue(true, &mut sender).await;
}

let bytes = self.buffer.split_to(4);
let len = u32::from_be_bytes(*array_ref![bytes, 0, 4]) as usize;
sender.start_write(len);
sender.start_write(len).await;
self.current_write = len;
continue 'inner; // to handle `len` == 0.
}
Expand All @@ -274,7 +274,7 @@ impl<P: ExecutorProviderInterface> Proxy<P> {
continue 'inner;
}

if sender.write(bytes.freeze()).is_err() {
if sender.write(bytes.freeze()).await.is_err() {
self.discard_bytes = true;
self.queued_primary_response.clear();
return State::NoConnection;
Expand All @@ -285,7 +285,7 @@ impl<P: ExecutorProviderInterface> Proxy<P> {
}
};

sender.terminate(reason);
sender.terminate(reason).await;
State::Terminated
}

Expand All @@ -312,7 +312,7 @@ impl<P: ExecutorProviderInterface> Proxy<P> {
res = p_receiver.recv() => {
match async_map(res, |r| self.handle_incoming(true, r)).await {
Some(HandleRequestResult::Ok) => {
self.maybe_flush_primary_queue(false, &mut p_sender);
self.maybe_flush_primary_queue(false, &mut p_sender).await;
},
Some(HandleRequestResult::TerminateConnection) => {
break 'outer TerminationReason::InternalError;
Expand Down Expand Up @@ -393,11 +393,11 @@ impl<P: ExecutorProviderInterface> Proxy<P> {

// This might be a window to flush some pending responses to the
// primary.
self.maybe_flush_primary_queue(false, &mut p_sender);
self.maybe_flush_primary_queue(false, &mut p_sender).await;

let bytes = self.buffer.split_to(4);
let len = u32::from_be_bytes(*array_ref![bytes, 0, 4]) as usize;
s_sender.start_write(len);
s_sender.start_write(len).await;
self.current_write = len;
continue 'inner; // to handle `len` == 0.
}
Expand All @@ -412,14 +412,14 @@ impl<P: ExecutorProviderInterface> Proxy<P> {
}

return if self.is_primary_the_current_sender {
if p_sender.write(bytes.freeze()).is_ok() {
if p_sender.write(bytes.freeze()).await.is_ok() {
continue 'inner;
}
self.discard_bytes = true;
self.queued_primary_response.clear();
State::OnlySecondaryConnection((s_sender, s_receiver).into())
} else {
if s_sender.write(bytes.freeze()).is_ok() {
if s_sender.write(bytes.freeze()).await.is_ok() {
continue 'inner;
}
self.discard_bytes = true;
Expand All @@ -431,8 +431,8 @@ impl<P: ExecutorProviderInterface> Proxy<P> {
}
};

s_sender.terminate(reason);
p_sender.terminate(reason);
s_sender.terminate(reason).await;
p_sender.terminate(reason).await;
State::Terminated
}

Expand Down Expand Up @@ -493,7 +493,7 @@ impl<P: ExecutorProviderInterface> Proxy<P> {
}

#[inline(always)]
fn maybe_flush_primary_queue<S: TransportSender>(
async fn maybe_flush_primary_queue<S: TransportSender>(
&mut self,
only_primary: bool,
sender: &mut S,
Expand All @@ -505,7 +505,7 @@ impl<P: ExecutorProviderInterface> Proxy<P> {
let is_primary_current_writer = only_primary || self.is_primary_the_current_sender;
if !is_primary_current_writer || self.current_write == 0 || self.discard_bytes {
while let Some(res) = self.queued_primary_response.pop_back() {
sender.send(res);
sender.send(res).await;
}
}
}
Expand Down
6 changes: 1 addition & 5 deletions core/handshake/src/transports/http/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@ pub async fn handler<P: ExecutorProviderInterface>(
};

let (frame_tx, frame_rx) = async_channel::bounded(8);
// Todo: Fix synchronization between data transfer and connection proxy.
// Notes: Reducing the size of this queue produces enough backpressure to slow down the
// transfer of data and if the socket is dropped before all the data is transferred,
// the proxy drops the connection and the client receives incomplete data.
let (body_tx, body_rx) = async_channel::bounded(2_000_000);
let (body_tx, body_rx) = async_channel::bounded(16);
let (termination_tx, termination_rx) = oneshot::channel();

let sender = HttpSender::new(service_id, frame_tx, body_tx, termination_tx);
Expand Down
30 changes: 13 additions & 17 deletions core/handshake/src/transports/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,22 @@ impl HttpSender {
}

#[inline(always)]
fn inner_send(&mut self, bytes: Bytes) {
if let Err(e) = self.body_tx.try_send(Ok(bytes)) {
async fn inner_send(&mut self, bytes: Bytes) {
if let Err(e) = self.body_tx.send(Ok(bytes)).await {
warn!("payload dropped, failed to send to write loop: {e}");
}
}

fn send_with_http_override(&mut self, bytes: Bytes) {
async fn send_with_http_override(&mut self, bytes: Bytes) {
if let Some(header_buffer) = self.header_buffer.as_mut() {
header_buffer.extend(bytes);
if self.current_write == 0 {
// always Some due to previous check
let payload = self.header_buffer.take().unwrap_or_default();

if let Err(e) = self.body_tx.try_send(Ok(payload.into())) {
warn!("payload dropped, failed to send to write loop: {e}");
}
self.inner_send(payload.freeze()).await;
}
} else if let Err(e) = self.body_tx.try_send(Ok(bytes)) {
warn!("payload dropped, failed to send to write loop: {e}");
} else {
self.inner_send(bytes).await
}
}

Expand All @@ -126,21 +123,21 @@ impl HttpSender {
}

impl TransportSender for HttpSender {
fn send_handshake_response(&mut self, _: HandshakeResponse) {
async fn send_handshake_response(&mut self, _: HandshakeResponse) {
unimplemented!()
}

fn send(&mut self, _: ResponseFrame) {
async fn send(&mut self, _: ResponseFrame) {
unimplemented!()
}

fn terminate(mut self, reason: TerminationReason) {
async fn terminate(mut self, reason: TerminationReason) {
if let Some(reason_sender) = self.termination_tx.take() {
let _ = reason_sender.send(reason);
}
}

fn start_write(&mut self, len: usize) {
async fn start_write(&mut self, len: usize) {
// if the header buffer is gone it means we sent the headers already and are ready to stream
// the body
if self.header_buffer.is_none() || !self.service.supports_http_overrides() {
Expand All @@ -155,7 +152,7 @@ impl TransportSender for HttpSender {
self.current_write = len;
}

fn write(&mut self, buf: Bytes) -> anyhow::Result<usize> {
async fn write(&mut self, buf: Bytes) -> anyhow::Result<usize> {
let len = buf.len();

debug_assert!(self.current_write != 0);
Expand All @@ -164,9 +161,9 @@ impl TransportSender for HttpSender {
self.current_write -= len;

if self.service.supports_http_overrides() {
self.send_with_http_override(buf);
self.send_with_http_override(buf).await;
} else {
self.inner_send(buf);
self.inner_send(buf).await;
}

Ok(len)
Expand All @@ -187,7 +184,6 @@ impl HttpReceiver {
}
}

#[async_trait]
impl TransportReceiver for HttpReceiver {
fn detail(&mut self) -> TransportDetail {
self.detail
Expand Down
18 changes: 9 additions & 9 deletions core/handshake/src/transports/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,36 +115,37 @@ pub struct MockTransportSender {
}

impl MockTransportSender {
fn send_inner(&mut self, bytes: Bytes) {
async fn send_inner(&mut self, bytes: Bytes) {
self.tx
.try_send(bytes)
.expect("failed to send bytes over the mock connection")
}
}

impl TransportSender for MockTransportSender {
fn send_handshake_response(&mut self, response: schema::HandshakeResponse) {
self.send_inner(response.encode());
async fn send_handshake_response(&mut self, response: schema::HandshakeResponse) {
self.send_inner(response.encode()).await;
}

fn send(&mut self, frame: schema::ResponseFrame) {
self.send_inner(frame.encode());
async fn send(&mut self, frame: schema::ResponseFrame) {
self.send_inner(frame.encode()).await;
}

fn start_write(&mut self, len: usize) {
async fn start_write(&mut self, len: usize) {
debug_assert!(self.buffer.is_empty());
self.buffer.reserve(len);
self.current_write = len;
}

fn write(&mut self, buf: Bytes) -> anyhow::Result<usize> {
async fn write(&mut self, buf: Bytes) -> anyhow::Result<usize> {
let len = buf.len();
debug_assert!(len <= self.current_write);
self.buffer.put(buf);
if self.buffer.len() >= self.current_write {
let bytes = self.buffer.split_to(self.current_write).into();
self.current_write = 0;
self.send(schema::ResponseFrame::ServicePayload { bytes });
self.send(schema::ResponseFrame::ServicePayload { bytes })
.await;
}
Ok(len)
}
Expand All @@ -155,7 +156,6 @@ pub struct MockTransportReceiver {
rx: async_channel::Receiver<Bytes>,
}

#[async_trait]
impl TransportReceiver for MockTransportReceiver {
async fn recv(&mut self) -> Option<schema::RequestFrame> {
let bytes = self.rx.recv().await.ok()?;
Expand Down
22 changes: 14 additions & 8 deletions core/handshake/src/transports/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use async_trait::async_trait;
use axum::Router;
use bytes::{BufMut, Bytes, BytesMut};
use fn_sdk::header::TransportDetail;
use futures::Future;
use lightning_interfaces::prelude::*;
use serde::de::DeserializeOwned;
use serde::Serialize;
Expand Down Expand Up @@ -101,25 +102,30 @@ pub trait Transport: Sized + Send + Sync + 'static {
// for secondary connections are added
pub trait TransportSender: Sized + Send + Sync + 'static {
/// Send the initial handshake response to the client.
fn send_handshake_response(&mut self, response: schema::HandshakeResponse);
fn send_handshake_response(
&mut self,
response: schema::HandshakeResponse,
) -> impl Future<Output = ()> + Send;

/// Send a frame to the client.
fn send(&mut self, frame: schema::ResponseFrame);
fn send(&mut self, frame: schema::ResponseFrame) -> impl Future<Output = ()> + Send;

/// Terminate the connection
fn terminate(mut self, reason: schema::TerminationReason) {
self.send(schema::ResponseFrame::Termination { reason })
fn terminate(mut self, reason: schema::TerminationReason) -> impl Future<Output = ()> + Send {
async move {
self.send(schema::ResponseFrame::Termination { reason })
.await
}
}

/// Declare a number of bytes to write as service payloads.
fn start_write(&mut self, len: usize);
fn start_write(&mut self, len: usize) -> impl Future<Output = ()> + Send;

/// Write some bytes as service payloads. Must ALWAYS be called after
/// [`TransportSender::start_write`].
fn write(&mut self, buf: Bytes) -> anyhow::Result<usize>;
fn write(&mut self, buf: Bytes) -> impl Future<Output = anyhow::Result<usize>> + Send;
}

#[async_trait]
pub trait TransportReceiver: Send + Sync + 'static {
/// Returns the transport detail from this connection which is then sent to the service on
/// the hello frame.
Expand All @@ -129,7 +135,7 @@ pub trait TransportReceiver: Send + Sync + 'static {

/// Receive a frame from the connection. Returns `None` when the connection
/// is closed.
async fn recv(&mut self) -> Option<schema::RequestFrame>;
fn recv(&mut self) -> impl Future<Output = Option<schema::RequestFrame>> + Send;
}

pub async fn spawn_transport_by_config<P: ExecutorProviderInterface>(
Expand Down
Loading

0 comments on commit 59637ff

Please sign in to comment.