From 790ed9b2bafb4550d4e4e9e6b29c724c6a6981cd Mon Sep 17 00:00:00 2001 From: abstractqqq Date: Sun, 11 Aug 2024 21:50:06 -0400 Subject: [PATCH] added knn freq cnts --- python/polars_ds/knn_queries.py | 136 +++++++++++++++++++++++++++----- src/num/knn.rs | 20 +++-- tests/test.ipynb | 49 +++++++----- 3 files changed, 156 insertions(+), 49 deletions(-) diff --git a/python/polars_ds/knn_queries.py b/python/polars_ds/knn_queries.py index cd14c552..dabe3d16 100644 --- a/python/polars_ds/knn_queries.py +++ b/python/polars_ds/knn_queries.py @@ -10,11 +10,13 @@ __all__ = [ "query_knn_ptwise", + "query_knn_freq_cnt", "query_knn_avg", - "is_knn_from", - "within_dist_from", "query_radius_ptwise", + "query_radius_freq_cnt", "query_nb_cnt", + "is_knn_from", + "within_dist_from", ] @@ -60,13 +62,13 @@ def query_knn_ptwise( Note `sql2` stands for squared l2. parallel : bool Whether to run the k-nearest neighbor query in parallel. This is recommended when you - are running only this expression, and not in group_by context. + are running only this expression, and not in group_by() or over() context. return_dist If true, return a struct with indices and distances. eval_mask Either None or a boolean expression or the name of a boolean column. If not none, this will - only evaluate KNN for rows where this is true. This can speed up computation with K is large - and when only results on a subset are nedded. + only evaluate KNN for rows where this is true. This can speed up computation when only results on a + subset are nedded. data_mask Either None or a boolean expression or the name of a boolean column. If none, all rows can be neighbors. If not None, the pool of possible neighbors will be rows where this is true. @@ -85,7 +87,7 @@ def query_knn_ptwise( feats: List[pl.Expr] = [str_to_expr(e) for e in features] skip_data = data_mask is not None - if skip_data: + if skip_data: # true means keep keep_mask = pl.all_horizontal(str_to_expr(data_mask), *(f.is_not_null() for f in feats)) else: keep_mask = pl.all_horizontal(f.is_not_null() for f in feats) @@ -120,9 +122,71 @@ def query_knn_ptwise( ) -def query_knn_avg( +def query_knn_freq_cnt( *features: StrOrExpr, - target: StrOrExpr, + index: StrOrExpr, + k: int, + dist: Distance = "sql2", + parallel: bool = False, + eval_mask: str | pl.Expr | None = None, + data_mask: str | pl.Expr | None = None, + epsilon: float = 0.0, + max_bound: float = 99999.0, +) -> pl.Expr: + """ + Takes the index column, and uses feature columns to determine the k nearest neighbors + to each row, and finally returns the number of times a row is a KNN of some other point. + + This calls `query_knn_ptwise` internally. See the docstring of `query_knn_ptwise` for more info. + + Parameters + ---------- + *features : str | pl.Expr + Other columns used as features + index : str | pl.Expr + The column used as index, must be castable to u32 + k : int + Number of neighbors to query + dist : Literal[`l1`, `l2`, `sql2`, `inf`, `cosine`] + Note `sql2` stands for squared l2. + parallel : bool + Whether to run the k-nearest neighbor query in parallel. This is recommended when you + are running only this expression, and not in group_by() or over() context. + return_dist + If true, return a struct with indices and distances. + eval_mask + Either None or a boolean expression or the name of a boolean column. If not none, this will + only evaluate KNN for rows where this is true. This can speed up computation when only results on a + subset are nedded. + data_mask + Either None or a boolean expression or the name of a boolean column. If none, all rows can be + neighbors. If not None, the pool of possible neighbors will be rows where this is true. + epsilon + If > 0, then it is possible to miss a neighbor within epsilon distance away. This parameter + should increase as the dimension of the vector space increases because higher dimensions + allow for errors from more directions. + max_bound + Max distance the neighbors must be within + """ + + knn_expr: pl.Expr = query_knn_ptwise( + *features, + index=index, + k=k, + dist=dist, + parallel=parallel, + return_dist=False, + eval_mask=eval_mask, + data_mask=data_mask, + epsilon=epsilon, + max_bound=max_bound, + ) + return knn_expr.explode().drop_nulls().value_counts(sort=True, parallel=parallel) + + +def query_knn_avg( + *features: str | pl.Expr, + target: str | pl.Expr, k: int, dist: Distance = "sql2", weighted: bool = False, @@ -156,7 +220,7 @@ def query_knn_avg( an extremely small value, this will default to 1/(1+distance) as weights to avoid division by 0. parallel : bool Whether to run the k-nearest neighbor query in parallel. This is recommended when you - are running only this expression, and not in group_by context. + are running only this expression, and not in group_by() or over() context. min_bound Min distance (>=) for a neighbor to be part of the average calculation. This prevents "identical" points from being part of the average and prevents division by 0. Note that this filter is applied @@ -257,7 +321,7 @@ def within_dist_from( def is_knn_from( - *features: StrOrExpr, + *features: str | pl.Expr, pt: Iterable[float], k: int, dist: Distance = "sql2", @@ -320,8 +384,8 @@ def is_knn_from( def query_radius_ptwise( - *features: StrOrExpr, - index: StrOrExpr, + *features: str | pl.Expr, + index: str | pl.Expr, r: float, dist: Distance = "sql2", sort: bool = True, @@ -329,8 +393,8 @@ def query_radius_ptwise( ) -> pl.Expr: """ Takes the index column, and uses features columns to determine distance, and finds all neighbors - within distance r from each id in the index column. If you only care about neighbor count, you - should use query_nb_cnt, which supports expression for radius. + within distance r from each id. If you only care about neighbor count, you should use + `query_nb_cnt`, which supports expression for radius and is way faster. Note that the index column must be convertible to u32. If you do not have a u32 ID column, you can generate one using pl.int_range(..), which should be a step before this. @@ -354,8 +418,9 @@ def query_radius_ptwise( improve performance by 10-20%. parallel : bool Whether to run the k-nearest neighbor query in parallel. This is recommended when you - are running only this expression, and not in group_by context. + are running only this expression, and not in group_by() or over() context. """ + if r <= 0.0: raise ValueError("Input `r` must be > 0.") elif isinstance(r, pl.Expr): @@ -372,9 +437,44 @@ def query_radius_ptwise( ) +def query_radius_freq_cnt( + *features: str | pl.Expr, + index: str | pl.Expr, + r: float, + dist: Distance = "sql2", + parallel: bool = False, +) -> pl.Expr: + """ + Takes the index column, and uses features columns to determine distance, finds all neighbors + within distance r from each index, and finally finds the count of the number of times the point is + within distance r from other points. + + This calls `query_radius_ptwise` internally. See the docstring of `query_radius_ptwise` for more info. + + Parameters + ---------- + *features : str | pl.Expr + Other columns used as features + index : str | pl.Expr + The column used as index, must be castable to u32 + r : float + The radius. Must be a scalar value now. + dist : Literal[`l1`, `l2`, `sql2`, `inf`, `cosine`] + Note `sql2` stands for squared l2. + parallel : bool + Whether to run the k-nearest neighbor query in parallel. This is recommended when you + are running only this expression, and not in group_by() or over() context. + """ + within_radius = query_radius_ptwise( + *features, index=index, r=r, dist=dist, sort=False, parallel=parallel + ) + + return within_radius.explode().drop_nulls().value_counts(sort=True, parallel=parallel) + + def query_nb_cnt( - r: Union[float, str, pl.Expr, List[float], "np.ndarray", pl.Series], # noqa: F821 - *features: StrOrExpr, + r: float | str | pl.Expr | Iterable[float], + *features: str | pl.Expr, dist: Distance = "sql2", parallel: bool = False, ) -> pl.Expr: @@ -395,7 +495,7 @@ def query_nb_cnt( Note `sql2` stands for squared l2. parallel : bool Whether to run the distance query in parallel. This is recommended when you - are running only this expression, and not in group_by context. + are running only this expression, and not in group_by() or over() context. """ if isinstance(r, (float, int)): rad = pl.lit(pl.Series(values=[r], dtype=pl.Float64)) diff --git a/src/num/knn.rs b/src/num/knn.rs index ceddec78..94053617 100644 --- a/src/num/knn.rs +++ b/src/num/knn.rs @@ -86,8 +86,8 @@ pub fn matrix_to_leaves_filtered<'a, T: Float + 'static, A: Copy>( } // used in all cases but squared l2 (multiple queries) -pub fn dist_from_str(dist_str: &str) -> Result, String> { - match dist_str { +pub fn dist_from_str(dist_str: String) -> Result, String> { + match dist_str.as_ref() { "l1" => Ok(DIST::L1), "l2" => Ok(DIST::L2), "sql2" => Ok(DIST::SQL2), @@ -124,7 +124,7 @@ fn pl_knn_avg( let binding = data.view(); let mut leaves = matrix_to_leaves_filtered(&binding, id, &null_mask); - let tree = match dist_from_str::(kwargs.metric.as_str()) { + let tree = match dist_from_str::(kwargs.metric) { Ok(d) => AnyKDT::from_leaves(&mut leaves, SplitMethod::MIDPOINT, d), Err(e) => Err(e), } @@ -265,7 +265,7 @@ fn pl_knn_ptwise( let data = series_to_ndarray(&inputs[inputs_offset..], IndexOrder::C)?; let binding = data.view(); - let ca = match dist_from_str::(kwargs.metric.as_str()) { + let ca = match dist_from_str::(kwargs.metric) { Ok(d) => { let mut leaves = matrix_to_leaves_filtered(&binding, id, null_mask); AnyKDT::from_leaves(&mut leaves, SplitMethod::MIDPOINT, d).map(|tree| { @@ -429,7 +429,7 @@ fn pl_knn_ptwise_w_dist( let data = series_to_ndarray(&inputs[inputs_offset..], IndexOrder::C)?; let binding = data.view(); - let (ca_nb, ca_dist) = match dist_from_str::(kwargs.metric.as_str()) { + let (ca_nb, ca_dist) = match dist_from_str::(kwargs.metric) { Ok(d) => { let mut leaves = matrix_to_leaves_filtered(&binding, id, null_mask); AnyKDT::from_leaves(&mut leaves, SplitMethod::MIDPOINT, d).map(|tree| { @@ -519,7 +519,7 @@ fn pl_query_radius_ptwise( let binding = data.view(); // Building output - let ca = match dist_from_str::(kwargs.metric.as_str()) { + let ca = match dist_from_str::(kwargs.metric) { Ok(d) => { let mut leaves = matrix_to_leaves(&binding, id); AnyKDT::from_leaves(&mut leaves, SplitMethod::MIDPOINT, d) @@ -620,9 +620,6 @@ where /// The point itself is always considered as a neighbor to itself. #[polars_expr(output_type=UInt32)] fn pl_nb_cnt(inputs: &[Series], context: CallerContext, kwargs: KDTKwargs) -> PolarsResult { - // Set up params - // let leaf_size = kwargs.leaf_size; - // Set up radius let radius = inputs[0].f64()?; let can_parallel = kwargs.parallel && !context.parallel(); @@ -633,7 +630,7 @@ fn pl_nb_cnt(inputs: &[Series], context: CallerContext, kwargs: KDTKwargs) -> Po let binding = data.view(); if radius.len() == 1 { let r = radius.get(0).unwrap(); - let ca = match dist_from_str::(kwargs.metric.as_str()) { + let ca = match dist_from_str::(kwargs.metric) { Ok(d) => { let mut leaves = matrix_to_empty_leaves(&binding); AnyKDT::from_leaves(&mut leaves, SplitMethod::MIDPOINT, d) @@ -644,7 +641,7 @@ fn pl_nb_cnt(inputs: &[Series], context: CallerContext, kwargs: KDTKwargs) -> Po .map_err(|err| PolarsError::ComputeError(err.into()))?; Ok(ca.with_name("cnt").into_series()) } else if radius.len() == nrows { - let ca = match dist_from_str::(kwargs.metric.as_str()) { + let ca = match dist_from_str::(kwargs.metric) { Ok(d) => { let mut leaves = matrix_to_empty_leaves(&binding); AnyKDT::from_leaves(&mut leaves, SplitMethod::MIDPOINT, d) @@ -660,3 +657,4 @@ fn pl_nb_cnt(inputs: &[Series], context: CallerContext, kwargs: KDTKwargs) -> Po )) } } + diff --git a/tests/test.ipynb b/tests/test.ipynb index 408a7278..334be422 100644 --- a/tests/test.ipynb +++ b/tests/test.ipynb @@ -22,29 +22,38 @@ " pds.random(0., 1.).alias(\"x1\"),\n", " pds.random(0., 1.).alias(\"x2\"),\n", " pds.random(0., 1.).alias(\"x3\"),\n", - ").with_columns(\n", - " pl.when(pds.random() < 0.2).then(None).otherwise(pl.col(\"x1\")).alias(\"x1\"),\n", - " pl.when(pds.random() < 0.2).then(None).otherwise(pl.col(\"x2\")).alias(\"x2\"),\n", - " pl.when(pds.random() < 0.2).then(None).otherwise(pl.col(\"x3\")).alias(\"x3\"),\n", - ").with_columns(\n", - " null_ref = pl.any_horizontal(pl.col(\"x1\").is_null(), pl.col(\"x2\").is_null(), pl.col(\"x3\").is_null()),\n", - " y = pl.col(\"x1\") * 0.15 + pl.col(\"x2\") * 0.3 - pl.col(\"x3\") * 1.5 + pds.random() * 0.0001\n", - ")\n", + ").with_row_index()\n", "\n", - "window_size = 6\n", - "min_valid_rows = 5\n", "\n", - "result = df.with_columns(\n", - " pds.query_rolling_lstsq(\n", + "# df = pds.frame(size = size).select(\n", + "# pds.random(0., 1.).alias(\"x1\"),\n", + "# pds.random(0., 1.).alias(\"x2\"),\n", + "# pds.random(0., 1.).alias(\"x3\"),\n", + "# ).with_row_index().with_columns(\n", + "# pl.when(pds.random() < 0.2).then(None).otherwise(pl.col(\"x1\")).alias(\"x1\"),\n", + "# pl.when(pds.random() < 0.2).then(None).otherwise(pl.col(\"x2\")).alias(\"x2\"),\n", + "# pl.when(pds.random() < 0.2).then(None).otherwise(pl.col(\"x3\")).alias(\"x3\"),\n", + "# ).with_columns(\n", + "# null_ref = pl.any_horizontal(pl.col(\"x1\").is_null(), pl.col(\"x2\").is_null(), pl.col(\"x3\").is_null()),\n", + "# y = pl.col(\"x1\") * 0.15 + pl.col(\"x2\") * 0.3 - pl.col(\"x3\") * 1.5 + pds.random() * 0.0001\n", + "# )\n", + "# df.head()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df.select(\n", + " pds.query_radius_freq_cnt(\n", " \"x1\", \"x2\", \"x3\",\n", - " target = \"y\",\n", - " window_size = window_size,\n", - " min_valid_rows = min_valid_rows,\n", - " null_policy = \"skip\" \n", - " ).alias(\"test\")\n", - ").with_columns(\n", - " pl.col(\"test\").is_null().alias(\"is_null\")\n", - ")" + " index = \"index\",\n", + " r = 0.1,\n", + " dist = \"sql2\"\n", + " )\n", + ").unnest(\"index\")" ] }, {