Skip to content

Commit

Permalink
added 1 samp ttest
Browse files Browse the repository at this point in the history
  • Loading branch information
abstractqqq committed Nov 27, 2023
1 parent abba8a4 commit e2ae1b5
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 41 deletions.
37 changes: 32 additions & 5 deletions python/polars_ds/stats_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def ttest_ind(
within 1e-10 precision from SciPy's result.
In the case of student's t test, the user is responsible for data to have equal length,
and nulls will be ignored when computing mean and variance. As a result, nulls might
cause problems for student's t test. In the case of Welch's t test, data
will be sanitized (nulls, NaNs, Infs will be dropped before the test).
and nulls will be ignored when computing mean and variance. The df will be 2n - 2. As a
result, nulls might cause problems. In the case of Welch's t test, data
will be sanitized (nulls, NaNs, Infs will be dropped before the test), and df will be
counted based on the length of sanitized data.
Parameters
----------
Expand All @@ -40,11 +41,11 @@ def ttest_ind(
m2 = other.mean()
v1 = self._expr.var()
v2 = other.var()
# Note here that nulls are not filtered
# Note here that nulls are not filtered to ensure the same length
cnt = self._expr.count().cast(pl.UInt64)
return m1.register_plugin(
lib=lib,
symbol="pl_student_t_2samp",
symbol="pl_ttest_2samp",
args=[m2, v1, v2, cnt, pl.lit(alternative, dtype=pl.Utf8)],
is_elementwise=False,
returns_scalar=True,
Expand All @@ -66,6 +67,32 @@ def ttest_ind(
returns_scalar=True,
)

def ttest_1samp(self, pop_mean: float, alternative: Alternative = "two-sided") -> pl.Expr:
"""
Performs a standard 1 sample t test using reference column and expected mean. This function
sanitizes the self column first. The df is the count of valid (non-null, finite) values.
Parameters
----------
pop_mean
The expected population mean in the hypothesis test
alternative
One of "two-sided", "less" or "greater"
"""
s1 = self._expr.filter(self._expr.is_finite())
sm = s1.mean()
pm = pl.lit(pop_mean, dtype=pl.Float64)
var = s1.var()
cnt = s1.count().cast(pl.UInt64)
alt = pl.lit(alternative, dtype=pl.Utf8)
return sm.register_plugin(
lib=lib,
symbol="pl_ttest_1samp",
args=[pm, var, cnt, alt],
is_elementwise=False,
returns_scalar=True,
)

def normal_test(self) -> pl.Expr:
"""
Perform a normality test which is based on D'Agostino and Pearson's test
Expand Down
4 changes: 1 addition & 3 deletions src/num_ext/fft.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
/// Performs forward FFT.
/// Since data in dataframe are always real numbers, only realfft
/// is implemented and inverse fft is not implemented and even if it
/// is implemented and inverse fft is not implemented and even if it
/// is eventually implemented, it would likely not be a dataframe
/// operation.
use itertools::Either;
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;
Expand All @@ -16,7 +15,6 @@ fn complex_output(_: &[Field]) -> PolarsResult<Field> {
))
}


#[polars_expr(output_type_func=complex_output)]
fn pl_rfft(inputs: &[Series]) -> PolarsResult<Series> {
// Take a step argument
Expand Down
1 change: 0 additions & 1 deletion src/num_ext/gcd_lcm.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
/// GCD and LCM for integers in dataframe.
use num;
use polars::prelude::*;
use polars_core::prelude::arity::binary_elementwise_values;
Expand Down
3 changes: 1 addition & 2 deletions src/num_ext/jaccard.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
/// Jaccard similarity for two columns
/// + Jaccard similarity for two columns of lists
///
///
use core::hash::Hash;
use polars::prelude::*;
use pyo3_polars::{
Expand Down
1 change: 0 additions & 1 deletion src/num_ext/ols.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
/// OLS using Faer.
use faer::IntoFaer;
use faer::{prelude::*, MatRef};
use polars::prelude::*;
Expand Down
2 changes: 0 additions & 2 deletions src/num_ext/powi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
/// Unfortunately, the pl.col("a").num_ext.powi(pl.col("b")) version may not
/// be faster, likely due to lack of SIMD (my hunch). However, something like
/// pl.col("a").num_ext.powi(16) is significantly faster than Polars's default.

use num::traits::Inv;
use polars::prelude::*;
use polars_core::prelude::arity::binary_elementwise_values;
Expand Down
3 changes: 1 addition & 2 deletions src/num_ext/tp_fp.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
/// All things true positive, false positive related.
/// ROC AUC, Average Precision, precision, recall, etc.
///
///
use ndarray::ArrayView1;
use polars::{lazy::dsl::count, prelude::*, series::ops::NullBehavior};
use pyo3_polars::derive::polars_expr;
Expand Down
1 change: 0 additions & 1 deletion src/num_ext/trapz.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
/// Integration via Trapezoidal rule.
use ndarray::{s, ArrayView1};
use polars::{
prelude::{PolarsError, PolarsResult},
Expand Down
6 changes: 5 additions & 1 deletion src/stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
/// multi-variate distributions, which is something that I think will not be needed in this
/// package. Another reason is that if I want to do linear algebra, I would use Faer since Faer
/// performs better and nalgebra is too much of a dependency for this package right now.
pub mod beta;
pub mod gamma;
pub mod normal;
Expand All @@ -12,3 +11,8 @@ pub const PREC_ACC: f64 = 0.0000000000000011102230246251565;
pub const LN_PI: f64 = 1.1447298858494001741434273513530587116472948129153;
//pub const LN_SQRT_2PI: f64 = 0.91893853320467274178032973640561763986139747363778;
pub const LN_2_SQRT_E_OVER_PI: f64 = 0.6207822376352452223455184457816472122518527279025978;

#[inline]
pub fn is_zero(x: f64) -> bool {
x.abs() < PREC_ACC
}
1 change: 0 additions & 1 deletion src/stats_ext/ks.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
/// KS statistics.
use crate::stats_ext::StatsResult;
use crate::utils::binary_search_right;
use itertools::Itertools;
Expand Down
7 changes: 3 additions & 4 deletions src/stats_ext/normal_test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/// Here we implement the test as in SciPy:
/// Here we implement the test as in SciPy:
/// https://github.com/scipy/scipy/blob/v1.11.4/scipy/stats/_stats_py.py#L1836-L1996
///
/// It is a method based on Kurtosis and Skew, and the Chi-2 distribution.
Expand All @@ -9,9 +9,8 @@
/// [2] https://www.stata.com/manuals/rsktest.pdf
///
/// I chose this over the Shapiro Francia test because the distribution is unknown and would require Monte Carlo
use super::{simple_stats_output, StatsResult};
use crate::stats::gamma;
use crate::stats::{gamma, is_zero};
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;

Expand Down Expand Up @@ -43,7 +42,7 @@ fn kurtosis_test_statistic(kur: f64, n: usize) -> Result<StatsResult, String> {

let tmp = 2. / (9. * a);
let denom = 1. + x * (2. / (a - 4.)).sqrt();
if denom == 0. {
if is_zero(denom) {
Err("Kurtosis test: Division by 0 encountered.".to_owned())
} else {
let term1 = 1. - tmp;
Expand Down
1 change: 0 additions & 1 deletion src/stats_ext/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
///
/// I think it is ok to use CSPRNGS because it is fast enough and we generally do not
/// want output to be easily guessable.
use itertools::Itertools;
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;
Expand Down
84 changes: 67 additions & 17 deletions src/stats_ext/t_test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
/// Student's t test and Welch's t test.
use super::{simple_stats_output, Alternative, StatsResult};
use crate::stats::beta;
use crate::stats::{beta, is_zero};
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;

Expand All @@ -17,8 +16,8 @@ fn ttest_ind(
let num = m1 - m2;
// ((var1 + var2) / 2 ).sqrt() * (2./n).sqrt() can be simplified as below
let denom = ((v1 + v2) / n).sqrt();
if denom == 0. {
Ok(StatsResult::new(f64::INFINITY, f64::NAN))
if is_zero(denom) {
Err("T Test: Division by 0 encountered.".into())
} else {
let t = num / denom;
let df = 2. * n - 2.;
Expand All @@ -35,6 +34,34 @@ fn ttest_ind(
}
}

#[inline]
fn ttest_1samp(
mean: f64,
pop_mean: f64,
var: f64,
n: f64,
alt: Alternative,
) -> Result<StatsResult, String> {
let num = mean - pop_mean;
let denom = (var / n).sqrt();
if is_zero(denom) {
Err("T Test: Division by 0 encountered.".into())
} else {
let t = num / denom;
let df = n - 1.;
let p = match alt {
Alternative::Less => beta::student_t_sf(-t, df),
Alternative::Greater => beta::student_t_sf(t, df),
Alternative::TwoSided => match beta::student_t_sf(t.abs(), df) {
Ok(p) => Ok(2.0 * p),
Err(e) => Err(e),
},
};
let p = p?;
Ok(StatsResult::new(t, p))
}
}

#[inline]
fn welch_t(
m1: f64,
Expand All @@ -49,8 +76,8 @@ fn welch_t(
let vn1 = v1 / n1;
let vn2 = v2 / n2;
let denom = (vn1 + vn2).sqrt();
if denom == 0. {
Ok(StatsResult::new(f64::INFINITY, f64::NAN))
if is_zero(denom) {
Err("T Test: Division by 0 encountered.".into())
} else {
let t = num / denom;
let df = (vn1 + vn2).powi(2) / (vn1.powi(2) / (n1 - 1.) + (vn2.powi(2) / (n2 - 1.)));
Expand All @@ -69,15 +96,15 @@ fn welch_t(
}

#[polars_expr(output_type_func=simple_stats_output)]
fn pl_student_t_2samp(inputs: &[Series]) -> PolarsResult<Series> {
fn pl_ttest_2samp(inputs: &[Series]) -> PolarsResult<Series> {
let mean1 = inputs[0].f64()?;
let mean1 = mean1.get(0).unwrap();
let mean1 = mean1.get(0).unwrap_or(f64::NAN);
let mean2 = inputs[1].f64()?;
let mean2 = mean2.get(0).unwrap();
let mean2 = mean2.get(0).unwrap_or(f64::NAN);
let var1 = inputs[2].f64()?;
let var1 = var1.get(0).unwrap();
let var1 = var1.get(0).unwrap_or(f64::NAN);
let var2 = inputs[3].f64()?;
let var2 = var2.get(0).unwrap();
let var2 = var2.get(0).unwrap_or(f64::NAN);
let n = inputs[4].u64()?;
let n = n.get(0).unwrap() as f64;

Expand Down Expand Up @@ -122,12 +149,7 @@ fn pl_welch_t(inputs: &[Series]) -> PolarsResult<Series> {
let alt = alt.get(0).unwrap();
let alt = super::Alternative::from(alt);

let valid = mean1.is_finite() && mean2.is_finite() && var1.is_finite() && var2.is_finite();
if !valid {
return Err(PolarsError::ComputeError(
"T Test: Sample Mean/Std is found to be NaN or Inf.".into(),
));
}
// No need to check for validity because input is sanitized.

let res = welch_t(mean1, mean2, var1, var2, n1, n2, alt)
.map_err(|e| PolarsError::ComputeError(e.into()));
Expand All @@ -139,3 +161,31 @@ fn pl_welch_t(inputs: &[Series]) -> PolarsResult<Series> {
let out = StructChunked::new("", &[s, p])?;
Ok(out.into_series())
}

#[polars_expr(output_type_func=simple_stats_output)]
fn pl_ttest_1samp(inputs: &[Series]) -> PolarsResult<Series> {
let mean = inputs[0].f64()?;
let mean = mean.get(0).unwrap();
let pop_mean = inputs[1].f64()?;
let pop_mean = pop_mean.get(0).unwrap();
let var = inputs[2].f64()?;
let var = var.get(0).unwrap();
let n = inputs[3].u64()?;
let n = n.get(0).unwrap() as f64;

let alt = inputs[4].utf8()?;
let alt = alt.get(0).unwrap();
let alt = super::Alternative::from(alt);

// No need to check for validity because input is sanitized.

let res =
ttest_1samp(mean, pop_mean, var, n, alt).map_err(|e| PolarsError::ComputeError(e.into()));
let res = res?;

let s = Series::from_vec("statistic", vec![res.statistic]);
let pchunked = Float64Chunked::from_iter_options("pvalue", [res.p].into_iter());
let p = pchunked.into_series();
let out = StructChunked::new("", &[s, p])?;
Ok(out.into_series())
}

0 comments on commit e2ae1b5

Please sign in to comment.