diff --git a/src/python/kdtree.cpp b/src/python/kdtree.cpp index cd05a8d..398588e 100644 --- a/src/python/kdtree.cpp +++ b/src/python/kdtree.cpp @@ -16,11 +16,22 @@ using namespace small_gicp; void define_kdtree(py::module& m) { // KdTree - py::class_, std::shared_ptr>>(m, "KdTree", "KdTree") // + py::class_, std::shared_ptr>>(m, "KdTree") // .def( py::init([](const PointCloud::ConstPtr& points, int num_threads) { return std::make_shared>(points, KdTreeBuilderOMP(num_threads)); }), py::arg("points"), - py::arg("num_threads") = 1) + py::arg("num_threads") = 1, + R"""( + Construct a KdTree from a point cloud. + + Parameters + ---------- + points : PointCloud + The input point cloud. + num_threads : int, optional + The number of threads to use for KdTree construction. Default is 1. + )""") + .def( "nearest_neighbor_search", [](const KdTree& kdtree, const Eigen::Vector3d& pt) { @@ -30,7 +41,23 @@ void define_kdtree(py::module& m) { return std::make_tuple(found, k_index, k_sq_dist); }, py::arg("pt"), - "Search the nearest neighbor. Returns a tuple of found flag, index, and squared distance.") + R"""( + Find the nearest neighbor to a given point. + + Parameters + ---------- + pt : NDArray, shape (3,) + The input point. + + Returns + ------- + found : int + Whether a neighbor was found (1 if found, 0 if not). + k_index : int + The index of the nearest neighbor in the point cloud. + k_sq_dist : float + The squared distance to the nearest neighbor. + )""") .def( "knn_search", [](const KdTree& kdtree, const Eigen::Vector3d& pt, int k) { @@ -41,5 +68,86 @@ void define_kdtree(py::module& m) { }, py::arg("pt"), py::arg("k"), - "Search the k-nearest neighbors. Returns a pair of indices and squared distances."); -} \ No newline at end of file + R"""( + Find the k nearest neighbors to a given point. + + Parameters + ---------- + pt : NDArray, shape (3,) + The input point. + k : int + The number of nearest neighbors to search for. + + Returns + ------- + k_indices : NDArray, shape (k,) + The indices of the k nearest neighbors in the point cloud. + k_sq_dists : NDArray, shape (k,) + The squared distances to the k nearest neighbors. + )""") + .def( + "batch_nearest_neighbor_search", + [](const KdTree& kdtree, const Eigen::MatrixXd& pts) { + std::vector k_indices(pts.rows(), -1); + std::vector k_sq_dists(pts.rows(), std::numeric_limits::max()); + for (int i = 0; i < pts.rows(); ++i) { + const size_t found = traits::nearest_neighbor_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), &k_indices[i], &k_sq_dists[i]); + if (!found) { + k_indices[i] = -1; + k_sq_dists[i] = std::numeric_limits::max(); + } + } + return std::make_pair(k_indices, k_sq_dists); + }, + py::arg("pts"), + R"""( + Find the nearest neighbors for a batch of points. + + Parameters + ---------- + pts : NDArray, shape (n, 3) + The input points. + + Returns + ------- + k_indices : NDArray, shape (n,) + The indices of the nearest neighbors for each input point. + k_sq_dists : NDArray, shape (n,) + The squared distances to the nearest neighbors for each input point. + )""") + .def( + "batch_knn_search", + [](const KdTree& kdtree, const Eigen::MatrixXd& pts, int k) { + std::vector> k_indices(pts.rows(), std::vector(k, -1)); + std::vector> k_sq_dists(pts.rows(), std::vector(k, std::numeric_limits::max())); + for (int i = 0; i < pts.rows(); ++i) { + const size_t found = traits::knn_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), k, k_indices[i].data(), k_sq_dists[i].data()); + if (found < k) { + for (size_t j = found; j < k; ++j) { + k_indices[i][j] = -1; + k_sq_dists[i][j] = std::numeric_limits::max(); + } + } + } + return std::make_pair(k_indices, k_sq_dists); + }, + py::arg("pts"), + py::arg("k"), + R"""( + Find the k nearest neighbors for a batch of points. + + Parameters + ---------- + pts : NDArray, shape (n, 3) + The input points. + k : int + The number of nearest neighbors to search for. + + Returns + ------- + k_indices : list of NDArray, shape (n,) + The list of indices of the k nearest neighbors for each input point. + k_sq_dists : list of NDArray, shape (n,) + The list of squared distances to the k nearest neighbors for each input point. + )"""); +}