-
Notifications
You must be signed in to change notification settings - Fork 2
/
kdtree.py
72 lines (60 loc) · 2.2 KB
/
kdtree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
from itertools import product
from collections import namedtuple
__all__ = ['query_radius_periodic']
def repeat_periodic(points, boxsize):
"""Repeat data to mock periodic boundaries.
points (m, n) -> repeated_points (m, 3**n, n)
Each point (x1, ..., xn) will have to 3**n copies:
(x1, ..., xn)
(x1, ..., xn-L)
(x1, ..., xn+L)
...
(x1+L, ..., xn)
(x1+L, ..., xn-L)
(x1+L, ..., xn+L)
"""
points = np.asarray(points)
ndim = points.shape[-1]
shift = np.array(list(product([0, -1, 1], repeat=ndim))) * boxsize
repeated_points = points[..., np.newaxis, :] + shift
return repeated_points
def query_radius_periodic(tree, points, radius, boxsize=None, merge=False):
"""
tree: sklearn.neighbors.KDTree instance
points : array-like
An array of points to query.
radius : float or array-like
Distance within which neighbors are returned.
boxsize : float or array-like
Periodic boxsize.
merge : bool
If True, all outputs will be merged into single array.
"""
ndim = tree.data.shape[-1]
nrep = 3**ndim
if points.shape[-1] != ndim:
raise ValueError("Incompatible shape.")
if boxsize is None:
periodic = False
else:
periodic = True
points = repeat_periodic(points, boxsize=boxsize).reshape(-1, ndim)
if not np.isscalar(radius):
radius = np.repeat(radius, nrep)
idx, dis = tree.query_radius(points, radius, return_distance=True)
cnt = np.array(list(map(len, idx)))
if periodic:
cnt = cnt.reshape(-1, nrep).sum(-1)
if merge:
idx = np.concatenate(idx)
dis = np.concatenate(dis)
elif periodic:
if (cnt == cnt[0]).all():
idx = np.array(list(map(np.concatenate, idx.reshape(-1, nrep))))
dis = np.array(list(map(np.concatenate, dis.reshape(-1, nrep))))
else:
idx = np.array(list(map(np.concatenate, idx.reshape(-1, nrep))), dtype=object)
dis = np.array(list(map(np.concatenate, dis.reshape(-1, nrep))), dtype=object)
type = namedtuple('KDTreeQuery', ['count', 'index', 'distance'])
return type(cnt, idx, dis)