"
+ ],
+ "text/plain": [
+ "shape: (50_000, 2)\n",
+ "┌───────────────────────────────────┬──────┐\n",
+ "│ sen ┆ word │\n",
+ "│ --- ┆ --- │\n",
+ "│ str ┆ str │\n",
+ "╞═══════════════════════════════════╪══════╡\n",
+ "│ Hello, world! I'm going to churc… ┆ word │\n",
+ "│ Hello, world! I'm going to churc… ┆ word │\n",
+ "│ Hello, world! I'm going to churc… ┆ word │\n",
+ "│ Hello, world! I'm going to churc… ┆ word │\n",
+ "│ … ┆ … │\n",
+ "│ Hello, world! I'm going to churc… ┆ word │\n",
+ "│ Hello, world! I'm going to churc… ┆ word │\n",
+ "│ Hello, world! I'm going to churc… ┆ word │\n",
+ "│ Hello, world! I'm going to churc… ┆ word │\n",
+ "└───────────────────────────────────┴──────┘"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.filter(\n",
+ " pl.col(\"word\").str_ext.levenshtein_dist(\"world\") == 1\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d4f45d3d-d3b9-4fde-9ed5-b3d01d0fa1ba",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8073ff19-21da-449d-87c5-2791a574bc81",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "02a88a93-8805-4a97-a94e-196fba7090c5",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/pyproject.toml b/pyproject.toml
index 539aedeb..cfc5809d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,7 +14,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
"License :: OSI Approved :: MIT License",
]
-version = "0.1.0"
+version = "0.1.1"
authors = [
{name = "Tianren Qin", email = "tq9695@gmail.com"},
{name = "Nelson Griffiths", email = "nelsongriffiths123@gmail.com"}
@@ -33,6 +33,7 @@ module-name = "polars_ds._polars_ds"
[project.optional-dependencies]
dev = [
"pytest >= 7.4.1",
+ "pre-commit"
]
[tool.ruff]
diff --git a/python/polars_ds/__init__.py b/python/polars_ds/__init__.py
index 71a71e46..51246c53 100644
--- a/python/polars_ds/__init__.py
+++ b/python/polars_ds/__init__.py
@@ -1,2 +1,8 @@
+version = "0.1.1"
-version = "0.1.0"
\ No newline at end of file
+from polars_ds.extensions import NumExt, StrExt # noqa: E402
+
+__all__ = [
+ "NumExt",
+ "StrExt"
+]
\ No newline at end of file
diff --git a/python/polars_ds/extensions.py b/python/polars_ds/extensions.py
index 391ccd55..78b89552 100644
--- a/python/polars_ds/extensions.py
+++ b/python/polars_ds/extensions.py
@@ -1,6 +1,7 @@
import polars as pl
from typing import Union
from polars.utils.udfs import _get_shared_lib_location
+# from polars.type_aliases import IntoExpr
lib = _get_shared_lib_location(__file__)
@@ -114,59 +115,73 @@ def lcm(self, other: Union[int, pl.Expr]) -> pl.Expr:
is_elementwise=True,
)
- def hubor_loss(self, other: pl.Expr, delta: float) -> pl.Expr:
+ def hubor_loss(self, pred: pl.Expr, delta: float) -> pl.Expr:
"""
- Computes huber loss between this and the other expression
+ Computes huber loss between this and the other expression. This assumes
+ this expression is actual, and the input is predicted, although the order
+ does not matter in this case.
Parameters
----------
- other
- Either an int or a Polars expression
+ pred
+ A Polars expression representing predictions
"""
- temp = (self._expr - other).abs()
+ temp = (self._expr - pred).abs()
return (
- pl.when(temp <= delta)
- .then(0.5 * temp.pow(2))
- .otherwise(delta * (temp - 0.5 * delta))
- / self._expr.count()
+ pl.when(temp <= delta).then(0.5 * temp.pow(2)).otherwise(delta * (temp - 0.5 * delta)) / self._expr.count()
)
- def l1_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr:
+ def l1_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
"""
- Computes L1 loss (normalized L1 distance) between this and the other expression. This
- is the norm without 1/p power.
+ Computes L1 loss (absolute difference) between this and the other expression.
Parameters
----------
- other
- Either an int or a Polars expression
+ pred
+ A Polars expression representing predictions
normalize
If true, divide the result by length of the series
"""
- temp = (self._expr - other).abs().sum()
+ temp = (self._expr - pred).abs().sum()
if normalize:
return temp / self._expr.count()
return temp
- def l2_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr:
+ def l2_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
"""
Computes L2 loss (normalized L2 distance) between this and the other expression. This
is the norm without 1/p power.
-
Parameters
----------
- other
- Either an int or a Polars expression
+ pred
+ A Polars expression representing predictions
normalize
If true, divide the result by length of the series
"""
- temp = self._expr - other
+ temp = self._expr - pred
temp = temp.dot(temp)
if normalize:
return temp / self._expr.count()
return temp
+ def msle(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
+ """
+ Computes the mean square log error.
+
+ Parameters
+ ----------
+ pred
+ A Polars expression representing predictions
+ normalize
+ If true, divide the result by length of the series
+ """
+ diff = self._expr.log1p() - pred.log1p()
+ out = diff.dot(diff)
+ if normalize:
+ return out / self._expr.count()
+ return out
+
# def lp_loss(self, other: pl.Expr, p: float, normalize: bool = True) -> pl.Expr:
# """
# Computes LP loss (normalized LP distance) between this and the other expression. This
@@ -189,30 +204,30 @@ def l2_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr:
# return (temp / self._expr.count())
# return temp
- def chebyshev_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr:
+ def chebyshev_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
"""
Alias for l_inf_loss.
"""
- return self.l_inf_dist(other, normalize)
+ return self.l_inf_dist(pred, normalize)
- def l_inf_loss(self, other: pl.Expr, normalize: bool = True) -> pl.Expr:
+ def l_inf_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
"""
Computes L^infinity loss between this and the other expression
Parameters
- ----------
- other
- Either an int or a Polars expression
+ ----------
+ pred
+ A Polars expression representing predictions
normalize
If true, divide the result by length of the series
"""
- temp = self._expr - other
+ temp = self._expr - pred
out = pl.max_horizontal(temp.min().abs(), temp.max().abs())
if normalize:
return out / self._expr.count()
return out
- def mape(self, other: pl.Expr, weighted: bool = False) -> pl.Expr:
+ def mape(self, pred: pl.Expr, weighted: bool = False) -> pl.Expr:
"""
Computes mean absolute percentage error between self and other. Self is actual.
If weighted, it will compute the weighted version as defined here:
@@ -221,17 +236,17 @@ def mape(self, other: pl.Expr, weighted: bool = False) -> pl.Expr:
Parameters
----------
- other
- Either an int or a Polars expression
+ pred
+ A Polars expression representing predictions
weighted
If true, computes wMAPE in the wikipedia article
"""
if weighted:
- return (self._expr - other).abs().sum() / self._expr.abs().sum()
+ return (self._expr - pred).abs().sum() / self._expr.abs().sum()
else:
- return (1 - other / self._expr).abs().mean()
+ return (1 - pred / self._expr).abs().mean()
- def smape(self, other: pl.Expr) -> pl.Expr:
+ def smape(self, pred: pl.Expr) -> pl.Expr:
"""
Computes symmetric mean absolute percentage error between self and other. Self is actual.
The value is always between 0 and 1. This is the third version in the wikipedia without
@@ -241,30 +256,146 @@ def smape(self, other: pl.Expr) -> pl.Expr:
Parameters
----------
- other
- Either an int or a Polars expression
+ pred
+ A Polars expression representing predictions
"""
- numerator = (self._expr - other).abs()
- denominator = 1.0 / (self._expr.abs() + other.abs())
+ numerator = (self._expr - pred).abs()
+ denominator = 1.0 / (self._expr.abs() + pred.abs())
return (1.0 / self._expr.count()) * numerator.dot(denominator)
- def bce(self, actual: pl.Expr, normalize:bool=True) -> pl.Expr:
+ def bce(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
"""
- Treats self as the prediction. and computes Binary Cross Entropy loss.
+ Computes Binary Cross Entropy loss.
Parameters
----------
- actual
- The actual binary lable. Note: if this column is not binary, then the result
- will be nonsense.
+ pred
+ The predicted probability.
normalize
Whether to divide by N.
"""
- out = actual.dot(self._expr.log()) + (1 - actual).dot((1 - self._expr).log())
+ out = pred.dot(self._expr.log()) + (1 - pred).dot((1 - self._expr).log())
if normalize:
return -(out / self._expr.count())
return -out
+ def r2(self, pred: pl.Expr) -> pl.Expr:
+ """
+ Returns the coefficient of determineation for a regression model.
+
+ Parameters
+ ----------
+ pred
+ A Polars expression representing predictions
+ """
+ diff = self._expr - pred
+ ss_res = diff.dot(diff)
+ diff2 = self._expr - self._expr.mean()
+ ss_tot = diff2.dot(diff2)
+ return 1.0 - ss_res / ss_tot
+
+ def adjusted_r2(self, pred: pl.Expr, p: int) -> pl.Expr:
+ """
+ Returns the adjusted r2 for a regression model.
+
+ Parameters
+ ----------
+ pred
+ A Polars expression representing predictions
+ p
+ The total number of explanatory variables in the model
+ """
+ diff = self._expr - pred
+ ss_res = diff.dot(diff)
+ diff2 = self._expr - self._expr.mean()
+ ss_tot = diff2.dot(diff2)
+ df_res = self._expr.count() - p
+ df_tot = self._expr.count() - 1
+ return 1.0 - (ss_res / df_res) / (ss_tot / df_tot)
+
+ def powi(self, n: Union[int, pl.Expr]) -> pl.Expr:
+ """
+ Computes positive integer power using the fast exponentiation algorithm. This is the
+ fastest when n is an integer input (Faster than Polars's builtin when n >= 16). When n
+ is an expression, it would depend on values in the expression (Still researching...)
+
+ Parameters
+ ----------
+ n
+ A single positive int or an expression representing a column of type i32. If type is
+ not i32, an error will occur.
+ """
+
+ if isinstance(n, int):
+ n_ = pl.lit(n, pl.Int32)
+ else:
+ n_ = n
+
+ return self._expr.register_plugin(
+ lib=lib, symbol="pl_fast_exp", args=[n_], is_elementwise=True, returns_scalar=False
+ )
+
+ def t_2samp(self, other: pl.Expr) -> pl.Expr:
+ """
+ Computes the t statistics for an Independent two-sample t-test. It is highly recommended
+ that nulls be imputed before calling this.
+
+ Parameters
+ ----------
+ other
+ Either an int or a Polars expression
+ """
+ numerator = self._expr.mean() - other.mean()
+ denom = ((self._expr.var() + other.var()) / self._expr.count()).sqrt()
+ return numerator / denom
+
+ def welch_t(self, other: pl.Expr, return_df: bool = True) -> pl.Expr:
+ """
+ Computes the statistics for Welch's t-test. Welch's t-test is often used when
+ the two series do not have the same length. Two series in a dataframe will always
+ have the same length. Here, only non-null values are counted.
+
+ Parameters
+ ----------
+ other
+ Either an int or a Polars expression
+ return_df
+ Whether to return the degree of freedom or not.
+ """
+ e1 = self._expr.drop_nulls()
+ e2 = other.drop_nulls()
+ numerator = e1.mean() - e2.mean()
+ s1: pl.Expr = e1.var() / e1.count()
+ s2: pl.Expr = e2.var() / e2.count()
+ denom = (s1 + s2).sqrt()
+ if return_df:
+ df_num = (s1 + s2).pow(2)
+ df_denom = s1.pow(2) / (e1.count() - 1) + s2.pow(2) / (e2.count() - 1)
+ return pl.concat_list(numerator / denom, df_num / df_denom)
+ else:
+ return numerator / denom
+
+ def jaccard(self, other: pl.Expr, include_null: bool = False) -> pl.Expr:
+ """
+ Computes jaccard similarity between this column and the other. This will hash entire
+ columns and compares the two hashsets. Note: only integer/str columns can be compared.
+ Input expressions must represent columns of the same dtype.
+
+ Parameters
+ ----------
+ other
+ Either an int or a Polars expression
+ include_null
+ Whether to include null as a distinct element.
+ """
+ return self._expr.register_plugin(
+ lib=lib,
+ symbol="pl_jaccard",
+ args=[other, pl.lit(include_null, dtype=pl.Boolean)],
+ is_elementwise=False,
+ returns_scalar=True,
+ )
+
def cond_entropy(self, other: pl.Expr) -> pl.Expr:
"""
Computes the conditional entropy of self(y) given other. H(y|other).
@@ -276,23 +407,19 @@ def cond_entropy(self, other: pl.Expr) -> pl.Expr:
"""
return self._expr.register_plugin(
- lib=lib,
- symbol="pl_conditional_entropy",
- args=[other],
- is_elementwise=False,
- returns_scalar=True
+ lib=lib, symbol="pl_conditional_entropy", args=[other], is_elementwise=False, returns_scalar=True
)
- def lstsq(self, *others: pl.Expr, add_bias:bool=False) -> pl.Expr:
+ def lstsq(self, *others: pl.Expr, add_bias: bool = False) -> pl.Expr:
"""
- Computes least squares solution to a linear matrix equation. If columns are
+ Computes least squares solution to the equation Ax = y. If columns are
not linearly independent, some numerical issue may occur. E.g you may see
- unrealistic coefficient in the output. This is a `silent` numerical issue during the
- computation.
+ unrealistic coefficient in the output. It is possible to have `silent` numerical
+ issue during computation.
+
+ All positional arguments should be expressions representing predictive variables. This
+ does not support composite expressions like pl.col(["a", "b"]), pl.all(), etc.
- All positional arguments should be expressions representing individual columns. This does
- not support composite expressions like pl.col(["a", "b"]), pl.all(), etc.
-
If add_bias is true, it will be the last coefficient in the output
and output will have length |other| + 1
@@ -309,13 +436,13 @@ def lstsq(self, *others: pl.Expr, add_bias:bool=False) -> pl.Expr:
symbol="pl_lstsq",
args=[pl.lit(add_bias, dtype=pl.Boolean)] + list(others),
is_elementwise=False,
- returns_scalar=True
+ returns_scalar=True,
)
- def fft(self, forward:bool=True) -> pl.Expr:
+ def fft(self, forward: bool = True) -> pl.Expr:
"""
Computes the DST transform of input series using FFT Algorithm. A series of equal length will
- be returned, with elements being the real and complex part of the transformed values.
+ be returned, with elements being the real and complex part of the transformed values.
Parameters
----------
@@ -329,17 +456,13 @@ def fft(self, forward:bool=True) -> pl.Expr:
is_elementwise=True,
)
+
@pl.api.register_expr_namespace("str_ext")
class StrExt:
def __init__(self, expr: pl.Expr):
self._expr: pl.Expr = expr
- def str_jaccard(
- self
- , other: Union[str, pl.Expr]
- , substr_size: int = 2
- , parallel: bool = False
- ) -> pl.Expr:
+ def str_jaccard(self, other: Union[str, pl.Expr], substr_size: int = 2, parallel: bool = False) -> pl.Expr:
"""
Treats substrings of size `substr_size` as a set. And computes the jaccard similarity between
this word and the other.
@@ -369,11 +492,7 @@ def str_jaccard(
is_elementwise=True,
)
- def levenshtein_dist(
- self
- , other: Union[str, pl.Expr]
- , parallel: bool = False
- ) -> pl.Expr:
+ def levenshtein_dist(self, other: Union[str, pl.Expr], parallel: bool = False) -> pl.Expr:
"""
Computes the levenshtein distance between this each value in the column with the str other.
@@ -399,11 +518,7 @@ def levenshtein_dist(
is_elementwise=True,
)
- def hamming_dist(
- self
- , other: Union[str, pl.Expr]
- , parallel: bool = False
- ) -> pl.Expr:
+ def hamming_dist(self, other: Union[str, pl.Expr], parallel: bool = False) -> pl.Expr:
"""
Computes the hamming distance between two strings. If they do not have the same length, null will
be returned.
@@ -450,17 +565,40 @@ def tokenize(self, pattern: str = r"(?u)\b\w\w+\b", stem: bool = False) -> pl.Ex
.register_plugin(
lib=lib,
symbol="pl_snowball_stem",
+ args=[pl.lit(True, dtype=pl.Boolean), pl.lit(False, dtype=pl.Boolean)],
is_elementwise=True,
- )
+ ) # True to no stop word, False to Parallel
.drop_nulls()
- ).list.unique()
+ )
return out
- def snowball(
- self
- , no_stopwords:bool=True
- , parallel:bool=False
- ) -> pl.Expr:
+ def freq_removal(self, lower: float = 0.05, upper: float = 0.95, parallel: bool = True) -> pl.Expr:
+ """
+ Removes from each documents words that are too frequent (in the entire dataset). This assumes
+ that the input expression represents lists of strings. E.g. output of tokenize.
+
+ Parameters
+ ----------
+ lower
+ Lower percentile. If a word's frequency is < than this, it will be removed.
+ upper
+ Upper percentile. If a word's frequency is > than this, it will be removed.
+ parallel
+ Whether to run word count in parallel. It is not recommended when you are in a group_by
+ context.
+ """
+
+ name = self._expr.meta.output_name(raise_if_undetermined=False)
+ vc = self._expr.list.explode().value_counts(parallel=parallel).sort()
+ lo = vc.struct.field("counts").quantile(lower)
+ u = vc.struct.field("counts").quantile(upper)
+ remove = (
+ vc.filter((vc.struct.field("counts") < lo) | (vc.struct.field("counts") > u)).struct.field(name).implode()
+ )
+
+ return self._expr.list.set_difference(remove)
+
+ def snowball(self, no_stopwords: bool = True, parallel: bool = False) -> pl.Expr:
"""
Applies the snowball stemmer for the column. The column is supposed to be a column of single words.
diff --git a/src/num_ext/expressions.rs b/src/num_ext/expressions.rs
index c092663f..3a8815c7 100644
--- a/src/num_ext/expressions.rs
+++ b/src/num_ext/expressions.rs
@@ -1,12 +1,20 @@
use faer::{prelude::*, MatRef};
use faer::{IntoFaer, IntoNdarray};
-// use faer::polars::{polars_to_faer_f64, Frame};
use ndarray::{Array1, Array2};
use num;
+use num::traits::Inv;
use polars::prelude::*;
use polars_core::prelude::arity::binary_elementwise_values;
use pyo3_polars::derive::polars_expr;
use rustfft::FftPlanner;
+use hashbrown::HashSet;
+
+// use faer::polars::{polars_to_faer_f64, Frame};
+
+// fn numeric_output(input_fields: &[Field]) -> PolarsResult {
+// let field = input_fields[0].clone();
+// Ok(field)
+// }
fn complex_output(_: &[Field]) -> PolarsResult {
let real = Field::new("re", DataType::Float64);
@@ -58,6 +66,114 @@ fn pl_lcm(inputs: &[Series]) -> PolarsResult {
}
}
+
+fn fast_exp_single(s:Series, n:i32) -> Series {
+
+ if n == 0 {
+ let ss = s.f64().unwrap();
+ let out:Float64Chunked = ss.apply_values(|x| {
+ if x == 0. {
+ f64::NAN
+ } else if x.is_infinite() | x.is_nan() {
+ x
+ } else {
+ 1.0
+ }
+ });
+ return out.into_series()
+ } else if n < 0 {
+ return fast_exp_single(1.div(&s), -n)
+ }
+
+ let mut ss = s.clone();
+ let mut m = n;
+ let mut y = Series::from_vec("", vec![1_f64; s.len()]);
+ while m > 0 {
+ if m % 2 == 1 {
+ y = &y * &ss;
+ }
+ ss = &ss * &ss;
+ m >>= 1;
+ }
+ y
+
+ }
+
+ #[inline]
+ fn _fast_exp_pairwise(x:f64, n:u32) -> f64 {
+
+ let mut m = n;
+ let mut x = x;
+ let mut y:f64 = 1.0;
+ while m > 0 {
+ if m % 2 == 1 {
+ y *= x;
+ }
+ x *= x;
+ m >>= 1;
+ }
+ y
+
+}
+
+#[inline]
+fn fast_exp_pairwise(x:f64, n:i32) -> f64 {
+
+ if n == 0 {
+ if x == 0. { // 0^0 is NaN
+ return f64::NAN
+ } else {
+ return 1.
+ }
+ } else if n < 0 {
+ return _fast_exp_pairwise(x.inv(), (-n) as u32)
+ }
+ _fast_exp_pairwise(x, n as u32)
+
+}
+
+
+#[polars_expr(output_type=Float64)]
+fn pl_fast_exp(inputs: &[Series]) -> PolarsResult {
+
+ let s = inputs[0].clone();
+ let exp = inputs[1].i32()?;
+
+ if exp.len() == 1 {
+ let n = exp.get(0).unwrap();
+ if s.dtype().is_numeric() {
+ let ss = s.cast(&DataType::Float64)?;
+ Ok(fast_exp_single(ss, n))
+ } else {
+ Err(PolarsError::ComputeError(
+ "Input column type must be numeric.".into(),
+ ))
+ }
+ } else if s.len() == exp.len() {
+ if s.dtype().is_numeric() {
+ if s.dtype() == &DataType::Float64 {
+ let ca = s.f64()?;
+ let out:Float64Chunked = binary_elementwise_values(ca, exp, fast_exp_pairwise);
+ Ok(out.into_series())
+ } else {
+ let t = s.cast(&DataType::Float64)?;
+ let ca = t.f64()?;
+ let out:Float64Chunked = binary_elementwise_values(ca, exp, fast_exp_pairwise);
+ Ok(out.into_series())
+ }
+ } else {
+ Err(PolarsError::ComputeError(
+ "Input column type must be numeric.".into(),
+ ))
+ }
+ } else {
+ Err(PolarsError::ShapeMismatch(
+ "Inputs must have the same length.".into(),
+ ))
+ }
+
+}
+
// Use QR to solve
fn faer_lstsq_qr(x: MatRef, y: MatRef) -> Result, String> {
let qr = x.qr();
@@ -85,7 +201,8 @@ fn pl_lstsq(inputs: &[Series]) -> PolarsResult {
let add_bias = inputs[1].bool()?;
let add_bias: bool = add_bias.get(0).unwrap();
// y
- let y = inputs[0].f64()?;
+ let y = inputs[0].rechunk(); // if not contiguous, this will do work. Otherwise, just a clone
+ let y = y.f64()?;
let y = y.to_ndarray()?.into_shape((nrows, 1)).unwrap();
let y = y.view().into_faer();
@@ -93,9 +210,9 @@ fn pl_lstsq(inputs: &[Series]) -> PolarsResult {
let mut vec_series: Vec = Vec::with_capacity(inputs[2..].len() + 1);
for (i, s) in inputs[2..].iter().enumerate() {
let t: Series = match s.dtype() {
- DataType::Float64 => s.clone().with_name(&i.to_string()),
+ DataType::Float64 => s.rechunk().with_name(&i.to_string()),
_ => {
- let t = s.clone().cast(&DataType::Float64)?;
+ let t = s.rechunk().cast(&DataType::Float64)?;
t.with_name(&i.to_string())
}
};
@@ -213,3 +330,69 @@ fn pl_fft(inputs: &[Series]) -> PolarsResult {
Ok(fft_struct)
}
+
+#[polars_expr(output_type=Float64)]
+fn pl_jaccard(inputs: &[Series]) -> PolarsResult {
+
+ let include_null = inputs[2].bool()?;
+ let include_null = include_null.get(0).unwrap();
+
+ let (s1, s2) = if include_null {
+ (inputs[0].clone(), inputs[1].clone())
+ } else {
+ let t1 = inputs[0].clone();
+ let t2 = inputs[1].clone();
+ (t1.drop_nulls(), t2.drop_nulls())
+ };
+
+ // let parallel = inputs[3].bool()?;
+ // let parallel = parallel.get(0).unwrap();
+
+ if s1.dtype() != s2.dtype() {
+ return Err(PolarsError::ComputeError(
+ "Input column must have the same type.".into(),
+ ))
+ }
+
+ let (n1, n2, intersection) =
+ if s1.dtype().is_integer() {
+ let ca1 = s1.cast(&DataType::Int64)?;
+ let ca2 = s2.cast(&DataType::Int64)?;
+ let ca1 = ca1.i64()?;
+ let ca2 = ca2.i64()?;
+
+ let hs1: HashSet