diff --git a/src/pipeline.rs b/src/pipeline.rs index fb3f314..a7958de 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -1,6 +1,10 @@ use std::future::Future; -use futures::{future::BoxFuture, stream::FuturesUnordered, FutureExt, Stream, StreamExt}; +use futures::{ + future::{self, BoxFuture}, + stream::FuturesUnordered, + FutureExt, Stream, StreamExt, +}; use tokio::{ sync::mpsc::{self, Receiver}, task::{JoinError, JoinHandle}, @@ -244,6 +248,108 @@ where self.pump(crate::pumps::batch::BatchPump { n }) } + /// Batch items from the input pipeline while the provided function returns `Some`. Similarly to iterator 'fold' method, + /// the function takes a state and an item, and returns a new state. This allows the user to control when to emit a batch + /// based on the accumulated state. On batch emission, the state is reset. + /// + /// If the pipeline input is finite, the last item emitted will be partial. + /// + /// # Example + /// ```rust + /// use pumps::Pipeline; + /// # tokio::runtime::Runtime::new().unwrap().block_on(async { + /// let (mut output, h) = Pipeline::from_iter(vec![1, 2, 3, 4, 5]) + /// .batch_while(0, |state, x| { + /// let sum = state + x; + /// + /// (sum < 10).then_some(sum) + /// }) + /// .build(); + /// + /// assert_eq!(output.recv().await, Some(vec![1, 2, 3, 4])); + /// assert_eq!(output.recv().await, Some(vec![5])); + /// assert_eq!(output.recv().await, None); + /// # }); + pub fn batch_while(self, state_init: State, mut while_fn: F) -> Pipeline> + where + F: FnMut(State, &Out) -> Option + Send + 'static, + State: Send + Clone + 'static, + { + self.pump( + crate::pumps::batch_while_with_expiry::BatchWhileWithExpiryPump { + state_init, + while_fn: move |state: State, x: &Out| { + let new_state = while_fn(state.clone(), x); + new_state.map(|new_state| (new_state, future::pending())) + }, + }, + ) + } + + /// Batch items from the input pipeline while the provided function returns `Some`. Similarly to iterator 'fold' method, + /// the function takes a state and an item, and returns a new state. This allows the user to control when to emit a batch + /// based on the accumulated state. + /// + /// This variant of batch allows the user to return a future aside from the new batch state. If this future resovles + /// before the batch is emitted, the batch will be emitted, and the state will be reset. Only the most + /// recent future is considered. For example, this feature can be used to implement a timeout for the batch. + /// + /// If the pipeline input is finite, the last item emitted will be partial. + /// + /// # Example + /// ```rust + /// use pumps::Pipeline; + /// use std::time::Duration; + /// use tokio::{time::sleep, sync::mpsc}; + /// + /// # async fn send(input_sender: &mpsc::Sender, data: Vec) { + /// # for x in data { + /// # input_sender.send(x).await.unwrap(); + /// # } + /// # } + /// + /// # tokio::runtime::Runtime::new().unwrap().block_on(async { + /// let (input_sender, input_receiver) = mpsc::channel(100); + /// + /// let (mut output, h) = Pipeline::from(input_receiver) + /// .batch_while_with_expiry(0, |state, x| { + /// let sum = state + x; + /// + /// // batch until accumulated sum of 10, or 100ms passed without activity + /// (sum < 10).then_some((sum, sleep(Duration::from_millis(100)))) + /// }) + /// .build(); + /// + /// send(&input_sender, vec![1,2]).await; + /// + /// sleep(Duration::from_millis(200)).await; + /// + /// assert_eq!(output.recv().await, Some(vec![1, 2])); + /// + /// send(&input_sender, vec![3, 3, 4]).await; + /// assert_eq!(output.recv().await, Some(vec![3, 3, 4])); + /// + /// drop(input_sender); + /// assert_eq!(output.recv().await, None); + /// # }); + pub fn batch_while_with_expiry( + self, + state_init: State, + while_fn: F, + ) -> Pipeline> + where + F: FnMut(State, &Out) -> Option<(State, Fut)> + Send + 'static, + Fut: Future + Send, + State: Send + Clone + 'static, + { + self.pump( + crate::pumps::batch_while_with_expiry::BatchWhileWithExpiryPump { + state_init, + while_fn, + }, + ) + } + /// Attach a skip pump to the pipeline. Skip will skip the first `n` items in the pipeline. pub fn skip(self, n: usize) -> Pipeline { self.pump(crate::pumps::skip::SkipPump { n }) diff --git a/src/pumps/backpressure.rs b/src/pumps/backpressure.rs index 9042433..e3b4afd 100644 --- a/src/pumps/backpressure.rs +++ b/src/pumps/backpressure.rs @@ -32,9 +32,10 @@ where #[cfg(test)] mod test { - use std::time::Duration; use tokio::sync::mpsc; + use crate::test_utils::wait_for_capacity; + #[tokio::test] async fn backpressure_buffers_1_input_at_a_time() { let (input_sender, input_receiver) = mpsc::channel(1); @@ -88,11 +89,6 @@ mod test { .backpressure(2) .build(); - async fn wait_for_capacity(sender: &mpsc::Sender) -> usize { - tokio::time::sleep(Duration::from_millis(20)).await; - sender.capacity() - } - input_sender.send(1).await.unwrap(); // this arrives to the output channel assert_eq!(wait_for_capacity(&input_sender).await, 1); diff --git a/src/pumps/batch.rs b/src/pumps/batch.rs index 6afeeb5..8026b75 100644 --- a/src/pumps/batch.rs +++ b/src/pumps/batch.rs @@ -44,7 +44,7 @@ mod tests { use crate::Pipeline; #[tokio::test] - async fn batch_works() { + async fn batch_batches_n_items() { let (input_sender, input_receiver) = mpsc::channel(100); let (mut output_receiver, join_handle) = Pipeline::from(input_receiver).batch(2).build(); diff --git a/src/pumps/batch_while_with_expiry.rs b/src/pumps/batch_while_with_expiry.rs new file mode 100644 index 0000000..8fc09fa --- /dev/null +++ b/src/pumps/batch_while_with_expiry.rs @@ -0,0 +1,193 @@ +use std::future::Future; + +use futures::FutureExt; +use tokio::{ + sync::mpsc::{self, Receiver}, + task::JoinHandle, +}; + +use crate::Pump; + +pub struct BatchWhileWithExpiryPump { + pub(crate) state_init: State, + pub(crate) while_fn: F, +} + +impl Pump> for BatchWhileWithExpiryPump +where + F: FnMut(State, &T) -> Option<(State, Fut)> + Send + 'static, + Fut: Future + Send, + State: Send + Clone + 'static, + T: Send + 'static, +{ + fn spawn(mut self, mut input_receiver: Receiver) -> (Receiver>, JoinHandle<()>) { + let (output_sender, output_receiver) = mpsc::channel(1); + + let h = tokio::spawn(async move { + let mut batch = Vec::new(); + let mut current_state = self.state_init.clone(); + + let mut expiry_fut = futures::future::pending().boxed(); + + loop { + tokio::select! { + biased; + + _ = expiry_fut => { + expiry_fut = futures::future::pending().boxed(); + current_state = self.state_init.clone(); + + if !batch.is_empty() { + let batch = std::mem::take(&mut batch); + if let Err(_e) = output_sender.send(batch).await { + break; + } + } + } + + input = input_receiver.recv() => { + let Some(input) = input else { + break; + }; + + let res = (self.while_fn)(current_state.clone(), &input); + batch.push(input); + + if let Some((new_state, new_state_fut)) = res { + current_state = new_state; + expiry_fut = new_state_fut.boxed(); + } else { + current_state = self.state_init.clone(); + expiry_fut = futures::future::pending().boxed(); + + let batch = std::mem::take(&mut batch); + if let Err(_e) = output_sender.send(batch).await { + break; + } + } + } + + } + } + + if !batch.is_empty() { + // Try to send the remaining batch, in case the output channel is still alive. + let _ = output_sender.send(batch).await; + } + }); + + (output_receiver, h) + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use tokio::sync::mpsc; + + use crate::{test_utils::wait_for_len, Pipeline}; + + #[tokio::test] + async fn batch_while_with_expiry_batches_while_predicate_is_true() { + let (input_sender, input_receiver) = mpsc::channel(100); + + let (mut output_receiver, join_handle) = Pipeline::from(input_receiver) + .batch_while_with_expiry(0, |state, x| { + let sum = state + x; + + (sum < 10).then_some((sum, tokio::time::sleep(Duration::from_millis(100)))) + }) + .build(); + + input_sender.send(4).await.unwrap(); + input_sender.send(5).await.unwrap(); + assert_eq!(wait_for_len(&output_receiver).await, 0); + + input_sender.send(2).await.unwrap(); // reached 10. should batch + input_sender.send(3).await.unwrap(); + input_sender.send(3).await.unwrap(); + + assert_eq!(output_receiver.recv().await, Some(vec![4, 5, 2])); + + input_sender.send(3).await.unwrap(); + assert_eq!(wait_for_len(&output_receiver).await, 0); + + input_sender.send(2).await.unwrap(); + input_sender.send(5).await.unwrap(); + assert_eq!(output_receiver.recv().await, Some(vec![3, 3, 3, 2])); + + drop(input_sender); + assert_eq!(output_receiver.recv().await, Some(vec![5])); // the last batch + + join_handle.await.unwrap(); + } + + #[tokio::test] + async fn batch_while_with_expiry_emits_batch_on_expire() { + let (input_sender, input_receiver) = mpsc::channel(100); + + let expiry = Duration::from_millis(100); + + let (mut output_receiver, join_handle) = Pipeline::from(input_receiver) + .batch_while_with_expiry(0, move |state, x| { + let sum = state + x; + + (sum < 10).then_some((sum, tokio::time::sleep(expiry))) + }) + .build(); + + input_sender.send(4).await.unwrap(); + input_sender.send(5).await.unwrap(); + assert_eq!(wait_for_len(&output_receiver).await, 0); + tokio::time::sleep(expiry * 2).await; + + assert_eq!(output_receiver.recv().await, Some(vec![4, 5])); + + input_sender.send(2).await.unwrap(); + input_sender.send(3).await.unwrap(); + input_sender.send(6).await.unwrap(); + input_sender.send(5).await.unwrap(); + assert_eq!(output_receiver.recv().await, Some(vec![2, 3, 6])); + + drop(input_sender); + assert_eq!(output_receiver.recv().await, Some(vec![5])); // the last batch + assert_eq!(output_receiver.recv().await, None); // the last batch + join_handle.await.unwrap(); + } + + #[tokio::test] + async fn batch_while_batches_while_predicate_is_true() { + let (input_sender, input_receiver) = mpsc::channel(100); + + let (mut output_receiver, join_handle) = Pipeline::from(input_receiver) + .batch_while(0, |state, x| { + let sum = state + x; + + (sum < 10).then_some(sum) + }) + .build(); + + input_sender.send(4).await.unwrap(); + input_sender.send(5).await.unwrap(); + assert_eq!(wait_for_len(&output_receiver).await, 0); + + input_sender.send(2).await.unwrap(); // reached 10. should batch + input_sender.send(3).await.unwrap(); + input_sender.send(3).await.unwrap(); + + assert_eq!(output_receiver.recv().await, Some(vec![4, 5, 2])); + + input_sender.send(3).await.unwrap(); + assert_eq!(wait_for_len(&output_receiver).await, 0); + + input_sender.send(2).await.unwrap(); + input_sender.send(5).await.unwrap(); + assert_eq!(output_receiver.recv().await, Some(vec![3, 3, 3, 2])); + + drop(input_sender); + assert_eq!(output_receiver.recv().await, Some(vec![5])); // the last batch + + join_handle.await.unwrap(); + } +} diff --git a/src/pumps/mod.rs b/src/pumps/mod.rs index 18c3cd1..0c53015 100644 --- a/src/pumps/mod.rs +++ b/src/pumps/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod backpressure; pub(crate) mod backpressure_with_relief_valve; pub(crate) mod batch; +pub(crate) mod batch_while_with_expiry; pub(crate) mod enumerate; pub(crate) mod filter_map; pub(crate) mod flatten; diff --git a/src/test_utils.rs b/src/test_utils.rs index d556434..67e7fdb 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,7 +1,11 @@ -use std::{collections::HashMap, sync::Arc, time::SystemTime}; +use std::{ + collections::HashMap, + sync::Arc, + time::{Duration, SystemTime}, +}; use futures::{future::BoxFuture, FutureExt}; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, Mutex}; pub struct TestValue { pub id: i32, @@ -145,3 +149,13 @@ impl FutureTimings { } } } + +pub async fn wait_for_capacity(sender: &mpsc::Sender) -> usize { + tokio::time::sleep(Duration::from_millis(20)).await; + sender.capacity() +} + +pub async fn wait_for_len(receiver: &mpsc::Receiver) -> usize { + tokio::time::sleep(Duration::from_millis(20)).await; + receiver.len() +}