Skip to content

Commit

Permalink
rapidfuzz backend
Browse files Browse the repository at this point in the history
  • Loading branch information
abstractqqq committed Nov 28, 2023
1 parent ef42115 commit cd42988
Show file tree
Hide file tree
Showing 11 changed files with 448 additions and 293 deletions.
14 changes: 7 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ hashbrown = {version = "0.14.2", features=["nightly"]}
# rustfft = "6.1.0"
itertools = "0.12.0"
aho-corasick = "1.1.2"
strsim = "0.10.0" # Consider alternatives
rand = {version = "0.8.5"} # Simd support feature seems to be broken atm
rand_distr = "0.4.3"
realfft = "3.3.0"
rapidfuzz = "0.3.0"

[target.'cfg(any(not(target_os = "linux"), use_mimalloc))'.dependencies]
mimalloc = { version = "0.1", default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion python/polars_ds/stats_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def rand_int(
) -> pl.Expr:
"""
Generates random integers uniformly from the range [low, high). Throws an error if low == high
or if (low is None and high is None and use_ref_nunique == False).
or if low is None and high is None and use_ref_nunique == False.
This treats self as the reference column.
Expand Down
129 changes: 106 additions & 23 deletions python/polars_ds/str_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from polars.utils.udfs import _get_shared_lib_location
from .type_alias import AhoCorasickMatchKind
import warnings
# from polars.type_aliases import IntoExpr

lib = _get_shared_lib_location(__file__)

Expand All @@ -13,6 +12,17 @@ class StrExt:
def __init__(self, expr: pl.Expr):
self._expr: pl.Expr = expr

def is_stopword(self) -> pl.Expr:
"""
Checks whether the string is a stopword or not.
"""
self._expr.register_plugin(
lib=lib,
symbol="pl_is_stopword",
args=[],
is_elementwise=True,
)

def extract_numbers(
self, ignore_comma: bool = False, join_by: str = "", dtype: pl.DataType = pl.Utf8
) -> pl.Expr:
Expand Down Expand Up @@ -194,7 +204,7 @@ def str_jaccard(
return self._expr.register_plugin(
lib=lib,
symbol="pl_str_jaccard",
args=[other_, pl.lit(substr_size, dtype=pl.UInt32), pl.lit(parallel, dtype=pl.Boolean)],
args=[other_, pl.lit(substr_size, pl.UInt32), pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)

Expand Down Expand Up @@ -226,7 +236,7 @@ def sorensen_dice(
return self._expr.register_plugin(
lib=lib,
symbol="pl_sorensen_dice",
args=[other_, pl.lit(substr_size, dtype=pl.UInt32), pl.lit(parallel, dtype=pl.Boolean)],
args=[other_, pl.lit(substr_size, pl.UInt32), pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)

Expand All @@ -251,14 +261,14 @@ def overlap_coeff(
when used with other expressions or in group_by/over context.
"""
if isinstance(other, str):
other_ = pl.lit(other, dtype=pl.Utf8)
other_ = pl.lit(other, pl.Utf8)
else:
other_ = other

return self._expr.register_plugin(
lib=lib,
symbol="pl_overlap_coeff",
args=[other_, pl.lit(substr_size, dtype=pl.UInt32), pl.lit(parallel, dtype=pl.Boolean)],
args=[other_, pl.lit(substr_size, pl.UInt32), pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)

Expand Down Expand Up @@ -289,14 +299,14 @@ def levenshtein(
return self._expr.register_plugin(
lib=lib,
symbol="pl_levenshtein_sim",
args=[other_, pl.lit(parallel, dtype=pl.Boolean)],
args=[other_, pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)
else:
return self._expr.register_plugin(
lib=lib,
symbol="pl_levenshtein",
args=[other_, pl.lit(parallel, dtype=pl.Boolean)],
args=[other_, pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)

Expand All @@ -323,15 +333,15 @@ def levenshtein_within(
when used with other expressions or in group_by/over context.
"""
if isinstance(other, str):
other_ = pl.lit(other, dtype=pl.Utf8)
other_ = pl.lit(other, pl.Utf8)
else:
other_ = other

bound = pl.lit(abs(bound), dtype=pl.UInt32)
bound = pl.lit(abs(bound), pl.UInt32)
return self._expr.register_plugin(
lib=lib,
symbol="pl_levenshtein_within",
args=[other_, bound, pl.lit(parallel, dtype=pl.Boolean)],
args=[other_, bound, pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)

Expand Down Expand Up @@ -362,20 +372,58 @@ def d_levenshtein(
return self._expr.register_plugin(
lib=lib,
symbol="pl_d_levenshtein_sim",
args=[other_, pl.lit(parallel, dtype=pl.Boolean)],
args=[other_, pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)
else:
return self._expr.register_plugin(
lib=lib,
symbol="pl_d_levenshtein",
args=[other_, pl.lit(parallel, dtype=pl.Boolean)],
args=[other_, pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)

def osa(
self, other: Union[str, pl.Expr], parallel: bool = False, return_sim: bool = False
) -> pl.Expr:
"""
Computes the Optimal String Alignment distance between this and the other str.
Parameters
----------
other
If this is a string, then the entire column will be compared with this string. If this
is an expression, then an element-wise OSA distance computation between this column
and the other (given by the expression) will be performed.
parallel
Whether to run the comparisons in parallel. Note that this is not always faster, especially
when used with other expressions or in group_by/over context.
return_sim
If true, return normalized OSA similarity.
"""
if isinstance(other, str):
other_ = pl.lit(other, dtype=pl.Utf8)
else:
other_ = other

if return_sim:
return self._expr.register_plugin(
lib=lib,
symbol="pl_osa_sim",
args=[other_, pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)
else:
return self._expr.register_plugin(
lib=lib,
symbol="pl_osa",
args=[other_, pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)

def jaro(self, other: Union[str, pl.Expr], parallel: bool = False) -> pl.Expr:
"""
Computes the Jaro similarity between this and the other str.
Computes the Jaro similarity between this and the other str. Jaro distance = 1 - Jaro sim.
Parameters
----------
Expand All @@ -395,11 +443,44 @@ def jaro(self, other: Union[str, pl.Expr], parallel: bool = False) -> pl.Expr:
return self._expr.register_plugin(
lib=lib,
symbol="pl_jaro",
args=[other_, pl.lit(parallel, dtype=pl.Boolean)],
args=[other_, pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)

def jw(
self, other: Union[str, pl.Expr], weight: float = 0.1, parallel: bool = False
) -> pl.Expr:
"""
Computes the Jaro-Winker similarity between this and the other str.
Jaro-Winkler distance = 1 - Jaro-Winkler sim.
Parameters
----------
other
If this is a string, then the entire column will be compared with this string. If this
is an expression, then an element-wise Levenshtein distance computation between this column
and the other (given by the expression) will be performed.
weight
Weight for prefix. A typical value is 0.1.
parallel
Whether to run the comparisons in parallel. Note that this is not always faster, especially
when used with other expressions or in group_by/over context.
"""
if isinstance(other, str):
other_ = pl.lit(other, pl.Utf8)
else:
other_ = other

return self._expr.register_plugin(
lib=lib,
symbol="pl_jw",
args=[other_, pl.lit(weight, pl.Float64), pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)

def hamming(self, other: Union[str, pl.Expr], parallel: bool = False) -> pl.Expr:
def hamming(
self, other: Union[str, pl.Expr], pad: bool = False, parallel: bool = False
) -> pl.Expr:
"""
Computes the hamming distance between two strings. If they do not have the same length, null will
be returned.
Expand All @@ -410,6 +491,8 @@ def hamming(self, other: Union[str, pl.Expr], parallel: bool = False) -> pl.Expr
If this is a string, then the entire column will be compared with this string. If this
is an expression, then an element-wise hamming distance computation between this column
and the other (given by the expression) will be performed.
pad
Whether to pad the string when lengths are not equal.
parallel
Whether to run the comparisons in parallel. Note that this is not always faster, especially
when used with other expressions or in group_by/over context.
Expand All @@ -422,7 +505,7 @@ def hamming(self, other: Union[str, pl.Expr], parallel: bool = False) -> pl.Expr
return self._expr.register_plugin(
lib=lib,
symbol="pl_hamming",
args=[other_, pl.lit(parallel, dtype=pl.Boolean)],
args=[other_, pl.lit(pad, pl.Boolean), pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)

Expand All @@ -446,7 +529,7 @@ def tokenize(self, pattern: str = r"(?u)\b\w\w+\b", stem: bool = False) -> pl.Ex
.register_plugin(
lib=lib,
symbol="pl_snowball_stem",
args=[pl.lit(True, dtype=pl.Boolean), pl.lit(False, dtype=pl.Boolean)],
args=[pl.lit(True, pl.Boolean), pl.lit(False, pl.Boolean)],
is_elementwise=True,
) # True to no stop word, False to Parallel
.drop_nulls()
Expand Down Expand Up @@ -498,7 +581,7 @@ def snowball(self, no_stopwords: bool = True, parallel: bool = False) -> pl.Expr
return self._expr.register_plugin(
lib=lib,
symbol="pl_snowball_stem",
args=[pl.lit(no_stopwords, dtype=pl.Boolean), pl.lit(parallel, dtype=pl.Boolean)],
args=[pl.lit(no_stopwords, pl.Boolean), pl.lit(parallel, pl.Boolean)],
is_elementwise=True,
)

Expand Down Expand Up @@ -530,14 +613,14 @@ def ac_match(
# then this will be slow (doubling vec capacity)
warnings.warn("Argument `case_sensitive` does not seem to work right now.")
warnings.warn(
"This function is unstable and may subject to change and may not perform well if there are more than "
"This function is unstable and is subject to change and may not perform well if there are more than "
"20 matches. Read the source code or contact the author for more information. The most difficulty part "
"is to design an output API that works well with Polars, which is harder than one might think."
)

pat = pl.Series(patterns, dtype=pl.Utf8)
cs = pl.lit(case_sensitive, dtype=pl.Boolean)
mk = pl.lit(match_kind, dtype=pl.Utf8)
cs = pl.lit(case_sensitive, pl.Boolean)
mk = pl.lit(match_kind, pl.Utf8)
if return_str:
return self._expr.register_plugin(
lib=lib,
Expand Down Expand Up @@ -571,13 +654,13 @@ def ac_replace(
Whether to run the comparisons in parallel. Note that this is not always faster, especially
when used with other expressions or in group_by/over context.
"""
if (len(replacements) == 0) | (len(patterns) == 0):
if (len(replacements) == 0) or (len(patterns) == 0):
return self._expr

mlen = min(len(patterns), len(replacements))
pat = pl.Series(patterns[:mlen], dtype=pl.Utf8)
rpl = pl.Series(replacements[:mlen], dtype=pl.Utf8)
par = pl.lit(parallel, dtype=pl.Boolean)
par = pl.lit(parallel, pl.Boolean)
return self._expr.register_plugin(
lib=lib,
symbol="pl_ac_replace",
Expand Down
Loading

0 comments on commit cd42988

Please sign in to comment.