Skip to content

Commit

Permalink
Add batch_while and batch_while_with_expiry
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpusch committed Nov 2, 2024
1 parent 3b15d43 commit db83ef3
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 10 deletions.
108 changes: 107 additions & 1 deletion src/pipeline.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::future::Future;

use futures::{future::BoxFuture, stream::FuturesUnordered, FutureExt, Stream, StreamExt};
use futures::{
future::{self, BoxFuture},
stream::FuturesUnordered,
FutureExt, Stream, StreamExt,
};
use tokio::{
sync::mpsc::{self, Receiver},
task::{JoinError, JoinHandle},
Expand Down Expand Up @@ -244,6 +248,108 @@ where
self.pump(crate::pumps::batch::BatchPump { n })
}

/// Batch items from the input pipeline while the provided function returns `Some`. Similarly to iterator 'fold' method,
/// the function takes a state and an item, and returns a new state. This allows the user to control when to emit a batch
/// based on the accumulated state. On batch emission, the state is reset.
///
/// If the pipeline input is finite, the last item emitted will be partial.
///
/// # Example
/// ```rust
/// use pumps::Pipeline;
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let (mut output, h) = Pipeline::from_iter(vec![1, 2, 3, 4, 5])
/// .batch_while(0, |state, x| {
/// let sum = state + x;
///
/// (sum < 10).then_some(sum)
/// })
/// .build();
///
/// assert_eq!(output.recv().await, Some(vec![1, 2, 3, 4]));
/// assert_eq!(output.recv().await, Some(vec![5]));
/// assert_eq!(output.recv().await, None);
/// # });
pub fn batch_while<F, State>(self, state_init: State, mut while_fn: F) -> Pipeline<Vec<Out>>
where
F: FnMut(State, &Out) -> Option<State> + Send + 'static,
State: Send + Clone + 'static,
{
self.pump(
crate::pumps::batch_while_with_expiry::BatchWhileWithExpiryPump {
state_init,
while_fn: move |state: State, x: &Out| {
let new_state = while_fn(state.clone(), x);
new_state.map(|new_state| (new_state, future::pending()))
},
},
)
}

/// Batch items from the input pipeline while the provided function returns `Some`. Similarly to iterator 'fold' method,
/// the function takes a state and an item, and returns a new state. This allows the user to control when to emit a batch
/// based on the accumulated state.
///
/// This variant of batch allows the user to return a future aside from the new batch state. If this future resovles
/// before the batch is emitted, the batch will be emitted, and the state will be reset. Only the most
/// recent future is considered. For example, this feature can be used to implement a timeout for the batch.
///
/// If the pipeline input is finite, the last item emitted will be partial.
///
/// # Example
/// ```rust
/// use pumps::Pipeline;
/// use std::time::Duration;
/// use tokio::{time::sleep, sync::mpsc};
///
/// # async fn send(input_sender: &mpsc::Sender<i32>, data: Vec<i32>) {
/// # for x in data {
/// # input_sender.send(x).await.unwrap();
/// # }
/// # }
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let (input_sender, input_receiver) = mpsc::channel(100);
///
/// let (mut output, h) = Pipeline::from(input_receiver)
/// .batch_while_with_expiry(0, |state, x| {
/// let sum = state + x;
///
/// // batch until accumulated sum of 10, or 100ms passed without activity
/// (sum < 10).then_some((sum, sleep(Duration::from_millis(100))))
/// })
/// .build();
///
/// send(&input_sender, vec![1,2]).await;
///
/// sleep(Duration::from_millis(200)).await;
///
/// assert_eq!(output.recv().await, Some(vec![1, 2]));
///
/// send(&input_sender, vec![3, 3, 4]).await;
/// assert_eq!(output.recv().await, Some(vec![3, 3, 4]));
///
/// drop(input_sender);
/// assert_eq!(output.recv().await, None);
/// # });
pub fn batch_while_with_expiry<F, Fut, State>(
self,
state_init: State,
while_fn: F,
) -> Pipeline<Vec<Out>>
where
F: FnMut(State, &Out) -> Option<(State, Fut)> + Send + 'static,
Fut: Future<Output = ()> + Send,
State: Send + Clone + 'static,
{
self.pump(
crate::pumps::batch_while_with_expiry::BatchWhileWithExpiryPump {
state_init,
while_fn,
},
)
}

/// Attach a skip pump to the pipeline. Skip will skip the first `n` items in the pipeline.
pub fn skip(self, n: usize) -> Pipeline<Out> {
self.pump(crate::pumps::skip::SkipPump { n })
Expand Down
8 changes: 2 additions & 6 deletions src/pumps/backpressure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ where

#[cfg(test)]
mod test {
use std::time::Duration;
use tokio::sync::mpsc;

use crate::test_utils::wait_for_capacity;

#[tokio::test]
async fn backpressure_buffers_1_input_at_a_time() {
let (input_sender, input_receiver) = mpsc::channel(1);
Expand Down Expand Up @@ -88,11 +89,6 @@ mod test {
.backpressure(2)
.build();

async fn wait_for_capacity(sender: &mpsc::Sender<i32>) -> usize {
tokio::time::sleep(Duration::from_millis(20)).await;
sender.capacity()
}

input_sender.send(1).await.unwrap(); // this arrives to the output channel
assert_eq!(wait_for_capacity(&input_sender).await, 1);

Expand Down
2 changes: 1 addition & 1 deletion src/pumps/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ mod tests {
use crate::Pipeline;

#[tokio::test]
async fn batch_works() {
async fn batch_batches_n_items() {
let (input_sender, input_receiver) = mpsc::channel(100);

let (mut output_receiver, join_handle) = Pipeline::from(input_receiver).batch(2).build();
Expand Down
193 changes: 193 additions & 0 deletions src/pumps/batch_while_with_expiry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
use std::future::Future;

use futures::FutureExt;
use tokio::{
sync::mpsc::{self, Receiver},
task::JoinHandle,
};

use crate::Pump;

pub struct BatchWhileWithExpiryPump<F, State> {
pub(crate) state_init: State,
pub(crate) while_fn: F,
}

impl<T, F, Fut, State> Pump<T, Vec<T>> for BatchWhileWithExpiryPump<F, State>
where
F: FnMut(State, &T) -> Option<(State, Fut)> + Send + 'static,
Fut: Future<Output = ()> + Send,
State: Send + Clone + 'static,
T: Send + 'static,
{
fn spawn(mut self, mut input_receiver: Receiver<T>) -> (Receiver<Vec<T>>, JoinHandle<()>) {
let (output_sender, output_receiver) = mpsc::channel(1);

let h = tokio::spawn(async move {
let mut batch = Vec::new();
let mut current_state = self.state_init.clone();

let mut expiry_fut = futures::future::pending().boxed();

loop {
tokio::select! {
biased;

_ = expiry_fut => {
expiry_fut = futures::future::pending().boxed();
current_state = self.state_init.clone();

if !batch.is_empty() {
let batch = std::mem::take(&mut batch);
if let Err(_e) = output_sender.send(batch).await {
break;
}
}
}

input = input_receiver.recv() => {
let Some(input) = input else {
break;
};

let res = (self.while_fn)(current_state.clone(), &input);
batch.push(input);

if let Some((new_state, new_state_fut)) = res {
current_state = new_state;
expiry_fut = new_state_fut.boxed();
} else {
current_state = self.state_init.clone();
expiry_fut = futures::future::pending().boxed();

let batch = std::mem::take(&mut batch);
if let Err(_e) = output_sender.send(batch).await {
break;
}
}
}

}
}

if !batch.is_empty() {
// Try to send the remaining batch, in case the output channel is still alive.
let _ = output_sender.send(batch).await;
}
});

(output_receiver, h)
}
}

#[cfg(test)]
mod tests {
use std::time::Duration;

use tokio::sync::mpsc;

use crate::{test_utils::wait_for_len, Pipeline};

#[tokio::test]
async fn batch_while_with_expiry_batches_while_predicate_is_true() {
let (input_sender, input_receiver) = mpsc::channel(100);

let (mut output_receiver, join_handle) = Pipeline::from(input_receiver)
.batch_while_with_expiry(0, |state, x| {
let sum = state + x;

(sum < 10).then_some((sum, tokio::time::sleep(Duration::from_millis(100))))
})
.build();

input_sender.send(4).await.unwrap();
input_sender.send(5).await.unwrap();
assert_eq!(wait_for_len(&output_receiver).await, 0);

input_sender.send(2).await.unwrap(); // reached 10. should batch
input_sender.send(3).await.unwrap();
input_sender.send(3).await.unwrap();

assert_eq!(output_receiver.recv().await, Some(vec![4, 5, 2]));

input_sender.send(3).await.unwrap();
assert_eq!(wait_for_len(&output_receiver).await, 0);

input_sender.send(2).await.unwrap();
input_sender.send(5).await.unwrap();
assert_eq!(output_receiver.recv().await, Some(vec![3, 3, 3, 2]));

drop(input_sender);
assert_eq!(output_receiver.recv().await, Some(vec![5])); // the last batch

join_handle.await.unwrap();
}

#[tokio::test]
async fn batch_while_with_expiry_emits_batch_on_expire() {
let (input_sender, input_receiver) = mpsc::channel(100);

let expiry = Duration::from_millis(100);

let (mut output_receiver, join_handle) = Pipeline::from(input_receiver)
.batch_while_with_expiry(0, move |state, x| {
let sum = state + x;

(sum < 10).then_some((sum, tokio::time::sleep(expiry)))
})
.build();

input_sender.send(4).await.unwrap();
input_sender.send(5).await.unwrap();
assert_eq!(wait_for_len(&output_receiver).await, 0);
tokio::time::sleep(expiry * 2).await;

assert_eq!(output_receiver.recv().await, Some(vec![4, 5]));

input_sender.send(2).await.unwrap();
input_sender.send(3).await.unwrap();
input_sender.send(6).await.unwrap();
input_sender.send(5).await.unwrap();
assert_eq!(output_receiver.recv().await, Some(vec![2, 3, 6]));

drop(input_sender);
assert_eq!(output_receiver.recv().await, Some(vec![5])); // the last batch
assert_eq!(output_receiver.recv().await, None); // the last batch
join_handle.await.unwrap();
}

#[tokio::test]
async fn batch_while_batches_while_predicate_is_true() {
let (input_sender, input_receiver) = mpsc::channel(100);

let (mut output_receiver, join_handle) = Pipeline::from(input_receiver)
.batch_while(0, |state, x| {
let sum = state + x;

(sum < 10).then_some(sum)
})
.build();

input_sender.send(4).await.unwrap();
input_sender.send(5).await.unwrap();
assert_eq!(wait_for_len(&output_receiver).await, 0);

input_sender.send(2).await.unwrap(); // reached 10. should batch
input_sender.send(3).await.unwrap();
input_sender.send(3).await.unwrap();

assert_eq!(output_receiver.recv().await, Some(vec![4, 5, 2]));

input_sender.send(3).await.unwrap();
assert_eq!(wait_for_len(&output_receiver).await, 0);

input_sender.send(2).await.unwrap();
input_sender.send(5).await.unwrap();
assert_eq!(output_receiver.recv().await, Some(vec![3, 3, 3, 2]));

drop(input_sender);
assert_eq!(output_receiver.recv().await, Some(vec![5])); // the last batch

join_handle.await.unwrap();
}
}
1 change: 1 addition & 0 deletions src/pumps/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub(crate) mod backpressure;
pub(crate) mod backpressure_with_relief_valve;
pub(crate) mod batch;
pub(crate) mod batch_while_with_expiry;
pub(crate) mod enumerate;
pub(crate) mod filter_map;
pub(crate) mod flatten;
Expand Down
18 changes: 16 additions & 2 deletions src/test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use std::{collections::HashMap, sync::Arc, time::SystemTime};
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, SystemTime},
};

use futures::{future::BoxFuture, FutureExt};
use tokio::sync::Mutex;
use tokio::sync::{mpsc, Mutex};

pub struct TestValue {
pub id: i32,
Expand Down Expand Up @@ -145,3 +149,13 @@ impl FutureTimings {
}
}
}

pub async fn wait_for_capacity(sender: &mpsc::Sender<i32>) -> usize {
tokio::time::sleep(Duration::from_millis(20)).await;
sender.capacity()
}

pub async fn wait_for_len<T>(receiver: &mpsc::Receiver<T>) -> usize {
tokio::time::sleep(Duration::from_millis(20)).await;
receiver.len()
}

0 comments on commit db83ef3

Please sign in to comment.