Skip to content

Commit

Permalink
Merge pull request #288 from abstractqqq/kth_nb_dist
Browse files Browse the repository at this point in the history
Kth nb dist
  • Loading branch information
abstractqqq authored Nov 18, 2024
2 parents b4eead2 + 96341bf commit 1de6d10
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 36 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
name: CI

permissions:
id-token: write

on:
push:
branches:
Expand Down
2 changes: 1 addition & 1 deletion examples/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2712,7 +2712,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.12.7"
}
},
"nbformat": 4,
Expand Down
53 changes: 50 additions & 3 deletions python/polars_ds/expr_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,60 @@
"query_radius_ptwise",
"query_radius_freq_cnt",
"query_nb_cnt",
"query_dist_from_kth_nb",
"is_knn_from",
"within_dist_from",
]


def query_dist_from_kth_nb(
*features: str | pl.Expr,
k: int,
dist: Distance = "sql2",
parallel: bool = False,
epsilon: float = 0.0,
max_bound: float = 99999.0,
) -> pl.Expr:
"""
Computes the distance of each row to its k-th closest neighbor. This is useful for outlier detection.
E.g. if the average distance to the 5th neighbor is 0.1, then a distance of 0.3 to the 5th neighbor might
indicate that the point might be far away from neighboring points, or that it occupies a sparse region in which
sample points typically do not appear.
This can be 10% faster and more direct than getting the result from `query_knn_ptwise` with return_distance = True.
Parameters
----------
*features : str | pl.Expr
Other columns used as features
k : int
Number of neighbors to query
dist : Literal[`l1`, `l2`, `sql2`, `inf`]
Note `sql2` stands for squared l2.
parallel : bool
Whether to run the k-nearest neighbor query in parallel. This is recommended when you
are running only this expression, and not in group_by() or over() context.
epsilon
If > 0, then it is possible to miss a neighbor within epsilon distance away. This parameter
should increase as the dimension of the vector space increases because higher dimensions
allow for errors from more directions.
max_bound
Max distance the neighbors must be within
"""
return pl_plugin(
symbol="pl_dist_from_kth_nb",
args=[str_to_expr(e) for e in features],
kwargs={
"k": k,
"metric": str(dist).lower(),
"parallel": parallel,
"skip_eval": False,
"max_bound": max_bound,
"epsilon": epsilon,
},
)


def query_knn_ptwise(
*features: str | pl.Expr,
index: str | pl.Expr,
Expand Down Expand Up @@ -107,21 +156,19 @@ def query_knn_ptwise(
"parallel": parallel,
"skip_eval": skip_eval,
"max_bound": max_bound,
"epsilon": abs(epsilon),
"epsilon": 0.0,
}
if return_dist:
return pl_plugin(
symbol="pl_knn_ptwise_w_dist",
args=cols,
kwargs=kwargs,
pass_name_to_apply=True,
)
else:
return pl_plugin(
symbol="pl_knn_ptwise",
args=cols,
kwargs=kwargs,
pass_name_to_apply=True,
)


Expand Down
16 changes: 12 additions & 4 deletions python/polars_ds/num.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ def digamma(x: str | pl.Expr) -> pl.Expr:
# target: str | pl.Expr,
# n_neighbors: int = 3,
# seed: int | None = None,
# ) -> pl.Expr:
# ) -> float:
# """
# Computes the mutual infomation between a continuous variable x and a discrete
# target varaible. Note: (1) This always assume `x` is continuous. (2) Unlike Scikit-learn,
Expand Down Expand Up @@ -1027,11 +1027,20 @@ def digamma(x: str | pl.Expr) -> pl.Expr:
# is_elementwise=True,
# )

# kwargs = {
# "k": n_neighbors,
# "metric": "l1",
# "parallel": False,
# "skip_eval": False,
# "max_bound": 99999.0,
# "epsilon": 0.
# }

# # This is not really exposed to the user. It does `dist_from_kth_nb` and a `next_down` in one go.
# r = pl_plugin(
# symbol="_pl_dist_from_kth_nb",
# symbol="pl_dist_from_kth_nb",
# args=[c],
# kwargs={"k": n_neighbors, "parallel": False},
# kwargs=kwargs,
# ).over(t)

# label_counts = c.len().over(t)
Expand All @@ -1050,7 +1059,6 @@ def digamma(x: str | pl.Expr) -> pl.Expr:
# },
# )

# # return nb_cnt

# # psi in SciPy is the diagamma function
# psi_label_counts = pl_plugin(
Expand Down
2 changes: 0 additions & 2 deletions src/arkadia/kdt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ impl<'a, T: Float + DistanceOps + 'static + Debug, A: Copy> KDT<'a, T, A> {
current_max: T,
max_dist_bound: T,
) {
// This is only called if is_leaf. Safe to unwrap.
let mut cur_max = current_max;
for element in self.data.iter() {
let dist = self.d.dist(element.row_vec, point);
Expand All @@ -306,7 +305,6 @@ impl<'a, T: Float + DistanceOps + 'static + Debug, A: Copy> KDT<'a, T, A> {

// #[inline(always)]
fn update_nb_within(&self, neighbors: &mut Vec<NB<T, A>>, point: &[T], radius: T) {
// This is only called if is_leaf. Safe to unwrap.
for element in self.data.iter() {
let y = element.row_vec;
let dist = self.d.dist(y, point);
Expand Down
118 changes: 92 additions & 26 deletions src/num_ext/knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,73 @@ where
}
}

#[polars_expr(output_type=Float64)]
fn pl_dist_from_kth_nb(
inputs: &[Series],
context: CallerContext,
kwargs: KDTKwargs,
) -> PolarsResult<Series> {
// Set up params
let k = kwargs.k;
let can_parallel = kwargs.parallel && !context.parallel();
let ncols = inputs.len();
let data = series_to_row_major_slice::<Float64Type>(inputs)?;
match DIST::<f64>::new_from_str_informed(kwargs.metric, ncols) {
Ok(d) => {
let mut leaves: Vec<Leaf<f64, ()>> = data
.chunks_exact(ncols)
.map(|sl| ((), sl).into())
.collect();
// let mut leaves = row_major_slice_to_leaves(&data, ncols, id, null_mask);
let tree = KDT::from_leaves_unchecked(&mut leaves, d);
let nrows = data.len() / ncols;
let ca = if can_parallel {
let n_threads = POOL.current_num_threads();
let splits = split_offsets(nrows, n_threads);
let chunks_iter = splits.into_par_iter()
.map(|(i, n)| {
let start = i * ncols;
let end = (i + n) * ncols;
let slice = &data[start..end];
let mut builder: PrimitiveChunkedBuilder<Float64Type>
= PrimitiveChunkedBuilder::new("".into(), n);

for point in slice.chunks_exact(ncols) {
match tree.knn_bounded(k + 1, point, kwargs.max_bound, kwargs.epsilon) {
Some(mut nbs) => {
builder.append_option(
nbs.pop().map(|nb| nb.to_dist())
);
},
_ => builder.append_null()
}
}
let ca = builder.finish();
ca.downcast_iter().cloned().collect::<Vec<_>>()
});
let chunks = POOL.install(|| chunks_iter.collect::<Vec<_>>());
Float64Chunked::from_chunk_iter("".into(), chunks.into_iter().flatten())
} else {
let mut builder: PrimitiveChunkedBuilder<Float64Type>
= PrimitiveChunkedBuilder::new("".into(), nrows);
for point in data.chunks_exact(ncols) {
match tree.knn_bounded(k + 1, point, kwargs.max_bound, kwargs.epsilon) {
Some(mut nbs) => {
builder.append_option(
nbs.pop().map(|nb| nb.to_dist())
);
},
_ => builder.append_null()
}
}
builder.finish()
};
Ok(ca.into_series())
}
Err(e) => Err(PolarsError::ComputeError(e.into())),
}
}

#[polars_expr(output_type_func=list_u32_output)]
fn pl_knn_ptwise(
inputs: &[Series],
Expand All @@ -228,8 +295,8 @@ fn pl_knn_ptwise(
let id = id.cont_slice()?;
let nrows = id.len();

// True means no nulls, keep
let null_mask = inputs[1].bool().unwrap();
// True means keep
let keep_mask = inputs[1].bool().unwrap();

let mut inputs_offset = 2;
let eval_mask = if skip_eval {
Expand All @@ -246,9 +313,9 @@ fn pl_knn_ptwise(
let ncols = inputs[inputs_offset..].len();
let data = series_to_row_major_slice::<Float64Type>(&inputs[inputs_offset..])?;

let ca = match DIST::<f64>::new_from_str_informed(kwargs.metric, ncols) {
match DIST::<f64>::new_from_str_informed(kwargs.metric, ncols) {
Ok(d) => {
let mut leaves = row_major_slice_to_leaves_filtered(&data, ncols, id, null_mask);
let mut leaves = row_major_slice_to_leaves_filtered(&data, ncols, id, keep_mask);
let tree = KDT::from_leaves_unchecked(&mut leaves, d);
Ok(knn_ptwise(
tree,
Expand All @@ -258,13 +325,11 @@ fn pl_knn_ptwise(
can_parallel,
kwargs.max_bound,
kwargs.epsilon,
))
).into_series())
}
Err(e) => Err(e),
Err(e) => Err(PolarsError::ComputeError(e.into())),
}
.map_err(|err| PolarsError::ComputeError(err.into()))?;

Ok(ca.into_series())
}

pub fn knn_ptwise_w_dist<'a, Kdt>(
Expand Down Expand Up @@ -427,9 +492,8 @@ fn pl_knn_ptwise_w_dist(
kwargs.epsilon,
))
}
Err(e) => Err(e),
}
.map_err(|err| PolarsError::ComputeError(err.into()))?;
Err(e) => Err(PolarsError::ComputeError(e.into())),
}?;

let out = StructChunked::from_series(
"knn_dist".into(),
Expand Down Expand Up @@ -512,16 +576,14 @@ fn pl_query_radius_ptwise(
let ncols = inputs[1..].len();
let data = series_to_row_major_slice::<Float64Type>(&inputs[1..])?;
// Building output
let ca = match DIST::<f64>::new_from_str_informed(kwargs.metric, ncols) {
match DIST::<f64>::new_from_str_informed(kwargs.metric, ncols) {
Ok(d) => {
let mut leaves = slice_to_leaves(&data, ncols, id);
let tree = KDT::from_leaves_unchecked(&mut leaves, d);
Ok(query_radius_ptwise(tree, &data, radius, can_parallel, sort))
Ok(query_radius_ptwise(tree, &data, radius, can_parallel, sort).into_series())
}
Err(e) => Err(e),
Err(e) => Err(PolarsError::ComputeError(e.into())),
}
.map_err(|err| PolarsError::ComputeError(err.into()))?;
Ok(ca.into_series())
}

#[inline]
Expand Down Expand Up @@ -619,27 +681,31 @@ fn pl_nb_cnt(inputs: &[Series], context: CallerContext, kwargs: KDTKwargs) -> Po

if radius.len() == 1 {
let r = radius.get(0).unwrap();
let ca = match DIST::<f64>::new_from_str_informed(kwargs.metric, ncols) {
match DIST::<f64>::new_from_str_informed(kwargs.metric, ncols) {
Ok(d) => {
let mut leaves = slice_to_empty_leaves(&data, ncols);
let tree = KDT::from_leaves_unchecked(&mut leaves, d);
Ok(query_nb_cnt(tree, &data, r, can_parallel))
Ok(
query_nb_cnt(tree, &data, r, can_parallel)
.with_name("cnt".into())
.into_series()
)
}
Err(e) => Err(e),
Err(e) => Err(PolarsError::ComputeError(e.into())),
}
.map_err(|err| PolarsError::ComputeError(err.into()))?;
Ok(ca.with_name("cnt".into()).into_series())
} else if radius.len() == nrows {
let ca = match DIST::<f64>::new_from_str_informed(kwargs.metric, ncols) {
match DIST::<f64>::new_from_str_informed(kwargs.metric, ncols) {
Ok(d) => {
let mut leaves = slice_to_empty_leaves(&data, ncols);
let tree = KDT::from_leaves_unchecked(&mut leaves, d);
Ok(query_nb_cnt_w_radius(tree, &data, radius, can_parallel))
Ok(
query_nb_cnt_w_radius(tree, &data, radius, can_parallel)
.with_name("cnt".into())
.into_series()
)
}
Err(e) => Err(e),
Err(e) => Err(PolarsError::ComputeError(e.into())),
}
.map_err(|err| PolarsError::ComputeError(err.into()))?;
Ok(ca.with_name("cnt".into()).into_series())
} else {
Err(PolarsError::ShapeMismatch(
"Inputs must have the same length or one of them must be a scalar.".into(),
Expand Down
32 changes: 32 additions & 0 deletions tests/test_many.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,3 +1813,35 @@ def test_digamma():
scipy_digamma = scipy.special.psi(a)

assert np.all(np.isclose(pds_digamma, scipy_digamma, atol=1e-5))


def test_kth_nb_dist():
size = 2000
df = pl.DataFrame(
{
"id": range(size),
}
).with_columns(
pds.random().alias("var1"),
pds.random().alias("var2"),
pds.random().alias("var3"),
)
# method 1 is what we want to test
# method 2 is assumed to be the truth.
test = (
df.select(
pds.query_dist_from_kth_nb("var1", "var2", "var3", dist="l1", k=3).alias(
"kth_nb_dist_method_1"
),
pds.query_knn_ptwise(
"var1", "var2", "var3", index="id", return_dist=True, k=3, dist="l1"
)
.struct.field("dist")
.list.last()
.alias("kth_nb_dist_method_2"),
)
.select((pl.col("kth_nb_dist_method_1") == pl.col("kth_nb_dist_method_2")).all())
.item(0, 0)
)

assert test is True

0 comments on commit 1de6d10

Please sign in to comment.