From e2ae1b5409fecdd3f7eb7e1e0e5a7ade3553711c Mon Sep 17 00:00:00 2001 From: abstractqqq Date: Mon, 27 Nov 2023 07:11:52 -0500 Subject: [PATCH] added 1 samp ttest --- python/polars_ds/stats_ext.py | 37 ++++++++++++--- src/num_ext/fft.rs | 4 +- src/num_ext/gcd_lcm.rs | 1 - src/num_ext/jaccard.rs | 3 +- src/num_ext/ols.rs | 1 - src/num_ext/powi.rs | 2 - src/num_ext/tp_fp.rs | 3 +- src/num_ext/trapz.rs | 1 - src/stats/mod.rs | 6 ++- src/stats_ext/ks.rs | 1 - src/stats_ext/normal_test.rs | 7 ++- src/stats_ext/sample.rs | 1 - src/stats_ext/t_test.rs | 84 ++++++++++++++++++++++++++++------- 13 files changed, 110 insertions(+), 41 deletions(-) diff --git a/python/polars_ds/stats_ext.py b/python/polars_ds/stats_ext.py index 87cb3d26..cbcc3e96 100644 --- a/python/polars_ds/stats_ext.py +++ b/python/polars_ds/stats_ext.py @@ -21,9 +21,10 @@ def ttest_ind( within 1e-10 precision from SciPy's result. In the case of student's t test, the user is responsible for data to have equal length, - and nulls will be ignored when computing mean and variance. As a result, nulls might - cause problems for student's t test. In the case of Welch's t test, data - will be sanitized (nulls, NaNs, Infs will be dropped before the test). + and nulls will be ignored when computing mean and variance. The df will be 2n - 2. As a + result, nulls might cause problems. In the case of Welch's t test, data + will be sanitized (nulls, NaNs, Infs will be dropped before the test), and df will be + counted based on the length of sanitized data. Parameters ---------- @@ -40,11 +41,11 @@ def ttest_ind( m2 = other.mean() v1 = self._expr.var() v2 = other.var() - # Note here that nulls are not filtered + # Note here that nulls are not filtered to ensure the same length cnt = self._expr.count().cast(pl.UInt64) return m1.register_plugin( lib=lib, - symbol="pl_student_t_2samp", + symbol="pl_ttest_2samp", args=[m2, v1, v2, cnt, pl.lit(alternative, dtype=pl.Utf8)], is_elementwise=False, returns_scalar=True, @@ -66,6 +67,32 @@ def ttest_ind( returns_scalar=True, ) + def ttest_1samp(self, pop_mean: float, alternative: Alternative = "two-sided") -> pl.Expr: + """ + Performs a standard 1 sample t test using reference column and expected mean. This function + sanitizes the self column first. The df is the count of valid (non-null, finite) values. + + Parameters + ---------- + pop_mean + The expected population mean in the hypothesis test + alternative + One of "two-sided", "less" or "greater" + """ + s1 = self._expr.filter(self._expr.is_finite()) + sm = s1.mean() + pm = pl.lit(pop_mean, dtype=pl.Float64) + var = s1.var() + cnt = s1.count().cast(pl.UInt64) + alt = pl.lit(alternative, dtype=pl.Utf8) + return sm.register_plugin( + lib=lib, + symbol="pl_ttest_1samp", + args=[pm, var, cnt, alt], + is_elementwise=False, + returns_scalar=True, + ) + def normal_test(self) -> pl.Expr: """ Perform a normality test which is based on D'Agostino and Pearson's test diff --git a/src/num_ext/fft.rs b/src/num_ext/fft.rs index 63b632e4..b1d13e4b 100644 --- a/src/num_ext/fft.rs +++ b/src/num_ext/fft.rs @@ -1,9 +1,8 @@ /// Performs forward FFT. /// Since data in dataframe are always real numbers, only realfft -/// is implemented and inverse fft is not implemented and even if it +/// is implemented and inverse fft is not implemented and even if it /// is eventually implemented, it would likely not be a dataframe /// operation. - use itertools::Either; use polars::prelude::*; use pyo3_polars::derive::polars_expr; @@ -16,7 +15,6 @@ fn complex_output(_: &[Field]) -> PolarsResult { )) } - #[polars_expr(output_type_func=complex_output)] fn pl_rfft(inputs: &[Series]) -> PolarsResult { // Take a step argument diff --git a/src/num_ext/gcd_lcm.rs b/src/num_ext/gcd_lcm.rs index 2a060e4a..854f8250 100644 --- a/src/num_ext/gcd_lcm.rs +++ b/src/num_ext/gcd_lcm.rs @@ -1,5 +1,4 @@ /// GCD and LCM for integers in dataframe. - use num; use polars::prelude::*; use polars_core::prelude::arity::binary_elementwise_values; diff --git a/src/num_ext/jaccard.rs b/src/num_ext/jaccard.rs index 476346f2..deaa763e 100644 --- a/src/num_ext/jaccard.rs +++ b/src/num_ext/jaccard.rs @@ -1,7 +1,6 @@ /// Jaccard similarity for two columns /// + Jaccard similarity for two columns of lists -/// - +/// use core::hash::Hash; use polars::prelude::*; use pyo3_polars::{ diff --git a/src/num_ext/ols.rs b/src/num_ext/ols.rs index d3779bb4..6e422a75 100644 --- a/src/num_ext/ols.rs +++ b/src/num_ext/ols.rs @@ -1,5 +1,4 @@ /// OLS using Faer. - use faer::IntoFaer; use faer::{prelude::*, MatRef}; use polars::prelude::*; diff --git a/src/num_ext/powi.rs b/src/num_ext/powi.rs index 2a43cf53..217d5b0c 100644 --- a/src/num_ext/powi.rs +++ b/src/num_ext/powi.rs @@ -2,8 +2,6 @@ /// Unfortunately, the pl.col("a").num_ext.powi(pl.col("b")) version may not /// be faster, likely due to lack of SIMD (my hunch). However, something like /// pl.col("a").num_ext.powi(16) is significantly faster than Polars's default. - - use num::traits::Inv; use polars::prelude::*; use polars_core::prelude::arity::binary_elementwise_values; diff --git a/src/num_ext/tp_fp.rs b/src/num_ext/tp_fp.rs index 398cf39d..e5314504 100644 --- a/src/num_ext/tp_fp.rs +++ b/src/num_ext/tp_fp.rs @@ -1,7 +1,6 @@ /// All things true positive, false positive related. /// ROC AUC, Average Precision, precision, recall, etc. -/// - +/// use ndarray::ArrayView1; use polars::{lazy::dsl::count, prelude::*, series::ops::NullBehavior}; use pyo3_polars::derive::polars_expr; diff --git a/src/num_ext/trapz.rs b/src/num_ext/trapz.rs index ad23d0fd..a0895532 100644 --- a/src/num_ext/trapz.rs +++ b/src/num_ext/trapz.rs @@ -1,5 +1,4 @@ /// Integration via Trapezoidal rule. - use ndarray::{s, ArrayView1}; use polars::{ prelude::{PolarsError, PolarsResult}, diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 89d41920..6805d2d5 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -3,7 +3,6 @@ /// multi-variate distributions, which is something that I think will not be needed in this /// package. Another reason is that if I want to do linear algebra, I would use Faer since Faer /// performs better and nalgebra is too much of a dependency for this package right now. - pub mod beta; pub mod gamma; pub mod normal; @@ -12,3 +11,8 @@ pub const PREC_ACC: f64 = 0.0000000000000011102230246251565; pub const LN_PI: f64 = 1.1447298858494001741434273513530587116472948129153; //pub const LN_SQRT_2PI: f64 = 0.91893853320467274178032973640561763986139747363778; pub const LN_2_SQRT_E_OVER_PI: f64 = 0.6207822376352452223455184457816472122518527279025978; + +#[inline] +pub fn is_zero(x: f64) -> bool { + x.abs() < PREC_ACC +} diff --git a/src/stats_ext/ks.rs b/src/stats_ext/ks.rs index a5a4117d..f4e1cf5b 100644 --- a/src/stats_ext/ks.rs +++ b/src/stats_ext/ks.rs @@ -1,5 +1,4 @@ /// KS statistics. - use crate::stats_ext::StatsResult; use crate::utils::binary_search_right; use itertools::Itertools; diff --git a/src/stats_ext/normal_test.rs b/src/stats_ext/normal_test.rs index 04a90c48..afbc2876 100644 --- a/src/stats_ext/normal_test.rs +++ b/src/stats_ext/normal_test.rs @@ -1,4 +1,4 @@ -/// Here we implement the test as in SciPy: +/// Here we implement the test as in SciPy: /// https://github.com/scipy/scipy/blob/v1.11.4/scipy/stats/_stats_py.py#L1836-L1996 /// /// It is a method based on Kurtosis and Skew, and the Chi-2 distribution. @@ -9,9 +9,8 @@ /// [2] https://www.stata.com/manuals/rsktest.pdf /// /// I chose this over the Shapiro Francia test because the distribution is unknown and would require Monte Carlo - use super::{simple_stats_output, StatsResult}; -use crate::stats::gamma; +use crate::stats::{gamma, is_zero}; use polars::prelude::*; use pyo3_polars::derive::polars_expr; @@ -43,7 +42,7 @@ fn kurtosis_test_statistic(kur: f64, n: usize) -> Result { let tmp = 2. / (9. * a); let denom = 1. + x * (2. / (a - 4.)).sqrt(); - if denom == 0. { + if is_zero(denom) { Err("Kurtosis test: Division by 0 encountered.".to_owned()) } else { let term1 = 1. - tmp; diff --git a/src/stats_ext/sample.rs b/src/stats_ext/sample.rs index def93dc3..87b2d525 100644 --- a/src/stats_ext/sample.rs +++ b/src/stats_ext/sample.rs @@ -5,7 +5,6 @@ /// /// I think it is ok to use CSPRNGS because it is fast enough and we generally do not /// want output to be easily guessable. - use itertools::Itertools; use polars::prelude::*; use pyo3_polars::derive::polars_expr; diff --git a/src/stats_ext/t_test.rs b/src/stats_ext/t_test.rs index 75261d84..bce55898 100644 --- a/src/stats_ext/t_test.rs +++ b/src/stats_ext/t_test.rs @@ -1,7 +1,6 @@ /// Student's t test and Welch's t test. - use super::{simple_stats_output, Alternative, StatsResult}; -use crate::stats::beta; +use crate::stats::{beta, is_zero}; use polars::prelude::*; use pyo3_polars::derive::polars_expr; @@ -17,8 +16,8 @@ fn ttest_ind( let num = m1 - m2; // ((var1 + var2) / 2 ).sqrt() * (2./n).sqrt() can be simplified as below let denom = ((v1 + v2) / n).sqrt(); - if denom == 0. { - Ok(StatsResult::new(f64::INFINITY, f64::NAN)) + if is_zero(denom) { + Err("T Test: Division by 0 encountered.".into()) } else { let t = num / denom; let df = 2. * n - 2.; @@ -35,6 +34,34 @@ fn ttest_ind( } } +#[inline] +fn ttest_1samp( + mean: f64, + pop_mean: f64, + var: f64, + n: f64, + alt: Alternative, +) -> Result { + let num = mean - pop_mean; + let denom = (var / n).sqrt(); + if is_zero(denom) { + Err("T Test: Division by 0 encountered.".into()) + } else { + let t = num / denom; + let df = n - 1.; + let p = match alt { + Alternative::Less => beta::student_t_sf(-t, df), + Alternative::Greater => beta::student_t_sf(t, df), + Alternative::TwoSided => match beta::student_t_sf(t.abs(), df) { + Ok(p) => Ok(2.0 * p), + Err(e) => Err(e), + }, + }; + let p = p?; + Ok(StatsResult::new(t, p)) + } +} + #[inline] fn welch_t( m1: f64, @@ -49,8 +76,8 @@ fn welch_t( let vn1 = v1 / n1; let vn2 = v2 / n2; let denom = (vn1 + vn2).sqrt(); - if denom == 0. { - Ok(StatsResult::new(f64::INFINITY, f64::NAN)) + if is_zero(denom) { + Err("T Test: Division by 0 encountered.".into()) } else { let t = num / denom; let df = (vn1 + vn2).powi(2) / (vn1.powi(2) / (n1 - 1.) + (vn2.powi(2) / (n2 - 1.))); @@ -69,15 +96,15 @@ fn welch_t( } #[polars_expr(output_type_func=simple_stats_output)] -fn pl_student_t_2samp(inputs: &[Series]) -> PolarsResult { +fn pl_ttest_2samp(inputs: &[Series]) -> PolarsResult { let mean1 = inputs[0].f64()?; - let mean1 = mean1.get(0).unwrap(); + let mean1 = mean1.get(0).unwrap_or(f64::NAN); let mean2 = inputs[1].f64()?; - let mean2 = mean2.get(0).unwrap(); + let mean2 = mean2.get(0).unwrap_or(f64::NAN); let var1 = inputs[2].f64()?; - let var1 = var1.get(0).unwrap(); + let var1 = var1.get(0).unwrap_or(f64::NAN); let var2 = inputs[3].f64()?; - let var2 = var2.get(0).unwrap(); + let var2 = var2.get(0).unwrap_or(f64::NAN); let n = inputs[4].u64()?; let n = n.get(0).unwrap() as f64; @@ -122,12 +149,7 @@ fn pl_welch_t(inputs: &[Series]) -> PolarsResult { let alt = alt.get(0).unwrap(); let alt = super::Alternative::from(alt); - let valid = mean1.is_finite() && mean2.is_finite() && var1.is_finite() && var2.is_finite(); - if !valid { - return Err(PolarsError::ComputeError( - "T Test: Sample Mean/Std is found to be NaN or Inf.".into(), - )); - } + // No need to check for validity because input is sanitized. let res = welch_t(mean1, mean2, var1, var2, n1, n2, alt) .map_err(|e| PolarsError::ComputeError(e.into())); @@ -139,3 +161,31 @@ fn pl_welch_t(inputs: &[Series]) -> PolarsResult { let out = StructChunked::new("", &[s, p])?; Ok(out.into_series()) } + +#[polars_expr(output_type_func=simple_stats_output)] +fn pl_ttest_1samp(inputs: &[Series]) -> PolarsResult { + let mean = inputs[0].f64()?; + let mean = mean.get(0).unwrap(); + let pop_mean = inputs[1].f64()?; + let pop_mean = pop_mean.get(0).unwrap(); + let var = inputs[2].f64()?; + let var = var.get(0).unwrap(); + let n = inputs[3].u64()?; + let n = n.get(0).unwrap() as f64; + + let alt = inputs[4].utf8()?; + let alt = alt.get(0).unwrap(); + let alt = super::Alternative::from(alt); + + // No need to check for validity because input is sanitized. + + let res = + ttest_1samp(mean, pop_mean, var, n, alt).map_err(|e| PolarsError::ComputeError(e.into())); + let res = res?; + + let s = Series::from_vec("statistic", vec![res.statistic]); + let pchunked = Float64Chunked::from_iter_options("pvalue", [res.p].into_iter()); + let p = pchunked.into_series(); + let out = StructChunked::new("", &[s, p])?; + Ok(out.into_series()) +}