diff --git a/lighthouse-client/src/lighthouse.rs b/lighthouse-client/src/lighthouse.rs index fcc0c67..804a5fe 100644 --- a/lighthouse-client/src/lighthouse.rs +++ b/lighthouse-client/src/lighthouse.rs @@ -11,14 +11,23 @@ use crate::{Check, Error, Result, Spawner, StreamExt}; pub struct Lighthouse { /// The sink-part of the WebSocket connection. ws_sink: SplitSink, - /// The response/event handlers, keyed by request id. - txs: Arc>>>>, + /// The response/event slots, keyed by request id. + slots: Arc>>, /// The credentials used to authenticate with the lighthouse. authentication: Authentication, /// The next request id. Incremented on every request. request_id: i32, } +enum ResponseSlot { + /// Indicates that responses were received before the request task queried it. + /// Generally set by the receive loop task. + EarlyResponses(Vec>), + /// Indicates that responses were not received before the request task queried it. + /// Generally set by the request task. + WaitForResponses(Sender>), +} + impl Lighthouse where S: Stream> + Sink @@ -29,36 +38,43 @@ impl Lighthouse /// Asynchronously runs a receive loop using the provided spawner. pub fn new(web_socket: S, authentication: Authentication) -> Result where W: Spawner { let (ws_sink, ws_stream) = web_socket.split(); - let txs = Arc::new(Mutex::new(HashMap::new())); + let slots = Arc::new(Mutex::new(HashMap::new())); let lh = Self { ws_sink, - txs: txs.clone(), + slots: slots.clone(), authentication, request_id: 0, }; - W::spawn(Self::run_receive_loop(ws_stream, txs)); + W::spawn(Self::run_receive_loop(ws_stream, slots)); Ok(lh) } /// Runs a loop that continuously receives events. - #[tracing::instrument(skip(ws_stream, txs))] - async fn run_receive_loop(mut ws_stream: SplitStream, txs: Arc>>>>) { + #[tracing::instrument(skip(ws_stream, slots))] + async fn run_receive_loop(mut ws_stream: SplitStream, slots: Arc>>) { loop { match Self::receive_message_from(&mut ws_stream).await { Ok(msg) => { - let mut txs = txs.lock().await; + let mut slots = slots.lock().await; if let Some(request_id) = msg.request_id { - if let Some(tx) = txs.get_mut(&request_id) { - if let Err(e) = tx.send(msg).await { - if e.is_disconnected() { - info!("Receiver for request id {} disconnected, removing the sender...", request_id); - txs.remove(&request_id); - } else { - warn!("Could not send message for request id {} via channel: {:?}", request_id, e); + if let Some(slot) = slots.get_mut(&request_id) { + match slot { + ResponseSlot::EarlyResponses(responses) => { + responses.push(msg); + }, + ResponseSlot::WaitForResponses(tx) => { + if let Err(e) = tx.send(msg).await { + if e.is_disconnected() { + info!("Receiver for request id {} disconnected, removing the sender...", request_id); + slots.remove(&request_id); + } else { + warn!("Could not send message for request id {} via channel: {:?}", request_id, e); + } + } } } } else { - warn!("No channel registered for request id {}", request_id); + slots.insert(request_id, ResponseSlot::EarlyResponses(vec![msg])); } } else { warn!("Got message without request id from server: {:?}", msg); @@ -233,17 +249,29 @@ impl Lighthouse where R: for<'de> Deserialize<'de> { let rx = { - let mut txs = self.txs.lock().await; - let (tx, rx) = mpsc::channel(4); - txs.insert(request_id, tx); + let capacity = 4; + let (tx, rx) = { + let mut slots = self.slots.lock().await; + if let Some(ResponseSlot::EarlyResponses(responses)) = slots.get_mut(&request_id) { + let (mut tx, rx) = mpsc::channel(capacity.min(responses.len())); + for response in responses.drain(..) { + tx.feed(response).await.map_err(|e| Error::Custom(format!("Could not feed tx with early response: {}", e)))?; + } + tx.flush().await.map_err(|e| Error::Custom(format!("Could not feed tx with early response: {}", e)))?; + (tx, rx) + } else { + mpsc::channel(capacity) + } + }; + self.slots.lock().await.insert(request_id, ResponseSlot::WaitForResponses(tx)); rx }; Ok(rx.map(|s| Ok(s.decode_payload()?)).guard({ - let txs = self.txs.clone(); + let slots = self.slots.clone(); move || { tokio::spawn(async move { - let mut txs = txs.lock().await; - txs.remove(&request_id); + let mut slots = slots.lock().await; + slots.remove(&request_id); }); } }))