Skip to content

Commit

Permalink
Merge pull request #233 from abstractqqq/knn_frequency_cnt
Browse files Browse the repository at this point in the history
added knn freq cnts
  • Loading branch information
abstractqqq authored Aug 12, 2024
2 parents 8554d8f + 790ed9b commit 1815e4d
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 49 deletions.
136 changes: 118 additions & 18 deletions python/polars_ds/knn_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

__all__ = [
"query_knn_ptwise",
"query_knn_freq_cnt",
"query_knn_avg",
"is_knn_from",
"within_dist_from",
"query_radius_ptwise",
"query_radius_freq_cnt",
"query_nb_cnt",
"is_knn_from",
"within_dist_from",
]


Expand Down Expand Up @@ -60,13 +62,13 @@ def query_knn_ptwise(
Note `sql2` stands for squared l2.
parallel : bool
Whether to run the k-nearest neighbor query in parallel. This is recommended when you
are running only this expression, and not in group_by context.
are running only this expression, and not in group_by() or over() context.
return_dist
If true, return a struct with indices and distances.
eval_mask
Either None or a boolean expression or the name of a boolean column. If not none, this will
only evaluate KNN for rows where this is true. This can speed up computation with K is large
and when only results on a subset are nedded.
only evaluate KNN for rows where this is true. This can speed up computation when only results on a
subset are nedded.
data_mask
Either None or a boolean expression or the name of a boolean column. If none, all rows can be
neighbors. If not None, the pool of possible neighbors will be rows where this is true.
Expand All @@ -85,7 +87,7 @@ def query_knn_ptwise(
feats: List[pl.Expr] = [str_to_expr(e) for e in features]

skip_data = data_mask is not None
if skip_data:
if skip_data: # true means keep
keep_mask = pl.all_horizontal(str_to_expr(data_mask), *(f.is_not_null() for f in feats))
else:
keep_mask = pl.all_horizontal(f.is_not_null() for f in feats)
Expand Down Expand Up @@ -120,9 +122,71 @@ def query_knn_ptwise(
)


def query_knn_avg(
def query_knn_freq_cnt(
*features: StrOrExpr,
target: StrOrExpr,
index: StrOrExpr,
k: int,
dist: Distance = "sql2",
parallel: bool = False,
eval_mask: str | pl.Expr | None = None,
data_mask: str | pl.Expr | None = None,
epsilon: float = 0.0,
max_bound: float = 99999.0,
) -> pl.Expr:
"""
Takes the index column, and uses feature columns to determine the k nearest neighbors
to each row, and finally returns the number of times a row is a KNN of some other point.
This calls `query_knn_ptwise` internally. See the docstring of `query_knn_ptwise` for more info.
Parameters
----------
*features : str | pl.Expr
Other columns used as features
index : str | pl.Expr
The column used as index, must be castable to u32
k : int
Number of neighbors to query
dist : Literal[`l1`, `l2`, `sql2`, `inf`, `cosine`]
Note `sql2` stands for squared l2.
parallel : bool
Whether to run the k-nearest neighbor query in parallel. This is recommended when you
are running only this expression, and not in group_by() or over() context.
return_dist
If true, return a struct with indices and distances.
eval_mask
Either None or a boolean expression or the name of a boolean column. If not none, this will
only evaluate KNN for rows where this is true. This can speed up computation when only results on a
subset are nedded.
data_mask
Either None or a boolean expression or the name of a boolean column. If none, all rows can be
neighbors. If not None, the pool of possible neighbors will be rows where this is true.
epsilon
If > 0, then it is possible to miss a neighbor within epsilon distance away. This parameter
should increase as the dimension of the vector space increases because higher dimensions
allow for errors from more directions.
max_bound
Max distance the neighbors must be within
"""

knn_expr: pl.Expr = query_knn_ptwise(
*features,
index=index,
k=k,
dist=dist,
parallel=parallel,
return_dist=False,
eval_mask=eval_mask,
data_mask=data_mask,
epsilon=epsilon,
max_bound=max_bound,
)
return knn_expr.explode().drop_nulls().value_counts(sort=True, parallel=parallel)


def query_knn_avg(
*features: str | pl.Expr,
target: str | pl.Expr,
k: int,
dist: Distance = "sql2",
weighted: bool = False,
Expand Down Expand Up @@ -156,7 +220,7 @@ def query_knn_avg(
an extremely small value, this will default to 1/(1+distance) as weights to avoid division by 0.
parallel : bool
Whether to run the k-nearest neighbor query in parallel. This is recommended when you
are running only this expression, and not in group_by context.
are running only this expression, and not in group_by() or over() context.
min_bound
Min distance (>=) for a neighbor to be part of the average calculation. This prevents "identical"
points from being part of the average and prevents division by 0. Note that this filter is applied
Expand Down Expand Up @@ -257,7 +321,7 @@ def within_dist_from(


def is_knn_from(
*features: StrOrExpr,
*features: str | pl.Expr,
pt: Iterable[float],
k: int,
dist: Distance = "sql2",
Expand Down Expand Up @@ -320,17 +384,17 @@ def is_knn_from(


def query_radius_ptwise(
*features: StrOrExpr,
index: StrOrExpr,
*features: str | pl.Expr,
index: str | pl.Expr,
r: float,
dist: Distance = "sql2",
sort: bool = True,
parallel: bool = False,
) -> pl.Expr:
"""
Takes the index column, and uses features columns to determine distance, and finds all neighbors
within distance r from each id in the index column. If you only care about neighbor count, you
should use query_nb_cnt, which supports expression for radius.
within distance r from each id. If you only care about neighbor count, you should use
`query_nb_cnt`, which supports expression for radius and is way faster.
Note that the index column must be convertible to u32. If you do not have a u32 ID column,
you can generate one using pl.int_range(..), which should be a step before this.
Expand All @@ -354,8 +418,9 @@ def query_radius_ptwise(
improve performance by 10-20%.
parallel : bool
Whether to run the k-nearest neighbor query in parallel. This is recommended when you
are running only this expression, and not in group_by context.
are running only this expression, and not in group_by() or over() context.
"""

if r <= 0.0:
raise ValueError("Input `r` must be > 0.")
elif isinstance(r, pl.Expr):
Expand All @@ -372,9 +437,44 @@ def query_radius_ptwise(
)


def query_radius_freq_cnt(
*features: str | pl.Expr,
index: str | pl.Expr,
r: float,
dist: Distance = "sql2",
parallel: bool = False,
) -> pl.Expr:
"""
Takes the index column, and uses features columns to determine distance, finds all neighbors
within distance r from each index, and finally finds the count of the number of times the point is
within distance r from other points.
This calls `query_radius_ptwise` internally. See the docstring of `query_radius_ptwise` for more info.
Parameters
----------
*features : str | pl.Expr
Other columns used as features
index : str | pl.Expr
The column used as index, must be castable to u32
r : float
The radius. Must be a scalar value now.
dist : Literal[`l1`, `l2`, `sql2`, `inf`, `cosine`]
Note `sql2` stands for squared l2.
parallel : bool
Whether to run the k-nearest neighbor query in parallel. This is recommended when you
are running only this expression, and not in group_by() or over() context.
"""
within_radius = query_radius_ptwise(
*features, index=index, r=r, dist=dist, sort=False, parallel=parallel
)

return within_radius.explode().drop_nulls().value_counts(sort=True, parallel=parallel)


def query_nb_cnt(
r: Union[float, str, pl.Expr, List[float], "np.ndarray", pl.Series], # noqa: F821
*features: StrOrExpr,
r: float | str | pl.Expr | Iterable[float],
*features: str | pl.Expr,
dist: Distance = "sql2",
parallel: bool = False,
) -> pl.Expr:
Expand All @@ -395,7 +495,7 @@ def query_nb_cnt(
Note `sql2` stands for squared l2.
parallel : bool
Whether to run the distance query in parallel. This is recommended when you
are running only this expression, and not in group_by context.
are running only this expression, and not in group_by() or over() context.
"""
if isinstance(r, (float, int)):
rad = pl.lit(pl.Series(values=[r], dtype=pl.Float64))
Expand Down
20 changes: 9 additions & 11 deletions src/num/knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ pub fn matrix_to_leaves_filtered<'a, T: Float + 'static, A: Copy>(
}

// used in all cases but squared l2 (multiple queries)
pub fn dist_from_str<T: Float + 'static>(dist_str: &str) -> Result<DIST<T>, String> {
match dist_str {
pub fn dist_from_str<T: Float + 'static>(dist_str: String) -> Result<DIST<T>, String> {
match dist_str.as_ref() {
"l1" => Ok(DIST::L1),
"l2" => Ok(DIST::L2),
"sql2" => Ok(DIST::SQL2),
Expand Down Expand Up @@ -124,7 +124,7 @@ fn pl_knn_avg(
let binding = data.view();
let mut leaves = matrix_to_leaves_filtered(&binding, id, &null_mask);

let tree = match dist_from_str::<f64>(kwargs.metric.as_str()) {
let tree = match dist_from_str::<f64>(kwargs.metric) {
Ok(d) => AnyKDT::from_leaves(&mut leaves, SplitMethod::MIDPOINT, d),
Err(e) => Err(e),
}
Expand Down Expand Up @@ -265,7 +265,7 @@ fn pl_knn_ptwise(
let data = series_to_ndarray(&inputs[inputs_offset..], IndexOrder::C)?;
let binding = data.view();

let ca = match dist_from_str::<f64>(kwargs.metric.as_str()) {
let ca = match dist_from_str::<f64>(kwargs.metric) {
Ok(d) => {
let mut leaves = matrix_to_leaves_filtered(&binding, id, null_mask);
AnyKDT::from_leaves(&mut leaves, SplitMethod::MIDPOINT, d).map(|tree| {
Expand Down Expand Up @@ -429,7 +429,7 @@ fn pl_knn_ptwise_w_dist(
let data = series_to_ndarray(&inputs[inputs_offset..], IndexOrder::C)?;
let binding = data.view();

let (ca_nb, ca_dist) = match dist_from_str::<f64>(kwargs.metric.as_str()) {
let (ca_nb, ca_dist) = match dist_from_str::<f64>(kwargs.metric) {
Ok(d) => {
let mut leaves = matrix_to_leaves_filtered(&binding, id, null_mask);
AnyKDT::from_leaves(&mut leaves, SplitMethod::MIDPOINT, d).map(|tree| {
Expand Down Expand Up @@ -519,7 +519,7 @@ fn pl_query_radius_ptwise(
let binding = data.view();
// Building output

let ca = match dist_from_str::<f64>(kwargs.metric.as_str()) {
let ca = match dist_from_str::<f64>(kwargs.metric) {
Ok(d) => {
let mut leaves = matrix_to_leaves(&binding, id);
AnyKDT::from_leaves(&mut leaves, SplitMethod::MIDPOINT, d)
Expand Down Expand Up @@ -620,9 +620,6 @@ where
/// The point itself is always considered as a neighbor to itself.
#[polars_expr(output_type=UInt32)]
fn pl_nb_cnt(inputs: &[Series], context: CallerContext, kwargs: KDTKwargs) -> PolarsResult<Series> {
// Set up params
// let leaf_size = kwargs.leaf_size;
// Set up radius

let radius = inputs[0].f64()?;
let can_parallel = kwargs.parallel && !context.parallel();
Expand All @@ -633,7 +630,7 @@ fn pl_nb_cnt(inputs: &[Series], context: CallerContext, kwargs: KDTKwargs) -> Po
let binding = data.view();
if radius.len() == 1 {
let r = radius.get(0).unwrap();
let ca = match dist_from_str::<f64>(kwargs.metric.as_str()) {
let ca = match dist_from_str::<f64>(kwargs.metric) {
Ok(d) => {
let mut leaves = matrix_to_empty_leaves(&binding);
AnyKDT::from_leaves(&mut leaves, SplitMethod::MIDPOINT, d)
Expand All @@ -644,7 +641,7 @@ fn pl_nb_cnt(inputs: &[Series], context: CallerContext, kwargs: KDTKwargs) -> Po
.map_err(|err| PolarsError::ComputeError(err.into()))?;
Ok(ca.with_name("cnt").into_series())
} else if radius.len() == nrows {
let ca = match dist_from_str::<f64>(kwargs.metric.as_str()) {
let ca = match dist_from_str::<f64>(kwargs.metric) {
Ok(d) => {
let mut leaves = matrix_to_empty_leaves(&binding);
AnyKDT::from_leaves(&mut leaves, SplitMethod::MIDPOINT, d)
Expand All @@ -660,3 +657,4 @@ fn pl_nb_cnt(inputs: &[Series], context: CallerContext, kwargs: KDTKwargs) -> Po
))
}
}

49 changes: 29 additions & 20 deletions tests/test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,38 @@
" pds.random(0., 1.).alias(\"x1\"),\n",
" pds.random(0., 1.).alias(\"x2\"),\n",
" pds.random(0., 1.).alias(\"x3\"),\n",
").with_columns(\n",
" pl.when(pds.random() < 0.2).then(None).otherwise(pl.col(\"x1\")).alias(\"x1\"),\n",
" pl.when(pds.random() < 0.2).then(None).otherwise(pl.col(\"x2\")).alias(\"x2\"),\n",
" pl.when(pds.random() < 0.2).then(None).otherwise(pl.col(\"x3\")).alias(\"x3\"),\n",
").with_columns(\n",
" null_ref = pl.any_horizontal(pl.col(\"x1\").is_null(), pl.col(\"x2\").is_null(), pl.col(\"x3\").is_null()),\n",
" y = pl.col(\"x1\") * 0.15 + pl.col(\"x2\") * 0.3 - pl.col(\"x3\") * 1.5 + pds.random() * 0.0001\n",
")\n",
").with_row_index()\n",
"\n",
"window_size = 6\n",
"min_valid_rows = 5\n",
"\n",
"result = df.with_columns(\n",
" pds.query_rolling_lstsq(\n",
"# df = pds.frame(size = size).select(\n",
"# pds.random(0., 1.).alias(\"x1\"),\n",
"# pds.random(0., 1.).alias(\"x2\"),\n",
"# pds.random(0., 1.).alias(\"x3\"),\n",
"# ).with_row_index().with_columns(\n",
"# pl.when(pds.random() < 0.2).then(None).otherwise(pl.col(\"x1\")).alias(\"x1\"),\n",
"# pl.when(pds.random() < 0.2).then(None).otherwise(pl.col(\"x2\")).alias(\"x2\"),\n",
"# pl.when(pds.random() < 0.2).then(None).otherwise(pl.col(\"x3\")).alias(\"x3\"),\n",
"# ).with_columns(\n",
"# null_ref = pl.any_horizontal(pl.col(\"x1\").is_null(), pl.col(\"x2\").is_null(), pl.col(\"x3\").is_null()),\n",
"# y = pl.col(\"x1\") * 0.15 + pl.col(\"x2\") * 0.3 - pl.col(\"x3\") * 1.5 + pds.random() * 0.0001\n",
"# )\n",
"# df.head()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df.select(\n",
" pds.query_radius_freq_cnt(\n",
" \"x1\", \"x2\", \"x3\",\n",
" target = \"y\",\n",
" window_size = window_size,\n",
" min_valid_rows = min_valid_rows,\n",
" null_policy = \"skip\" \n",
" ).alias(\"test\")\n",
").with_columns(\n",
" pl.col(\"test\").is_null().alias(\"is_null\")\n",
")"
" index = \"index\",\n",
" r = 0.1,\n",
" dist = \"sql2\"\n",
" )\n",
").unnest(\"index\")"
]
},
{
Expand Down

0 comments on commit 1815e4d

Please sign in to comment.