Skip to content

Commit

Permalink
Merge pull request #48 from abstractqqq/better_arguments
Browse files Browse the repository at this point in the history
some improvements
  • Loading branch information
abstractqqq authored Jan 4, 2024
2 parents f246b13 + 581ed65 commit 581677c
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 477 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ serde = {version = "*", features=["derive"]}
ndarray = {version="0.15.6", features=["rayon"]} # see if we can get rid of this
hashbrown = {version = "0.14.2", features=["nightly"]}
itertools = "0.12.0"
aho-corasick = "1.1"
rand = {version = "0.8.5"} # Simd support feature seems to be broken atm
rand_distr = "0.4.3"
realfft = "3.3.0"
Expand Down
228 changes: 113 additions & 115 deletions examples/basics.ipynb

Large diffs are not rendered by default.

95 changes: 73 additions & 22 deletions python/polars_ds/stats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import polars as pl
from .type_alias import Alternative
from typing import Optional
from typing import Optional, Union
from polars.utils.udfs import _get_shared_lib_location
# from polars.type_aliases import IntoExpr

Expand Down Expand Up @@ -165,7 +165,8 @@ def normal_test(self) -> pl.Expr:

def ks_stats(self, var: pl.Expr) -> pl.Expr:
"""
Computes two-sided KS statistics with other. Currently it only returns the statistics.
Computes two-sided KS statistics with other. Currently it only returns the statistics. This will
sanitize data (only non-null finite values are used) before doing the computation.
Parameters
----------
Expand Down Expand Up @@ -228,8 +229,8 @@ def chi2(self, var: pl.Expr) -> pl.Expr:

def rand_int(
self,
low: Optional[int] = 0,
high: Optional[int] = 10,
low: Union[int, pl.Expr] = 0,
high: Optional[Union[int, pl.Expr]] = 10,
respect_null: bool = False,
use_ref: bool = False,
) -> pl.Expr:
Expand All @@ -242,17 +243,31 @@ def rand_int(
Parameters
----------
low
Lower end of random sample. None will be replaced 0.
Lower end of random sample. If high is none, low will be set to 0.
high
Higher end of random sample. None will be replaced n_unique of reference.
Higher end of random sample. If this is None, then it will be replaced n_unique of reference.
respect_null
If true, null in reference column will be null in the new column
"""
if (low is None) & (high is None):
raise ValueError("Either low or high must be set.")

lo = pl.lit(low, dtype=pl.Int32)
hi = self._expr.n_unique.cast(pl.UInt32) if high is None else pl.lit(high, dtype=pl.Int32)
if high is None:
lo = pl.lit(0, dtype=pl.Int32)
hi = self._expr.n_unique.cast(pl.UInt32)
else:
if isinstance(low, pl.Expr):
lo = low
elif isinstance(low, int):
lo = pl.lit(low, dtype=pl.Int32)
else:
raise ValueError("Input `low` must be expression or int.")

if isinstance(high, pl.Expr):
hi = high
elif isinstance(high, int):
hi = pl.lit(high, dtype=pl.Int32)
else:
raise ValueError("Input `high` must be expression or int.")

resp = pl.lit(respect_null, dtype=pl.Boolean)
return self._expr.register_plugin(
lib=_lib,
Expand All @@ -263,7 +278,10 @@ def rand_int(
)

def sample_uniform(
self, low: Optional[float] = None, high: Optional[float] = None, respect_null: bool = False
self,
low: Optional[Union[float, pl.Expr]] = None,
high: Optional[Union[float, pl.Expr]] = None,
respect_null: bool = False,
) -> pl.Expr:
"""
Creates self.len() many random points sampled from a uniform distribution within [low, high).
Expand All @@ -281,8 +299,20 @@ def sample_uniform(
If true, null in reference column will be null in the new column
"""

lo = self._expr.min() if low is None else pl.lit(low, dtype=pl.Float64)
hi = self._expr.max() if high is None else pl.lit(high, dtype=pl.Float64)
if isinstance(low, pl.Expr):
lo = low
elif isinstance(low, float):
lo = pl.lit(low, dtype=pl.Float64)
else:
lo = self._expr.min()

if isinstance(high, pl.Expr):
hi = high
elif isinstance(high, float):
hi = pl.lit(high, dtype=pl.Float64)
else:
hi = self._expr.max()

resp = pl.lit(respect_null, dtype=pl.Boolean)
return self._expr.register_plugin(
lib=_lib,
Expand Down Expand Up @@ -319,33 +349,43 @@ def sample_binomial(self, n: int, p: float, respect_null: bool = False) -> pl.Ex
returns_scalar=False,
)

def sample_exp(self, lam: Optional[float] = None, respect_null: bool = False) -> pl.Expr:
def sample_exp(
self, lambda_: Optional[Union[float, pl.Expr]] = None, respect_null: bool = False
) -> pl.Expr:
"""
Creates self.len() many random points sampled from a exponential distribution with n and p.
This treats self as the reference column.
Parameters
----------
lam
lambda_
lambda in a exponential distribution. If none, it will be 1/reference col's mean. Note that if
lambda < 0 will throw an error and lambda = 0 will only return infinity.
respect_null
If true, null in reference column will be null in the new column
"""
if isinstance(lambda_, pl.Expr):
la = lambda_
elif isinstance(lambda_, float):
la = pl.lit(lambda_, dtype=pl.Float64)
else:
la = 1.0 / self._expr.mean()

lamb = (1.0 / self._expr.mean()) if lam is None else pl.lit(lam, dtype=pl.Float64)
resp = pl.lit(respect_null, dtype=pl.Boolean)
return self._expr.register_plugin(
lib=_lib,
symbol="pl_sample_exp",
args=[lamb, resp],
args=[la, resp],
is_elementwise=True,
returns_scalar=False,
)

def sample_normal(
self, mean: Optional[float] = None, std: Optional[float] = None, respect_null: bool = False
self,
mean: Optional[Union[float, pl.Expr]] = None,
std: Optional[Union[float, pl.Expr]] = None,
respect_null: bool = False,
) -> pl.Expr:
"""
Creates self.len() many random points sampled from a normal distribution with the given
Expand All @@ -362,9 +402,20 @@ def sample_normal(
respect_null
If true, null in reference column will be null in the new column
"""
if isinstance(mean, pl.Expr):
me = mean
elif isinstance(mean, (float, int)):
me = pl.lit(mean, dtype=pl.Float64)
else:
me = self._expr.mean()

if isinstance(std, pl.Expr):
st = std
elif isinstance(std, (float, int)):
st = pl.lit(std, dtype=pl.Float64)
else:
st = self._expr.std()

me = self._expr.mean() if mean is None else pl.lit(mean, dtype=pl.Float64)
st = self._expr.std() if std is None else pl.lit(std, dtype=pl.Float64)
resp = pl.lit(respect_null, dtype=pl.Boolean)
return self._expr.register_plugin(
lib=_lib,
Expand Down Expand Up @@ -395,8 +446,8 @@ def rand_str(
respect_null
If true, null in reference column will be null in the new column
"""
if max_size <= 0:
raise ValueError("Input `max_size` must be positive.")
if min_size <= 0 or (max_size < min_size):
raise ValueError("String size must be positive and max_size must be >= min_size.")

min_s = pl.lit(min_size, dtype=pl.UInt32)
max_s = pl.lit(max_size, dtype=pl.UInt32)
Expand Down
91 changes: 0 additions & 91 deletions python/polars_ds/str2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import polars as pl
from typing import Union, Optional, Literal
from polars.utils.udfs import _get_shared_lib_location
from .type_alias import AhoCorasickMatchKind
import warnings

_lib = _get_shared_lib_location(__file__)

Expand Down Expand Up @@ -780,95 +778,6 @@ def snowball(self, no_stopwords: bool = True, parallel: bool = False) -> pl.Expr
is_elementwise=True,
)

def ac_match(
self,
patterns: list[str],
case_sensitive: bool = False,
match_kind: AhoCorasickMatchKind = "standard",
return_str: bool = False,
) -> pl.Expr:
"""
Try to match the patterns using the Aho-Corasick algorithm. The matched pattern's indices will be
returned. E.g. If for string1, pattern 2, 1, 3 are matched in this order, then [1, 0, 2] are
returned. (Indices in pattern list)
Polars >= 0.20 now has native aho-corasick support. The backend package is the same, though the function
api is different. See polars's str.contains_any and str.replace_many.
Parameters
----------
patterns
A list of strs, which are patterns to be matched
case_sensitive
Should this match be case sensitive? Default is false. Not working now.
match_kind
One of `standard`, `left_most_first`, or `left_most_longest`. For more information, see
https://docs.rs/aho-corasick/latest/aho_corasick/enum.MatchKind.html. Any other input will
be treated as standard.
"""

# Currently value_capacity for each list is hard-coded to 20. If there are more than 20 matches,
# then this will be slow (doubling vec capacity)
warnings.warn("Argument `case_sensitive` does not seem to work right now.")
warnings.warn(
"This function is unstable and is subject to change and may not perform well if there are more than "
"20 matches. Read the source code or contact the author for more information. The most difficulty part "
"is to design an output API that works well with Polars, which is harder than one might think."
)

pat = pl.Series(patterns, dtype=pl.Utf8)
cs = pl.lit(case_sensitive, pl.Boolean)
mk = pl.lit(match_kind, pl.Utf8)
if return_str:
return self._expr.register_plugin(
lib=_lib,
symbol="pl_ac_match_str",
args=[pat, cs, mk],
is_elementwise=True,
)
else:
return self._expr.register_plugin(
lib=_lib,
symbol="pl_ac_match",
args=[pat, cs, mk],
is_elementwise=True,
)

def ac_replace(
self, patterns: list[str], replacements: list[str], parallel: bool = False
) -> pl.Expr:
"""
Try to replace the patterns using the Aho-Corasick algorithm. The length of patterns should match
the length of replacements. If not, both sequences will be capped at the shorter length. If an error
happens during replacement, None will be returned.
Polars >= 0.20 now has native aho-corasick support. The backend package is the same, though the function
api is different. See polars's str.contains_any and str.replace_many.
Parameters
----------
patterns
A list of strs, which are patterns to be matched
replacements
A list of strs to replace the patterns with
parallel
Whether to run the comparisons in parallel. Note that this is not always faster, especially
when used with other expressions or in group_by/over context.
"""
if (len(replacements) == 0) or (len(patterns) == 0):
return self._expr

mlen = min(len(patterns), len(replacements))
pat = pl.Series(patterns[:mlen], dtype=pl.Utf8)
rpl = pl.Series(replacements[:mlen], dtype=pl.Utf8)
par = pl.lit(parallel, pl.Boolean)
return self._expr.register_plugin(
lib=_lib,
symbol="pl_ac_replace",
args=[pat, rpl, par],
is_elementwise=True,
)

def to_camel_case(self) -> pl.Expr:
"""Turns itself into camel case. E.g. helloWorld"""
return self._expr.register_plugin(
Expand Down
1 change: 0 additions & 1 deletion python/polars_ds/type_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@


DetrendMethod: TypeAlias = Literal["linear", "mean"]
AhoCorasickMatchKind: TypeAlias = Literal["standard", "left_most_first", "left_most_longest"]
Alternative: TypeAlias = Literal["two-sided", "less", "greater"]
Distance = Literal["l1", "l2", "inf", "h", "haversine"]
Loading

0 comments on commit 581677c

Please sign in to comment.