Shrike
is a 100% pure Julia package for building ensembles of random projection trees. Random projection trees are a generalization of KD-Trees and are used to quickly approximate nearest neighbors or build k-nearest-neighbor graphs. They conform to low dimensionality that is often present in high dimensional data.
The implementation here is based on the MRPT algorithm. This package also includes optimizations for knn-graph creation and has built-in support for multithreading.
To install just type
] add Shrike
in the REPL or
using Pkg
Pkg.add("Shrike")
To build an ensemble of random projection trees use the ShrikeIndex
type.
using Shrike
maxk = 100
X = rand(100, 10000)
shi = ShrikeIndex(X, maxk; depth=8, ntrees=10)
The type accepts a matrix of data, X
where each column represents a datapoint.
maxk
represents the maximum number of nearest neighbors you will be able to find with this index.maxk
is used to set a safedepth
for the tree. You can also construct an index without this parameter if you need to.depth
describes the number of times each random projection tree will split the data. Leaf nodes in the tree contain aboutnpoints / 2^depth
data points. Increasingdepth
increases speed but decreases accuracy. By default, the index sets depth as large as possible.ntrees
controls the number of trees in the ensemble. More trees means more accuracy but more memory.
In this case, since we need an index that can find the 100 nearest neighbors, setting depth
equal to 8 will result in
some leaf nodes with less than 100 points. The index will infer this using maxk
and set the depth
to be as large as
possible given maxk
. In this case, depth = 6
.
To query the index for approximate 10 nearest neighbors use:
k = 10
q = X[:, 1]
approx_nn = ann(shi, q, k; vote_cutoff=2)
- The
vote_cutoff
parameter signifies how many "votes" a point needs in order to be included in a linear search. Increasingvote_cutoff
speeds up the algorithm but may reduce accuracy. Each tree "votes" for all points in relevant leaf nodes. If there aren't many points in the leaves, and there aren't many trees, the odds of a point receiving more than one vote is low. Thus, whendepth
is large andntrees
is less than 5, it is reccomended to setvote_cutoff = 1
.
This package includes fast algorithms to generate k-nearest-neighbor graphs and has specialized functions for this purpose. It uses neighbor of neighbor exploration (outlined here) to efficiently improve the accuracy of a knn-graph.
Nearest neighbor graphs are used to give a sparse topology to large datasets. Their structure can be used to project the data onto a lower dimensional manifold, to cluster datapoints with community detection algorithms or to preform other analyses.
To generate nearest neighbor graphs:
using Shrike
X = rand(100, 10000)
shi = ShrikeIndex(X; depth=6, ntrees=5)
k = 10
g = knngraph(shi, k; vote_cutoff=1, ne_iters=1, gtype=SimpleDiGraph)
- The
vote_cutoff
parameter signifies how many "votes" a point needs in order to be included in a linear search. ne_iters
controlls how many iterations of neighbor exploration the algorithm will undergo. Successive iterations are increasingly fast. It is reccomened to use more iterations of neighbor exploration when the number of trees is small and less when many trees are used.- The
gtype
parameter allows the user to specify aLightGraphs.jl
graph type to return.gtype=identity
returns a sparse adjacency matrix.
If an array of nearest neighbor indices is preferred,
nn = allknn(shi, k; vote_cutoff=1, ne_iters=0)
can be used to generate an shi.npoints
xk
array of integer indexes where nn[i, :]
corresponds to the nearest neighbors of X[:, i]
. The keyword arguments work in the same way as in knngraph
(outlined above).
Shrike
has built in support for multithreading. To allocate multiple threads, start julia
with the --threads
flag:
user@sys:~$ julia --threads 4
To see this at work, consider a small scale example:
user@sys:~$ cmd="using Shrike; shi=ShrikeIndex(rand(100, 10000)); @time knngraph(shi, 10, ne_iters=1)"
user@sys:~$ julia -e "$cmd"
12.373127 seconds (8.66 M allocations: 4.510 GiB, 6.85% gc time, 18.88% compilation time)
user@sys:~$ julia --threads 4 -e "$cmd"
6.306410 seconds (8.67 M allocations: 4.498 GiB, 13.12% gc time, 31.64% compilation time)
(This assumes that Shrike
is installed.)
This package was compared to the original mrpt
C++ implementation (on which this algorithm was based), annoy
, a popular package for approximate nearest neighbors, and NearestNeighbors.jl
, a Julia package for nearest neighbor search. The benchmarks were written in the spirit of ann-benchmarks
, a repository for comparing different approximate nearest neighbor algorithms. The datasets used for the benchmark were taken directly from ann-benchmarks
. The following are links to the HDF5 files in question: FashionMNIST, SIFT, MNIST and GIST. The benchmarks below were run on a compute cluster, restricting all algorithms to a single thread.
In this plot, up and to the right is better. (Faster queries, better recall). Each point represents a parameter combination. For a full documentation of parameters run and timing methods consult the original scripts located in the benchmark/
directory.
This plot illustrates how for this dataset, on most parameter combinations, Shrike
has better preformance. Compared to SIFT, below, where some parameter combinations are not as strong. We speculate that this has to do with the high dimensionality of points in FashionMNIST (d=784), compared to the lower dimensionality of SIFT (d=128).
It is important to note that NearestNeighbors.jl
was designed to return the exact k-nearest-neighbors as quickly as possible, and does not approximate, hence the high accuracy and lower speed.
The takeaway here is that Shrike
is fast! It is possibly a little faster than the original C++ implementation. Go Julia! We should note, that Shrike
was not benchmarked against state of the art algorithms for approximate nearest neighbor search. These algorithms are faster than annoy
and mrpt
, but unfortunately, the developers of Shrike
aren't familiar with these algorithms.