Skip to content

Commit

Permalink
Add slot mechanism for handling early responses
Browse files Browse the repository at this point in the history
Fixes #12 (hopefully)
  • Loading branch information
fwcd committed May 30, 2024
1 parent fb6920c commit d924080
Showing 1 changed file with 50 additions and 22 deletions.
72 changes: 50 additions & 22 deletions lighthouse-client/src/lighthouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,23 @@ use crate::{Check, Error, Result, Spawner, StreamExt};
pub struct Lighthouse<S> {
/// The sink-part of the WebSocket connection.
ws_sink: SplitSink<S, Message>,
/// The response/event handlers, keyed by request id.
txs: Arc<Mutex<HashMap<i32, Sender<ServerMessage<Value>>>>>,
/// The response/event slots, keyed by request id.
slots: Arc<Mutex<HashMap<i32, ResponseSlot>>>,
/// The credentials used to authenticate with the lighthouse.
authentication: Authentication,
/// The next request id. Incremented on every request.
request_id: i32,
}

enum ResponseSlot {
/// Indicates that responses were received before the request task queried it.
/// Generally set by the receive loop task.
EarlyResponses(Vec<ServerMessage<Value>>),
/// Indicates that responses were not received before the request task queried it.
/// Generally set by the request task.
WaitForResponses(Sender<ServerMessage<Value>>),
}

impl<S> Lighthouse<S>
where S: Stream<Item = tungstenite::Result<Message>>
+ Sink<Message, Error = tungstenite::Error>
Expand All @@ -29,36 +38,43 @@ impl<S> Lighthouse<S>
/// Asynchronously runs a receive loop using the provided spawner.
pub fn new<W>(web_socket: S, authentication: Authentication) -> Result<Self> where W: Spawner {
let (ws_sink, ws_stream) = web_socket.split();
let txs = Arc::new(Mutex::new(HashMap::new()));
let slots = Arc::new(Mutex::new(HashMap::new()));
let lh = Self {
ws_sink,
txs: txs.clone(),
slots: slots.clone(),
authentication,
request_id: 0,
};
W::spawn(Self::run_receive_loop(ws_stream, txs));
W::spawn(Self::run_receive_loop(ws_stream, slots));
Ok(lh)
}

/// Runs a loop that continuously receives events.
#[tracing::instrument(skip(ws_stream, txs))]
async fn run_receive_loop(mut ws_stream: SplitStream<S>, txs: Arc<Mutex<HashMap<i32, Sender<ServerMessage<Value>>>>>) {
#[tracing::instrument(skip(ws_stream, slots))]
async fn run_receive_loop(mut ws_stream: SplitStream<S>, slots: Arc<Mutex<HashMap<i32, ResponseSlot>>>) {
loop {
match Self::receive_message_from(&mut ws_stream).await {
Ok(msg) => {
let mut txs = txs.lock().await;
let mut slots = slots.lock().await;
if let Some(request_id) = msg.request_id {
if let Some(tx) = txs.get_mut(&request_id) {
if let Err(e) = tx.send(msg).await {
if e.is_disconnected() {
info!("Receiver for request id {} disconnected, removing the sender...", request_id);
txs.remove(&request_id);
} else {
warn!("Could not send message for request id {} via channel: {:?}", request_id, e);
if let Some(slot) = slots.get_mut(&request_id) {
match slot {
ResponseSlot::EarlyResponses(responses) => {
responses.push(msg);
},
ResponseSlot::WaitForResponses(tx) => {
if let Err(e) = tx.send(msg).await {
if e.is_disconnected() {
info!("Receiver for request id {} disconnected, removing the sender...", request_id);
slots.remove(&request_id);
} else {
warn!("Could not send message for request id {} via channel: {:?}", request_id, e);
}
}
}
}
} else {
warn!("No channel registered for request id {}", request_id);
slots.insert(request_id, ResponseSlot::EarlyResponses(vec![msg]));
}
} else {
warn!("Got message without request id from server: {:?}", msg);
Expand Down Expand Up @@ -233,17 +249,29 @@ impl<S> Lighthouse<S>
where
R: for<'de> Deserialize<'de> {
let rx = {
let mut txs = self.txs.lock().await;
let (tx, rx) = mpsc::channel(4);
txs.insert(request_id, tx);
let capacity = 4;
let (tx, rx) = {
let mut slots = self.slots.lock().await;
if let Some(ResponseSlot::EarlyResponses(responses)) = slots.get_mut(&request_id) {
let (mut tx, rx) = mpsc::channel(capacity.min(responses.len()));
for response in responses.drain(..) {
tx.feed(response).await.map_err(|e| Error::Custom(format!("Could not feed tx with early response: {}", e)))?;
}
tx.flush().await.map_err(|e| Error::Custom(format!("Could not feed tx with early response: {}", e)))?;
(tx, rx)
} else {
mpsc::channel(capacity)
}
};
self.slots.lock().await.insert(request_id, ResponseSlot::WaitForResponses(tx));
rx
};
Ok(rx.map(|s| Ok(s.decode_payload()?)).guard({
let txs = self.txs.clone();
let slots = self.slots.clone();
move || {
tokio::spawn(async move {
let mut txs = txs.lock().await;
txs.remove(&request_id);
let mut slots = slots.lock().await;
slots.remove(&request_id);
});
}
}))
Expand Down

0 comments on commit d924080

Please sign in to comment.