From 7e42a90d27f2d257f35b799dcd4c1182224ebd24 Mon Sep 17 00:00:00 2001 From: koide3 <31344317+koide3@users.noreply.github.com> Date: Thu, 20 Jun 2024 11:25:51 +0900 Subject: [PATCH] parallel batch nearest neighbor search (#68) --- src/python/kdtree.cpp | 31 ++++++++++++++++--- src/test/python_test.py | 67 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/src/python/kdtree.cpp b/src/python/kdtree.cpp index 398588e..37ec158 100644 --- a/src/python/kdtree.cpp +++ b/src/python/kdtree.cpp @@ -58,6 +58,7 @@ void define_kdtree(py::module& m) { k_sq_dist : float The squared distance to the nearest neighbor. )""") + .def( "knn_search", [](const KdTree& kdtree, const Eigen::Vector3d& pt, int k) { @@ -85,11 +86,18 @@ void define_kdtree(py::module& m) { k_sq_dists : NDArray, shape (k,) The squared distances to the k nearest neighbors. )""") + .def( "batch_nearest_neighbor_search", - [](const KdTree& kdtree, const Eigen::MatrixXd& pts) { + [](const KdTree& kdtree, const Eigen::MatrixXd& pts, int num_threads) { + if (pts.cols() != 3 && pts.cols() != 4) { + throw std::invalid_argument("pts must have shape (n, 3) or (n, 4)"); + } + std::vector k_indices(pts.rows(), -1); std::vector k_sq_dists(pts.rows(), std::numeric_limits::max()); + +#pragma omp parallel for num_threads(num_threads) for (int i = 0; i < pts.rows(); ++i) { const size_t found = traits::nearest_neighbor_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), &k_indices[i], &k_sq_dists[i]); if (!found) { @@ -97,16 +105,20 @@ void define_kdtree(py::module& m) { k_sq_dists[i] = std::numeric_limits::max(); } } + return std::make_pair(k_indices, k_sq_dists); }, py::arg("pts"), + py::arg("num_threads") = 1, R"""( Find the nearest neighbors for a batch of points. Parameters ---------- - pts : NDArray, shape (n, 3) + pts : NDArray, shape (n, 3) or (n, 4) The input points. + num_threads : int, optional + The number of threads to use for the search. Default is 1. Returns ------- @@ -115,11 +127,18 @@ void define_kdtree(py::module& m) { k_sq_dists : NDArray, shape (n,) The squared distances to the nearest neighbors for each input point. )""") + .def( "batch_knn_search", - [](const KdTree& kdtree, const Eigen::MatrixXd& pts, int k) { + [](const KdTree& kdtree, const Eigen::MatrixXd& pts, int k, int num_threads) { + if (pts.cols() != 3 && pts.cols() != 4) { + throw std::invalid_argument("pts must have shape (n, 3) or (n, 4)"); + } + std::vector> k_indices(pts.rows(), std::vector(k, -1)); std::vector> k_sq_dists(pts.rows(), std::vector(k, std::numeric_limits::max())); + +#pragma omp parallel for num_threads(num_threads) for (int i = 0; i < pts.rows(); ++i) { const size_t found = traits::knn_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), k, k_indices[i].data(), k_sq_dists[i].data()); if (found < k) { @@ -129,19 +148,23 @@ void define_kdtree(py::module& m) { } } } + return std::make_pair(k_indices, k_sq_dists); }, py::arg("pts"), py::arg("k"), + py::arg("num_threads") = 1, R"""( Find the k nearest neighbors for a batch of points. Parameters ---------- - pts : NDArray, shape (n, 3) + pts : NDArray, shape (n, 3) or (n, 4) The input points. k : int The number of nearest neighbors to search for. + num_threads : int, optional + The number of threads to use for the search. Default is 1. Returns ------- diff --git a/src/test/python_test.py b/src/test/python_test.py index 15672a2..e1030f2 100755 --- a/src/test/python_test.py +++ b/src/test/python_test.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright 2024 Kenji Koide # SPDX-License-Identifier: MIT import numpy +from scipy.spatial import KDTree from scipy.spatial.transform import Rotation import small_gicp @@ -188,3 +189,69 @@ def test_registration(load_points): result = small_gicp.align(target_voxelmap, source) verify_result(result.T_target_source, gt_T_target_source) + +# KdTree test +def test_kdtree(load_points): + _, target_raw_numpy, source_raw_numpy = load_points + + target, target_tree = small_gicp.preprocess_points(target_raw_numpy, downsampling_resolution=0.5) + source, source_tree = small_gicp.preprocess_points(source_raw_numpy, downsampling_resolution=0.5) + + target_tree_ref = KDTree(target.points()) + source_tree_ref = KDTree(source.points()) + + def batch_test(points, queries, tree, tree_ref, num_threads): + # test for batch interface + k_dists_ref, k_indices_ref = tree_ref.query(queries, k=1) + k_indices, k_sq_dists = tree.batch_nearest_neighbor_search(queries) + assert numpy.all(numpy.abs(numpy.square(k_dists_ref) - k_sq_dists) < 1e-6) + assert numpy.all(numpy.abs(numpy.linalg.norm(points[k_indices] - queries, axis=1) ** 2 - k_sq_dists) < 1e-6) + + for k in [2, 10]: + k_dists_ref, k_indices_ref = tree_ref.query(queries, k=k) + k_sq_dists_ref, k_indices_ref = numpy.array(k_dists_ref) ** 2, numpy.array(k_indices_ref) + + k_indices, k_sq_dists = tree.batch_knn_search(queries, k, num_threads=num_threads) + k_indices, k_sq_dists = numpy.array(k_indices), numpy.array(k_sq_dists) + + assert(numpy.all(numpy.abs(k_sq_dists_ref - k_sq_dists) < 1e-6)) + for i in range(k): + diff = numpy.linalg.norm(points[k_indices[:, i]] - queries, axis=1) ** 2 - k_sq_dists[:, i] + assert(numpy.all(numpy.abs(diff) < 1e-6)) + + # test for single query interface + if num_threads != 1: + return + + k_dists_ref, k_indices_ref = tree_ref.query(queries, k=1) + k_indices2, k_sq_dists2 = [], [] + for query in queries: + found, index, sq_dist = tree.nearest_neighbor_search(query[:3]) + assert found + k_indices2.append(index) + k_sq_dists2.append(sq_dist) + + assert numpy.all(numpy.abs(numpy.square(k_dists_ref) - k_sq_dists2) < 1e-6) + assert numpy.all(numpy.abs(numpy.linalg.norm(points[k_indices2] - queries, axis=1) ** 2 - k_sq_dists2) < 1e-6) + + for k in [2, 10]: + k_dists_ref, k_indices_ref = tree_ref.query(queries, k=k) + k_sq_dists_ref, k_indices_ref = numpy.array(k_dists_ref) ** 2, numpy.array(k_indices_ref) + + k_indices2, k_sq_dists2 = [], [] + for query in queries: + indices, sq_dists = tree.knn_search(query[:3], k) + k_indices2.append(indices) + k_sq_dists2.append(sq_dists) + k_indices2, k_sq_dists2 = numpy.array(k_indices2), numpy.array(k_sq_dists2) + + assert(numpy.all(numpy.abs(k_sq_dists_ref - k_sq_dists2) < 1e-6)) + for i in range(k): + diff = numpy.linalg.norm(points[k_indices2[:, i]] - queries, axis=1) ** 2 - k_sq_dists2[:, i] + assert(numpy.all(numpy.abs(diff) < 1e-6)) + + + for num_threads in [1, 2]: + batch_test(target.points(), target.points(), target_tree, target_tree_ref, num_threads=num_threads) + batch_test(target.points(), source.points(), target_tree, target_tree_ref, num_threads=num_threads) + batch_test(source.points(), target.points(), source_tree, source_tree_ref, num_threads=num_threads)