From 7eea93cdc88254d3521bfb9ab3f2991f90f86763 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 22 Oct 2024 18:26:29 +0000 Subject: [PATCH] refactor(python): Expose group_by_dynamic in pyir Previously we were just ignoring this, which meant we would sometimes fail to fall back correctly in the GPU engine. --- Cargo.lock | 4 + crates/polars-arrow/Cargo.toml | 2 + .../legacy/kernels/rolling/no_nulls/mod.rs | 4 +- .../polars-arrow/src/legacy/kernels/time.rs | 4 +- crates/polars-core/Cargo.toml | 1 + crates/polars-core/src/frame/mod.rs | 4 +- crates/polars-ops/Cargo.toml | 1 + crates/polars-ops/src/frame/join/args.rs | 4 +- .../polars-ops/src/series/ops/is_between.rs | 4 +- .../src/dsl/function_expr/bitwise.rs | 4 +- crates/polars-plan/src/dsl/options.rs | 4 +- crates/polars-python/src/lazyframe/visit.rs | 2 +- .../src/lazyframe/visitor/expr_nodes.rs | 148 ++++++++++-------- .../src/lazyframe/visitor/nodes.rs | 22 +-- crates/polars-time/Cargo.toml | 1 + crates/polars-time/src/windows/group_by.rs | 10 +- 16 files changed, 125 insertions(+), 94 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6c5dd1cf83f5..bac37ef2819a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2776,6 +2776,7 @@ dependencies = [ "simdutf8", "streaming-iterator", "strength_reduce", + "strum_macros", "tokio", "tokio-util", "version_check", @@ -2839,6 +2840,7 @@ dependencies = [ "regex", "serde", "serde_json", + "strum_macros", "thiserror", "version_check", "xxhash-rust", @@ -3048,6 +3050,7 @@ dependencies = [ "regex-syntax 0.8.5", "serde", "serde_json", + "strum_macros", "unicode-reverse", "version_check", ] @@ -3276,6 +3279,7 @@ dependencies = [ "polars-utils", "regex", "serde", + "strum_macros", ] [[package]] diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index 4cbf4a0ad6e4..2bca77858be1 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -71,6 +71,8 @@ ahash = { workspace = true } async-stream = { version = "0.3", optional = true } tokio = { workspace = true, optional = true, features = ["io-util"] } +strum_macros = { workspace = true } + [dev-dependencies] criterion = "0.5" crossbeam-channel = { workspace = true } diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs index 3277318e6807..7abe2455e61f 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs @@ -11,6 +11,7 @@ use num_traits::{Float, Num, NumCast}; pub use quantile::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; pub use sum::*; pub use variance::*; @@ -69,8 +70,9 @@ where Ok(Box::new(arr)) } -#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum QuantileMethod { #[default] Nearest, diff --git a/crates/polars-arrow/src/legacy/kernels/time.rs b/crates/polars-arrow/src/legacy/kernels/time.rs index 08bc285c7ffe..73caabb7587b 100644 --- a/crates/polars-arrow/src/legacy/kernels/time.rs +++ b/crates/polars-arrow/src/legacy/kernels/time.rs @@ -9,6 +9,7 @@ use polars_error::PolarsResult; use polars_error::{polars_bail, PolarsError}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; pub enum Ambiguous { Earliest, @@ -32,8 +33,9 @@ impl FromStr for Ambiguous { } } -#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum NonExistent { Null, Raise, diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index d01ae4dd0203..bb5cdc85cdac 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -36,6 +36,7 @@ regex = { workspace = true, optional = true } # activate if you want serde support for Series and DataFrames serde = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } +strum_macros = { workspace = true } thiserror = { workspace = true } xxhash-rust = { workspace = true } diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index b82d43365010..5941970d14cf 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -33,6 +33,7 @@ use arrow::record_batch::RecordBatch; use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; use crate::chunked_array::cast::CastOptions; #[cfg(feature = "row_hash")] @@ -49,8 +50,9 @@ pub enum NullStrategy { Propagate, } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum UniqueKeepStrategy { /// Keep the first unique row. First, diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 63e4491b25d8..63e52cffc1e4 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -37,6 +37,7 @@ regex = { workspace = true } regex-syntax = { workspace = true } serde = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } +strum_macros = { workspace = true } unicode-reverse = { workspace = true, optional = true } [dependencies.jsonpath_lib] diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index b4d347170cbb..4c845a2ba541 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -18,6 +18,7 @@ pub type ChunkJoinIds = Vec; use polars_core::export::once_cell::sync::Lazy; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; #[derive(Clone, PartialEq, Eq, Debug, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -108,8 +109,9 @@ impl JoinArgs { } } -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum JoinType { Inner, Left, diff --git a/crates/polars-ops/src/series/ops/is_between.rs b/crates/polars-ops/src/series/ops/is_between.rs index 053493d552f6..96b1074b4d82 100644 --- a/crates/polars-ops/src/series/ops/is_between.rs +++ b/crates/polars-ops/src/series/ops/is_between.rs @@ -3,9 +3,11 @@ use std::ops::BitAnd; use polars_core::prelude::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; -#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default)] +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum ClosedInterval { #[default] Both, diff --git a/crates/polars-plan/src/dsl/function_expr/bitwise.rs b/crates/polars-plan/src/dsl/function_expr/bitwise.rs index 2d4dd779cff0..1f0be9247993 100644 --- a/crates/polars-plan/src/dsl/function_expr/bitwise.rs +++ b/crates/polars-plan/src/dsl/function_expr/bitwise.rs @@ -2,6 +2,7 @@ use std::fmt; use std::sync::Arc; use polars_core::prelude::*; +use strum_macros::IntoStaticStr; use super::{ColumnsUdf, SpecialEq}; use crate::dsl::FieldsMapper; @@ -21,7 +22,8 @@ pub enum BitwiseFunction { } #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)] +#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash, IntoStaticStr)] +#[strum(serialize_all = "snake_case")] pub enum BitwiseAggFunction { And, Or, diff --git a/crates/polars-plan/src/dsl/options.rs b/crates/polars-plan/src/dsl/options.rs index 73481796d3e0..259d66af95ae 100644 --- a/crates/polars-plan/src/dsl/options.rs +++ b/crates/polars-plan/src/dsl/options.rs @@ -5,6 +5,7 @@ use polars_utils::pl_str::PlSmallStr; use polars_utils::IdxSize; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; use crate::dsl::Selector; @@ -87,8 +88,9 @@ impl Default for WindowType { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum WindowMapping { /// Map the group values to the position #[default] diff --git a/crates/polars-python/src/lazyframe/visit.rs b/crates/polars-python/src/lazyframe/visit.rs index 19f0f06a2c8c..5a98398703b9 100644 --- a/crates/polars-python/src/lazyframe/visit.rs +++ b/crates/polars-python/src/lazyframe/visit.rs @@ -57,7 +57,7 @@ impl NodeTraverser { // Increment major on breaking changes to the IR (e.g. renaming // fields, reordering tuples), minor on backwards compatible // changes (e.g. exposing a new expression node). - const VERSION: Version = (2, 3); + const VERSION: Version = (3, 0); pub fn new(root: Node, lp_arena: Arena, expr_arena: Arena) -> Self { Self { diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index 4798c257ec82..41f04b7c1cad 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -2,9 +2,7 @@ use polars::datatypes::TimeUnit; #[cfg(feature = "iejoin")] use polars::prelude::InequalityOperator; use polars::series::ops::NullBehavior; -use polars_core::prelude::{NonExistent, QuantileMethod}; use polars_core::series::IsSorted; -use polars_ops::prelude::ClosedInterval; use polars_ops::series::InterpolationMethod; #[cfg(feature = "search_sorted")] use polars_ops::series::SearchSortedSide; @@ -16,6 +14,7 @@ use polars_plan::prelude::{ WindowMapping, WindowType, }; use polars_time::prelude::RollingGroupOptions; +use polars_time::{Duration, DynamicGroupOptions}; use pyo3::exceptions::PyNotImplementedError; use pyo3::prelude::*; @@ -44,18 +43,6 @@ pub struct Literal { dtype: PyObject, } -impl IntoPy for Wrap { - fn into_py(self, py: Python<'_>) -> PyObject { - match self.0 { - ClosedInterval::Both => "both", - ClosedInterval::Left => "left", - ClosedInterval::Right => "right", - ClosedInterval::None => "none", - } - .into_py(py) - } -} - #[pyclass(name = "Operator")] #[derive(Copy, Clone)] pub enum PyOperator { @@ -404,15 +391,25 @@ pub struct PyWindowMapping { impl PyWindowMapping { #[getter] fn kind(&self, py: Python<'_>) -> PyResult { - let result = match self.inner { - WindowMapping::GroupsToRows => "groups_to_rows".to_object(py), - WindowMapping::Explode => "explode".to_object(py), - WindowMapping::Join => "join".to_object(py), - }; + let result: &str = self.inner.into(); Ok(result.into_py(py)) } } +impl IntoPy for Wrap { + fn into_py(self, py: Python<'_>) -> PyObject { + ( + self.0.months(), + self.0.weeks(), + self.0.days(), + self.0.nanoseconds(), + self.0.parsed_int, + self.0.negative(), + ) + .into_py(py) + } +} + #[pyclass(name = "RollingGroupOptions")] pub struct PyRollingGroupOptions { inner: RollingGroupOptions, @@ -427,41 +424,68 @@ impl PyRollingGroupOptions { #[getter] fn period(&self, py: Python<'_>) -> PyResult { - let result = vec![ - self.inner.period.months().to_object(py), - self.inner.period.weeks().to_object(py), - self.inner.period.days().to_object(py), - self.inner.period.nanoseconds().to_object(py), - self.inner.period.parsed_int.to_object(py), - self.inner.period.negative().to_object(py), - ] - .into_py(py); - Ok(result) + Ok(Wrap(self.inner.period).into_py(py)) } #[getter] fn offset(&self, py: Python<'_>) -> PyResult { - let result = vec![ - self.inner.offset.months().to_object(py), - self.inner.offset.weeks().to_object(py), - self.inner.offset.days().to_object(py), - self.inner.offset.nanoseconds().to_object(py), - self.inner.offset.parsed_int.to_object(py), - self.inner.offset.negative().to_object(py), - ] - .into_py(py); - Ok(result) + Ok(Wrap(self.inner.offset).into_py(py)) } #[getter] fn closed_window(&self, py: Python<'_>) -> PyResult { - let result = match self.inner.closed_window { - polars::time::ClosedWindow::Left => "left".to_object(py), - polars::time::ClosedWindow::Right => "right".to_object(py), - polars::time::ClosedWindow::Both => "both".to_object(py), - polars::time::ClosedWindow::None => "none".to_object(py), - }; - Ok(result.into_py(py)) + let result: &str = self.inner.closed_window.into(); + Ok(result.to_object(py)) + } +} + +#[pyclass(name = "DynamicGroupOptions")] +pub struct PyDynamicGroupOptions { + inner: DynamicGroupOptions, +} + +#[pymethods] +impl PyDynamicGroupOptions { + #[getter] + fn index_column(&self, py: Python<'_>) -> PyResult { + Ok(self.inner.index_column.to_object(py)) + } + + #[getter] + fn every(&self, py: Python<'_>) -> PyResult { + Ok(Wrap(self.inner.every).into_py(py)) + } + + #[getter] + fn period(&self, py: Python<'_>) -> PyResult { + Ok(Wrap(self.inner.period).into_py(py)) + } + + #[getter] + fn offset(&self, py: Python<'_>) -> PyResult { + Ok(Wrap(self.inner.offset).into_py(py)) + } + + #[getter] + fn label(&self, py: Python<'_>) -> PyResult { + let result: &str = self.inner.label.into(); + Ok(result.to_object(py)) + } + + #[getter] + fn include_boundaries(&self, py: Python<'_>) -> PyResult { + Ok(self.inner.include_boundaries.into_py(py)) + } + + #[getter] + fn closed_window(&self, py: Python<'_>) -> PyResult { + let result: &str = self.inner.closed_window.into(); + Ok(result.to_object(py)) + } + #[getter] + fn start_by(&self, py: Python<'_>) -> PyResult { + let result: &str = self.inner.start_by.into(); + Ok(result.to_object(py)) } } @@ -486,6 +510,14 @@ impl PyGroupbyOptions { .map_or_else(|| py.None(), |f| f.to_object(py))) } + #[getter] + fn dynamic(&self, py: Python<'_>) -> PyResult { + Ok(self.inner.dynamic.as_ref().map_or_else( + || py.None(), + |f| PyDynamicGroupOptions { inner: f.clone() }.into_py(py), + )) + } + #[getter] fn rolling(&self, py: Python<'_>) -> PyResult { Ok(self.inner.rolling.as_ref().map_or_else( @@ -705,15 +737,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { } => Agg { name: "quantile".to_object(py), arguments: vec![expr.0, quantile.0], - options: match interpol { - QuantileMethod::Nearest => "nearest", - QuantileMethod::Lower => "lower", - QuantileMethod::Higher => "higher", - QuantileMethod::Midpoint => "midpoint", - QuantileMethod::Linear => "linear", - QuantileMethod::Equiprobable => "equiprobable", - } - .to_object(py), + options: Into::<&str>::into(interpol).to_object(py), }, IRAggExpr::Sum(n) => Agg { name: "sum".to_object(py), @@ -743,12 +767,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { IRAggExpr::Bitwise(n, f) => Agg { name: "bitwise".to_object(py), arguments: vec![n.0], - options: match f { - polars::prelude::BitwiseAggFunction::And => "and", - polars::prelude::BitwiseAggFunction::Or => "or", - polars::prelude::BitwiseAggFunction::Xor => "xor", - } - .to_object(py), + options: Into::<&str>::into(f).to_object(py), }, } .into_py(py), @@ -1035,10 +1054,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { time_zone .as_ref() .map_or_else(|| py.None(), |s| s.to_object(py)), - match non_existent { - NonExistent::Null => "nullify", - NonExistent::Raise => "raise", - }, + Into::<&str>::into(non_existent), ) .into_py(py), TemporalFunction::Combine(time_unit) => { @@ -1078,7 +1094,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { BooleanFunction::IsUnique => (PyBooleanFunction::IsUnique,).into_py(py), BooleanFunction::IsDuplicated => (PyBooleanFunction::IsDuplicated,).into_py(py), BooleanFunction::IsBetween { closed } => { - (PyBooleanFunction::IsBetween, Wrap(*closed)).into_py(py) + (PyBooleanFunction::IsBetween, Into::<&str>::into(closed)).into_py(py) }, #[cfg(feature = "is_in")] BooleanFunction::IsIn => (PyBooleanFunction::IsIn,).into_py(py), diff --git a/crates/polars-python/src/lazyframe/visitor/nodes.rs b/crates/polars-python/src/lazyframe/visitor/nodes.rs index 7d03f509496a..28c5e459b1e5 100644 --- a/crates/polars-python/src/lazyframe/visitor/nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/nodes.rs @@ -1,4 +1,4 @@ -use polars_core::prelude::{IdxSize, UniqueKeepStrategy}; +use polars_core::prelude::IdxSize; use polars_ops::prelude::JoinType; use polars_plan::plans::IR; use polars_plan::prelude::{ @@ -454,7 +454,6 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { ))) })?, maintain_order: *maintain_order, - // TODO: dynamic options options: PyGroupbyOptions::new(options.as_ref().clone()).into_py(py), } .into_py(py), @@ -472,23 +471,16 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { right_on: right_on.iter().map(|e| e.into()).collect(), options: { let how = &options.args.how; - + let name = Into::<&str>::into(how).to_object(py); ( match how { - JoinType::Left => "left".to_object(py), - JoinType::Right => "right".to_object(py), - JoinType::Inner => "inner".to_object(py), - JoinType::Full => "full".to_object(py), #[cfg(feature = "asof_join")] JoinType::AsOf(_) => { return Err(PyNotImplementedError::new_err("asof join")) }, - JoinType::Cross => "cross".to_object(py), - JoinType::Semi => "leftsemi".to_object(py), - JoinType::Anti => "leftanti".to_object(py), #[cfg(feature = "iejoin")] JoinType::IEJoin(ie_options) => ( - "inequality".to_object(py), + name, crate::Wrap(ie_options.operator1).into_py(py), ie_options .operator2 @@ -496,6 +488,7 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { .map_or_else(|| py.None(), |op| crate::Wrap(*op).into_py(py)), ) .into_py(py), + _ => name, }, options.args.join_nulls, options.args.slice, @@ -529,12 +522,7 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { IR::Distinct { input, options } => Distinct { input: input.0, options: ( - match options.keep_strategy { - UniqueKeepStrategy::First => "first", - UniqueKeepStrategy::Last => "last", - UniqueKeepStrategy::None => "none", - UniqueKeepStrategy::Any => "any", - }, + Into::<&str>::into(options.keep_strategy), options.subset.as_ref().map_or_else( || py.None(), |f| { diff --git a/crates/polars-time/Cargo.toml b/crates/polars-time/Cargo.toml index d75d634d213d..6d878b0ba27f 100644 --- a/crates/polars-time/Cargo.toml +++ b/crates/polars-time/Cargo.toml @@ -23,6 +23,7 @@ now = { version = "0.1" } once_cell = { workspace = true } regex = { workspace = true } serde = { workspace = true, optional = true } +strum_macros = { workspace = true } [dev-dependencies] polars-ops = { workspace = true, features = ["abs"] } diff --git a/crates/polars-time/src/windows/group_by.rs b/crates/polars-time/src/windows/group_by.rs index 9ba3a2d3dbc2..0a40b9af6fbc 100644 --- a/crates/polars-time/src/windows/group_by.rs +++ b/crates/polars-time/src/windows/group_by.rs @@ -8,11 +8,13 @@ use polars_core::POOL; use polars_utils::slice::GetSaferUnchecked; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; use crate::prelude::*; -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum ClosedWindow { Left, Right, @@ -20,16 +22,18 @@ pub enum ClosedWindow { None, } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum Label { Left, Right, DataPoint, } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum StartBy { WindowBound, DataPoint,