diff --git a/hydroflow_lang/src/graph/ops/fold_keyed.rs b/hydroflow_lang/src/graph/ops/fold_keyed.rs index be038a804c9c..959ac2c8572f 100644 --- a/hydroflow_lang/src/graph/ops/fold_keyed.rs +++ b/hydroflow_lang/src/graph/ops/fold_keyed.rs @@ -143,11 +143,19 @@ pub const FOLD_KEYED: OperatorConstraints = OperatorConstraints { fn check_input, A: ::std::clone::Clone, B: ::std::clone::Clone>(iter: Iter) -> impl ::std::iter::Iterator { iter } + #[inline(always)] + /// A: accumulator type + /// T: iterator item type + /// O: output type + fn call_comb_type(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O { + f(a, t) + } + for kv in check_input(#input) { // TODO(mingwei): remove `unknown_lints` when `clippy::unwrap_or_default` is stabilized. #[allow(unknown_lints, clippy::unwrap_or_default)] let entry = #hashtable_ident.entry(kv.0).or_insert_with(#initfn); - #[allow(clippy::redundant_closure_call)] (#aggfn)(entry, kv.1); + #[allow(clippy::redundant_closure_call)] call_comb_type(entry, kv.1, #aggfn); } } @@ -172,11 +180,19 @@ pub const FOLD_KEYED: OperatorConstraints = OperatorConstraints { fn check_input, A: ::std::clone::Clone, B: ::std::clone::Clone>(iter: Iter) -> impl ::std::iter::Iterator { iter } + #[inline(always)] + /// A: accumulator type + /// T: iterator item type + /// O: output type + fn call_comb_type(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O { + f(a, t) + } + for kv in check_input(#input) { // TODO(mingwei): remove `unknown_lints` when `clippy::unwrap_or_default` is stabilized. #[allow(unknown_lints, clippy::unwrap_or_default)] let entry = #hashtable_ident.entry(kv.0).or_insert_with(#initfn); - #[allow(clippy::redundant_closure_call)] (#aggfn)(entry, kv.1); + #[allow(clippy::redundant_closure_call)] call_comb_type(entry, kv.1, #aggfn); } } @@ -206,11 +222,19 @@ pub const FOLD_KEYED: OperatorConstraints = OperatorConstraints { fn check_input>, K: ::std::clone::Clone, V: ::std::clone::Clone>(iter: Iter) -> impl ::std::iter::Iterator> { iter } + #[inline(always)] + /// A: accumulator type + /// T: iterator item type + /// O: output type + fn call_comb_type(a: &mut A, t: T, f: impl Fn(&mut A, T) -> O) -> O { + f(a, t) + } + for item in check_input(#input) { match item { Persist(k, v) => { let entry = #hashtable_ident.entry(k).or_insert_with(#initfn); - #[allow(clippy::redundant_closure_call)] (#aggfn)(entry, v); + #[allow(clippy::redundant_closure_call)] call_comb_type(entry, v, #aggfn); }, Delete(k) => { #hashtable_ident.remove(&k); diff --git a/hydroflow_lang/src/graph/ops/reduce_keyed.rs b/hydroflow_lang/src/graph/ops/reduce_keyed.rs index c217db1fb57c..cef77cd8975a 100644 --- a/hydroflow_lang/src/graph/ops/reduce_keyed.rs +++ b/hydroflow_lang/src/graph/ops/reduce_keyed.rs @@ -132,13 +132,20 @@ pub const REDUCE_KEYED: OperatorConstraints = OperatorConstraints { fn check_input, A: ::std::clone::Clone, B: ::std::clone::Clone>(iter: Iter) -> impl ::std::iter::Iterator { iter } + #[inline(always)] + /// A: accumulator type + /// O: output type + fn call_comb_type(acc: &mut A, item: A, f: impl Fn(&mut A, A) -> O) -> O { + f(acc, item) + } + for kv in check_input(#input) { match #hashtable_ident.entry(kv.0) { ::std::collections::hash_map::Entry::Vacant(vacant) => { vacant.insert(kv.1); } ::std::collections::hash_map::Entry::Occupied(mut occupied) => { - #[allow(clippy::redundant_closure_call)] (#aggfn)(occupied.get_mut(), kv.1); + #[allow(clippy::redundant_closure_call)] call_comb_type(occupied.get_mut(), kv.1, #aggfn); } } } @@ -165,13 +172,20 @@ pub const REDUCE_KEYED: OperatorConstraints = OperatorConstraints { fn check_input, A: ::std::clone::Clone, B: ::std::clone::Clone>(iter: Iter) -> impl ::std::iter::Iterator { iter } + #[inline(always)] + /// A: accumulator type + /// O: output type + fn call_comb_type(acc: &mut A, item: A, f: impl Fn(&mut A, A) -> O) -> O { + f(acc, item) + } + for kv in check_input(#input) { match #hashtable_ident.entry(kv.0) { ::std::collections::hash_map::Entry::Vacant(vacant) => { vacant.insert(kv.1); } ::std::collections::hash_map::Entry::Occupied(mut occupied) => { - #[allow(clippy::redundant_closure_call)] (#aggfn)(occupied.get_mut(), kv.1); + #[allow(clippy::redundant_closure_call)] call_comb_type(occupied.get_mut(), kv.1, #aggfn); } } } diff --git a/hydroflow_plus/src/stream.rs b/hydroflow_plus/src/stream.rs index 81a738390745..f67ae8a2fd13 100644 --- a/hydroflow_plus/src/stream.rs +++ b/hydroflow_plus/src/stream.rs @@ -338,7 +338,7 @@ impl<'a, T, N: Location<'a>> Stream<'a, T, Windowed, N> { } } - pub fn reduce( + pub fn reduce( &self, comb: impl IntoQuotedMut<'a, C>, ) -> Stream<'a, T, Windowed, N> { @@ -467,6 +467,36 @@ impl<'a, K, V1, W, N: Location<'a>> Stream<'a, (K, V1), W, N> { } } +impl<'a, K: Eq + Hash, V, N: Location<'a>> Stream<'a, (K, V), Windowed, N> { + pub fn fold_keyed A + 'a, C: Fn(&mut A, V) + 'a>( + &self, + init: impl IntoQuotedMut<'a, I>, + comb: impl IntoQuotedMut<'a, C>, + ) -> Stream<'a, (K, A), Windowed, N> { + let init = init.splice(); + let comb = comb.splice(); + + if self.is_delta { + self.pipeline_op(parse_quote!(fold_keyed::<'static>(#init, #comb)), false) + } else { + self.pipeline_op(parse_quote!(fold_keyed::<'tick>(#init, #comb)), false) + } + } + + pub fn reduce_keyed( + &self, + comb: impl IntoQuotedMut<'a, F>, + ) -> Stream<'a, (K, V), Windowed, N> { + let comb = comb.splice(); + + if self.is_delta { + self.pipeline_op(parse_quote!(reduce_keyed::<'static>(#comb)), false) + } else { + self.pipeline_op(parse_quote!(reduce_keyed::<'tick>(#comb)), false) + } + } +} + fn get_this_crate() -> TokenStream { let hydroflow_crate = proc_macro_crate::crate_name("hydroflow_plus") .expect("hydroflow_plus should be present in `Cargo.toml`"); diff --git a/hydroflow_plus_test/src/cluster.rs b/hydroflow_plus_test/src/cluster.rs index e0b75951df01..d7cdfe182f16 100644 --- a/hydroflow_plus_test/src/cluster.rs +++ b/hydroflow_plus_test/src/cluster.rs @@ -46,7 +46,7 @@ pub fn map_reduce<'a, D: Deploy<'a>>( let cluster = flow.cluster(cluster_spec); let words = process - .source_iter(q!(vec!["abc", "abc", "xyz"])) + .source_iter(q!(vec!["abc", "abc", "xyz", "abc"])) .map(q!(|s| s.to_string())); let all_ids_vec = cluster.ids(); @@ -57,12 +57,16 @@ pub fn map_reduce<'a, D: Deploy<'a>>( words_partitioned .demux_bincode(&cluster) .tick_batch() - .fold(q!(|| 0), q!(|count, string| *count += string.len())) - .inspect(q!(|count| println!("partition count: {}", count))) + .map(q!(|string| (string, ()))) + .fold_keyed(q!(|| 0), q!(|count, _| *count += 1)) + .inspect(q!(|(string, count)| println!( + "partition count: {} - {}", + string, count + ))) .send_bincode_interleaved(&process) .all_ticks() - .fold(q!(|| 0), q!(|total, count| *total += count)) - .for_each(q!(|data| println!("total: {}", data))); + .reduce_keyed(q!(|total, count| *total += count)) + .for_each(q!(|(string, count)| println!("{}: {}", string, count))); (process, cluster) }