From 79b0d906b64cd8f79a36f2fb9fa47068f4e820d7 Mon Sep 17 00:00:00 2001 From: abstractqqq Date: Thu, 2 Nov 2023 01:55:03 -0400 Subject: [PATCH 1/2] worked on lstsq for a bit --- python/polars_ds/extensions.py | 22 ++++- src/num_ext/expressions.rs | 154 +++++++++++++++++++++++++++++++-- src/str_ext/expressions.rs | 6 +- tests/adhoc.py | 32 +++++-- 4 files changed, 196 insertions(+), 18 deletions(-) diff --git a/python/polars_ds/extensions.py b/python/polars_ds/extensions.py index bca129c2..cd0b1afa 100644 --- a/python/polars_ds/extensions.py +++ b/python/polars_ds/extensions.py @@ -250,11 +250,27 @@ def lstsq(self, *other: pl.Expr) -> pl.Expr: Parameters ---------- other - Either an int or a Polars expression + List of Polars expressions + """ + return self._expr.register_plugin( + lib=lib, + symbol="pl_lstsq", + args=list(other), + is_elementwise=False, + ) + + def lstsq2(self, *other: pl.Expr) -> pl.Expr: """ - return self._expr._register_plugin( + Computes least squares solution to a linear matrix equation. + + Parameters + ---------- + other + List of Polars expressions + """ + return self._expr.register_plugin( lib=lib, - symbol="lstsq", + symbol="pl_lstsq2", args=list(other), is_elementwise=False, ) diff --git a/src/num_ext/expressions.rs b/src/num_ext/expressions.rs index 20b7c99c..a01aa77d 100644 --- a/src/num_ext/expressions.rs +++ b/src/num_ext/expressions.rs @@ -1,10 +1,11 @@ use faer::prelude::*; use faer::solvers::Qr; use faer::{IntoFaer, IntoNdarray}; -use ndarray::ArrayView2; +// use faer::polars::{polars_to_faer_f64, Frame}; +use ndarray::{ArrayView2, Array2}; use num; use polars::prelude::*; -use polars::prelude::*; +use polars_core::utils::rayon::prelude::{IntoParallelRefIterator, ParallelIterator, IndexedParallelIterator}; use pyo3_polars::derive::polars_expr; #[polars_expr(output_type=Int64)] @@ -98,11 +99,12 @@ fn lstsq_output(input_fields: &[Field]) -> PolarsResult { /// This function returns a struct series with betas, y_pred, and residuals #[polars_expr(output_type_func=lstsq_output)] -fn lstsq(inputs: &[Series]) -> PolarsResult { - // Iterate over the inputs and name each one with .with_name() and collect them into a vector - let mut series_vec = Vec::new(); +fn pl_lstsq(inputs: &[Series]) -> PolarsResult { + // Iterate over the inputs and name each one with .with_name() and collect them into a vector + let mut series_vec = Vec::with_capacity(inputs.len()); // Have to name each one because they don't have names if passed in via .over() + for (i, series) in inputs[1..].iter().enumerate() { let series = series.clone().with_name(&format!("x{i}")); series_vec.push(series); @@ -116,6 +118,7 @@ fn lstsq(inputs: &[Series]) -> PolarsResult { .to_owned() .into_shape((inputs[0].len(), 1)) .unwrap(); + let y = y.view().into_faer(); // Create a polars DataFrame from the input series @@ -127,6 +130,8 @@ fn lstsq(inputs: &[Series]) -> PolarsResult { .unwrap() .to_owned(); let x = x.view().into_faer(); + + // Solving Least Square Qr::new(x); let betas = Qr::new(x).solve_lstsq(y); let preds = x * &betas; @@ -141,10 +146,10 @@ fn lstsq(inputs: &[Series]) -> PolarsResult { .map(|(beta, name)| Series::new(name, vec![*beta; inputs[0].len()])) .collect(); // Add a series of residuals and y_pred to the output - let y_pred_series = - Series::new("y_pred", preds_array.iter().copied().collect::>()); - let resid_series = - Series::new("resid", resid_array.iter().copied().collect::>()); + let y_pred_series = Series::from_iter(preds_array).with_name("y_pred"); + + let resid_series = Series::from_iter(resid_array).with_name("resid"); + out_series.push(y_pred_series); out_series.push(resid_series); let out = StructChunked::new("results", &out_series)?.into_series(); @@ -156,3 +161,134 @@ fn lstsq(inputs: &[Series]) -> PolarsResult { } } } + + +#[polars_expr(output_type_func=lstsq_output)] +fn pl_lstsq2(inputs: &[Series]) -> PolarsResult { + + + // let beta_names: Vec = (0..(inputs.len()-1)).map(|i| format!("x{i}")).collect(); + let nrows = inputs[0].len(); + // let ncols = inputs.len() - 1; + + // y + let y = inputs[0].f64()?; + let y = y.to_ndarray()?.into_shape((nrows,1)).unwrap(); + let y = y.view().into_faer(); + + // X + let df_x = DataFrame::new(inputs[1..].to_vec())?; + + match df_x.to_ndarray::(IndexOrder::Fortran) { + + Ok(x) => { + + // Solving Least Square, without bias term + let x = x.view().into_faer(); + Qr::new(x); + let betas = Qr::new(x).solve_lstsq(y); + let preds = x * &betas; + let preds_array = preds.as_ref().into_ndarray(); + let resid = y - &preds; + let resid_array: ArrayView2 = resid.as_ref().into_ndarray(); + let betas = betas.as_ref().into_ndarray(); + + let mut out_series: Vec = Vec::with_capacity(betas.len() + 2); + for (i, b) in betas.into_iter().enumerate() { + out_series.push( + // A copy + Series::from_vec(&format!("x{i}") , vec![*b; nrows]) + ); + } + out_series.push( + // A copy + Series::from_iter(preds_array).with_name("y_pred") + ); + out_series.push( + // A copy + Series::from_iter(resid_array).with_name("resid") + ); + + let out = StructChunked::new("results", &out_series)?.into_series(); + Ok(out) + + } + , Err(e) => { + Err(e) + } + + } + + + +} + + +// #[polars_expr(output_type_func=lstsq_output)] +// fn lstsq2(inputs: &[Series]) -> PolarsResult { +// // Iterate over the inputs and name each one with .with_name() and collect them into a vector +// let mut series_vec = Vec::new(); + +// // Have to name each one because they don't have names if passed in via .over() +// for (i, series) in inputs[1..].iter().enumerate() { +// let series = series.clone().with_name(&format!("x{i}")); +// series_vec.push(series); +// } +// let beta_names: Vec = series_vec.iter().map(|s| s.name().to_string()).collect(); + +// let y = &inputs[0]; + + +// let df_y = df!(y.name() => y)? +// .lazy(); +// let mat_y = polars_to_faer_f64(df_y); + + +// let y = &inputs[0] +// .f64() +// .unwrap() +// .to_ndarray() +// .unwrap() +// .to_owned() +// .into_shape((inputs[0].len(), 1)) +// .unwrap(); +// let y = y.view().into_faer(); + +// // Create a polars DataFrame from the input series +// let todf = DataFrame::new(series_vec); +// match todf { +// Ok(df) => { +// let x = df +// .to_ndarray::(IndexOrder::Fortran) +// .unwrap() +// .to_owned(); +// let x = x.view().into_faer(); +// Qr::new(x); +// let betas = Qr::new(x).solve_lstsq(y); +// let preds = x * &betas; +// let preds_array = preds.as_ref().into_ndarray(); +// let resids = y - &preds; +// let resid_array: ArrayView2 = resids.as_ref().into_ndarray(); +// let betas = betas.as_ref().into_ndarray(); + +// let mut out_series: Vec = betas +// .iter() +// .zip(beta_names.iter()) +// .map(|(beta, name)| Series::new(name, vec![*beta; inputs[0].len()])) +// .collect(); +// // Add a series of residuals and y_pred to the output +// let y_pred_series = +// Series::new("y_pred", preds_array.iter().copied().collect::>()); +// let resid_series = +// Series::new("resid", resid_array.iter().copied().collect::>()); +// out_series.push(y_pred_series); +// out_series.push(resid_series); +// let out = StructChunked::new("results", &out_series)?.into_series(); +// Ok(out) +// } +// Err(e) => { +// println!("Error: {}", e); +// PolarsResult::Err(e) +// } +// } +// } \ No newline at end of file diff --git a/src/str_ext/expressions.rs b/src/str_ext/expressions.rs index ad3df2ed..c12b18a1 100644 --- a/src/str_ext/expressions.rs +++ b/src/str_ext/expressions.rs @@ -25,6 +25,7 @@ pub fn snowball_stem(word:Option<&str>, no_stopwords:bool) -> Option { } } + #[inline] pub fn hamming_dist(s1:&str, s2:&str) -> Option { if s1.len() != s2.len() { @@ -37,6 +38,7 @@ pub fn hamming_dist(s1:&str, s2:&str) -> Option { ) } + #[inline] pub fn levenshtein_dist(s1:&str, s2:&str) -> u32 { // It is possible to go faster by not using a matrix to represent the @@ -69,8 +71,8 @@ pub fn levenshtein_dist(s1:&str, s2:&str) -> u32 { dp[len1][len2] } -// Wrapper for Polars Extension +// Wrapper for Polars Extension #[polars_expr(output_type=Utf8)] fn pl_snowball_stem(inputs: &[Series]) -> PolarsResult { let ca = inputs[0].utf8()?; @@ -110,6 +112,7 @@ fn pl_levenshtein_dist(inputs: &[Series]) -> PolarsResult { } } + #[polars_expr(output_type=Float64)] fn pl_str_jaccard(inputs: &[Series]) -> PolarsResult { let ca1 = inputs[0].utf8()?; @@ -178,6 +181,7 @@ fn pl_str_jaccard(inputs: &[Series]) -> PolarsResult { } } + #[polars_expr(output_type=UInt32)] fn pl_hamming_dist(inputs: &[Series]) -> PolarsResult { let ca1 = inputs[0].utf8()?; diff --git a/tests/adhoc.py b/tests/adhoc.py index 8f1d031e..2e48b0d3 100644 --- a/tests/adhoc.py +++ b/tests/adhoc.py @@ -1,15 +1,37 @@ import polars as pl +import timeit from polars_ds.extensions import StrExt # noqa: F401 +def least_square1(df:pl.dataframe) -> pl.DataFrame: + + return df.select( + pl.col("y").num_ext.lstsq(pl.col("a"), pl.col("b")) + ) + +def least_square2(df:pl.dataframe) -> pl.DataFrame: + + return df.select( + pl.col("y").num_ext.lstsq2(pl.col("a"), pl.col("b")) + ) + + if __name__ == "__main__": df = pl.DataFrame({ - "a":["karolin", "karolin", "kathrin", "0000", "2173896"], - "b":["kathrin", "kerstin", "kerstin", "1111", "2233796"] + "a":pl.Series(range(500_000), dtype=pl.Float64), + "b":pl.Series([1.0] * 500_000, dtype=pl.Float64), + "y":pl.Series(range(500_000), dtype=pl.Float64) + 0.5, }) + res1 = least_square1(df) + res2 = least_square2(df) + + from polars.testing import assert_frame_equal - res = df.select( - pl.col("a").str_ext.hamming_dist(pl.col("b")) + assert_frame_equal( + res1, res2 ) - print(res) \ No newline at end of file + time1 = timeit.timeit(lambda: least_square1(df), number = 10) + time2 = timeit.timeit(lambda: least_square2(df), number = 10) + print(f"Time for Implementation 1: {time1:.4f}s.") + print(f"Time for Implementation 2: {time2:.4f}s.") \ No newline at end of file From 6ac4bb03cf295eb9a15f0cf3b73fa71097cfa87f Mon Sep 17 00:00:00 2001 From: abstractqqq Date: Thu, 2 Nov 2023 23:45:30 -0400 Subject: [PATCH 2/2] better lstsq --- Cargo.lock | 1 + Cargo.toml | 1 + Makefile | 23 ++ python/polars_ds/extensions.py | 49 ++- requirements.txt | 4 + src/num_ext/expressions.rs | 297 +++++++----------- src/str_ext/expressions.rs | 73 ++++- .../Untitled-checkpoint.ipynb | 285 +++++++++++++++++ tests/Untitled.ipynb | 285 +++++++++++++++++ tests/adhoc.py | 37 --- 10 files changed, 804 insertions(+), 251 deletions(-) create mode 100644 Makefile create mode 100644 requirements.txt create mode 100644 tests/.ipynb_checkpoints/Untitled-checkpoint.ipynb create mode 100644 tests/Untitled.ipynb delete mode 100644 tests/adhoc.py diff --git a/Cargo.lock b/Cargo.lock index 5b1f3179..7c25388c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1449,6 +1449,7 @@ name = "polars_ds" version = "0.1.0" dependencies = [ "faer", + "hashbrown", "jemallocator", "ndarray", "num", diff --git a/Cargo.toml b/Cargo.toml index 7a4b4694..59a4dc7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ polars-lazy = "0.34" num = "0.4.1" faer = {version = "0.14.1", features = ["ndarray"]} ndarray = "0.15.6" +hashbrown = "0.14.2" [target.'cfg(target_os = "linux")'.dependencies] jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] } diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..cf668219 --- /dev/null +++ b/Makefile @@ -0,0 +1,23 @@ +SHELL=/bin/bash + +venv: ## Set up virtual environment + python3 -m venv .venv + .venv/bin/pip install -r requirements.txt + +install: venv + unset CONDA_PREFIX && \ + source .venv/bin/activate && maturin develop -m Cargo.toml + +dev-release: venv + unset CONDA_PREFIX && \ + source .venv/bin/activate && maturin develop --release -m Cargo.toml + pip install . + +pre-commit: venv + cargo fmt --all --manifest-path Cargo.toml && cargo clippy --all-features --manifest-path Cargo.toml + +# run: install +# source .venv/bin/activate && python run.py + +# run-release: install-release +# source venv/bin/activate && python run.py \ No newline at end of file diff --git a/python/polars_ds/extensions.py b/python/polars_ds/extensions.py index cd0b1afa..b8531982 100644 --- a/python/polars_ds/extensions.py +++ b/python/polars_ds/extensions.py @@ -133,7 +133,8 @@ def hubor_loss(self, other: pl.Expr, delta: float) -> pl.Expr: def l1_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr: """ - Computes L1 loss (normalized L1 distance) between this and the other expression + Computes L1 loss (normalized L1 distance) between this and the other expression. This + is the norm without 1/p power. Parameters ---------- @@ -149,7 +150,9 @@ def l1_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr: def l2_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr: """ - Computes L2 loss (normalized L2 distance) between this and the other expression + Computes L2 loss (normalized L2 distance) between this and the other expression. This + is the norm without 1/p power. + Parameters ---------- @@ -166,7 +169,9 @@ def l2_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr: def lp_loss(self, other: pl.Expr, p: float, normalize: bool = True) -> pl.Expr: """ - Computes LP loss (normalized LP distance) between this and the other expression + Computes LP loss (normalized LP distance) between this and the other expression. This + is the norm without 1/p power. + for p finite. Parameters @@ -179,9 +184,9 @@ def lp_loss(self, other: pl.Expr, p: float, normalize: bool = True) -> pl.Expr: if p <= 0: raise ValueError(f"Input `p` must be > 0, not {p}") - temp = (self._expr - other).abs().pow(p) + temp = (self._expr - other).abs().pow(p).sum() if normalize: - return temp / self._expr.count() + return (temp / self._expr.count()) return temp def chebyshev_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr: @@ -243,36 +248,30 @@ def smape(self, other: pl.Expr) -> pl.Expr: denominator = 1.0 / (self._expr.abs() + other.abs()) return (1.0 / self._expr.count()) * numerator.dot(denominator) - def lstsq(self, *other: pl.Expr) -> pl.Expr: + def lstsq(self, other: list[pl.Expr], add_bias:bool=False) -> pl.Expr: """ - Computes least squares solution to a linear matrix equation. + Computes least squares solution to a linear matrix equation. If columns are + not linearly independent, some numerical issue or error may occur. Unrealistic + coefficient values is an indication of `silent` numerical problem during the + computation. + + If add_bias is true, it will be the last coefficient in the output + and output will have length |other| + 1 Parameters ---------- other - List of Polars expressions + List of Polars expressions. They should have the same size. + add_bias + Whether to add a bias term """ - return self._expr.register_plugin( - lib=lib, - symbol="pl_lstsq", - args=list(other), - is_elementwise=False, - ) - def lstsq2(self, *other: pl.Expr) -> pl.Expr: - """ - Computes least squares solution to a linear matrix equation. - - Parameters - ---------- - other - List of Polars expressions - """ return self._expr.register_plugin( lib=lib, - symbol="pl_lstsq2", - args=list(other), + symbol="pl_lstsq", + args=[pl.lit(add_bias, dtype=pl.Boolean)] + other, is_elementwise=False, + returns_scalar=True ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..c59abeca --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +maturin +polars +numpy +pytest \ No newline at end of file diff --git a/src/num_ext/expressions.rs b/src/num_ext/expressions.rs index a01aa77d..8a33ea6e 100644 --- a/src/num_ext/expressions.rs +++ b/src/num_ext/expressions.rs @@ -1,43 +1,46 @@ -use faer::prelude::*; -use faer::solvers::Qr; +use faer::{prelude::*, MatRef, Side}; use faer::{IntoFaer, IntoNdarray}; // use faer::polars::{polars_to_faer_f64, Frame}; -use ndarray::{ArrayView2, Array2}; +use ndarray::{Array2, Array1}; use num; use polars::prelude::*; -use polars_core::utils::rayon::prelude::{IntoParallelRefIterator, ParallelIterator, IndexedParallelIterator}; +use polars::chunked_array::ops::arity::binary_elementwise; use pyo3_polars::derive::polars_expr; +fn optional_gcd(op_a:Option, op_b:Option) -> Option { + if let (Some(a), Some(b)) = (op_a, op_b) { + Some(num::integer::gcd(a, b)) + } else { + None + } +} + +fn optional_lcm(op_a:Option, op_b:Option) -> Option { + if let (Some(a), Some(b)) = (op_a, op_b) { + Some(num::integer::lcm(a, b)) + } else { + None + } +} + + #[polars_expr(output_type=Int64)] fn pl_gcd(inputs: &[Series]) -> PolarsResult { + let ca1 = inputs[0].i64()?; let ca2 = inputs[1].i64()?; - if ca2.len() == 1 { let b = ca2.get(0).unwrap(); - let out: Int64Chunked = ca1 - .into_iter() - .map(|op_a| { - if let Some(a) = op_a { - Some(num::integer::gcd(a, b)) - } else { - None - } - }) - .collect(); + let out:Int64Chunked = ca1.apply_generic(|op_a:Option| { + if let Some(a) = op_a { + Some(num::integer::gcd(a, b)) + } else { + None + } + }); Ok(out.into_series()) } else if ca1.len() == ca2.len() { - let out: Int64Chunked = ca1 - .into_iter() - .zip(ca2.into_iter()) - .map(|(op_a, op_b)| { - if let (Some(a), Some(b)) = (op_a, op_b) { - Some(num::integer::gcd(a, b)) - } else { - None - } - }) - .collect(); + let out:Int64Chunked = binary_elementwise(ca1, ca2, optional_gcd); Ok(out.into_series()) } else { Err(PolarsError::ComputeError( @@ -50,32 +53,18 @@ fn pl_gcd(inputs: &[Series]) -> PolarsResult { fn pl_lcm(inputs: &[Series]) -> PolarsResult { let ca1 = inputs[0].i64()?; let ca2 = inputs[1].i64()?; - if ca2.len() == 1 { let b = ca2.get(0).unwrap(); - let out: Int64Chunked = ca1 - .into_iter() - .map(|op_a| { - if let Some(a) = op_a { - Some(num::integer::lcm(a, b)) - } else { - None - } - }) - .collect(); + let out:Int64Chunked = ca1.apply_generic(|op_a:Option| { + if let Some(a) = op_a { + Some(num::integer::lcm(a, b)) + } else { + None + } + }); Ok(out.into_series()) } else if ca1.len() == ca2.len() { - let out: Int64Chunked = ca1 - .into_iter() - .zip(ca2.into_iter()) - .map(|(op_a, op_b)| { - if let (Some(a), Some(b)) = (op_a, op_b) { - Some(num::integer::lcm(a, b)) - } else { - None - } - }) - .collect(); + let out:Int64Chunked = binary_elementwise(ca1, ca2, optional_lcm); Ok(out.into_series()) } else { Err(PolarsError::ComputeError( @@ -84,166 +73,121 @@ fn pl_lcm(inputs: &[Series]) -> PolarsResult { } } -// I am not sure this is right. I still don't quite understand the purpose of this. -fn lstsq_output(input_fields: &[Field]) -> PolarsResult { - Ok(Field::new( - "betas", - DataType::Struct( - input_fields[1..] - .iter() - .map(|f| Field::new(&format!("beta_{}", f.name()), DataType::Float64)) - .collect(), - ), - )) -} - -/// This function returns a struct series with betas, y_pred, and residuals -#[polars_expr(output_type_func=lstsq_output)] -fn pl_lstsq(inputs: &[Series]) -> PolarsResult { - - // Iterate over the inputs and name each one with .with_name() and collect them into a vector - let mut series_vec = Vec::with_capacity(inputs.len()); - // Have to name each one because they don't have names if passed in via .over() - - for (i, series) in inputs[1..].iter().enumerate() { - let series = series.clone().with_name(&format!("x{i}")); - series_vec.push(series); - } - let beta_names: Vec = series_vec.iter().map(|s| s.name().to_string()).collect(); - let y = &inputs[0] - .f64() - .unwrap() - .to_ndarray() - .unwrap() - .to_owned() - .into_shape((inputs[0].len(), 1)) - .unwrap(); - - let y = y.view().into_faer(); - // Create a polars DataFrame from the input series - let todf = DataFrame::new(series_vec); - match todf { - Ok(df) => { - let x = df - .to_ndarray::(IndexOrder::Fortran) - .unwrap() - .to_owned(); - let x = x.view().into_faer(); +// Use QR to solve +fn faer_lstsq_qr( + x: MatRef, + y: MatRef +) -> Result, String> { - // Solving Least Square - Qr::new(x); - let betas = Qr::new(x).solve_lstsq(y); - let preds = x * &betas; - let preds_array = preds.as_ref().into_ndarray(); - let resids = y - &preds; - let resid_array: ArrayView2 = resids.as_ref().into_ndarray(); - let betas = betas.as_ref().into_ndarray(); + let qr = x.qr(); + let betas = qr.solve_lstsq(y); + Ok(betas.as_ref().into_ndarray().to_owned()) - let mut out_series: Vec = betas - .iter() - .zip(beta_names.iter()) - .map(|(beta, name)| Series::new(name, vec![*beta; inputs[0].len()])) - .collect(); - // Add a series of residuals and y_pred to the output - let y_pred_series = Series::from_iter(preds_array).with_name("y_pred"); +} - let resid_series = Series::from_iter(resid_array).with_name("resid"); - - out_series.push(y_pred_series); - out_series.push(resid_series); - let out = StructChunked::new("results", &out_series)?.into_series(); - Ok(out) - } - Err(e) => { - println!("Error: {}", e); - PolarsResult::Err(e) - } +// Closed form. +fn faer_lstsq_cf( + x: MatRef, + y: MatRef +) -> Result, String> { + + let xt = x.transpose(); + let xtx = xt * x; + let decomp = xtx.cholesky(Side::Lower); // .unwrap(); + if let Ok(cholesky) = decomp { + let xtx_inv = cholesky.inverse(); + let betas = xtx_inv * xt * y; + Ok( + betas.as_ref().into_ndarray().to_owned() + ) + } else { + Err("Linear algebra error. Likely cause: column duplication or extremely high correlation.".to_owned()) } -} +} -#[polars_expr(output_type_func=lstsq_output)] -fn pl_lstsq2(inputs: &[Series]) -> PolarsResult { +fn lstsq_beta_output(_: &[Field]) -> PolarsResult { + Ok(Field::new("betas", DataType::List(Box::new(DataType::Float64)))) +} +#[polars_expr(output_type_func=lstsq_beta_output)] +fn pl_lstsq(inputs: &[Series]) -> PolarsResult { - // let beta_names: Vec = (0..(inputs.len()-1)).map(|i| format!("x{i}")).collect(); let nrows = inputs[0].len(); - // let ncols = inputs.len() - 1; - + let add_bias = inputs[1].bool()?; + let add_bias:bool = add_bias.get(0).unwrap(); // y let y = inputs[0].f64()?; let y = y.to_ndarray()?.into_shape((nrows,1)).unwrap(); let y = y.view().into_faer(); - // X - let df_x = DataFrame::new(inputs[1..].to_vec())?; - + // X, Series is ref counted, so cheap + let mut vec_series: Vec = inputs[2..].iter().enumerate().map( + |(i,s)| s.clone().with_name(&i.to_string()) + ).collect(); + if add_bias { + let one = Series::new_empty("cst", &DataType::Float64); + vec_series.push( + one.extend_constant(polars::prelude::AnyValue::Float64(1.), nrows)? + ) + } + let df_x = DataFrame::new(vec_series)?; + // Copy data match df_x.to_ndarray::(IndexOrder::Fortran) { Ok(x) => { - // Solving Least Square, without bias term + // Change this after faer updates let x = x.view().into_faer(); - Qr::new(x); - let betas = Qr::new(x).solve_lstsq(y); - let preds = x * &betas; - let preds_array = preds.as_ref().into_ndarray(); - let resid = y - &preds; - let resid_array: ArrayView2 = resid.as_ref().into_ndarray(); - let betas = betas.as_ref().into_ndarray(); + let betas = faer_lstsq_qr(x,y); // .unwrap(); + match betas { + Ok(b) => { + let betas:Array1 = Array1::from_iter(b); + let mut builder:ListPrimitiveChunkedBuilder = + ListPrimitiveChunkedBuilder::new("betas", 1, betas.len(), DataType::Float64); - let mut out_series: Vec = Vec::with_capacity(betas.len() + 2); - for (i, b) in betas.into_iter().enumerate() { - out_series.push( - // A copy - Series::from_vec(&format!("x{i}") , vec![*b; nrows]) - ); + builder.append_slice(betas.as_slice().unwrap()); + let out = builder.finish(); + Ok(out.into_series()) + }, + Err(e) => Err(PolarsError::ComputeError(e.into())) } - out_series.push( - // A copy - Series::from_iter(preds_array).with_name("y_pred") - ); - out_series.push( - // A copy - Series::from_iter(resid_array).with_name("resid") - ); - - let out = StructChunked::new("results", &out_series)?.into_series(); - Ok(out) - - } - , Err(e) => { - Err(e) } - + , Err(e) => Err(e) } +} +// ----------------------------------------------------------------------------------------- -} +// // I am not sure this is right. I still don't quite understand the purpose of this. +// fn lstsq_output(input_fields: &[Field]) -> PolarsResult { +// Ok(Field::new( +// "betas", +// DataType::Struct( +// input_fields[1..] +// .iter() +// .map(|f| Field::new(&format!("beta_{}", f.name()), DataType::Float64)) +// .collect(), +// ), +// )) +// } +// /// This function returns a struct series with betas, y_pred, and residuals // #[polars_expr(output_type_func=lstsq_output)] -// fn lstsq2(inputs: &[Series]) -> PolarsResult { -// // Iterate over the inputs and name each one with .with_name() and collect them into a vector -// let mut series_vec = Vec::new(); +// fn pl_lstsq_old(inputs: &[Series]) -> PolarsResult { +// // Iterate over the inputs and name each one with .with_name() and collect them into a vector +// let mut series_vec = Vec::with_capacity(inputs.len()); // // Have to name each one because they don't have names if passed in via .over() + // for (i, series) in inputs[1..].iter().enumerate() { // let series = series.clone().with_name(&format!("x{i}")); // series_vec.push(series); // } // let beta_names: Vec = series_vec.iter().map(|s| s.name().to_string()).collect(); - -// let y = &inputs[0]; - - -// let df_y = df!(y.name() => y)? -// .lazy(); -// let mat_y = polars_to_faer_f64(df_y); - - // let y = &inputs[0] // .f64() // .unwrap() @@ -252,6 +196,7 @@ fn pl_lstsq2(inputs: &[Series]) -> PolarsResult { // .to_owned() // .into_shape((inputs[0].len(), 1)) // .unwrap(); + // let y = y.view().into_faer(); // // Create a polars DataFrame from the input series @@ -263,6 +208,8 @@ fn pl_lstsq2(inputs: &[Series]) -> PolarsResult { // .unwrap() // .to_owned(); // let x = x.view().into_faer(); + +// // Solving Least Square // Qr::new(x); // let betas = Qr::new(x).solve_lstsq(y); // let preds = x * &betas; @@ -277,10 +224,10 @@ fn pl_lstsq2(inputs: &[Series]) -> PolarsResult { // .map(|(beta, name)| Series::new(name, vec![*beta; inputs[0].len()])) // .collect(); // // Add a series of residuals and y_pred to the output -// let y_pred_series = -// Series::new("y_pred", preds_array.iter().copied().collect::>()); -// let resid_series = -// Series::new("resid", resid_array.iter().copied().collect::>()); +// let y_pred_series = Series::from_iter(preds_array).with_name("y_pred"); + +// let resid_series = Series::from_iter(resid_array).with_name("resid"); + // out_series.push(y_pred_series); // out_series.push(resid_series); // let out = StructChunked::new("results", &out_series)?.into_series(); diff --git a/src/str_ext/expressions.rs b/src/str_ext/expressions.rs index c12b18a1..ea0e4282 100644 --- a/src/str_ext/expressions.rs +++ b/src/str_ext/expressions.rs @@ -1,4 +1,6 @@ use polars::prelude::*; +use hashbrown::HashSet; +// use polars::chunked_array::ops::arity::binary_elementwise; use polars_core::utils::rayon::prelude::{ParallelIterator, IndexedParallelIterator}; use crate::str_ext::consts::EN_STOPWORDS; use pyo3_polars::derive::polars_expr; @@ -81,6 +83,7 @@ fn pl_snowball_stem(inputs: &[Series]) -> PolarsResult { Ok(out.into_series()) } + #[polars_expr(output_type=UInt32)] fn pl_levenshtein_dist(inputs: &[Series]) -> PolarsResult { let ca1 = inputs[0].utf8()?; @@ -88,15 +91,18 @@ fn pl_levenshtein_dist(inputs: &[Series]) -> PolarsResult { if ca2.len() == 1 { let r = ca2.get(0).unwrap(); - let out: UInt32Chunked = ca1.par_iter().map(|op_s| { - if let Some(s) = op_s { - Some(levenshtein_dist(s, r)) - } else { - None + let out: UInt32Chunked = ca1.par_iter().map( + |op_s| { + if let Some(s) = op_s { + Some(levenshtein_dist(s, r)) + } else { + None + } } - }).collect(); + ).collect(); Ok(out.into_series()) } else if ca1.len() == ca2.len() { + let out: UInt32Chunked = ca1.par_iter_indexed() .zip(ca2.par_iter_indexed()) .map(|(op_w1, op_w2)| { @@ -112,6 +118,45 @@ fn pl_levenshtein_dist(inputs: &[Series]) -> PolarsResult { } } +// #[polars_expr(output_type=UInt32)] +// fn pl_levenshtein_dist2(inputs: &[Series]) -> PolarsResult { +// let ca1 = inputs[0].utf8()?; +// let ca2 = inputs[1].utf8()?; + +// if ca2.len() == 1 { +// let r = ca2.get(0).unwrap(); +// let op = |x:Option<&str>| { +// if let Some(s) = x { +// Some(levenshtein_dist(s, r)) +// } else { +// None +// } +// }; +// // ca1.apply_generic(op) +// let out: UInt32Chunked = ca1.apply_generic(op); +// Ok(out.into_series()) +// } else if ca1.len() == ca2.len() { + +// let op = |x:Option<&str>,y:Option<&str>| { +// if let (Some(s1), Some(s2)) = (x,y) { +// Some(levenshtein_dist(s1, s2)) +// } else { +// None +// } +// }; +// let out:UInt32Chunked = binary_elementwise( +// ca1, +// ca2, +// op +// ); +// Ok(out.into_series()) +// } else { +// Err(PolarsError::ComputeError("Inputs must have the same length.".into())) +// } +// } + + +// binary_elementwise() #[polars_expr(output_type=Float64)] fn pl_str_jaccard(inputs: &[Series]) -> PolarsResult { @@ -124,21 +169,21 @@ fn pl_str_jaccard(inputs: &[Series]) -> PolarsResult { if ca2.len() == 1 { let r = ca2.get(0).unwrap(); - let s2 = if r.len() > n { - PlHashSet::from_iter( + let s2: HashSet<&str> = if r.len() > n { + HashSet::from_iter( r.as_bytes().windows(n).map(|sl| str::from_utf8(sl).unwrap() ) )} else { - PlHashSet::from_iter([r]) + HashSet::from_iter([r]) }; let out: Float64Chunked = ca1.par_iter().map(|op_s| { if let Some(s) = op_s { - let s1 = if s.len() > n { - PlHashSet::from_iter( + let s1: HashSet<&str> = if s.len() > n { + HashSet::from_iter( s.as_bytes().windows(n).map(|sl| str::from_utf8(sl).unwrap()) ) } else { - PlHashSet::from_iter([s]) + HashSet::from_iter([s]) }; let intersection = s1.intersection(&s2).count(); Some( @@ -156,10 +201,10 @@ fn pl_str_jaccard(inputs: &[Series]) -> PolarsResult { .map(|(op_w1, op_w2)| { if let (Some(w1), Some(w2)) = (op_w1, op_w2) { if (w1.len() >= n) & (w2.len() >= n) { - let s1 = PlHashSet::from_iter( + let s1: HashSet<&str> = HashSet::from_iter( w1.as_bytes().windows(n).map(|sl| str::from_utf8(sl).unwrap()) ); - let s2 = PlHashSet::from_iter( + let s2: HashSet<&str> = HashSet::from_iter( w2.as_bytes().windows(n).map(|sl| str::from_utf8(sl).unwrap()) ); let intersection = s1.intersection(&s2).count(); diff --git a/tests/.ipynb_checkpoints/Untitled-checkpoint.ipynb b/tests/.ipynb_checkpoints/Untitled-checkpoint.ipynb new file mode 100644 index 00000000..532124e8 --- /dev/null +++ b/tests/.ipynb_checkpoints/Untitled-checkpoint.ipynb @@ -0,0 +1,285 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "529f4422-5c3a-4bd6-abe0-a15edfc62abb", + "metadata": {}, + "outputs": [], + "source": [ + "from polars_ds.extensions import StrExt, NumExt\n", + "import polars as pl\n", + "import numpy as np " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "592cb462-ee66-49df-9f65-740116abcac2", + "metadata": {}, + "outputs": [], + "source": [ + "df = pl.DataFrame({\n", + " \"a\":[\"Atlanta\"] * 100_000,\n", + " \"b\":[\"Atlantis\"] * 100_000,\n", + "})\n", + "\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f82ebcaa-25df-4ddd-b166-55f06e979593", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 4)
dummyaby
stri64i64f64
"a"0-100000100000.5
"a"1-99999100001.5
"a"2-99998100002.5
"a"3-99997100003.5
"a"4-99996100004.5
" + ], + "text/plain": [ + "shape: (5, 4)\n", + "┌───────┬─────┬─────────┬──────────┐\n", + "│ dummy ┆ a ┆ b ┆ y │\n", + "│ --- ┆ --- ┆ --- ┆ --- │\n", + "│ str ┆ i64 ┆ i64 ┆ f64 │\n", + "╞═══════╪═════╪═════════╪══════════╡\n", + "│ a ┆ 0 ┆ -100000 ┆ 100000.5 │\n", + "│ a ┆ 1 ┆ -99999 ┆ 100001.5 │\n", + "│ a ┆ 2 ┆ -99998 ┆ 100002.5 │\n", + "│ a ┆ 3 ┆ -99997 ┆ 100003.5 │\n", + "│ a ┆ 4 ┆ -99996 ┆ 100004.5 │\n", + "└───────┴─────┴─────────┴──────────┘" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pl.DataFrame({\n", + " \"dummy\": [\"a\"] * 50_000 + [\"b\"] * 50_000,\n", + " \"a\": range(100_000),\n", + " \"b\": range(-100_000, 0),\n", + " \"y\": pl.Series(range(100_000, 200_000)) + 0.5\n", + "})\n", + "\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5821d4d0-fe4f-4864-9d56-0a2c0ef03334", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (2, 2)
dummybetas
strlist[f64]
"a"[0.668022, 0.331978, 133198.343417]
"b"[-0.411356, 1.411356, 241136.143309]
" + ], + "text/plain": [ + "shape: (2, 2)\n", + "┌───────┬───────────────────────────────────┐\n", + "│ dummy ┆ betas │\n", + "│ --- ┆ --- │\n", + "│ str ┆ list[f64] │\n", + "╞═══════╪═══════════════════════════════════╡\n", + "│ a ┆ [0.668022, 0.331978, 133198.3434… │\n", + "│ b ┆ [-0.411356, 1.411356, 241136.143… │\n", + "└───────┴───────────────────────────────────┘" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.group_by(\"dummy\").agg(\n", + " pl.col(\"y\").num_ext.lstsq([pl.col(\"a\"), pl.col(\"b\")], add_bias = True)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3041be37-fa78-4fcc-bb04-a5d598aeedf8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (2, 2)
dummybetas
strlist[f64]
"a"[2.000005, -1.000005]
"b"[2.000005, -1.000005]
" + ], + "text/plain": [ + "shape: (2, 2)\n", + "┌───────┬───────────────────────┐\n", + "│ dummy ┆ betas │\n", + "│ --- ┆ --- │\n", + "│ str ┆ list[f64] │\n", + "╞═══════╪═══════════════════════╡\n", + "│ a ┆ [2.000005, -1.000005] │\n", + "│ b ┆ [2.000005, -1.000005] │\n", + "└───────┴───────────────────────┘" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.group_by(\"dummy\").agg(\n", + " pl.col(\"y\").num_ext.lstsq([pl.col(\"a\"), pl.col(\"b\")], add_bias = False)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3f9ad447-c7bb-4830-b996-882a382f0854", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "757 µs ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + ] + } + ], + "source": [ + "%timeit df.select(pl.col(\"y\").num_ext.lstsq([pl.col(\"a\"), pl.col(\"b\")]))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ff9b6b21-532a-4655-a8a4-e5396c760a67", + "metadata": {}, + "outputs": [ + { + "ename": "ComputeError", + "evalue": "the length of the window expression did not match that of the group\n> group: (\"a\")\n> group length: 1\n> output: 'shape: (2,)\nSeries: '' [f64]\n[\n\t2.000005\n\t-1.000005\n]'\n\nError originated in expression: 'col(\"y\")./home/abstractqqq/Desktop/MY/Projects/polars_ds_extension/.venv/lib/python3.11/site-packages/polars_ds/_polars_ds.cpython-311-x86_64-linux-gnu.so:pl_lstsq([false.strict_cast(Boolean), col(\"a\"), col(\"b\")]).over([col(\"dummy\")])'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mComputeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mdf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43my\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_ext\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstsq\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43ma\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mover\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdummy\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Desktop/MY/Projects/polars_ds_extension/.venv/lib/python3.11/site-packages/polars/dataframe/frame.py:7766\u001b[0m, in \u001b[0;36mDataFrame.select\u001b[0;34m(self, *exprs, **named_exprs)\u001b[0m\n\u001b[1;32m 7664\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mselect\u001b[39m(\n\u001b[1;32m 7665\u001b[0m \u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39mexprs: IntoExpr \u001b[38;5;241m|\u001b[39m Iterable[IntoExpr], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mnamed_exprs: IntoExpr\n\u001b[1;32m 7666\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m DataFrame:\n\u001b[1;32m 7667\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 7668\u001b[0m \u001b[38;5;124;03m Select columns from this DataFrame.\u001b[39;00m\n\u001b[1;32m 7669\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 7764\u001b[0m \n\u001b[1;32m 7765\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 7766\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlazy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mexprs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnamed_exprs\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollect\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_eager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Desktop/MY/Projects/polars_ds_extension/.venv/lib/python3.11/site-packages/polars/utils/deprecation.py:100\u001b[0m, in \u001b[0;36mdeprecate_renamed_parameter..decorate..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(function)\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs: P\u001b[38;5;241m.\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: P\u001b[38;5;241m.\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 97\u001b[0m _rename_keyword_argument(\n\u001b[1;32m 98\u001b[0m old_name, new_name, kwargs, function\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, version\n\u001b[1;32m 99\u001b[0m )\n\u001b[0;32m--> 100\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Desktop/MY/Projects/polars_ds_extension/.venv/lib/python3.11/site-packages/polars/lazyframe/frame.py:1787\u001b[0m, in \u001b[0;36mLazyFrame.collect\u001b[0;34m(self, type_coercion, predicate_pushdown, projection_pushdown, simplify_expression, slice_pushdown, comm_subplan_elim, comm_subexpr_elim, no_optimization, streaming, _eager)\u001b[0m\n\u001b[1;32m 1774\u001b[0m comm_subplan_elim \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 1776\u001b[0m ldf \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ldf\u001b[38;5;241m.\u001b[39moptimization_toggle(\n\u001b[1;32m 1777\u001b[0m type_coercion,\n\u001b[1;32m 1778\u001b[0m predicate_pushdown,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1785\u001b[0m _eager,\n\u001b[1;32m 1786\u001b[0m )\n\u001b[0;32m-> 1787\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m wrap_df(ldf\u001b[38;5;241m.\u001b[39mcollect())\n", + "\u001b[0;31mComputeError\u001b[0m: the length of the window expression did not match that of the group\n> group: (\"a\")\n> group length: 1\n> output: 'shape: (2,)\nSeries: '' [f64]\n[\n\t2.000005\n\t-1.000005\n]'\n\nError originated in expression: 'col(\"y\")./home/abstractqqq/Desktop/MY/Projects/polars_ds_extension/.venv/lib/python3.11/site-packages/polars_ds/_polars_ds.cpython-311-x86_64-linux-gnu.so:pl_lstsq([false.strict_cast(Boolean), col(\"a\"), col(\"b\")]).over([col(\"dummy\")])'" + ] + } + ], + "source": [ + "df.select(\n", + " pl.col(\"y\").num_ext.lstsq([pl.col(\"a\"), pl.col(\"b\")]).over(\"dummy\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2bbfd7a6-2edd-4406-9729-c6d649429d8f", + "metadata": {}, + "outputs": [], + "source": [ + "%timeit df.select(pl.col(\"y\").num_ext.lstsq2(pl.col(\"a\"), pl.col(\"b\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86caf610-76c2-4a58-9199-71f05a3d5afb", + "metadata": {}, + "outputs": [], + "source": [ + "%timeit df.select(pl.col(\"y\").num_ext.lstsq(pl.col(\"a\"), pl.col(\"b\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe338708-1585-4682-937e-40099f895a76", + "metadata": {}, + "outputs": [], + "source": [ + "%timeit df.select(pl.col(\"a\").num_ext.gcd(15))\n", + "%timeit df.select(pl.col(\"a\").num_ext.gcd(pl.col(\"b\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "275dac95-04ed-4c0e-9be9-4d9631e25a58", + "metadata": {}, + "outputs": [], + "source": [ + "df.select(pl.col(\"a\").num_ext.gcd2(pl.col(\"b\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9acfcf37-85f9-44d6-96b4-897cd3fca4e6", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "803ba306-791f-4a1e-a5b6-a6d6b9b55447", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/Untitled.ipynb b/tests/Untitled.ipynb new file mode 100644 index 00000000..532124e8 --- /dev/null +++ b/tests/Untitled.ipynb @@ -0,0 +1,285 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "529f4422-5c3a-4bd6-abe0-a15edfc62abb", + "metadata": {}, + "outputs": [], + "source": [ + "from polars_ds.extensions import StrExt, NumExt\n", + "import polars as pl\n", + "import numpy as np " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "592cb462-ee66-49df-9f65-740116abcac2", + "metadata": {}, + "outputs": [], + "source": [ + "df = pl.DataFrame({\n", + " \"a\":[\"Atlanta\"] * 100_000,\n", + " \"b\":[\"Atlantis\"] * 100_000,\n", + "})\n", + "\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f82ebcaa-25df-4ddd-b166-55f06e979593", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 4)
dummyaby
stri64i64f64
"a"0-100000100000.5
"a"1-99999100001.5
"a"2-99998100002.5
"a"3-99997100003.5
"a"4-99996100004.5
" + ], + "text/plain": [ + "shape: (5, 4)\n", + "┌───────┬─────┬─────────┬──────────┐\n", + "│ dummy ┆ a ┆ b ┆ y │\n", + "│ --- ┆ --- ┆ --- ┆ --- │\n", + "│ str ┆ i64 ┆ i64 ┆ f64 │\n", + "╞═══════╪═════╪═════════╪══════════╡\n", + "│ a ┆ 0 ┆ -100000 ┆ 100000.5 │\n", + "│ a ┆ 1 ┆ -99999 ┆ 100001.5 │\n", + "│ a ┆ 2 ┆ -99998 ┆ 100002.5 │\n", + "│ a ┆ 3 ┆ -99997 ┆ 100003.5 │\n", + "│ a ┆ 4 ┆ -99996 ┆ 100004.5 │\n", + "└───────┴─────┴─────────┴──────────┘" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pl.DataFrame({\n", + " \"dummy\": [\"a\"] * 50_000 + [\"b\"] * 50_000,\n", + " \"a\": range(100_000),\n", + " \"b\": range(-100_000, 0),\n", + " \"y\": pl.Series(range(100_000, 200_000)) + 0.5\n", + "})\n", + "\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5821d4d0-fe4f-4864-9d56-0a2c0ef03334", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (2, 2)
dummybetas
strlist[f64]
"a"[0.668022, 0.331978, 133198.343417]
"b"[-0.411356, 1.411356, 241136.143309]
" + ], + "text/plain": [ + "shape: (2, 2)\n", + "┌───────┬───────────────────────────────────┐\n", + "│ dummy ┆ betas │\n", + "│ --- ┆ --- │\n", + "│ str ┆ list[f64] │\n", + "╞═══════╪═══════════════════════════════════╡\n", + "│ a ┆ [0.668022, 0.331978, 133198.3434… │\n", + "│ b ┆ [-0.411356, 1.411356, 241136.143… │\n", + "└───────┴───────────────────────────────────┘" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.group_by(\"dummy\").agg(\n", + " pl.col(\"y\").num_ext.lstsq([pl.col(\"a\"), pl.col(\"b\")], add_bias = True)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3041be37-fa78-4fcc-bb04-a5d598aeedf8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (2, 2)
dummybetas
strlist[f64]
"a"[2.000005, -1.000005]
"b"[2.000005, -1.000005]
" + ], + "text/plain": [ + "shape: (2, 2)\n", + "┌───────┬───────────────────────┐\n", + "│ dummy ┆ betas │\n", + "│ --- ┆ --- │\n", + "│ str ┆ list[f64] │\n", + "╞═══════╪═══════════════════════╡\n", + "│ a ┆ [2.000005, -1.000005] │\n", + "│ b ┆ [2.000005, -1.000005] │\n", + "└───────┴───────────────────────┘" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.group_by(\"dummy\").agg(\n", + " pl.col(\"y\").num_ext.lstsq([pl.col(\"a\"), pl.col(\"b\")], add_bias = False)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3f9ad447-c7bb-4830-b996-882a382f0854", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "757 µs ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + ] + } + ], + "source": [ + "%timeit df.select(pl.col(\"y\").num_ext.lstsq([pl.col(\"a\"), pl.col(\"b\")]))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ff9b6b21-532a-4655-a8a4-e5396c760a67", + "metadata": {}, + "outputs": [ + { + "ename": "ComputeError", + "evalue": "the length of the window expression did not match that of the group\n> group: (\"a\")\n> group length: 1\n> output: 'shape: (2,)\nSeries: '' [f64]\n[\n\t2.000005\n\t-1.000005\n]'\n\nError originated in expression: 'col(\"y\")./home/abstractqqq/Desktop/MY/Projects/polars_ds_extension/.venv/lib/python3.11/site-packages/polars_ds/_polars_ds.cpython-311-x86_64-linux-gnu.so:pl_lstsq([false.strict_cast(Boolean), col(\"a\"), col(\"b\")]).over([col(\"dummy\")])'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mComputeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mdf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43my\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_ext\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstsq\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43ma\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mover\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdummy\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Desktop/MY/Projects/polars_ds_extension/.venv/lib/python3.11/site-packages/polars/dataframe/frame.py:7766\u001b[0m, in \u001b[0;36mDataFrame.select\u001b[0;34m(self, *exprs, **named_exprs)\u001b[0m\n\u001b[1;32m 7664\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mselect\u001b[39m(\n\u001b[1;32m 7665\u001b[0m \u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39mexprs: IntoExpr \u001b[38;5;241m|\u001b[39m Iterable[IntoExpr], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mnamed_exprs: IntoExpr\n\u001b[1;32m 7666\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m DataFrame:\n\u001b[1;32m 7667\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 7668\u001b[0m \u001b[38;5;124;03m Select columns from this DataFrame.\u001b[39;00m\n\u001b[1;32m 7669\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 7764\u001b[0m \n\u001b[1;32m 7765\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 7766\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlazy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mexprs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnamed_exprs\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollect\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_eager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Desktop/MY/Projects/polars_ds_extension/.venv/lib/python3.11/site-packages/polars/utils/deprecation.py:100\u001b[0m, in \u001b[0;36mdeprecate_renamed_parameter..decorate..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(function)\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs: P\u001b[38;5;241m.\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: P\u001b[38;5;241m.\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 97\u001b[0m _rename_keyword_argument(\n\u001b[1;32m 98\u001b[0m old_name, new_name, kwargs, function\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, version\n\u001b[1;32m 99\u001b[0m )\n\u001b[0;32m--> 100\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Desktop/MY/Projects/polars_ds_extension/.venv/lib/python3.11/site-packages/polars/lazyframe/frame.py:1787\u001b[0m, in \u001b[0;36mLazyFrame.collect\u001b[0;34m(self, type_coercion, predicate_pushdown, projection_pushdown, simplify_expression, slice_pushdown, comm_subplan_elim, comm_subexpr_elim, no_optimization, streaming, _eager)\u001b[0m\n\u001b[1;32m 1774\u001b[0m comm_subplan_elim \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 1776\u001b[0m ldf \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ldf\u001b[38;5;241m.\u001b[39moptimization_toggle(\n\u001b[1;32m 1777\u001b[0m type_coercion,\n\u001b[1;32m 1778\u001b[0m predicate_pushdown,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1785\u001b[0m _eager,\n\u001b[1;32m 1786\u001b[0m )\n\u001b[0;32m-> 1787\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m wrap_df(ldf\u001b[38;5;241m.\u001b[39mcollect())\n", + "\u001b[0;31mComputeError\u001b[0m: the length of the window expression did not match that of the group\n> group: (\"a\")\n> group length: 1\n> output: 'shape: (2,)\nSeries: '' [f64]\n[\n\t2.000005\n\t-1.000005\n]'\n\nError originated in expression: 'col(\"y\")./home/abstractqqq/Desktop/MY/Projects/polars_ds_extension/.venv/lib/python3.11/site-packages/polars_ds/_polars_ds.cpython-311-x86_64-linux-gnu.so:pl_lstsq([false.strict_cast(Boolean), col(\"a\"), col(\"b\")]).over([col(\"dummy\")])'" + ] + } + ], + "source": [ + "df.select(\n", + " pl.col(\"y\").num_ext.lstsq([pl.col(\"a\"), pl.col(\"b\")]).over(\"dummy\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2bbfd7a6-2edd-4406-9729-c6d649429d8f", + "metadata": {}, + "outputs": [], + "source": [ + "%timeit df.select(pl.col(\"y\").num_ext.lstsq2(pl.col(\"a\"), pl.col(\"b\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86caf610-76c2-4a58-9199-71f05a3d5afb", + "metadata": {}, + "outputs": [], + "source": [ + "%timeit df.select(pl.col(\"y\").num_ext.lstsq(pl.col(\"a\"), pl.col(\"b\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe338708-1585-4682-937e-40099f895a76", + "metadata": {}, + "outputs": [], + "source": [ + "%timeit df.select(pl.col(\"a\").num_ext.gcd(15))\n", + "%timeit df.select(pl.col(\"a\").num_ext.gcd(pl.col(\"b\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "275dac95-04ed-4c0e-9be9-4d9631e25a58", + "metadata": {}, + "outputs": [], + "source": [ + "df.select(pl.col(\"a\").num_ext.gcd2(pl.col(\"b\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9acfcf37-85f9-44d6-96b4-897cd3fca4e6", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "803ba306-791f-4a1e-a5b6-a6d6b9b55447", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/adhoc.py b/tests/adhoc.py deleted file mode 100644 index 2e48b0d3..00000000 --- a/tests/adhoc.py +++ /dev/null @@ -1,37 +0,0 @@ -import polars as pl -import timeit -from polars_ds.extensions import StrExt # noqa: F401 - -def least_square1(df:pl.dataframe) -> pl.DataFrame: - - return df.select( - pl.col("y").num_ext.lstsq(pl.col("a"), pl.col("b")) - ) - -def least_square2(df:pl.dataframe) -> pl.DataFrame: - - return df.select( - pl.col("y").num_ext.lstsq2(pl.col("a"), pl.col("b")) - ) - - -if __name__ == "__main__": - df = pl.DataFrame({ - "a":pl.Series(range(500_000), dtype=pl.Float64), - "b":pl.Series([1.0] * 500_000, dtype=pl.Float64), - "y":pl.Series(range(500_000), dtype=pl.Float64) + 0.5, - }) - - res1 = least_square1(df) - res2 = least_square2(df) - - from polars.testing import assert_frame_equal - - assert_frame_equal( - res1, res2 - ) - - time1 = timeit.timeit(lambda: least_square1(df), number = 10) - time2 = timeit.timeit(lambda: least_square2(df), number = 10) - print(f"Time for Implementation 1: {time1:.4f}s.") - print(f"Time for Implementation 2: {time2:.4f}s.") \ No newline at end of file