diff --git a/Makefile b/Makefile index 18137a8..f2dcae6 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,7 @@ .PHONY : docs docs : rm -rf docs/build/ - sphinx-autobuild -b html --watch simplextree/ docs/source/ docs/build/ - + .PHONY : run-checks run-checks : isort --check . diff --git a/include/UnionFind.h b/include/UnionFind.h new file mode 100644 index 0000000..06dd7d8 --- /dev/null +++ b/include/UnionFind.h @@ -0,0 +1,75 @@ +//---------------------------------------------------------------------- +// Disjoint-set data structure +// File: union_find.h +//---------------------------------------------------------------------- +// Copyright (c) 2018 Matt Piekenbrock. All Rights Reserved. +// +// Class definition based off of data-structure described here: +// https://en.wikipedia.org/wiki/Disjoint-set_data_structure + +#ifndef UNIONFIND_H_ +#define UNIONFIND_H_ + +#include // size_t +#include // vector +#include // iota +#include // transform + +struct UnionFind { + using idx_t = std::size_t; + using idx_v = std::vector< size_t >; + idx_t size; + mutable idx_v parent, rank; + + UnionFind(const idx_t _size) : size(_size), parent(_size), rank(_size){ + std::iota(parent.begin(), parent.end(), 0); + } + + // Main operations + void Union(const idx_t x, const idx_t y){ + if (x >= size || y >= size){ return; } + const idx_t xRoot = Find(x), yRoot = Find(y); + if (xRoot == yRoot){ return; } + else if (rank[xRoot] > rank[yRoot]) { parent[yRoot] = xRoot; } + else if (rank[xRoot] < rank[yRoot]) { parent[xRoot] = yRoot; } + else if (rank[xRoot] == rank[yRoot]) { + parent[yRoot] = parent[xRoot]; + rank[xRoot] = rank[xRoot] + 1; + } + } + idx_t Find(const idx_t x) const { + if (x >= size || parent[x] == x){ return x; } + else { + parent[x] = Find(parent[x]); + return parent[x]; + } + } + void AddSets(const idx_t n_sets){ + parent.resize(size + n_sets); + std::iota(parent.begin() + size, parent.end(), size); // parent initialized incrementally + rank.resize(size + n_sets, 0); // rank all 0 + size += n_sets; + } + + // Convenience functions + void UnionAll(const idx_v& idx){ + if (idx.size() <= 1){ return; } + const idx_t n_pairs = idx.size()-1; + for (idx_t i = 0; i < n_pairs; ++i){ Union(idx[i], idx[i+1]); } + } + idx_v FindAll(const idx_v& idx){ + if (idx.size() == 0){ return idx_v(); } + idx_v cc = idx_v(idx.size()); + std::transform(idx.begin(), idx.end(), cc.begin(), [this](const size_t i){ + return(Find(i)); + }); + return(cc); + } + idx_v ConnectedComponents(){ + idx_v cc = idx_v(size); + for (size_t i = 0; i < size; ++i){ cc[i] = Find(i); } + return(cc); + } +}; // class UnionFind + +#endif diff --git a/include/combinatorial.h b/include/combinatorial.h new file mode 100644 index 0000000..23a5f23 --- /dev/null +++ b/include/combinatorial.h @@ -0,0 +1,617 @@ +// combinatorial.h +// Contains routines for combinatorics-related tasks +// The combinations and permutations generation code is copyright Howard Hinnant, taken from: https://github.com/HowardHinnant/combinations/blob/master/combinations.h +#ifndef COMBINATORIAL_H +#define COMBINATORIAL_H + +#include // uint_fast64_t +#include +// #include // span (C++20) +#include // round, sqrt +#include // midpoint, accumulate +#include // vector +#include +#include +#include +#include +#include +#include + +using std::begin; +using std::end; +using std::vector; +using std::size_t; + +namespace combinatorial { + using index_t = uint_fast64_t; + + template + using it_diff_t = typename std::iterator_traits::difference_type; + + // Rotates two discontinuous ranges to put *first2 where *first1 is. + // If last1 == first2 this would be equivalent to rotate(first1, first2, last2), + // but instead the rotate "jumps" over the discontinuity [last1, first2) - + // which need not be a valid range. + // In order to make it faster, the length of [first1, last1) is passed in as d1, + // and d2 must be the length of [first2, last2). + // In a perfect world the d1 > d2 case would have used swap_ranges and + // reverse_iterator, but reverse_iterator is too inefficient. + template + void rotate_discontinuous( + It first1, It last1, it_diff_t< It > d1, + It first2, It last2, it_diff_t< It > d2) + { + using std::swap; + if (d1 <= d2){ std::rotate(first2, std::swap_ranges(first1, last1, first2), last2); } + else { + It i1 = last1; + while (first2 != last2) + swap(*--i1, *--last2); + std::rotate(first1, i1, last1); + } + } + + // Call f() for each combination of the elements [first1, last1) + [first2, last2) + // swapped/rotated into the range [first1, last1). As long as f() returns + // false, continue for every combination and then return [first1, last1) and + // [first2, last2) to their original state. If f() returns true, return + // immediately. + // Does the absolute mininum amount of swapping to accomplish its task. + // If f() always returns false it will be called (d1+d2)!/(d1!*d2!) times. + template < typename It, typename Lambda > + bool combine_discontinuous( + It first1, It last1, it_diff_t< It > d1, + It first2, It last2, it_diff_t< It > d2, + Lambda&& f, it_diff_t< It > d = 0) + { + using D = it_diff_t< It >; + using std::swap; + if (d1 == 0 || d2 == 0){ return f(); } + if (d1 == 1) { + for (It i2 = first2; i2 != last2; ++i2) { + if (f()){ return true; } + swap(*first1, *i2); + } + } + else { + It f1p = std::next(first1), i2 = first2; + for (D d22 = d2; i2 != last2; ++i2, --d22){ + if (combine_discontinuous(f1p, last1, d1-1, i2, last2, d22, f, d+1)) + return true; + swap(*first1, *i2); + } + } + if (f()){ return true; } + if (d != 0){ rotate_discontinuous(first1, last1, d1, std::next(first2), last2, d2-1); } + else { rotate_discontinuous(first1, last1, d1, first2, last2, d2); } + return false; + } + + template < typename Lambda, typename It > + struct bound_range { + Lambda f_; + It first_, last_; + bound_range(Lambda& f, It first, It last) : f_(f), first_(first), last_(last) {} + bool operator()(){ return f_(first_, last_); } + bool operator()(It, It) { return f_(first_, last_); } + }; + + template + Function for_each_combination(It first, It mid, It last, Function&& f) { + bound_range wfunc(f, first, mid); + combine_discontinuous(first, mid, std::distance(first, mid), + mid, last, std::distance(mid, last), + wfunc); + return std::move(f); + } + + template + constexpr auto make_index_dispatcher(std::index_sequence) { + return [] (auto&& f) { (f(std::integral_constant{}), ...); }; + }; + + template + constexpr auto make_index_dispatcher() { + return make_index_dispatcher(std::make_index_sequence< N >{}); + }; + + template + struct tuple_n{ + template< typename...Args> using type = typename tuple_n::template type; + }; + + // Modified from: https://stackoverflow.com/questions/38885406/produce-stdtuple-of-same-type-in-compile-time-given-its-length-by-a-template-a + template + struct tuple_n { + template using type = std::tuple; + }; + template < typename T, size_t I > using tuple_of = typename tuple_n::template type<>; + + // Constexpr binomial coefficient using recursive formulation + template < size_t n, size_t k > + constexpr auto bc_recursive() noexcept { + if constexpr ( n == k || k == 0 ){ return(1); } + else if constexpr (n == 0 || k > n){ return(0); } + else { + return (n * bc_recursive< n - 1, k - 1>()) / k; + } + } + + // Baseline from: https://stackoverflow.com/questions/44718971/calculate-binomial-coffeficient-very-reliably + // Requires O(min{k,n-k}), uses pascals triangle approach (+ degenerate cases) + constexpr inline size_t binom(size_t n, size_t k) noexcept { + return + ( k> n )? 0 : // out of range + (k==0 || k==n )? 1 : // edge + (k==1 || k==n-1)? n : // first + ( k+k < n )? // recursive: + (binom(n-1,k-1) * n)/k : // path to k=1 is faster + (binom(n-1,k) * n)/(n-k); // path to k=n-1 is faster + } + + // Non-cached version of the binomial coefficient using floating point algorithm + // Requires O(k), uses very simple loop + [[nodiscard]] + inline size_t binomial_coeff_(const double n, const size_t k) noexcept { + // std::cout << "here2"; + double bc = n; + for (size_t i = 2; i <= k; ++i){ bc *= (n+1-i)/i; } + return(static_cast< size_t >(std::round(bc))); + } + + // Table to cache low values of the binomial coefficient + template< size_t max_n, size_t max_k, typename value_t = index_t > + struct BinomialCoefficientTable { + size_t pre_n = 0; + size_t pre_k = 0; + value_t combinations[max_k][max_n+1]; + vector< vector< value_t > > BT; + + constexpr BinomialCoefficientTable() : combinations() { + auto n_dispatcher = make_index_dispatcher< max_n+1 >(); + auto k_dispatcher = make_index_dispatcher< max_k >(); + n_dispatcher([&](auto i) { + k_dispatcher([&](auto j){ + combinations[j][i] = bc_recursive< i, j >(); + }); + }); + } + + // Evaluate general binomial coefficient, using cached table if possible + value_t operator()(const index_t n, const index_t k) const { + // std::cout << "INFO: " << pre_n << ", " << n << ":" << pre_k << ", " << k << " : " << BT.size() << std::endl; + if (n < max_n && k < max_k){ return combinations[k][n]; } // compile-time computed table + if (n <= pre_n && k <= pre_k){ return BT[k][n]; } // runtime computed extension table + if (k == 0 || n == k){ return 1; } + if (n < k){ return 0; } + if (k == 2){ return static_cast< value_t >((n*(n-1))/2); } + if (k == 1){ return n; } + // return binom(n, k); + // return binomial_coeff_(n,std::min(k,n-k)); + return static_cast< value_t >(binomial_coeff_(n,std::min(k,n-k))); + } + + void precompute(index_t n, index_t k){ + // std::cout << "here" << std::endl; + pre_n = n; + pre_k = k; + BT = vector< vector< index_t > >(k + 1, vector< index_t >(n + 1, 0)); + for (index_t i = 0; i <= n; ++i) { + BT[0][i] = 1; + for (index_t j = 1; j < std::min(i, k + 1); ++j){ + BT[j][i] = BT[j - 1][i - 1] + BT[j][i - 1]; + } + if (i <= k) { BT[i][i] = 1; }; + } + } + + // Fast but unsafe access to a precompute table + [[nodiscard]] + constexpr auto at(index_t n, index_t k) noexcept -> index_t { + return BT[k][n]; + } + + }; // BinomialCoefficientTable + + // Build the cached table + static auto BC = BinomialCoefficientTable< 64, 3 >(); + static bool keep_table_alive = false; + + // Wrapper to choose between cached and non-cached version of the Binomial Coefficient + template< bool safe = true > + constexpr size_t BinomialCoefficient(const size_t n, const size_t k){ + if constexpr(safe){ + return BC(n,k); + } else { + return BC.at(n,k); + } + } + + #if __cplusplus >= 202002L + // C++20 (and later) code + // constexpr midpoint midpoint + using std::midpoint; + #else + template < class Integer > + constexpr Integer midpoint(Integer a, Integer b) noexcept { + return (a+b)/2; + } + #endif + + // All inclusive range binary search + // Compare must return -1 for <(key, index), 0 for ==(key, index), and 1 for >(key, index) + // Guaranteed to return an index in [0, n-1] representing the lower_bound + template< typename T, typename Compare > [[nodiscard]] + int binary_search(const T key, size_t n, Compare p) noexcept { + int low = 0, high = n - 1, best = 0; + while( low <= high ){ + int mid = int{ midpoint(low, high) }; + auto cmp = p(key, mid); + if (cmp == 0){ + while(p(key, mid + 1) == 0){ ++mid; } + return(mid); + } + if (cmp < 0){ high = mid - 1; } + else { + low = mid + 1; + best = mid; + } + } + return(best); + } + + // ----- Combinatorial Number System functions ----- + template< std::integral I, typename Compare > + void sort_contiguous(vector< I >& S, const size_t modulus, Compare comp){ + for (size_t i = 0; i < S.size(); i += modulus){ + std::sort(S.begin()+i, S.begin()+i+modulus, comp); + } + } + + // Lexicographically rank 2-subsets + [[nodiscard]] + constexpr auto rank_lex_2(index_t i, index_t j, const index_t n) noexcept { + if (j < i){ std::swap(i,j); } + return index_t(n*i - i*(i+1)/2 + j - i - 1); + } + + // #include + // Lexicographically rank k-subsets + template< bool safe = true, typename InputIter > + [[nodiscard]] + inline index_t rank_lex_k(InputIter s, const size_t n, const size_t k, const index_t N){ + index_t i = k; + // std::cout << std::endl; + const index_t index = std::accumulate(s, s+k, 0, [n, &i](index_t val, index_t num){ + // std::cout << BinomialCoefficient((n-1) - num, i) << ", "; + return val + BinomialCoefficient< safe >((n-1) - num, i--); + }); + // std::cout << std::endl; + const index_t combinadic = (N-1) - index; // Apply the dual index mapping + return combinadic; + } + + // Rank a stream of integers (lexicographically) + template< bool safe = true, typename InputIt, typename OutputIt > + inline void rank_lex(InputIt s, const InputIt e, const size_t n, const size_t k, OutputIt out){ + switch (k){ + case 2:{ + for (; s != e; s += k){ + *out++ = rank_lex_2(*s, *(s+1), n); + } + break; + } + default: { + const index_t N = BinomialCoefficient< safe >(n, k); + for (; s != e; s += k){ + *out++ = rank_lex_k< safe >(s, n, k, N); + } + break; + } + } + } + + // should be in reverse colexicographical + template< bool safe = true > + [[nodiscard]] + constexpr auto rank_colex_2(index_t i, index_t j) noexcept { + assert(i > j); // should be in colex order! + // return BinomialCoefficient< safe >(j, 2) + i; + return j*(j-1)/2 + i; + // const index_t index = std::accumulate(s, s+k, 0, [&i](index_t val, index_t num){ + // return val + BinomialCoefficient< safe >(num, i--); + // }); + // return index; + } + + // Colexicographically rank k-subsets + // assumes each k tuple of s is in colex order! + template< bool safe = true, typename InputIter > + [[nodiscard]] + constexpr auto rank_colex_k(InputIter s, const size_t k) noexcept { + index_t i = k; + const index_t index = std::accumulate(s, s+k, 0, [&i](index_t val, index_t num){ + return val + BinomialCoefficient< safe >(num, i--); + }); + return index; + } + + template< bool safe = true, typename InputIt, typename OutputIt > + inline void rank_colex(InputIt s, const InputIt e, [[maybe_unused]] const size_t n, const size_t k, OutputIt out){ + switch (k){ + case 2:{ + for (; s != e; s += k){ + *out++ = rank_colex_2(*s, *(s+1)); + } + break; + } + default: { + for (; s != e; s += k){ + *out++ = rank_colex_k< safe >(s, k); + } + break; + } + } + } + + // colex bijection from a lexicographical order + // index_t i = 1; + // const index_t index = std::accumulate(s, s+k, 0, [&i](index_t val, index_t num){ + // return val + BinomialCoefficient< safe >(num, i++); + // }); + // return index; + + template< bool colex = true, bool safe = true, typename InputIt > + inline auto rank_comb(InputIt s, const size_t n, const size_t k){ + if constexpr(colex){ + return rank_colex_k< safe >(s, k); + } else { + const index_t N = BinomialCoefficient< safe >(n, k); + return rank_lex_k< safe >(s, n, k, N); + } + } + + template< bool colex = true, bool safe = true, typename InputIt, typename OutputIt > + inline void rank_combs(InputIt s, const InputIt e, const size_t n, const size_t k, OutputIt out){ + if constexpr(colex){ + for (; s != e; s += k){ + *out++ = rank_colex_k< safe >(s, k); + } + } else { + const index_t N = BinomialCoefficient< safe >(n, k); + for (; s != e; s += k){ + *out++ = rank_lex_k< safe >(s, n, k, N); + } + } + } + + // Lexicographically unrank 2-subsets + template< typename OutputIt > + inline auto unrank_lex_2(const index_t r, const index_t n, OutputIt out) noexcept { + auto i = static_cast< index_t >( (n - 2 - floor(sqrt(-8*r + 4*n*(n-1)-7)/2.0 - 0.5)) ); + auto j = static_cast< index_t >( r + i + 1 - n*(n-1)/2 + (n-i)*((n-i)-1)/2 ); + *out++ = i; // equivalent to *out = i; ++i; + *out++ = j; // equivalent to *out = j; ++j; + } + + // Lexicographically unrank k-subsets [ O(log n) version ] + template< bool safe = true, typename OutputIterator > + inline void unrank_lex_k(index_t r, const size_t n, const size_t k, OutputIterator out) noexcept { + const size_t N = combinatorial::BinomialCoefficient< safe >(n, k); + r = (N-1) - r; + // auto S = std::vector< size_t >(k); + for (size_t ki = k; ki > 0; --ki){ + int offset = binary_search(r, n, [ki](const auto& key, int index) -> int { + auto c = combinatorial::BinomialCoefficient< safe >(index, ki); + return(key == c ? 0 : (key < c ? -1 : 1)); + }); + r -= combinatorial::BinomialCoefficient< safe >(offset, ki); + *out++ = (n-1) - offset; + } + } + + // Lexicographically unrank subsets wrapper + template< bool safe = true, typename InputIt, typename OutputIt > + inline void unrank_lex(InputIt s, const InputIt e, const size_t n, const size_t k, OutputIt out){ + switch(k){ + case 2: + for (; s != e; ++s){ unrank_lex_2(*s, n, out); } + break; + default: + for (; s != e; ++s){ unrank_lex_k< safe >(*s, n, k, out); } + break; + } + } + + template + [[nodiscard]] + index_t get_max(index_t top, const index_t bottom, const Predicate pred) noexcept { + if (!pred(top)) { + index_t count = top - bottom; + while (count > 0) { + index_t step = count >> 1, mid = top - step; + if (!pred(mid)) { + top = mid - 1; + count -= step + 1; + } else { + count = step; + } + } + } + return top; + } + + // From: Kruchinin, Vladimir, et al. "Unranking Small Combinations of a Large Set in Co-Lexicographic Order." Algorithms 15.2 (2022): 36. + // return std::ceil(m * exp(log(r)/m + log(2*pi*m)/2*m + 1/(12*m*m) - 1/(360*pow(m,4)) - 1) + (m-1)/2); + [[nodiscard]] + constexpr auto find_k(index_t r, index_t m) noexcept -> int { + if (r == 0 || m == 0){ return m - 1; } + else if (m == 1){ return r - 1; } + else if (m == 2){ return std::ceil(std::sqrt(1 + 8*r)/2) - 1; } + else if (m == 3){ return std::ceil(std::pow(6*r, 1/3.)) - 1; } + else { + return m - 1; + } + } + + template< bool safe = true > + [[nodiscard]] + index_t get_max_vertex(const index_t r, const index_t k, const index_t n) noexcept { + // Binary searches in the range [k-1, n] for the largest index _w_ satisfying r >= C(w,k) + // return get_max(n, k-1, [&](index_t w) -> bool { return r >= BinomialCoefficient< safe >(w, k); }); + + const int lb = find_k(r,k); + // assert(BinomialCoefficient(lb, k) <= r); + return BinomialCoefficient< safe >(lb+1, k) > r ? + lb : + ( BinomialCoefficient< safe >(lb+2, k) > r ? + lb + 1 : + ( BinomialCoefficient< safe >(lb+3, k) > r ? + lb + 2 : + get_max(n, lb+3, [&](index_t w) -> bool { return r >= BinomialCoefficient< safe >(w, k); }) + ) + ); + } + + template < bool safe = true, typename InputIt, typename OutputIt > + void unrank_colex(InputIt s, const InputIt e, const index_t n, const index_t k, OutputIt out) { + for (index_t N = n - 1; s != e; ++s, N = n - 1){ + index_t r = *s; + for (index_t d = k; d > 1; --d) { + N = get_max_vertex< safe >(r, d, N); + // std::cout << "r: " << r << ", d: " << d << ", N: " << N << std::endl; + *out++ = N; + r -= BinomialCoefficient< safe >(N, d); + } + *out++ = r; + } + } + + // Unrank subsets wrapper + template< bool colex = true, bool safe = true, typename InputIt, typename OutputIt > + inline void unrank_combs(InputIt s, const InputIt e, const size_t n, const size_t k, OutputIt out){ + if constexpr(colex){ + unrank_colex< safe >(s, e, n, k, out); + } else { + unrank_lex< safe >(s, e, n, k, out); + } + } + +} // namespace combinatorial + + + + + + + +// // Lexicographically unrank subsets wrapper +// template< size_t k, typename InputIt, typename Lambda > +// inline void lex_unrank_f(InputIt s, const InputIt e, const size_t n, Lambda f){ +// if constexpr (k == 2){ +// std::array< I, 2 > edge; +// for (; s != e; ++s){ +// lex_unrank_2(*s, n, edge.begin()); +// f(edge); +// } +// } else if (k == 3){ +// std::array< I, 3 > triangle; +// for (; s != e; ++s){ +// lex_unrank_k(*s, n, 3, triangle.begin()); +// f(triangle); +// } +// } else { +// std::array< I, k > simplex; +// for (; s != e; ++s){ +// lex_unrank_k(*s, n, k, simplex.begin()); +// f(simplex); +// } +// } +// } + + +// [[nodiscard]] +// inline auto lex_unrank_2_array(const index_t r, const index_t n) noexcept -> std::array< I, 2 > { +// auto i = static_cast< index_t >( (n - 2 - floor(sqrt(-8*r + 4*n*(n-1)-7)/2.0 - 0.5)) ); +// auto j = static_cast< index_t >( r + i + 1 - n*(n-1)/2 + (n-i)*((n-i)-1)/2 ); +// return(std::array< I, 2 >{ i, j }); +// } + +// [[nodiscard]] +// inline auto lex_unrank(const size_t rank, const size_t n, const size_t k) -> std::vector< index_t > { +// if (k == 2){ +// auto a = lex_unrank_2_array(rank, n); +// std::vector< index_t > out(a.begin(), a.end()); +// return(out); +// } else { +// std::vector< index_t > out; +// out.reserve(k); +// lex_unrank_k(rank, n, k, std::back_inserter(out)); +// return(out); +// } +// } + + +// template< typename Lambda > +// void apply_boundary(const size_t r, const size_t n, const size_t k, Lambda f){ +// // Given a p-simplex's rank representing a tuple of size p+1, enumerates the ranks of its (p-1)-faces, calling Lambda(*) on its rank +// using combinatorial::I; +// switch(k){ +// case 0: { return; } +// case 1: { +// f(r); +// return; +// } +// case 2: { +// auto p_vertices = std::array< I, 2 >(); +// lex_unrank_2(static_cast< index_t >(r), static_cast< index_t >(n), begin(p_vertices)); +// f(p_vertices[0]); +// f(p_vertices[1]); +// return; +// } +// case 3: { +// auto p_vertices = std::array< I, 3 >(); +// lex_unrank_k(r, n, k, begin(p_vertices)); +// f(lex_rank_2(p_vertices[0], p_vertices[1], n)); +// f(lex_rank_2(p_vertices[0], p_vertices[2], n)); +// f(lex_rank_2(p_vertices[1], p_vertices[2], n)); +// return; +// } +// default: { +// auto p_vertices = std::vector< index_t >(0, k); +// lex_unrank_k(r, n, k, p_vertices.begin()); +// const index_t N = BinomialCoefficient(n, k); +// combinatorial::for_each_combination(begin(p_vertices), begin(p_vertices)+2, end(p_vertices), [&](auto a, auto b){ +// f(lex_rank_k(a, n, k, N)); +// return false; +// }); +// return; +// } +// } +// } // apply boundary + +// template< typename OutputIt > +// void boundary(const size_t p_rank, const size_t n, const size_t k, OutputIt out){ +// apply_boundary(p_rank, n, k, [&out](auto face_rank){ +// *out = face_rank; +// out++; +// }); +// } +// } // namespace combinatorial + + // // Lexicographically unrank k-subsets + // template< typename OutputIterator > + // inline void lex_unrank_k(index_t r, const size_t k, const size_t n, OutputIterator out){ + // auto subset = std::vector< size_t >(k); + // size_t x = 1; + // for (size_t i = 1; i <= k; ++i){ + // while(r >= BinomialCoefficient(n-x, k-i)){ + // r -= BinomialCoefficient(n-x, k-i); + // x += 1; + // } + // *out++ = (x - 1); + // x += 1; + // } + // } + +#endif \ No newline at end of file diff --git a/include/nerve.cpp b/include/nerve.cpp new file mode 100644 index 0000000..853a416 --- /dev/null +++ b/include/nerve.cpp @@ -0,0 +1,181 @@ +#include +using namespace Rcpp; + +// [[Rcpp::plugins(cpp11)]] +#include "simplextree.h" + +// Given list of integer vectors +// [[Rcpp::export]] +bool nfold_intersection(vector< vector< int > > x, const size_t n){ + using it_t = vector< int >::iterator; + auto ranges = vector< std::pair< it_t, it_t > >(); + std::transform(begin(x), end(x), std::back_inserter(ranges), [](vector< int >& cs){ + return std::make_pair(cs.begin(), cs.end()); + }); + bool is_connected = n_intersects(ranges, n); + return(is_connected); +} + +// Computes the nerve up to dimension k +// [[Rcpp::export]] +void nerve_expand(SEXP stx, vector< size_t > ids, vector< vector< int > > cover, const size_t k, const size_t threshold){ + const size_t n_sets = cover.size(); + if (ids.size() != n_sets){ stop("Invalid id/cover combination."); } + + // Extract the simplex tree + SimplexTree& st = *(Rcpp::XPtr< SimplexTree >(stx)); + + // Inserts vertices + std::array< idx_t, 1 > v; + for (auto v_id: ids){ + v[0] = v_id; + st.insert_it(begin(v), end(v), st.root.get(), 0); + } + + // Extract the range pairs + using it_t = vector< int >::iterator; + using range_t = std::pair< it_t, it_t >; + auto ranges = map< size_t, range_t >(); + size_t i = 0; + for (auto& c_set: cover){ ranges.emplace(ids[i++], std::make_pair(begin(c_set), end(c_set))); } + + // First insert all the edges w/ a common intersection + using it_t2 = typename vector< size_t >::iterator; + for_each_combination(begin(ids), begin(ids)+2, end(ids), [&st, &ranges, threshold](it_t2 b, it_t2 e){ + auto edge = std::make_pair(*b, *std::next(b, 1)); + vector< range_t > sets = { ranges[edge.first], ranges[edge.second] }; + bool valid_edge = n_intersects(sets, threshold); + if (valid_edge){ + st.insert_it(b, e, st.root.get(), 0); + } + return false; // always continue + }); + + // Then perform the conditional k-expansion + st.expansion_f(k, [&](node_ptr parent, idx_t depth, idx_t label){ + + // Collect simplex to test + auto k_simplex = st.full_simplex(parent, depth); + k_simplex.push_back(label); + + // Extract current set of ranges + auto current_sets = vector< std::pair< it_t, it_t > >(); + for (auto c_label: k_simplex){ + auto it = ranges.find(c_label); + if (it != ranges.end()){ current_sets.push_back(it->second); } + } + + // Test that their intersection is at least 'threshold'; if so, insert + if ((current_sets.size() == k_simplex.size()) && n_intersects(current_sets, threshold)){ + std::array< idx_t, 1 > int_label = { label }; + st.insert_it(begin(int_label), end(int_label), parent, depth); + } + }); + + return; // Return nothing +} + +// [[Rcpp::export]] +void nerve_expand_f(SEXP stx, vector< size_t > ids, Function include_f, const size_t k){ + + // Extract the simplex tree + SimplexTree& st = *(Rcpp::XPtr< SimplexTree >(stx)); + + // Inserts vertices + std::array< idx_t, 1 > v; + for (auto v_id: ids){ + v[0] = v_id; + st.insert_it(begin(v), end(v), st.root.get(), 0); + } + + // First insert all the edges w/ a common intersection + using it_t = vector< size_t >::iterator; + for_each_combination(begin(ids), begin(ids)+2, end(ids), [&st, &include_f](it_t b, it_t e){ + IntegerVector edge = IntegerVector(b, e); + LogicalVector valid_check = include_f(edge); // This is needed to coerce correctly to bool + bool valid_edge = is_true(all(valid_check)); + if (valid_edge){ st.insert_it(b, e, st.root.get(), 0); } + return false; // always continue + }); + + // Then perform the conditional k-expansion + st.expansion_f(k, [&](node_ptr parent, idx_t depth, idx_t label){ + + // Collect simplex to test + auto k_simplex = st.full_simplex(parent, depth); + k_simplex.push_back(label); + + LogicalVector valid_check = include_f(k_simplex); // This is needed to coerce correctly to bool + bool valid = is_true(all(valid_check)); + + // Test that their intersection is at least 'threshold'; if so, insert + if (valid){ + std::array< idx_t, 1 > int_label = { label }; + st.insert_it(begin(int_label), end(int_label), parent, depth); + } + }); + + return; // Return nothing +} + +/*** R +library(simplextree) + +# st <- simplex_tree(combn(4,2)) +# simplextree:::nerve_expand(st$as_XPtr(), 2) + +# set.seed(1234) +# st <- simplex_tree(as.list(seq(3))) +# simplextree:::nerve_expand(st$as_XPtr(), ids = st$vertices, cover = list(c(1,2,3), c(3, 4, 5, 6), c(3, 6, 7)), k = 2, threshold = 1) +# +# +# set.seed(1234) +# alphabet <- seq(50) +# cover <- lapply(seq(15), function(i){ +# set_size <- as.integer(runif(n = 1, min = 1, max = 15)) +# sample(alphabet, size = set_size, replace = FALSE) +# }) +# st <- simplex_tree(as.list(seq(length(cover)))) +# simplextree:::nerve_expand(st$as_XPtr(), ids = st$vertices, cover = cover, k = 3, threshold = 1) + + +# simplextree:::nerve_comb(st$as_XPtr(), ids = seq(3), cover = list(c(1,2,3), c(3, 4, 5, 6), c(3, 6, 7)), k = 10, threshold = 1) +st +# +# st <- simplex_tree() +# simplextree:::nerve_comb(st$as_XPtr(), list(c(1), c(3, 4, 5, 6)), seq(2), k = 5, n = 1) +# st +# +# st <- simplex_tree() +# simplextree:::nerve_comb(st$as_XPtr(), list(c(1, 2), c(3, 4, 5, 6, 1), c(3, 1, 2)), seq(3), k = 5, n = 1) +# st +# +# st <- simplex_tree() +# simplextree:::nerve_comb(st$as_XPtr(), list(c(1), c(3, 4, 5, 6, 1), c(3, 1, 2)), seq(3), k = 2, n = 1) +# st +# +# set.seed(1234) +# cover <- lapply(seq(15), function(i){ sample(seq(as.integer(runif(1, min = 1, max = 30)))) }) +# st <- simplex_tree() +# simplextree:::nerve_comb(st$as_XPtr(), cover, seq(length(cover)), k = 1, threshold = 1) +# +# +# all(combn(length(cover), 3, function(idx){ length(Reduce(function(x, y){ intersect(x, y) }, x = cover[idx])) > 1 })) +# +# combn(length(cover), 2, function(idx){ +# simplextree:::nfold_intersection(cover[idx], 2) +# }) +# +# +# +# simplextree:::nerve_comb(st$as_XPtr(), cover, seq(length(cover)), k = 1, n = 1) +# +# +# simplextree:::nfold_intersection(list(c(1,2,3), c(3, 4, 5, 6)), 1) + +## Test if there is at least a single element in common +# nerve_cpp(st$as_XPtr(), list(c(1,2,3), c(3, 4, 5, 6)), 1) +# nerve_cpp(st$as_XPtr(), list(sample(seq(100)), sample(seq(150)), sample(seq(200))), 101) +# nerve_cpp(st$as_XPtr(), list(c(1,2,3), c(3, 4, 5, 6)), 1) +# nerve_cpp(st$as_XPtr(), list(c(2,3,1), c(5, 3, 4, 6), c(4, 6, 8, 3)), 1) +*/ diff --git a/include/simplextree.h b/include/simplextree.h new file mode 100644 index 0000000..f17a12e --- /dev/null +++ b/include/simplextree.h @@ -0,0 +1,263 @@ +// simplextree.h +// Author: Matt Piekenbrock +// License: MIT +// This package provides a simple implementation of the Simplex tree data structure using Rcpp + STL +// The simplex tree was originally introduced in the following paper: +// Boissonnat, Jean-Daniel, and Clement Maria. "The simplex tree: An efficient data structure for general simplicial complexes." Algorithmica 70.3 (2014): 406-427. + +#ifndef SIMPLEXTREE_H_ +#define SIMPLEXTREE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "UnionFind.h" +#include "utility/discrete.h" +#include "utility/combinations.h" +// #include "utility/delegate.hpp" +#include "utility/delegate2.hpp" +#include "utility/set_utilities.h" + +// #include + +// Type for simplex labels +typedef std::size_t idx_t; + +// Maximum expected size of simplices +static constexpr size_t array_threshold = 9; + +// Buffer type +using splex_t = std::vector< idx_t, short_alloc< idx_t, 16, alignof(idx_t)> >; +using splex_alloc_t = typename splex_t::allocator_type::arena_type; + +// Aliases +using std::array; +using std::tuple; +using std::vector; +using std::map; +using std::size_t; +using std::begin; +using std::end; +using std::set; +using std::find; +using std::get; + +template +std::unique_ptr make_unique(Args&&... args){ + return std::unique_ptr(new T(std::forward(args)...)); +} + +// Simplex tree data structure. +// The Simplex Tree is a normal trie structure, with additional restrictions on the storage order, +// and a auxiliary map that is used to map 'cousin' simplexes at varying depths. +struct SimplexTree { + struct node; + using node_ptr = node*; + using node_uptr = std::unique_ptr< node >; + + struct less_np_label { + bool operator()(const node_ptr& lhs, const node_uptr& rhs) { + return lhs->label < rhs->label; + } + bool operator()(const node_uptr& lhs, const node_ptr& rhs) { + return lhs->label < rhs->label; + } + }; + + struct less_ptr { + bool operator() (const node_uptr& lhs, const node_uptr& rhs) const { return (*lhs) < (*rhs); } + }; + using node_set_t = set< node_uptr, less_ptr >; + using simplex_t = vector< idx_t >; + using cousin_map_t = std::map< idx_t, vector< node_ptr > >; + using difference_type = std::ptrdiff_t; + using size_type = std::size_t; + using value_type = node_ptr; + + // Fields + node_uptr root; // empty face; initialized to id = 0, parent = nullptr + vector< cousin_map_t > level_map; // adjacency map between cousins + std::array< size_t, 32 > n_simplexes = { { 0 } }; + // vector< size_t > n_simplexes; // tracks the number of simplices if each order + size_t tree_max_depth; // maximum tree depth; largest path from any given leaf to the root. The depth of the root is 0. + size_t max_id; // maximum vertex id used so far. Only needed by the id generator. + size_t id_policy; // policy type to generate new ids + + // Node structure stored by the simplex tree. Contains the following fields: + // label := integer index type representing the id of simplex it represents + // parent := (shared) node pointer to its parent in the trie + // children := connected simplexes whose labels > the current simplex's label + struct node { + idx_t label; + node* parent; + node_set_t children; + node(idx_t id, node_ptr c_parent) : label(id), parent(c_parent){ } + node(const node& other) = delete; + bool operator== (const node& rhs) const noexcept { // equivalence by pointer address + return (this == &rhs); + } + bool operator< (const node& rhs) const { return (label < rhs.label); } // order by label + + // struct iterator { + // using difference_type = std::ptrdiff_t; + // using value_type = node_ptr; + // using pointer = node_ptr*; + // using reference = node_ptr&; + // using iterator_category = std::forward_iterator_tag; + // std::reference_wrapper< node_ptr > cn; + // iterator(node_ptr current) : cn(std::ref(current)) { }; + // bool operator==(const iterator& t) const { return cn == t.cn; } + // bool operator!=(const iterator& t) const { return cn != t.cn; } + // auto operator*() { return cn.get(); }; + // auto operator++() { + // cn = std::ref(cn.get()->parent); + // return(*this); + // }; + // }; + // auto begin() { return iterator(this); }; + // auto end() { return iterator(nullptr); }; + }; + + SimplexTree(const SimplexTree&); + SimplexTree& operator=(const SimplexTree&); + + // Adds node ptr to cousin map + void add_cousin(node_ptr cn, const idx_t depth){ + if (depth_index(depth) >= level_map.size()){ + level_map.resize(depth_index(depth) + 1); + } + auto& label_map = level_map[depth_index(depth)][cn->label]; + auto it = std::find(begin(label_map), end(label_map), cn); + if (it == end(label_map)){ label_map.push_back(cn); } // insert + } + + // Removes node ptr from cousin map + void remove_cousin(node_ptr cn, const idx_t depth){ + if (depth_index(depth) >= level_map.size()){ return; } + auto& depth_map = level_map[depth_index(depth)]; + auto cousin_it = depth_map.find(cn->label); + if (cousin_it != end(depth_map)){ + auto& v = cousin_it->second; + v.erase(std::remove(v.begin(), v.end(), cn), v.end()); + } + } + + template < typename Iter > + auto append_node(Iter pos, node_ptr cn, idx_t label, size_t depth) -> node_set_t::iterator; + + // Checks if cousins exist + bool cousins_exist(const idx_t label, const idx_t depth) const noexcept { + if (depth_index(depth) >= level_map.size()){ return false; } + return level_map[depth_index(depth)].find(label) != end(level_map[depth_index(depth)]); + } + + const cousin_map_t::mapped_type& cousins(const idx_t label, const idx_t depth) const { + return level_map[depth_index(depth)].at(label); + } + + template < typename Lambda > + void traverse_cousins(const idx_t label, const idx_t depth, Lambda f) const { + if (depth_index(depth) >= level_map.size()){ return; } + if (cousins_exist(label, depth)){ + const auto& c_cousins = level_map[depth_index(depth)].at(label); + std::for_each(begin(c_cousins), end(c_cousins), f); + } + }; + + // Constructor + SimplexTree() : root(new node(-1, nullptr)), tree_max_depth(0), max_id(0), id_policy(0) { }; + + // Generates a new set of vertex ids, according to the given rule. + auto generate_ids(size_t) -> vector< size_t >; + auto degree(idx_t) const -> size_t; + auto adjacent_vertices(const idx_t) const -> simplex_t; + auto record_new_simplexes(const idx_t k, const int n) -> void;// record keeping + auto dimension() const -> idx_t { return tree_max_depth == 0 ? 0 : tree_max_depth - 1; } + + template< typename Iterable > + auto insert(Iterable v) -> void; + + template< bool use_lex = false, typename Iter > + auto insert_it(Iter, Iter, node_ptr, const idx_t) -> void; + + template< typename Iterable > + auto find(Iterable v) const -> node_ptr; + + template< typename Iter > + auto find_it(Iter, Iter, node_ptr cn) const -> node_ptr; + + auto find_by_id(const node_set_t&, idx_t) const -> node_ptr; + + auto remove(node_ptr cn) -> void; + auto remove_leaf(node_ptr, idx_t) -> void; + auto remove_subtree(node_ptr parent) -> void; + + template < typename Lambda > + void traverse_facets(node_ptr, Lambda) const; + + // Utility + auto is_face(simplex_t, simplex_t) const -> bool; + auto depth(node_ptr cn) const -> size_t; + auto max_depth(node_ptr cn) const -> size_t; + auto connected_components() const -> vector< idx_t >; + auto reindex(vector< idx_t >) -> void; + auto is_tree() const -> bool; + auto get_vertices() const -> vector< idx_t >; + auto clear() -> void; + + // Modifying the complex w/ higher order operations + auto collapse(node_ptr, node_ptr) -> bool; + auto vertex_collapse(idx_t, idx_t, idx_t) -> bool; + auto vertex_collapse(node_ptr, node_ptr, node_ptr) -> bool; + auto contract(simplex_t) -> void; + + auto expansion(const idx_t k) -> void; + + template < typename Lambda > + auto expansion_f(const idx_t, Lambda&&) -> void; + + template < typename Lambda > + auto expand_f(node_set_t&, const idx_t, size_t, Lambda&&) -> void; + + template < typename Lambda > + void traverse_up(node_ptr, const size_t, Lambda&&) const noexcept; + + template< typename OutputIt > + void full_simplex_out(node_ptr, const idx_t, OutputIt) const noexcept; + auto full_simplex(node_ptr cn, const idx_t depth = 0) const noexcept -> simplex_t; + + template < typename T > // Assumes T is pointer type + static constexpr auto node_label(T& cn) -> idx_t { return cn->label; } + template < typename T > // Assumes T is pointer type + static constexpr auto node_children(T& cn) -> node_set_t& { return cn->children; } + static constexpr auto depth_index(const idx_t depth) noexcept -> idx_t { return(depth - 2); } + + // Printing + template < typename OutputStream > void print_tree(OutputStream&) const; + template < typename OutputStream > void print_cousins(OutputStream&) const; + template < typename OutputStream > void print_level(OutputStream&, node_ptr, idx_t) const; + template < typename OutputStream > void print_subtree(OutputStream&, node_ptr) const; + template < typename OutputStream > void print_simplex(OutputStream&, node_ptr, bool newline = true) const; + + + // Policy for generating ids + std::string get_id_policy() const; + void set_id_policy(std::string); + +}; + +#include "simplextree/st_iterators.hpp" +#include "simplextree/st_filtration.hpp" +#include "simplextree/st.hpp" + + +#endif + + diff --git a/include/simplextree/st.hpp b/include/simplextree/st.hpp new file mode 100644 index 0000000..375deec --- /dev/null +++ b/include/simplextree/st.hpp @@ -0,0 +1,858 @@ +#ifndef SIMPLEXTREE_HPP_ +#define SIMPLEXTREE_HPP_ + +#include "simplextree.h" + + +// --------- Begin C++ only API --------- +// These functions are only available through the included header, and thus can only be accessed +// on the C++ side. R-facing functions are exported through the module. +using node_ptr = SimplexTree::node_ptr; +using simplex_t = SimplexTree::simplex_t; +using namespace st; + +// Copy constructor +inline SimplexTree::SimplexTree(const SimplexTree& sc) : root(new node(-1, nullptr)), tree_max_depth(0), max_id(0), id_policy(0) { + auto max_tr = st::maximal< true >(&sc, sc.root.get()); + traverse(max_tr, [this](node_ptr cn, idx_t depth, simplex_t sigma){ + insert_it(begin(sigma), end(sigma), root.get(), 0); + return true; + }); + id_policy = sc.id_policy; +}; + +// Assignment operator +inline SimplexTree& SimplexTree::operator=(const SimplexTree& sc) { + auto max_tr = st::maximal< true >(&sc, sc.root.get()); + traverse(max_tr, [this](node_ptr cn, idx_t depth, simplex_t sigma){ + insert_it(begin(sigma), end(sigma), root.get(), 0); + return true; + }); + id_policy = sc.id_policy; + return *this; +}; + +// Clear the entire tree by removing all subtrees rooted at the vertices +// Use labels in finding other iterator will be invalidated +inline void SimplexTree::clear(){ + root.reset(new node(-1, nullptr)); + level_map.clear(); + n_simplexes.fill(0); + // n_simplexes.clear(); + tree_max_depth = 0; + max_id = 0; +} + +inline std::string SimplexTree::get_id_policy() const{ + return id_policy == 0 ? std::string("compressed") : std::string("unique"); +} + +inline void SimplexTree::set_id_policy(std::string policy){ + if (policy == "compressed"){ id_policy = 0; } + else if (policy == "unique"){ id_policy = 1; } +} + +// Returns an integer vector of new ids vertex which can be used to insert 0-simplices +// If compress is set to true, the ids are chosen as the first n unoccupied ids found by iterating +// through the current set of vertices. Otherwise, is compress is false, a maximum id value is +// maintained, such that new ids generated must exceed that value. +inline vector< idx_t > SimplexTree::generate_ids(size_t n){ + if (id_policy == 0){ + vector< idx_t > new_ids = vector< idx_t >(); + idx_t max = root->children.size() + n; + for (idx_t cc = 0; cc < max && new_ids.size() < n; ++cc){ + if (find_by_id(root->children, cc) == nullptr){ + new_ids.push_back(cc); + } + } + return(new_ids); + } else if (id_policy == 1) { + // TODO: get rid of max_lementn + auto vid = node_label(*std::max_element(begin(root->children), end(root->children))); + if (max_id < vid){ max_id = vid; } + vector< idx_t > new_ids(n); + std::iota(begin(new_ids), end(new_ids), max_id+1); + max_id = new_ids.back(); + return(new_ids); + } + return vector< idx_t >(0); +} + +// Returns the degree of a node with a given id +inline size_t SimplexTree::degree(idx_t vid) const{ + auto cn = find_by_id(root->children, vid); + if (cn == nullptr) { return(0); } + else { + size_t res_deg = cn->children.size(); // Labels with id < v + traverse_cousins(vid, 2, [&res_deg](node_ptr cousin){ res_deg += 1; }); +// auto it = level_map.find(encode_node(vid, 2)); +// if (it != level_map.end()){ +// const auto& cousins = (*it).second; +// for (const auto& ch: cousins){ +// res_deg += node_children(ch).size(); +// } +// } + return(res_deg); + } +} + +// Returns the degree of a node with a given id +// inline vector< size_t > SimplexTree::degree(vector< idx_t > vids) const{ +// vector< size_t > res = vector< size_t >(); +// for (auto id: vids){ +// node_ptr cn = find_by_id(root->children, id); +// if (cn == nullptr) { res.push_back(0); } +// else { +// size_t res_deg = 0; +// // auto it = level_map.find(std::to_string(id) + "-2"); // Labels with id > v +// auto it = level_map.find(encode_node(id, 2)); +// if (it != level_map.end()){ +// const auto& cousins = (*it).second; +// for (const auto& ch: cousins){ +// res_deg += node_children(ch).size(); +// } +// } +// res_deg += node_children(cn).size(); // Labels with id < v +// res.push_back(res_deg); +// } +// } +// return(res); +// } + +// Search the level map (cousins) to quickly get the adjacency relations. +// The set of adjacency relations are the 0-simplexes connected to a given vertex v. +inline vector< idx_t > SimplexTree::adjacent_vertices(const size_t v) const { + + // Resulting vector to return + vector< idx_t > res = vector< idx_t >(); + + // First extract the vertices which labels > v by checking edges + //std::string key = std::to_string(v) + "-2"; + if (cousins_exist(v, 2)){ + traverse_cousins(v, 2, [&res](node_ptr cousin){ + res.push_back(node_label(cousin->parent)); + }); + } + // Then get the vertices with labels < v + node_ptr cn = find_by_id(root->children, v); + if (cn != nullptr){ + for (const auto& ch: node_children(cn)){ + res.push_back(node_label(ch)); + } + } + + // Return + vector< idx_t >::iterator tmp = std::unique(res.begin(), res.end()); + res.resize( std::distance(res.begin(), tmp) ); + return(res); +} + +// Modifies the number of simplices at dimension k by +/- n. Shrinks the n_simplexes array as needed. +inline void SimplexTree::record_new_simplexes(const idx_t k, const int n){ + if (k >= 32){ std::invalid_argument("Invalid dimension to record."); } + n_simplexes[k] += n; + auto first_zero = std::find(n_simplexes.begin(), n_simplexes.end(), 0); + tree_max_depth = std::distance(n_simplexes.begin(), first_zero); + // if (n_simplexes.size() < k+1){ n_simplexes.resize(k+1); } + // n_simplexes.at(k) += n; + // while(n_simplexes.back() == 0 && n_simplexes.size() > 0){ n_simplexes.resize(n_simplexes.size() - 1); } + // tree_max_depth = n_simplexes.size(); +} + +// Remove a child node directly from the parent, if it exists +// This will check that the child is a leaf, and if so, then it will: +// 1) Remove the child from the parents children map +// 2) Remove the child from the level map +inline void SimplexTree::remove_leaf(node_ptr parent, idx_t child_label){ + if (parent == nullptr){ return; } + const idx_t child_depth = depth(parent) + 1; + auto child_it = std::find_if(begin(parent->children), end(parent->children), [child_label](const node_uptr& cn)->bool{ return(cn->label == child_label); }); + if (child_it != end(parent->children)){ + // Remove from level map + auto child = (*child_it).get(); // copy regular node_ptr + remove_cousin(child, child_depth); + + // Remove from parents children + parent->children.erase(child_it); + record_new_simplexes(child_depth-1, -1); + } +} + +// Removes an entire subtree rooted as 'sroot', including 'sroot' itself; calls 'remove_leaf' recursively. +inline void SimplexTree::remove_subtree(node_ptr sroot){ + if (sroot == nullptr){ return; } + if (sroot->children.empty()){ remove_leaf(sroot->parent, sroot->label); } // remove self + else { + // Remark: make sure to use labels instead of iterator here, otherwise the iterator will be invalidated. + vector< node_ptr > nc(sroot->children.size()); + std::transform(begin(node_children(sroot)), end(node_children(sroot)), begin(nc), [](const node_uptr& u_np){ + return u_np.get(); + }); + for (auto cn: nc){ + remove_subtree(find_by_id(node_children(sroot), node_label(cn))); + } + // Remove self + if (sroot && sroot != root.get()){ remove_leaf(sroot->parent, sroot->label); } + } +} + +// First removes all the cofaces of a given simplex, including the simplex itself. +inline void SimplexTree::remove(node_ptr cn){ + if (cn != nullptr && cn != root.get()){ + auto cr = st::coface_roots< false >(this, cn); + SmallVector< node_ptr >::allocator_type::arena_type arena; + SmallVector< node_ptr > co_v{ arena }; + std::transform(cr.begin(), cr.end(), std::back_inserter(co_v), [](std::tuple< node_ptr, idx_t >& cn){ return(get< 0 >(cn)); }); + for (auto co_n: co_v){ + remove_subtree(co_n); + } + } +}; + +template< typename Iter > +inline auto SimplexTree::append_node(Iter pos, node_ptr cn, idx_t label, size_t depth) -> node_set_t::iterator { + auto new_it = cn->children.emplace_hint(pos, make_unique< node >(label, cn)); + add_cousin((*new_it).get(), depth); + record_new_simplexes(depth-1, 1); + return(new_it); +} + +// Inserts man simplices of a fixed dimension +// Assumes Iterator to int types are correct type to avoid casts + sorted +// template< typename Iter, size_t d > +// inline void SimplexTree::insert_fast(Iter s, Iter e){ +// +// if constexpr(d == 0){ +// while(s != e){ +// node_children(root).emplace(make_unique< node >(*s, root)) +// ++s; +// } +// } else if constexpr ( d == 1 ){ +// const auto label1 = *s; +// const auto label2 = *(s+1); +// // const auto val_end = std::upper_bound(s, e, label); // wrong [1, 2, 1, 3, 1, 4]... +// auto v_it = node_children(root).find(label); +// if (v_it == end(node_children(root))){ v_it = append_node(v_it, root, label1, 1); } +// if (v_it != end(node_children(root))){ +// v_it-> +// } +// } +// if (it == end(c_node->children)){ +// auto new_it = c_node->children.emplace_hint(it, make_unique< node >(label, c_node)); +// if (child_depth > 1){ // keep track of nodes which share ids at the same depth +// add_cousin((*new_it).get(), child_depth); +// } +// record_new_simplexes(child_depth-1, 1); +// } +// } + +// Create a set of (i)-simplexes as children of the current node, if they don't already exist +// depth == (depth of c_node) +template< bool lex_order, typename Iter > +inline void SimplexTree::insert_it(Iter s, Iter e, node_ptr c_node, const idx_t depth){ + if (s == e || c_node == nullptr){ return; } + // using it_t = typename Iter::value_t; + + const idx_t child_depth = depth+1; + std::for_each(s, e, [this, &c_node, child_depth](idx_t label){ + // if constexpr (lex_order){ + // auto new_it = c_node->children.emplace_hint(c_node->children.end(), make_unique< node >(label, c_node)); + // if (child_depth > 1){ // keep track of nodes which share ids at the same depth + // // level_map[encode_node(label, child_depth)].push_back((*new_it).get()); + // add_cousin((*new_it).get(), child_depth); + // } + // record_new_simplexes(child_depth-1, 1); + // } else { + auto it = std::find_if(begin(node_children(c_node)), end(node_children(c_node)), [label](const node_uptr& cn){ + return(cn->label == label); + }); + if (it == end(c_node->children)){ + auto new_it = c_node->children.emplace_hint(it, make_unique< node >(label, c_node)); + if (child_depth > 1){ // keep track of nodes which share ids at the same depth + add_cousin((*new_it).get(), child_depth); + } + record_new_simplexes(child_depth-1, 1); + } + // } + }); + + // Recurse on the subtrees of the current node + idx_t j = 1; + std::for_each(s, e, [&](idx_t label){ + insert_it(std::next(s, j), e, find_by_id(c_node->children, label), depth+1); + ++j; + }); +} + +// Wrapper to find a vertex from the top nodes +template< typename Iterable > +inline void SimplexTree::insert(Iterable v) { + static_assert(st::detail::has_begin< Iterable >::value, "Must be iterable object."); + auto b = v.begin(), e = v.end(); + std::sort(b, e); // Demand sorted labels + e = std::unique(b, e); // Demand unique labels + insert_it(b, e, root.get(), 0); +} + +// Create a set of (i)-simplexes as children of the current node, if they don't already exist +// inline void SimplexTree::insert(idx_t* labels, const size_t i, const size_t n_keys, node_ptr c_node, const idx_t depth){ +// if (i >= n_keys || labels == nullptr || c_node == nullptr){ return; } // base case + safety checks +// idx_t child_depth = depth + 1; // depth refers to parent depth, child_depth to child depth +// for (int j = i; j < n_keys; ++j){ +// using ut = decltype(*begin(node_children(c_node))); +// const auto compare_node_id = [&labels, &j](ut& cn) -> bool { +// return(cn->label == labels[j]); +// }; +// auto it = std::find_if(begin(node_children(c_node)), end(node_children(c_node)), compare_node_id); +// // auto it = std::find_if(begin(node_children(c_node)), end(node_children(c_node)), eq_node_id(labels[j])); +// if (it == end(c_node->children)){ // doesn't exist yet +// auto new_it = c_node->children.emplace_hint(it, std::make_unique< node >(labels[j], c_node)); +// record_new_simplexes(depth, 1); +// if (child_depth > tree_max_depth){ tree_max_depth = child_depth; } +// if (child_depth > 1){ // keep track of nodes which share ids at the same depth +// add_cousin((*new_it).get(), child_depth); +// // level_map[encode_node(labels[j], child_depth)].push_back((*new_it).get()); +// } +// } +// } +// // Recurse on the subtrees of the current node +// for (int j = i; j < n_keys; ++j){ +// node_ptr child_node = find_by_id(c_node->children, labels[j]); +// insert(labels, j + 1, n_keys, child_node, child_depth); +// } +// } + +// Overloaded in the case where a single (1-length vector) label is given +inline SimplexTree::node_ptr SimplexTree::find_by_id(const node_set_t& level, idx_t label) const{ + auto it = std::lower_bound(begin(level), end(level), label, [](const node_uptr& np, const idx_t id){ + return np->label < id; + }); + return (it != end(level) && (*it)->label == label) ? (*it).get() : nullptr; +} + +// Wrapper to find a vertex from the top nodes +template< typename Iterable > +inline SimplexTree::node_ptr SimplexTree::find(Iterable v) const { + // static_assert(std::is_integral<>) + auto b = v.begin(), e = v.end(); + std::sort(b, e); // Demand sorted labels + e = std::unique(b, e); // Demand unique labels + return find_it(b, e, root.get()); +} + +// Find iterator version +template< typename Iter > +inline SimplexTree::node_ptr SimplexTree::find_it(Iter s, Iter e, node_ptr cn) const { + for (; s != e && cn != nullptr; ++s){ + cn = find_by_id(cn->children, *s); + if (cn == nullptr){ return nullptr; } + } + return(cn); +} + +// Recursively calculate the depth of a node +inline size_t SimplexTree::depth(node_ptr cn) const { + if (cn == nullptr || cn == root.get()){ return 0; } + size_t d; + for (d = 1; cn && cn->parent != root.get(); ++d){ + cn = cn->parent; + } + return d; +} + +// Utility to get the maximum height / longest path from any given node. +inline size_t SimplexTree::max_depth(node_ptr cn) const { + auto dfs = st::preorder< false >(this, cn); + idx_t max_d = 0; + traverse(dfs, [&max_d](node_ptr np, idx_t depth){ + if (depth > max_d){ max_d = depth; } + return true; + }); + // traverse_node_pairs(dfs, [&max_d](std::pair< node_ptr, idx_t > np){ + // if (np.second > max_d){ max_d = np.second; } + // }); + return(max_d); +} + +// Print the whole tree. +template < typename OutputStream > +inline void SimplexTree::print_tree(OutputStream& os) const { + print_subtree(os, root.get()); +} + +template < typename OutputStream > +inline void SimplexTree::print_cousins(OutputStream& os) const { + auto labels = get_vertices(); + for (idx_t c_depth = 2; c_depth <= tree_max_depth; ++c_depth){ + for (auto &label: labels){ + if (cousins_exist(label, c_depth)){ + os << "(last=" << label << ", depth=" << c_depth << "): "; + traverse_cousins(label, c_depth, [this, &os](node_ptr cousin){ + print_simplex(os, cousin, false); + os << " "; + }); + os << std::endl; + } + } + } +}; + +// Basic breadth-first printing. Each level is prefixed with '.' number of times, followed by the +// the ids of the nodes at that breadth-level enclosed within parenthesis, e.g. ..( 2 3 4 ) +template < typename OutputStream > +inline void SimplexTree::print_subtree(OutputStream& os, node_ptr cn) const { + for (const auto& ch: cn->children){ + idx_t h = max_depth(ch.get())-1; + os << ch->label << " (h = " << h << "): "; + for (size_t i = 1; i <= h; ++i){ + for (size_t j = 1; j <= i; ++j){ os << "."; } + os << "("; + print_level(os, ch.get(), i); + os << " )"; + } + os << std::endl; + } +} + +// Prints a given level of the tree +template < typename OutputStream > +inline void SimplexTree::print_level(OutputStream& os, node_ptr cn, idx_t level) const{ + if (cn == nullptr || cn->parent == nullptr) return; + if (level == 0) { os << " " << cn->label; } + else if (level > 0 && (!cn->children.empty())) { + for (const auto& ch: cn->children){ + print_level(os, ch.get(), level-1); + } + } +} + +// Prints an individual simplex +template < typename OutputStream > +inline void SimplexTree::print_simplex(OutputStream& os, node_ptr cn, bool newline) const { + simplex_t si = full_simplex(cn); + os << "{ "; + std::for_each(si.begin(), si.end(), [&os](const idx_t i){ os << i << " "; }); + os << "}"; + if (newline){ os << std::endl; } +} + +// Performs an expansion of order k, reconstructing the k-skeleton flag complex via an in-depth expansion of the 1-skeleton. +inline void SimplexTree::expansion(const idx_t k){ + expansion_f(k, [this](node_ptr parent, idx_t depth, idx_t label){ + std::array< idx_t, 1 > int_label = { label }; + insert_it(begin(int_label), end(int_label), parent, depth); + }); +} + +template < typename Lambda > +inline void SimplexTree::expansion_f(const idx_t k, Lambda&& f){ + for (auto& cn: node_children(root)){ + if (!node_children(cn).empty()){ + expand_f(cn->children, k-1, 2, f); + } + } +} + +// Expand operation checks A \cap N^+(vj) \neq \emptyset +// If they have a non-empty intersection, then the intersection is added as a child to the head node. +template < typename Lambda > +inline void SimplexTree::expand_f(node_set_t& c_set, const idx_t k, size_t depth, Lambda&& f){ + if (k == 0 || c_set.empty()){ return; } + // Traverse the children + auto siblings = std::next(begin(c_set), 1); + SmallVector< node_ptr >::allocator_type::arena_type arena1; + SmallVector< node_ptr > intersection { arena1 }; + for (auto& cn: c_set){ + node_ptr top_v = find_by_id(root->children, cn->label); + if (top_v != nullptr && (!top_v->children.empty())){ + + // Temporary + SmallVector< node_ptr >::allocator_type::arena_type arena2; + SmallVector< node_ptr > sib_ptrs { arena2 } ; + std::transform(siblings, end(c_set), std::back_inserter(sib_ptrs), [](const node_uptr& n){ + return (node_ptr) n.get(); + }); + + // Get the intersection + intersection.clear(); + std::set_intersection( + begin(sib_ptrs), end(sib_ptrs), + begin(top_v->children), end(top_v->children), + std::back_inserter(intersection), + less_np_label() + ); + + // Insert and recursively expand + if (intersection.size() > 0){ + for (auto& int_node: intersection){ + auto face = find_by_id(cn->children, int_node->label); + if (face == nullptr){ + f((node_ptr) cn.get(), depth, int_node->label); + } + } + expand_f(cn->children, k-1, depth+1, f); // recurse + } + } + if (siblings != end(c_set)){ ++siblings; } + } +} + + +inline void SimplexTree::reindex(vector< idx_t > target_ids){ + if (n_simplexes.at(0) != target_ids.size()){ throw std::invalid_argument("target id vector must match the size of the number of 0-simplices."); } + if (!std::is_sorted(begin(target_ids), end(target_ids))){ throw std::invalid_argument("target ids must be ordered."); } + if (std::unique(begin(target_ids), end(target_ids)) != end(target_ids)){ throw std::invalid_argument("target ids must all unique."); } + + // Create the map between vertex ids -> target ids + auto id_map = std::map< idx_t, idx_t >(); + auto vertex_ids = get_vertices(); + for (size_t i = 0; i < vertex_ids.size(); ++i){ + id_map.emplace_hint(end(id_map), vertex_ids[i], target_ids[i]); + } + + // Apply the map + auto tr = st::preorder< false >(this); + st::traverse(tr, [&id_map](node_ptr cn, idx_t depth){ + cn->label = id_map[cn->label]; + return true; + }); + + // Remap the cousins + for (size_t d = 2; d < tree_max_depth; ++d){ + auto& cousins = level_map.at(depth_index(d)); + for (auto v_id: vertex_ids){ + auto it = cousins.find(v_id); + if (it != cousins.end()){ + auto kv = std::make_pair(std::move(it->first), std::move(it->second)); // copy + cousins.erase(it); + kv.first = id_map[v_id]; + cousins.insert(kv); + } + // TODO: Include this when support for std::extract improves + // auto node = cousins.extract(idx_t(v_id)); + // if (!node.empty()){ + // node.key() = id_map[v_id]; + // cousins.insert(std::move(node)); + // } + } + } +} + +// Given two simplices tau and sigma, checks to see if tau is a face of sigma +// Assumes tau and sigma are both sorted. +inline bool SimplexTree::is_face(vector< idx_t > tau, vector< idx_t > sigma) const { + auto tau_np = find(tau); + auto sigma_np = find(sigma); + if (tau_np != nullptr && sigma_np != nullptr){ + return std::includes(sigma.begin(), sigma.end(), tau.begin(), tau.end()); + } + return false; +} + +// Returns a vector of simplices representing the cofaces of a given simplex 'sigma' +// First, all simplices with d > depth(sigma) which end in the same vertex label as sigma are found. +// Let each of these nodes 'n_j'. There are two condition to test whether n_j is a coface of sigma: +// 1) n_j is a leaf := n_j is a coface of the current node +// 2) n_j has children := every node in the subtree rooted at n_j is a coface of the current node. +// In the second case, any node in the subtree rooted at n_j is a coface of sigma. +// Note that this procedure returns only the roots of these subtrees. +// inline vector< SimplexTree::node_ptr > SimplexTree::locate_cofaces(node_ptr cn) const { +// vector< idx_t > c_word = full_simplex(cn); +// const size_t h = c_word.size(); +// set< node_ptr > cofaces = { cn }; // a simplex cofaces include the simplex itself +// for (idx_t i = tree_max_depth; i > h; --i){ +// for (auto& n_j: node_cousins(cn, i)){ +// if (is_face(c_word, full_simplex(n_j))){ +// cofaces.insert(n_j); // insert roots only +// } +// } +// } +// vector< node_ptr > output(cofaces.begin(), cofaces.end()); +// return output; +// } + +// Given a node 'sigma', returns a vector of all the nodes part of the subtree of sigma, +// including sigma. +// inline vector< SimplexTree::node_ptr > SimplexTree::expand_subtree(node_ptr sigma) const { +// vector< node_ptr > subtree_nodes = vector< node_ptr >(); +// for (auto& node: dfs< false >(this, sigma)){ +// subtree_nodes.push_back(get< 0 >(node)); +// } +// return(subtree_nodes); +// } + +// Expand a given vector of subtrees, collected all of the simplices under these trees. +// inline vector< node_ptr > SimplexTree::expand_subtrees(vector< node_ptr > roots) const { +// vector< node_ptr > faces = vector< node_ptr >(); +// for (auto& subtree_root: roots){ +// vector< node_ptr > tmp = expand_subtree(subtree_root); +// faces.insert(faces.end(), tmp.begin(), tmp.end()); +// } +// return(faces); +// } + +inline bool SimplexTree::vertex_collapse(idx_t v1, idx_t v2, idx_t v3){ + node_ptr vp1 = find_by_id(root->children, v1); + node_ptr vp2 = find_by_id(root->children, v2); + node_ptr vt = find_by_id(root->children, v3); + return vertex_collapse(vp1, vp2, vt); // collapse the free pair (vp1, vp2) --> vt +} + +// Vertex collapse - A vertex collapse, in this sense, is the result of applying a +// peicewise map f to all vertices sigma \in K, where given a pair (u,v) -> w, +// f is defined as: +// f(x) = { (1) w if x in { u, v }, (2) x o.w. } +inline bool SimplexTree::vertex_collapse(node_ptr vp1, node_ptr vp2, node_ptr vt){ + // Lambda to do the mapping + vector< simplex_t > to_insert; + const auto map_collapse = [&to_insert, vt](simplex_t si, node_ptr vp){ + std::replace(begin(si), end(si), vp->label, vt->label); + to_insert.push_back(si); + }; + + // Enumerates the cofaces of each node, performing the maps on each one + for (auto& cn: cofaces< true >(this, vp1)){ map_collapse(get< 2 >(cn), vp1); } + for (auto& cn: cofaces< true >(this, vp2)){ map_collapse(get< 2 >(cn), vp2); } + + // Insert all the mapped simplices, remove vertices if they exist + for (auto& sigma: to_insert){ insert(sigma); } + if (vp1 != vt) { remove(find_by_id(root->children, vp1->label)); } + if (vp2 != vt) { remove(find_by_id(root->children, vp2->label)); } + return true; +} + +// Elementary collapse - only capable of collapsing sigma through tau, and only if tau has sigma +// as its only coface. There are technically two cases, either tau and sigma are both leaves or +// tau contains sigma as its unique child. Both can be handled by removing sigma first, then tau. +inline bool SimplexTree::collapse(node_ptr tau, node_ptr sigma){ + if (tau == nullptr || sigma == nullptr){ return false; } + auto tau_cofaces = st::cofaces< false >(this, tau); + bool sigma_only_coface = true; + traverse(tau_cofaces, [&tau, &sigma, &sigma_only_coface](node_ptr coface, idx_t depth){ + sigma_only_coface &= (coface == tau) || (coface == sigma); + return(sigma_only_coface); + }); + if (sigma_only_coface){ + remove_leaf(sigma->parent, sigma->label); + remove_leaf(tau->parent, tau->label); + return(true); + } + return(false); +} + +// Returns the connected components given by the simplicial complex +inline vector< idx_t > SimplexTree::connected_components() const{ + + // Provide means of mapping vertex ids to index values + vector< idx_t > v = get_vertices(); // vertices are ordered, so lower_bound is valid + const auto idx_of = [&v](const idx_t val) { return(std::distance(begin(v), std::lower_bound(begin(v), end(v), val))); }; + + // Traverse the edges, unioning vertices + UnionFind uf = UnionFind(root->children.size()); + traverse(st::k_simplices< false >(this, root.get(), 1), [&idx_of, &uf](node_ptr cn, idx_t d){ + uf.Union(idx_of(cn->label), idx_of(cn->parent->label)); + return true; + }); + + // Create the connected components + std::transform(begin(v), end(v), begin(v), idx_of); + return(uf.FindAll(v)); +} + +// Edge contraction +inline void SimplexTree::contract(vector< idx_t > edge){ + vector< simplex_t > to_remove; + vector< simplex_t > to_insert; + traverse(st::preorder< true >(this, root.get()), [this, edge, &to_remove, &to_insert](node_ptr np, idx_t depth, simplex_t sigma){ + const idx_t va = edge[0], vb = edge[1]; + if (np->label == vb){ // only consider simplices which contain v_lb + bool includes_a = std::find(sigma.begin(), sigma.end(), va) != sigma.end(); + if (includes_a){ // case 1: sigma includes both v_la and v_lb + to_remove.push_back(sigma); // add whole simplex to remove list, and we're done. + } else { // case 2: sigma includes v_lb, but not v_la + // Insert new simplices with v_la --> v_lb, identity otherwise + const auto local_preorder = st::preorder< true >(this, np); + traverse(local_preorder, [&to_insert, va, vb](node_ptr end, idx_t depth, simplex_t tau){ + std::replace(tau.begin(), tau.end(), vb, va); + to_insert.push_back(tau); // si will be sorted upon insertion + return true; + }); + to_remove.push_back(sigma); + } + } + return true; + }); + + // for (auto& edge: to_remove){ print_simplex(std::cout, edge, true); } + + // Remove the simplices containing vb + for (auto& edge: to_remove){ remove(find(edge)); } + for (auto& edge: to_insert){ insert(edge); } +} + +template < typename Lambda > // Assume lambda is boolean return +void SimplexTree::traverse_up(node_ptr cn, const size_t depth, Lambda&& f) const noexcept { + if (cn == nullptr || cn->parent == nullptr){ return; }; + switch(depth){ + case 6: + f(cn); + cn = cn->parent; + case 5: + f(cn); + cn = cn->parent; + case 4: + f(cn); + cn = cn->parent; + case 3: + f(cn); + cn = cn->parent; + case 2: + f(cn); + cn = cn->parent; + case 1: + f(cn); + break; + default: + idx_t d = 0; + while (cn != root.get() && cn->parent != nullptr && d <= tree_max_depth){ + f(cn); + cn = cn->parent; + ++d; + } + break; + } +} + +template< typename OutputIt > +inline void SimplexTree::full_simplex_out(node_ptr cn, const idx_t depth, OutputIt out) const noexcept { + if (cn == nullptr || cn == root.get()){ return; } + if (depth == 0){ + std::deque< idx_t > labels; + traverse_up(cn, depth, [&labels](node_ptr np){ labels.push_front(np->label); }); + std::move(begin(labels), end(labels), out); + } else { + splex_alloc_t a; + splex_t labels{a}; + labels.resize(depth); + size_t i = 1; + traverse_up(cn, depth, [&depth, &i, &labels](node_ptr np){ labels.at(depth - (i++)) = np->label; }); + std::move(begin(labels), end(labels), out); + } +} + +inline simplex_t SimplexTree::full_simplex(node_ptr cn, const idx_t depth) const noexcept { + simplex_t labels; + labels.reserve(depth); + full_simplex_out(cn, depth, std::back_inserter(labels)); + return(labels); +} + +// // Serialize the simplex tree +// inline vector< vector< idx_t > > SimplexTree::serialize() const{ +// using simplex_t = vector< idx_t >; +// vector< simplex_t > minimal; +// traverse(st::maximal< true >(this, root.get()), [&minimal](node_ptr cn, idx_t depth, simplex_t sigma){ +// minimal.push_back(sigma); +// return true; +// }); +// return(minimal); +// } + +// Deserialization +// inline void SimplexTree::deserialize(vector< vector< idx_t > > simplices){ +// for (auto& sigma: simplices){ insert_simplex(sigma); } +// } + +// template +// inline void SimplexTree::traverse_facets(node_ptr s, Lambda f) const{ +// +// // Constants +// const simplex_t sigma = full_simplex(s); +// const size_t sigma_depth = sigma.size(); +// +// // Conditions for recursion +// const auto valid_eval = [sigma_depth](auto& cn) -> bool { return get< 1 >(cn) <= (sigma_depth - 1); }; +// const auto valid_children = [&sigma, sigma_depth](auto& cn){ +// return (get< 1 >(cn) <= (sigma_depth - 2)) && get< 0 >(cn)->label >= sigma.at(get< 1 >(cn)-1); +// }; +// +// // Facet search +// if (sigma_depth <= 1){ return; } +// else if (sigma_depth == 2){ +// // The facet of an edge is just its two vertices +// f(find_by_id(root->children, s->label)); +// f(s->parent, 1); +// return; +// } else { +// size_t c_depth = sigma_depth - 1; +// node_ptr cn = s->parent; +// simplex_t tau = sigma; +// +// while (cn != root.get()){ +// // Check self +// tau[c_depth-1] = cn->label; +// if (c_depth == sigma_depth-1 && std::includes(begin(sigma), end(sigma), begin(tau), begin(tau)+c_depth)){ +// f(cn, c_depth); +// } +// +// // Look at the siblings +// using ut = decltype(*begin(node_children(root))); +// auto tn = std::find_if(begin(cn->parent->children), end(cn->parent->children), eq_node_id(cn->label)); +// std::advance(tn,1); +// +// // Check siblings + their children up to facet depth +// for (; tn != end(cn->parent->children); ++tn){ // TODO: end loop if tn != face +// simplex_t c_tau = full_simplex((*tn).get()); +// c_tau.resize(sigma_depth); +// auto tr = st::preorder(this, (*tn).get(), valid_eval, valid_children); +// traverse(tr, [&c_tau, &sigma, &f](node_ptr t, idx_t d){ +// c_tau[d-1] = t->label; +// if ((d == c_tau.size()-1) && std::includes(begin(sigma), end(sigma), begin(c_tau), begin(c_tau)+d)){ +// f(t, d); +// } +// return true; +// }); +// } +// +// // Move up +// cn = cn->parent; +// c_depth--; +// } +// } +// } + +inline vector< idx_t > SimplexTree::get_vertices() const{ + if (tree_max_depth == 0){ return vector< idx_t >(0); } + vector< idx_t > v; + v.reserve(n_simplexes[0]); + for (auto& cn: node_children(root)){ v.push_back(node_label(cn)); } + return v; +} + +// Returns whether the graph is acyclic. +inline bool SimplexTree::is_tree() const{ + if (tree_max_depth == 0){ return false; } + UnionFind ds = UnionFind(n_simplexes.at(0)); + + // Traverse the 1-skeleton, unioning all edges. If any of them are part of the same CC, there is a cycle. + const vector< idx_t > v = get_vertices(); + const auto index_of = [&v](const idx_t vid) -> size_t{ return std::distance(begin(v), std::find(begin(v), end(v), vid)); }; + + // Apply DFS w/ UnionFind. If a cycle is detected, no more recursive evaluations are performed. + bool has_cycle = false; + auto st_dfs = st::k_simplices< true >(this, root.get(), 1); + for (auto& cn: st_dfs){ + const auto si = get< 2 >(cn); + idx_t i1 = index_of(si.at(0)), i2 = index_of(si.at(1)); + if (ds.Find(i1) == ds.Find(i2)){ + has_cycle = true; + break; + } + ds.Union(i1, i2); + } + return !has_cycle; +} + + + +#endif diff --git a/include/simplextree/st_filtration.hpp b/include/simplextree/st_filtration.hpp new file mode 100644 index 0000000..e8b5aad --- /dev/null +++ b/include/simplextree/st_filtration.hpp @@ -0,0 +1,391 @@ +#ifndef ST_FILTRATION_HPP_ +#define ST_FILTRATION_HPP_ + +#include "simplextree.h" + +// Intermediate struct to enable faster filtration building +struct weighted_simplex { + node_ptr np; + size_t depth; + double weight; +}; + +// Indexed simplex +struct indexed_simplex { + int parent_idx; // index of its parent simplex in tree + idx_t label; // last(sigma) + double value; // diameter/weight of the simplex +}; + +// Lexicographical comparison for simplices +struct s_lex_less { + bool operator()(const simplex_t& s1, const simplex_t &s2) const { + if (s1.size() == s2.size()){ + return std::lexicographical_compare(begin(s1), end(s1), begin(s2), end(s2)); + } + return s1.size() < s2.size(); + } +}; + +// Weighted simplex lexicographically-refined comparison +struct ws_lex_less { + SimplexTree* st; + explicit ws_lex_less(SimplexTree* st_) : st(st_){ } + bool operator()(const weighted_simplex& s1, const weighted_simplex& s2) const { + if (s1.weight == s2.weight){ + if (s1.depth == s2.depth){ + auto s1_simplex = st->full_simplex(s1.np, s1.depth); + auto s2_simplex = st->full_simplex(s2.np, s2.depth); + return s_lex_less()(s1_simplex, s2_simplex); + } + return s1.depth < s2.depth; + } + return s1.weight < s2.weight; + } +}; + +// Computes the maximum weight of a given simplex given +// inline double max_weight_sparse(simplex_t sigma, const vector< double >& D, const size_t n) noexcept { +// switch(sigma.size()){ +// case 0: case 1: { return(0.0); } +// case 2: { return(D[to_natural_2(sigma[0], sigma[1], n)]); } +// default: { +// double weight = 0.0; +// for_each_combination(sigma.begin(), sigma.begin()+2, sigma.end(), [&D, &weight, n](auto& it1, auto& it2){ +// const size_t idx = to_natural_2(*it1, *std::next(it1), n); +// if (D[idx] > weight){ weight = D[idx]; } +// return false; +// }); +// return(weight); +// } +// }; +// } + +// Computes the maximum weight of a simplex 'sigma' by finding the highest +// weighted edge representing a face of sigma in the distance vector 'D' +// Note: sigma must constain 0-based contiguous indices here to use D +// inline double max_weight_dense(simplex_t sigma, const vector< double >& D, const size_t n) noexcept { +// switch(sigma.size()){ +// case 0: case 1: { return(0.0); } +// case 2: { return(D[to_natural_2(sigma[0], sigma[1], n)]); } +// default: { +// double weight = 0.0; +// for_each_combination(sigma.begin(), sigma.begin()+2, sigma.end(), [&D, &weight, n](auto& it1, auto& it2){ +// const size_t idx = to_natural_2(*it1, *std::next(it1), n); +// if (D[idx] > weight){ weight = D[idx]; } +// return false; +// }); +// return(weight); +// } +// }; +// } + +// Use mixin to define a filtration +class Filtration : public SimplexTree { +public: + // Filtration-specific fields + vector< bool > included; + vector< indexed_simplex > fc; + + // Constructor + Filtration() : SimplexTree() {} + + // copy over the simplex tree + void initialize(const SimplexTree& sc){ + auto max_tr = st::maximal< true >(&sc, sc.root.get()); + traverse(max_tr, [this](node_ptr cn, idx_t depth, simplex_t sigma){ + this->insert_it(begin(sigma), end(sigma), root.get(), 0); + return true; + }); + id_policy = sc.id_policy; + } + + // Filtration building methods + void flag_filtration(const vector< double >&, const bool fixed=false); + + // Iterating through the filtration + template < typename Lambda > + void traverse_filtration(size_t, size_t, Lambda&&); + + // Changing the state of complex + void threshold_value(double); + void threshold_index(size_t); + + // Querying information about the filtration + vector< simplex_t > simplices() const; + vector< double > weights() const; + vector< idx_t > dimensions() const; + size_t current_index() const; + double current_value() const; + + // Internal helpers + vector< size_t > simplex_idx(const size_t) const; + + template< typename Lambda > + void apply_idx(size_t, Lambda&&) const; + + // template < typename Iter > + // simplex_t expand_simplex(Iter, Iter) const; + simplex_t expand_simplex(simplex_t) const; + +}; // End filtration + +// Given sorted vector 'ref', matches elements of 'x' with 'ref', returning +// a vector of the same length as 'x' with the matched indices. +template < typename T > +inline vector< T > match(const vector< T >& x, const vector< T >& ref){ + vector< T > indices; + indices.reserve(x.size()); + for(auto& elem: x){ + auto idx = std::distance(begin(ref), std::lower_bound(begin(ref), end(ref), elem)); + indices.push_back(idx); + } + return(indices); +} + +struct sorted_edges { + using it_t = vector< size_t >::iterator; + vector< size_t > keys; + const vector< double >& values; + const vector< size_t > vertices; + + sorted_edges(Filtration* st, const vector< double >& weights) : values(weights), vertices(st->get_vertices()) { + const size_t n = vertices.size(); + auto edge_traversal = st::k_simplices< true >(dynamic_cast< SimplexTree* >(st), st->root.get(), 1); + st::traverse(edge_traversal, [this, n](node_ptr np, idx_t depth, simplex_t edge){ + auto eid = match(edge, vertices); + keys.push_back(to_natural_2(eid[0], eid[1], n)); + return true; + }); + if (!std::is_sorted(keys.begin(), keys.end())){ throw std::invalid_argument("keys not ordered."); } + } + + // Given a simplex 'sigma' whose values are vertex ids in 'vertices', calculates the maximum weight + double max_weight(simplex_t sigma){ + auto v_ids = match(sigma, vertices); + double weight = 0.0; + for_each_combination(v_ids.begin(), v_ids.begin()+2, v_ids.end(), [this, &weight](it_t it1, it_t it2){ + const size_t idx = to_natural_2(*it1, *std::next(it1), vertices.size()); + const auto key_idx = std::lower_bound(keys.begin(), keys.end(), idx); + const double ew = values[std::distance(keys.begin(), key_idx)]; + if (ew > weight){ weight = ew; } + return false; + }); + return(weight); + } +}; + + +// Given a dimension k and set of weighted edges (u,v) representing the weights of the ordered edges in the trie, +// constructs a std::function object which accepts as an argument some weight value 'epsilon' and +// returns the simplex tree object. +inline void Filtration::flag_filtration(const vector< double >& D, const bool fixed){ + if (this->tree_max_depth <= 1){ return; } + if (D.size() != this->n_simplexes.at(1)){ throw std::invalid_argument("Must have one weight per edge."); } + + // 0. Create the sorted map between the edges and their weights + sorted_edges se = sorted_edges(this, D); + + // 1. Calculate simplex weights from the edge weights + const size_t ns = std::accumulate(begin(this->n_simplexes), end(this->n_simplexes), 0, std::plus< size_t >()); + std::vector< weighted_simplex > w_simplices; + w_simplices.reserve(ns); + + size_t i = 0; + traverse(st::level_order< true >(this), [&w_simplices, &D, &i, &se](node_ptr cn, idx_t d, simplex_t sigma){ + double c_weight = d == 1 ? 0.0 : (d == 2 ? D.at(i++) : se.max_weight(sigma)); + weighted_simplex ws = { cn, d, c_weight }; + w_simplices.push_back(ws); + return true; + }); + + // 2. Sort simplices by weight to create the filtration + std::sort(begin(w_simplices), end(w_simplices), ws_lex_less(this)); + + // 3. Index the simplices + fc.clear(); + fc.reserve(w_simplices.size()); + i = 0; + for (auto sigma_it = begin(w_simplices); sigma_it != end(w_simplices); ++sigma_it){ + auto& sigma = *sigma_it; + indexed_simplex tau; + tau.label = sigma.np->label; + tau.value = sigma.weight; + + // Find the index of sigma's parent, or 0 if sigma if the parent is the empty face. + if (sigma.np->parent == this->root.get() || sigma.np == this->root.get()){ + tau.parent_idx = -1; + } else { + const auto sp = sigma.np->parent; + auto p_it = std::find_if(begin(w_simplices), sigma_it, [&sp](const weighted_simplex& si){ + return(si.np == sp); + }); + if (p_it == sigma_it){ throw std::range_error("sigma detected itself as the parent!"); } + tau.parent_idx = std::distance(begin(w_simplices), p_it); + } + fc.push_back(tau); + } + + // Set state of the filtration to the max + included = vector< bool >(fc.size(), true); +} + +// Accepts lambda that takes iterator start and end +template< typename Lambda > +inline void Filtration::apply_idx(size_t idx, Lambda&& f) const { + if (idx >= fc.size()){ throw std::out_of_range("Bad simplex index"); } + + SmallVector< size_t >::allocator_type::arena_type arena; + SmallVector< size_t > indices{ arena }; + indices.push_back(idx); + while (fc[idx].parent_idx != -1){ + idx = fc[idx].parent_idx; + indices.push_back(idx); + } + + std::for_each(indices.rbegin(), indices.rend(), [&f](size_t index){ + f(index); + }); +} + +// Returns the indices of where the labels that make up the simplex +// at index 'idx' are in the filtration in ascending order. +inline vector< size_t > Filtration::simplex_idx(size_t idx) const { + if (idx >= fc.size()){ throw std::out_of_range("Bad simplex index"); } + vector< size_t > expanded = { idx }; + while (fc[idx].parent_idx != -1){ + idx = fc[idx].parent_idx; + expanded.push_back(idx); + } + std::reverse(begin(expanded), end(expanded)); + return(expanded); +} + +// Given a vector of indices, expand the indices to form the simplex +// template < typename Iter > +// inline vector< idx_t > Filtration::expand_simplex(Iter s, Iter e) const { +// using index_t = typename std::iterator_traits< Iter >::value_type; +// static_assert(std::is_integral< index_t >::value, "Integral type required for indexing."); +// simplex_t sigma(std::distance(s, e)); +// std::transform(s, e, begin(sigma), [this](auto i){ std::cout << i << std::endl; return(fc.at(i).label); }); +// return(sigma); +// } + +// inline void Filtration::get_simplex(const size_t idx, SmallVector< size_t >& out) const { +// out.resize(0); +// out.push_back(fc[idx].label); +// size_t pidx = fc[idx].parent_idx; +// while (pidx != -1){ +// out.push_back(fc[pidx].label); +// pidx = fc[pidx].parent_idx; +// } +// std::reverse(out.begin(), out.end()); +// } +// template< typename Iter, typename OutputIt > +// inline void Filtration::expand_it(Iter s, Iter e, OutputIt out) const { +// for(; s != e; ++s){ +// *out = fc[*s].label; +// } +// } + +inline simplex_t Filtration::expand_simplex(vector< size_t > indices) const { + std::transform(begin(indices), end(indices), begin(indices), [this](size_t i){ return(fc.at(i).label); }); + return(indices); +} + + +template < typename Lambda > +inline void Filtration::traverse_filtration(size_t a, size_t b, Lambda&& f){ + if (b > fc.size()) { b = fc.size(); }; + if (a == b){ return; } + + SmallVector< size_t >::allocator_type::arena_type arena; + SmallVector< size_t > expanded{ arena }; + expanded.reserve(tree_max_depth); + + const auto apply_f = [this, &expanded, &f](const size_t i){ + apply_idx(i, [this, &expanded](size_t index){ expanded.push_back(fc[index].label); }); + f(i, expanded.begin(), expanded.end()); + expanded.resize(0); + }; + + if (a < b){ + for (size_t i = a; i < b; ++i){ apply_f(i); } + // for (size_t i = a; i < b; ++i){ f(i, expand_simplex(simplex_idx(i))); } + } + if (a > b){ + int i = a >= fc.size() ? fc.size() - 1 : a; // i possibly negative! + //for (; i >= b && i >= 0; --i){ f(i, expand_simplex(simplex_idx(i))); } + for (; i >= int(b) && i >= 0; --i){ apply_f(i); } + } + return; +} + + +// Get index corresponding to eps (inclusive) +inline void Filtration::threshold_value(double eps){ + using IS = indexed_simplex; + const auto eps_it = std::lower_bound(begin(fc), end(fc), eps, [](const IS s, double val) -> bool { + return(s.value <= val); + }); + const size_t eps_idx = std::distance( begin(fc), eps_it ); + threshold_index(eps_idx); +} + +inline void Filtration::threshold_index(size_t req_index){ + const size_t c_idx = current_index(); + const bool is_increasing = c_idx < req_index; + using it_t = SmallVector< size_t >::iterator; + traverse_filtration(c_idx, req_index, [this, is_increasing](const size_t i, it_t s, it_t e){ + included.at(i) = is_increasing; + is_increasing ? insert_it(s,e,root.get(),0) : remove(find_it(s,e,root.get())); + }); +} + +// Returns the current index in the filtration +inline size_t Filtration::current_index() const { + if (included.size() == 0){ return 0; } + const size_t current_idx = std::distance( + begin(included), std::find(begin(included), end(included), false) + ); + return(current_idx); +} + +inline double Filtration::current_value() const { + const double INF = std::numeric_limits< double >::infinity(); + if (included.size() == 0){ return -INF; } + const size_t c_idx = current_index(); + return(c_idx == fc.size() ? INF : fc[c_idx].value); +} + +// Returns the simplices in the filtration in a list +inline vector< vector< idx_t > > Filtration::simplices() const { + const size_t n = current_index(); + vector< vector< idx_t > > simplices(n); + for (size_t i = 0; i < n; ++i){ + simplices[i] = expand_simplex(simplex_idx(i)); + } + return simplices; +} + +// Retrieves the filtration weights +inline vector< double > Filtration::weights() const{ + const size_t n = current_index(); + vector< double > weights = vector< double >(n); + for (size_t i = 0; i < n; ++i){ + weights[i] = fc[i].value; + } + return weights; +} + +inline vector< idx_t > Filtration::dimensions() const{ + const size_t n = current_index(); + vector< idx_t > dims = vector< idx_t >(n); + for (size_t i = 0; i < n; ++i){ + dims[i] = simplex_idx(i).size()-1; + } + return dims; +} + +#endif + diff --git a/include/simplextree/st_iterators.hpp b/include/simplextree/st_iterators.hpp new file mode 100644 index 0000000..f243441 --- /dev/null +++ b/include/simplextree/st_iterators.hpp @@ -0,0 +1,1121 @@ +#ifndef ST_ITERS_H +#define ST_ITERS_H + +#include "simplextree.h" +#include +#include +#include +#include +#include +#include +#include +#include + +using std::get; +using simplex_t = SimplexTree::simplex_t; +using node_ptr = SimplexTree::node_ptr; +using node_uptr = SimplexTree::node_uptr; + + +// #include +// +// template +// constexpr std::string_view +// type_name() { +// std::string_view name, prefix, suffix; +// #ifdef __clang__ +// name = __PRETTY_FUNCTION__; +// prefix = "std::string_view type_name() [T = "; +// suffix = "]"; +// #elif defined(__GNUC__) +// name = __PRETTY_FUNCTION__; +// prefix = "constexpr std::string_view type_name() [with T = "; +// suffix = "; std::string_view = std::basic_string_view]"; +// #elif defined(_MSC_VER) +// name = __FUNCSIG__; +// prefix = "class std::basic_string_view > __cdecl type_name<"; +// suffix = ">(void)"; +// #endif +// name.remove_prefix(prefix.size()); +// name.remove_suffix(suffix.size()); +// return name; +// } + +// Simplextree namespace +namespace st { + +// detail namespace contains internal template boilerplate, including several implementations of the detection idiom + namespace detail { + template + struct has_dereference { + template + static constexpr auto test_dereference(int) -> decltype(std::declval().operator*(), bool()) { + return true; + } + template + static constexpr bool test_dereference(...) { + return false; + } + static constexpr bool value = test_dereference(int()); + }; + + // Detection idom applied to member function operator== + template + struct has_equality { + template + static constexpr auto test_equality(int) -> decltype(std::declval().operator==(), bool()) { + return true; + } + template + static constexpr bool test_equality(...) { + return false; + } + static constexpr bool value = test_equality(int()); + }; + + // Detection idiom applied to member function operator!= + template + struct has_not_equality { + template + static constexpr auto test_not_equality(int) -> decltype(std::declval().operator!=(), bool()) { + return true; + } + template + static constexpr bool test_not_equality(...) { + return false; + } + static constexpr bool value = test_not_equality(int()); + }; + + // Detection idiom applied to member function operator++ + template + struct has_increment { + template + static constexpr auto test_increment(int) -> decltype(std::declval().operator++(), bool()) { + return true; + } + template + static constexpr bool test_increment(...) { + return false; + } + static constexpr bool value = test_increment(int()); + }; + + template + struct has_begin { + template + static constexpr auto test_begin(int) -> decltype(std::declval().begin(), bool()) { return true; } + template + static constexpr bool test_begin(...) { return false; } + static constexpr bool value = test_begin(int()); + }; + + template + struct has_end { + template + static constexpr auto test_end(int) -> decltype(std::declval().end(), bool()) { return true; } + template + static constexpr bool test_end(...) { return false; } + static constexpr bool value = test_end(int()); + }; + + template + struct has_update_simplex { + template + static constexpr auto test_update_simplex(int) -> decltype(std::declval().has_update_simplex(), bool()) { return true; } + template + static constexpr bool test_update_simplex(...) { return false; } + static constexpr bool value = test_update_simplex(int()); + }; + }; // end namespace detail + + template < bool ts, template< bool > class Derived > + struct TraversalInterface { + using d_type = Derived< ts >; + using d_node = std::tuple< node_ptr, idx_t >; + using t_node = typename std::conditional< ts, std::tuple< node_ptr, idx_t, simplex_t >, d_node >::type; + // using pair_pred_t = delegate< bool (t_node&) >; // delegate appears to not work + using pair_pred_t = std::function; + + public: + static const bool is_tracking = ts; + static const idx_t NP = 0; + static const idx_t DEPTH = 1; + static const idx_t LABELS = 2; + node_ptr init; + const SimplexTree* st; + pair_pred_t p1 { [](t_node& cn){ return true; } }; + pair_pred_t p2 { [](t_node& cn){ return true; } }; + + TraversalInterface() = default; + TraversalInterface(const SimplexTree* st_) : st(st_){ + init = nullptr; + } + TraversalInterface(const SimplexTree* st_, node_ptr start) : st(st_){ + init = start; + }; + template< typename P1, typename P2 > + TraversalInterface(const SimplexTree* st_, node_ptr start, P1 pred1, P2 pred2) : init(start), st(st_), p1(pred1), p2(pred2) { + // p1.set(pred1); + // p2.set(pred2); + init = start; + }; + + struct iterator { + using difference_type = std::ptrdiff_t; + using value_type = t_node; + using pointer = t_node*; + using reference = t_node&; + using iterator_category = std::forward_iterator_tag; + + static constexpr d_node sentinel() { return std::make_tuple(nullptr, 0); }; + + using d_iter = typename d_type::iterator; + std::reference_wrapper< d_type > info; + d_node current; + simplex_t labels; + t_node output; + + // Constructors + iterator(d_type& dd) : info(dd){ + labels = simplex_t(); + labels.reserve(dd.st->tree_max_depth); + }; + + constexpr d_type& base() const { return(this->info.get()); }; + constexpr const SimplexTree& trie() const { return(*(this->info.get().st)); }; + + template < bool T = ts > + typename std::enable_if< T == true, t_node& >::type + current_t_node(){ + output = std::tuple_cat(current, std::make_tuple(labels)); + return(output); + } + template < bool T = ts> + typename std::enable_if< T == false, t_node& >::type + current_t_node(){ return(current); } + + // Operators + bool operator==(const d_iter& t) const { return (get< 0 >(t.current) == get< 0 >(current)); } + bool operator!=(const d_iter& t) const { return !(*this == t); } + t_node& operator*() { return current_t_node(); }; + }; + + // methods within TraversalInterface can use template to access members of Derived + // auto begin() { return static_cast< d_type* >(this)->begin(); }; + // auto end() { return static_cast< d_type* >(this)->end(); }; + + static constexpr bool derived_has_begin = detail::has_begin< Derived< ts > >::value; + static constexpr bool derived_has_end = detail::has_end< Derived< ts > >::value; + + // template < typename std::enable_if< true, int >::type = 0 > + // auto begin(){ + // return static_cast< d_type* >(this)->begin(); + // }; + // template < typename std::enable_if< false, int >::type = 0 > + // auto begin(){ + // return static_cast< d_type* >(this)->begin(); + // }; + + // template < typename std::enable_if< !detail::has_begin< d_type >::value, int >::type = 0 > + // auto begin() { + // return static_cast< d_type* >(this)->begin(); + // }; + // -> std::enable_if< false, decltype(static_cast< d_type* >(this)->begin()) > + // + // auto end() -> std::enable_if< detail::has_begin< d_type >::value, decltype(d_type::iterator(static_cast< d_type& >(*this), nullptr)) >{ + // return d_type::iterator(static_cast< d_type& >(*this), nullptr); + // }; + // auto end() -> std::enable_if< !detail::has_begin< d_type >::value, decltype(static_cast< d_type* >(this)->begin()) >{ + // return static_cast< d_type* >(this)->end(); + // }; + + // Tag dispatching + // using d_iter = typename Derived< ts >::iterator; + // template struct TD; + + // template struct TD; + // + // auto begin(std::true_type) { + // return static_cast< d_type* >(this)->begin(); + // }; + // + // auto begin(std::false_type) { + // using d_iter = typename Derived< ts >::iterator; + // return d_iter(static_cast< d_type& >(*this), nullptr); + // }; + // + // auto begin() { + // auto b = begin(derived_has_begin); + // std::cout << type_name() << std::endl; + // return b; + // }; + // + // // Tag dispatching + // auto end(std::true_type) { + // return static_cast< d_type* >(this)->end(); + // }; + // + // auto end(std::false_type) { + // using d_iter = typename d_type::iterator; + // return d_iter(static_cast< d_type& >(*this), nullptr); + // }; + // + // auto end() { + // return end(derived_has_end); + // }; + + // auto end() { + // if constexpr(derived_has_end){ + // return static_cast< d_type* >(this)->end(); + // } else { + // using d_iter = typename d_type::iterator; + // return d_iter(static_cast< d_type& >(*this), nullptr); + // } + // }; + }; // end Traversal Interface + + // https://eli.thegreenplace.net/2011/05/17/the-curiously-recurring-template-pattern-in-c + + template < class T, std::size_t BufSize = 64 > + using small_vec = std::vector>; + + // Preorder traversal iterator + template < bool ts = false > + struct preorder : TraversalInterface< ts, preorder > { + using B = TraversalInterface< ts, st::preorder >; + using B::init; + using B::st; + using B::p1; + using B::p2; + using B::NP; + using B::DEPTH; + using B::LABELS; + + // Constructors + preorder(const SimplexTree* st_) : TraversalInterface< ts, preorder >(st_, st_->root.get()) {}; + preorder(const SimplexTree* st_, node_ptr start) : TraversalInterface< ts, preorder >(st_, start) {}; + template< typename P1, typename P2 > + preorder(const SimplexTree* st_, node_ptr start, P1 pred1, P2 pred2) : TraversalInterface< ts, preorder >(st_, start, pred1, pred2){ + // p1.set(pred1); + // p2.set(pred2); + }; + + void reset(node_ptr start){ + init = start; + } + + // Iterator type + struct iterator : public TraversalInterface< ts, preorder >::iterator { + using Bit = typename B::iterator; + using Bit::current; + using Bit::labels; + using Bit::info; + using Bit::base; + using Bit::trie; + using Bit::current_t_node; + using Bit::sentinel; + + // DFS specific data structures + using dn_t = typename B::d_node; + std::stack< typename B::d_node > node_stack; + + // small_vec< dn_t >::allocator_type::arena_type a; + // small_vec< dn_t > v = {a}; + // std::stack< dn_t, small_vec< dn_t > > node_stack = {v}; + + // Iterator constructor + iterator(preorder& dd, node_ptr cn = nullptr) : TraversalInterface< ts, preorder >::iterator(dd){ + const idx_t d = dd.st->depth(cn); + current = std::make_tuple(cn, d); + labels = dd.st->full_simplex(cn, d); + } + + // Increment operator is all that is needed by default + auto operator++() -> decltype(*this){ + do { + if (get< NP >(current) != nullptr && base().p2(current_t_node())) { + const auto& ch = get< NP >(current)->children; + for (auto cn_up = ch.rbegin(); cn_up != ch.rend(); ++cn_up){ + node_stack.push(std::make_tuple((*cn_up).get(), get< DEPTH >(current)+1)); + } + } + if (node_stack.empty()){ + current = sentinel(); + } else { + current = node_stack.top(); + node_stack.pop(); + } + update_simplex(); + } while (!base().p1(current_t_node()) && get< NP >(current) != nullptr); + return(*this); + } + + template < bool T = ts > + auto update_simplex() -> typename std::enable_if< T == true, void >::type { + // std::cout << "update_simplex: " << base().init << std::endl; + if (get< NP >(current) != nullptr && get< DEPTH >(current) > 0){ + labels.resize(get< DEPTH >(current)); + labels.at(get< DEPTH >(current)-1) = get< NP >(current)->label; + } + } + template < bool T = ts > + auto update_simplex() -> typename std::enable_if< T == false, void >::type { + return; + } + }; + + // Only start at a non-root node, else return the end + auto begin() -> decltype(iterator(*this, init)) { + if (init == st->root.get()){ + return st->n_simplexes.empty() ? iterator(*this, nullptr) : ++iterator(*this, st->root.get()); + } else { + return iterator(*this, init); + } + }; + auto end() -> decltype(iterator(*this, nullptr)){ + return iterator(*this, nullptr); + } + + }; + + // level order traversal iterator + template < bool ts = false > + struct level_order : TraversalInterface< ts, level_order > { + using B = TraversalInterface< ts, st::level_order >; + using B::init; + using B::st; + using B::p1; + using B::p2; + using B::NP; + using B::DEPTH; + using B::LABELS; + using d_node = typename TraversalInterface< ts, level_order >::d_node; + + // Constructors + level_order(const SimplexTree* st_) : TraversalInterface< ts, level_order >(st_, st_->root.get()) {}; + level_order(const SimplexTree* st_, node_ptr start) : TraversalInterface< ts, level_order >(st_, start) {}; + template< typename P1, typename P2 > + level_order(const SimplexTree* st_, node_ptr start, P1 pred1, P2 pred2) : TraversalInterface< ts, level_order >(st_, start, pred1, pred2){ + // p1.set(pred1); + // p2.set(pred2); + }; + + void reset(node_ptr start){ init = start; } + + // Iterator type + struct iterator : public TraversalInterface< ts, level_order >::iterator { + using Bit = typename B::iterator; + using Bit::current; + using Bit::labels; + using Bit::info; + using Bit::base; + using Bit::trie; + using Bit::current_t_node; + using Bit::sentinel; + + // BFS specific data structures + std::queue< typename B::d_node > node_queue; + + // Iterator constructor + iterator(level_order& dd, node_ptr cn = nullptr) : TraversalInterface< ts, level_order >::iterator(dd){ + current = std::make_tuple(cn, dd.st->depth(cn)); + update_simplex(); + } + + // Increment operator is all that is needed by default + auto operator++() -> decltype(*this){ + do { + if (get< NP >(current) != nullptr && base().p2(current_t_node())) { + const auto& ch = get< NP >(current)->children; + for (auto cn_up = ch.begin(); cn_up != ch.end(); ++cn_up){ + node_queue.emplace(std::make_tuple((*cn_up).get(), get< DEPTH >(current)+1)); + } + } + if (node_queue.empty()){ + current = sentinel(); + } else { + current = node_queue.front(); + node_queue.pop(); + } + update_simplex(); + } while (!base().p1(current_t_node()) && get< NP >(current) != nullptr); + return(*this); + } + + template < bool T = ts > + typename std::enable_if< T == true, void >::type + update_simplex(){ labels = trie().full_simplex(get< NP >(current), get< DEPTH >(current)); } + + template < bool T = ts > + typename std::enable_if< T == false, void >::type + update_simplex(){ return; }; + };// iterator + + // Only start at a non-root node, else return the end + auto begin() -> decltype(iterator(*this, init)){ + if (init == st->root.get()){ + return st->n_simplexes.empty() ? iterator(*this, nullptr) : ++iterator(*this, st->root.get()); + } else { + return iterator(*this, init); + } + }; + auto end() -> decltype(iterator(*this, nullptr)) { return iterator(*this, nullptr); } + }; + + + // Coface-roots search + template < bool ts = false > + struct coface_roots : TraversalInterface< ts, coface_roots > { + using B = TraversalInterface< ts, st::coface_roots >; + using B::init; + using B::st; + using B::p1; + using B::p2; + using B::NP; + using B::DEPTH; + using B::LABELS; + using d_node = typename TraversalInterface< ts, coface_roots >::d_node; + + // Constructors + coface_roots() : TraversalInterface< ts, coface_roots >() {}; + coface_roots(const SimplexTree* st_, node_ptr start = nullptr) : TraversalInterface< ts, coface_roots >(st_, start) {}; + template< typename P1, typename P2 > + coface_roots(const SimplexTree* st_, node_ptr start, P1 pred1, P2 pred2) : TraversalInterface< ts, coface_roots >(st_, start, pred1, pred2){ + // p1.set(pred1); + // p2.set(pred2); + }; + + // Iterator type + struct iterator : public TraversalInterface< ts, coface_roots >::iterator { + using Bit = typename B::iterator; + using Bit::current; + using Bit::labels; + using Bit::info; + using Bit::base; + using Bit::trie; + using Bit::current_t_node; + using Bit::sentinel; + + // coface-roots-specific structures + simplex_t start_coface_s; + size_t c_level_key = 0; // the current level_map key + size_t c_level_idx = 0; // the current level_map index + + // Iterator constructor + iterator(coface_roots& dd, node_ptr cn) : TraversalInterface< ts, coface_roots >::iterator(dd){ + if (cn == dd.st->root.get()){ throw std::invalid_argument("Invalid given coface."); }; + const size_t c_depth = dd.st->depth(cn); + start_coface_s = dd.st->full_simplex(cn, c_depth); + current = std::make_tuple(cn, c_depth); + update_simplex(); + ++get< DEPTH >(current); // start at next depth + } + + // Finds the next coface of a given face starting at some offset + key, or nullptr if there is none + std::pair< node_ptr, bool > next_coface(simplex_t face, size_t& offset, idx_t depth){ + auto& st = trie(); + bool has_cousins = st.cousins_exist(base().init->label, depth); + + // If the key doesn't exist or the cousins are empty, return the end + if (!has_cousins || offset >= st.cousins(base().init->label, depth).size()){ + return std::make_pair(nullptr, false); + } + + // Else the cousins exist; see if any of the are a coface + const auto& c_cousins = st.cousins(base().init->label, depth); + auto it = c_cousins.cbegin(); //+ offset, should be smaller than cousin size + // if (offset >= std::distance(c_cousins.begin(), c_cousins.end())){ + // return std::make_pair(nullptr, false); + // } + std::advance(it, offset); + const auto coface_it = std::find_if(it, c_cousins.cend(), [&st, &face, depth](const node_ptr np){ + return st.is_face(face, st.full_simplex(np, depth)); + }); + + // If it exists, return it, otherwise return sentinel + offset += (std::distance(it, coface_it)+1); + return coface_it != c_cousins.end() ? std::make_pair(*coface_it, true) : std::make_pair(nullptr, false); + } + + // Increment operator is all that is needed by default + auto operator++() -> decltype(*this){ + // If root was given, end the traversal immediately. + if (get< NP >(current) == trie().root.get() || get< NP >(current) == nullptr){ + current = sentinel(); + return(*this); + } + + // While the coface doesn't exist at the current depth, increment depth + std::pair< node_ptr, bool > coface = next_coface(start_coface_s, c_level_idx, get< DEPTH >(current)); + while (coface.second == false && get< DEPTH >(current) <= trie().tree_max_depth){ + c_level_idx = 0; + get< DEPTH >(current)++; + coface = next_coface(start_coface_s, c_level_idx, get< DEPTH >(current)); + } + + // If it doesn't exist, we're done + if (coface.second == false){ + current = sentinel(); + } else { + get< NP >(current) = coface.first; + // c_level_idx++; // Shouldnt be +1, should be + whatever the distance in find_if is + // std::cout << "c level ind: " << c_level_idx << std::endl; + } + update_simplex(); + return *this; + }; // operator++ + + template < bool T = ts > + auto update_simplex() -> typename std::enable_if< T == true, void >::type { + labels = trie().full_simplex(get< NP >(current), get< DEPTH >(current)); + } + + template < bool T = ts > + auto update_simplex() -> typename std::enable_if< T == false, void >::type { + return; + }; + + }; // iterator + + // Only start at a non-root node, else return the end + auto begin() -> decltype(iterator(*this, nullptr)) { + if (init == st->root.get() || init == nullptr){ return iterator(*this, nullptr); } + else { return iterator(*this, init); } + }; + auto end() -> decltype(iterator(*this, nullptr)) { return iterator(*this, nullptr); }; + + }; // end coface_roots + + // ---- coface iterator ------ + template < bool ts = false > + struct cofaces : TraversalInterface< ts, cofaces > { + using B = TraversalInterface< ts, st::cofaces >; + using B::init; + using B::st; + using B::p1; + using B::p2; + using B::NP; + using B::DEPTH; + using B::LABELS; + using d_node = typename TraversalInterface< ts, cofaces >::d_node; + + cofaces(const SimplexTree* st, node_ptr start) : TraversalInterface< ts, cofaces >(st, start){ } + + struct iterator : public TraversalInterface< ts, cofaces >::iterator { + using Bit = typename B::iterator; + using Bit::current; + using Bit::labels; + using Bit::info; + using Bit::base; + using Bit::trie; + using Bit::current_t_node; + using Bit::sentinel; + + using preorder_it = decltype(std::declval< preorder< ts > >().begin()); + using coface_root_it = decltype(std::declval< coface_roots< false > >().begin()); + + // coface-roots-specific structures + coface_roots< false > roots; + coface_root_it c_root; + preorder< ts > subtree; + preorder_it c_node; + + // Iterator constructor + iterator(cofaces& dd, node_ptr cn) : TraversalInterface< ts, cofaces >::iterator(dd), + roots(coface_roots< false >(dd.st, cn)), c_root(roots, cn), subtree(preorder< ts >(dd.st, cn)), c_node(subtree.begin()) { + current = std::make_tuple(cn, dd.st->depth(cn)); + update_simplex(); + } + + // Increment operator is all that is needed by default + auto operator++() -> decltype(*this) { + // Need to increment iterator by one and return the result + // Logically what needs to happen is: + // - if next node is end of subtree + // - ... then if next root is end of coface_roots, we're finished return end + // - else next root is not end of coface_roots, advance root and reset subtree + // else while next node is not at end of subtree + // - advance one in subtree, report it back + if (get< NP >(*c_root) == trie().root.get()){ ++c_root; } + if (std::next(c_node) == subtree.end()){ + if (c_root == roots.end()){ + current = sentinel(); + } else { + ++c_root; + subtree.reset(get< NP >(*c_root)); + c_node = subtree.begin(); + current = std::make_tuple(get< NP >(*c_node), get< DEPTH >(*c_node)); + } + } else { + ++c_node; + current = std::make_tuple(get< NP >(*c_node), get< DEPTH >(*c_node)); + } + update_simplex(); + return(*this); + }; + + template < bool T = ts > + typename std::enable_if< T == true, void >::type + update_simplex(){ labels = get< LABELS >(*c_node); } + + template < bool T = ts > + typename std::enable_if< T == false, void >::type + update_simplex(){ return; }; + + }; // iterator + + // Only start at a non-root node, else return the end + auto begin() -> decltype(iterator(*this, init)) { return iterator(*this, init); }; + auto end() -> decltype(iterator(*this, nullptr)) { return iterator(*this, nullptr); }; + }; // cofaces + + // ---- expansion iterator ------ + // template < bool ts = false > + // struct expansion : TraversalInterface< ts, expansion > { + // using B = TraversalInterface< ts, expansion >; + // using B::init, B::st, B::p1, B::p2, B::NP, B::DEPTH, B::LABELS; + // using d_node = typename TraversalInterface< ts, expansion >::d_node; + // + // expansion(const SimplexTree* st, node_ptr start) : TraversalInterface< ts, expansion >(st, start){ } + // + // struct iterator : public TraversalInterface< ts, expansion >::iterator { + // using Bit = typename B::iterator; + // using Bit::current, Bit::labels, Bit::info, Bit::update_simplex, Bit::base, Bit::trie, Bit::current_t_node, Bit::sentinel; + // + // // expansion-specific structures + // const size_t k; + // std::reference_wrapper< node_set_t > c_set; + // node_set_t::iterator c_sib; + // node_set_t::iterator c_node; + // vector< node_ptr > intersection; + // + // // Iterator constructor + // iterator(expansion& dd, node_set_t& exp_set, size_t dim) : TraversalInterface< ts, expansion >::iterator(dd), k(dim) { + // // current = std::make_tuple(cn, dd.st->depth(cn)); + // // update_simplex(); + // // node_set_t& c_set, const idx_t k + // c_set = std::ref(exp_set); + // c_node = begin(c_set); + // c_sib = std::next(c_set.begin(), 1); + // + // } + // + // node_ptr next_top_node() { + // if (c_node == end(c_node)){ return(nullptr); } + // node_ptr top_v = find_vertex((*c_node)->label); + // while (top_v != nullptr && c_node != end(c_node) && top_v->children.empty()){ + // std::advance(c_node,1); + // if (c_node == end(c_node)){ + // top_v <- nullptr; + // } else { + // top_v = find_vertex((*c_node)->label); + // } + // } + // return(top_v); + // } + // + // // Increment operator is all that is needed by default + // auto& operator++(){ + // // Postcondition: iterator is incremented by one, returns the result + // node_ptr top_v = next_top_node(); + // node_ptr cn = (*c_node).get(); + // + // if (cn != nullptr && top_v != nullptr){ + // vector< node_ptr > sib_ptrs; + // std::transform(c_sib, end(c_set), std::back_inserter(sib_ptrs), [](const auto& n){ + // return n.get(); + // }); + // + // // Get the intersection + // intersection.clear(); + // std::set_intersection( + // begin(sib_ptrs), end(sib_ptrs), + // begin(top_v->children), end(top_v->children), + // std::back_inserter(intersection), + // [](auto& sib_n, auto& child_n) -> bool { + // return node_label(sib_n) < node_label(child_n); + // } + // ); + // + // // Insert and recursively expand + // if (intersection.size() > 0){ + // std::array< idx_t, 1 > int_label; + // for (auto& int_node: intersection){ + // int_label[0] = int_node->label; + // auto np = find_it(begin(int_label), end(int_label), cn); + // if (np == nullptr){ // not found, need to insert it + // current = { cn, depth(cn) }; + // + // } + // } + // + // expand(cn->children, k-1); // recurse + // } + // if (siblings != end(c_set)){ ++siblings; } + // + // } + // + // update_simplex(); + // return(*this); + // }; + // + // // Doesn't work with regular update method, so override + // constexpr void update_simplex(){ + // if constexpr (ts){ + // labels = get< LABELS >(*c_node); + // } + // }; + // }; // iterator + // + // // Only start at a non-root node, else return the end + // auto begin(){ return iterator(*this, init); }; + // auto end(){ return iterator(*this, nullptr); }; + // }; // expansion + + + // K-skeleton iterator + template < bool ts = false > + struct k_skeleton : preorder< ts > { + using P = preorder< ts >; + using B = TraversalInterface< ts, st::preorder >; + using B::init; + using t_node = typename B::t_node; + using iterator_t = typename P::iterator; + + k_skeleton(const SimplexTree* st, node_ptr start, const size_t k) + : preorder< ts >( + st, start, + [k](t_node& cn) -> bool { return get< 1 >(cn) <= (k+1); }, + [k](t_node& cn) -> bool { return get< 1 >(cn) <= k; } + ) + {}; + // Only start at a non-root node, else return the end + auto begin() -> iterator_t { + return static_cast< P& >(*this).begin(); + //return static_cast< P& >(*this).begin(); + // return iterator_t(static_cast< P& >(*this), (node_ptr) init); + }; + auto end() -> iterator_t { + return static_cast< P& >(*this).end(); + // return iterator_t(static_cast< P& >(*this), nullptr); + } + }; + + template < bool ts = false > + struct k_simplices : preorder< ts > { + using P = preorder< ts >; + using B = TraversalInterface< ts, st::preorder >; + using B::init; + using t_node = typename B::t_node; + using iterator_t = typename P::iterator; + + //using d_node = tuple< node_ptr, idx_t >; + // const static auto valid_eval = [](const size_t k) { + // return([k](t_node& cn) -> bool { return get< 1 >(cn) == (k+1); }); + // }; + // const static auto valid_children = [](const size_t k) { + // return([k](t_node& cn) -> bool{ return get< 1 >(cn) <= k; }); + // }; + k_simplices(const SimplexTree* st, node_ptr start, const size_t k) + : preorder< ts >( + st, start, + [k](t_node& cn) -> bool { return get< 1 >(cn) == (k+1); }, + [k](t_node& cn) -> bool { return get< 1 >(cn) <= k; } + ) + {}; + // Only start at a non-root node, else return the end + auto begin() -> iterator_t { + return static_cast< P& >(*this).begin(); + // return iterator_t(static_cast< P& >(*this), (node_ptr) init); + }; + auto end() -> iterator_t { + return static_cast< P& >(*this).end(); + //return iterator_t(static_cast< P& >(*this), nullptr); + } + }; + + + template < bool ts = false > + struct maximal : preorder< ts > { + using P = preorder< ts >; + using B = TraversalInterface< ts, st::preorder >; + using B::init; + using t_node = typename B::t_node; + using iterator_t = typename P::iterator; + // Check if a given node has no children + // const static bool has_no_children(const t_node& cn){ return(get< 0 >(cn)->children.empty()); } + // + // // Check cn is only coface + // const static bool is_only_root(const SimplexTree* st, const t_node& cn){ + // auto cr = coface_roots< ts >(st, get< 0 >(cn)); + // return(std::next(cr.begin()) == cr.end()); + // } + // + // A valid member in a preorder-traversal is just + // const static auto valid_eval = [](const SimplexTree* st) { + // return([&st](t_node& cn) -> bool { + // return get<0>(cn) == nullptr ? false : has_no_children(cn) && is_only_root(st, cn); + // }); + // }; + // const static auto always_true = [](t_node& cn) -> bool { return true; }; + + // Constructor is all that is needed + maximal(const SimplexTree* st, node_ptr start) : + preorder< ts >(st, start, [st](t_node& cn) -> bool { + node_ptr np = get< 0 >(cn); + if (np != nullptr && np != st->root.get()){ + auto cr = coface_roots< false >(st, np); + return(np->children.empty() && std::next(cr.begin()) == cr.end()); + } + return false; + }, [](t_node& cn) -> bool { return true; }) + {}; + auto begin() -> iterator_t { + return static_cast< P& >(*this).begin(); + // return iterator_t(static_cast< P& >(*this), (node_ptr) init); + }; + auto end() -> iterator_t { + return static_cast< P& >(*this).end(); + // return iterator_t(static_cast< P& >(*this), nullptr); + } + + }; + + // Checks for empty intersection + inline bool empty_intersection(const vector< idx_t > x, const vector< idx_t > y){ + vector::const_iterator i = x.begin(), j = y.begin(); + while (i != x.end() && j != y.end()){ + if (*i<*j) ++i; else if(*j<*i) ++j; else return false; + } + return true; + } + + template< typename tnode_t > + std::function< bool(tnode_t&) > link_condition(const SimplexTree* st, const node_ptr s_np){ + const simplex_t s = st->full_simplex(s_np); + return([st, s](tnode_t& cn) -> bool { + bool is_link = false; + const simplex_t t = st->full_simplex(get< 0 >(cn)); + bool is_disjoint = empty_intersection(t, s); + if (is_disjoint){ + vector< idx_t > pot_link; + std::set_union(s.begin(), s.end(), t.begin(), t.end(), std::back_inserter(pot_link)); + if (st->find_it(begin(pot_link), end(pot_link), st->root.get()) != nullptr){ + is_link = true; + } + } + return(is_link); + }); + } + + template < bool ts = false > + struct link : preorder< ts > { + using P = preorder< ts >; + using B = TraversalInterface< ts, st::preorder >; + using B::init; + using t_node = typename B::t_node; + using iterator_t = typename P::iterator; + + // template < bool T = ts > + // static auto get_simplex2(const SimplexTree* st, t_node& cn) -> std::enable_if< T == true, simplex_t > { + // return(get< 2 >(cn)); + // } + // template < bool T = ts > + // static auto get_simplex2(const SimplexTree* st, t_node& cn) -> std::enable_if< T == false, simplex_t > { + // return(st->full_simplex(get< 0 >(cn))); + // } + + // A valid member in a preorder-traversal is just + // const auto valid_eval = [](const SimplexTree* st, const node_ptr s_np) { + // const simplex_t s = st->full_simplex(s_np); + // return([st, s](t_node& cn) -> bool { + // bool is_link = false; + // const simplex_t t = st->full_simplex(get< 0 >(cn)); + // bool is_disjoint = empty_intersection(t, s); + // if (is_disjoint){ + // vector< idx_t > pot_link; + // std::set_union(s.begin(), s.end(), t.begin(), t.end(), std::back_inserter(pot_link)); + // if (st->find_it(begin(pot_link), end(pot_link), st->root.get()) != nullptr){ + // is_link = true; + // } + // } + // return(is_link); + // }); + // }; + // static const auto always_true = [](t_node& cn) -> bool { return true; }; + + + // [st, &start](t_node& cn) -> bool { + // bool is_link = false; + // std::cout << "testing: \n"; + // std::cout << "1: " << start << std::endl; + // st->print_simplex(std::cout, get<0>(cn), true); + // st->print_simplex(std::cout, start, true); + // // std::cout << "2: " << start << std::endl; + // const simplex_t s = st->full_simplex(start); + // const simplex_t t = st->full_simplex(get< 0 >(cn)); + // // std::cout << "3: " << start << std::endl; + // bool is_disjoint = empty_intersection(t, s); + // if (is_disjoint){ + // vector< idx_t > pot_link; + // std::set_union(s.begin(), s.end(), t.begin(), t.end(), std::back_inserter(pot_link)); + // if (st->find_it(pot_link.begin(), pot_link.end(), st->root.get()) != nullptr){ + // is_link = true; + // } + // } + // // std::cout << "4: " << start << std::endl; + // std::cout << "is_link: " << is_link << std::endl; + // return(is_link); + // } + + // Constructor is all that is needed + link(const SimplexTree* st, const node_ptr start) : + preorder< ts >( + st, st->root.get(), + link_condition< t_node >(st, start), + [](const t_node& cn) -> bool { return true; } + ) + {}; + auto begin() -> iterator_t { + return static_cast< P& >(*this).begin();// maybe change this to not be P& + // return iterator_t(static_cast< P& >(*this), (node_ptr) init); + }; + auto end() -> iterator_t { + return static_cast< P& >(*this).end(); + // return iterator_t(static_cast< P& >(*this), nullptr); + } + }; // link + + template< typename tnode_t > + std::function face_condition(const SimplexTree* st, node_ptr start){ + const simplex_t sigma = st->full_simplex(start); + return [st, sigma](tnode_t& cn) -> bool { + node_ptr np = get<0>(cn); + if (np == nullptr || np == st->root.get()){ return false; } + simplex_t tau = st->full_simplex(np, get<1>(cn)); + + return std::includes(sigma.begin(), sigma.end(), tau.begin(), tau.end()); + }; + }; + + template< typename tnode_t > + std::function face_condition2(const SimplexTree* st, node_ptr start){ + const size_t d = st->depth(start); + return [d](tnode_t& cn) -> bool { return get< 1 >(cn) <= d; }; + }; + + template < bool ts = false > + struct faces : level_order< ts > { + using P = level_order< ts >; + using B = TraversalInterface< ts, st::level_order >; + using B::init; + using t_node = typename B::t_node; + using iterator_t = typename P::iterator; + + // A valid member in a preorder-traversal is just + // const static auto valid_eval = [](const SimplexTree* st, node_ptr start) { + // simplex_t sigma = st->full_simplex(start); + // return([sigma](t_node& cn) -> bool { + // return SimplexTree::is_face(get< 2 >(cn), sigma); + // }); + // }; + // const static auto valid_children = [](const SimplexTree* st, node_ptr start) { + // idx_t k = st->depth(start); + // return([k](t_node& cn) -> bool{ return get< 1 >(cn) <= k; }); + // }; + // Constructor is all that is needed + faces(const SimplexTree* st, node_ptr start) : + level_order< ts >( + st, st->root.get(), + // [](t_node& cn) -> bool { return true; }, + face_condition< t_node >(st, start), + face_condition2< t_node >(st, start) + // [&st, &start](t_node& cn) ->bool { return get< 1 >(cn) <= st->depth(start); } + // face_condition< t_node >(st, start), + // [&st, &start](t_node& cn) ->bool { return get< 1 >(cn) <= st->depth(start); } + ) + {}; + auto begin() -> iterator_t { + return static_cast< P& >(*this).begin(); + // return iterator_t(static_cast< P& >(*this), (node_ptr) init); + }; + auto end() -> iterator_t { + return static_cast< P& >(*this).end(); + // return iterator_t(static_cast< P& >(*this), nullptr); + } + }; + + + template + auto get_node_ptr(T& cn) -> node_ptr { return std::get< 0 >(cn); } + + template + auto get_depth(T& cn) -> idx_t { return std::get< 1 >(cn); } + + template + auto get_simplex(T& cn) -> simplex_t { return std::get< 2 >(cn); } + + // Generic traversal function which unpacks the tuple and allows for early termination of the iterable + template + auto traverse(Iterable traversal, Lambda f) -> typename std::enable_if< Iterable::is_tracking, void >::type { + for (auto& cn: traversal){ + bool should_continue = f(get_node_ptr(cn), get_depth(cn), get_simplex(cn)); + if (!should_continue){ break; } + } + } + template + auto traverse(Iterable traversal, Lambda f) -> typename std::enable_if< !Iterable::is_tracking, void >::type { + for (auto& cn: traversal){ + bool should_continue = f(get_node_ptr(cn), get_depth(cn)); + if (!should_continue){ break; } + } + } + + + // template < class Iterable, typename Lambda > + // void traverse_node_pairs(Iterable traversal, Lambda&& f) -> typename std::enable_if< Iterable::is_tracking, void >::type { + // for (auto& cn: traversal){ + // f(std::make_pair(get_node_ptr(cn), get_depth(cn))); + // } + // } + // template < class Iterable, typename Lambda > + // void traverse_node_pairs(Iterable traversal, Lambda&& f) -> typename std::enable_if< !Iterable::is_tracking, void >::type { + // for (auto& cn: traversal){ + // f(std::make_pair(get_node_ptr(cn), get_depth(cn))); + // } + // } + // + // template < class Iterable, typename Lambda > + // void traverse_simplices(Iterable traversal, Lambda&& f) -> typename std::enable_if< Iterable::is_tracking, void >::type { + // for (auto& cn: traversal){ + // f(get_simplex(cn)); + // } + // } + // template < class Iterable, typename Lambda > + // void traverse_simplices(Iterable traversal, Lambda&& f) -> typename std::enable_if< !Iterable::is_tracking, void >::type { + // for (auto& cn: traversal){ + // f(get_simplex(cn)); + // } + // } + + // Generic traversal function which unpacks the tuple and allows for early termination of the iterable + // template + // decltype(auto) generate(Iterable traversal, Lambda f){ + // using V = typename Iterable::value_type; + // using T = decltype( std::declval< Lambda >()(V) ); + // vector< T > result; + // for (auto& cn: traversal){ + // bool should_continue = std::apply(f, cn); + // if (!should_continue){ break; } + // } + // } + + + +}; // end namespace st + + +#endif + + diff --git a/include/splex_ranges.h b/include/splex_ranges.h new file mode 100644 index 0000000..f384422 --- /dev/null +++ b/include/splex_ranges.h @@ -0,0 +1,341 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "combinatorial.h" + +typedef float value_t; +typedef int64_t index_t; +using std::vector; +using std::array; +using namespace combinatorial; + +template< uint8_t dim, bool colex = true, typename index_t = uint_fast64_t > +struct RankRange { + static_assert(std::is_integral< index_t >::value, "Must be integral"); + const size_t n; + const vector< index_t > ranks; + mutable array< uint16_t, dim + 1 > labels; // intermediate storage + + using iterator = typename vector< index_t >::iterator; + using const_iterator = typename vector< index_t >::const_iterator; + + RankRange(const vector< index_t > _ranks, const size_t _n) : n(_n), ranks(_ranks){ + combinatorial::BC.precompute(_n, dim+1); + } + ~RankRange(){ + combinatorial::BC.BT.clear(); + combinatorial::BC.BT.shrink_to_fit(); + combinatorial::BC.pre_n = 0; + combinatorial::BC.pre_k = 0; + } + + struct RankLabelIterator { + const size_t _n; + const_iterator _it; + array< uint16_t, dim + 1 >& _labels; + + RankLabelIterator(const size_t __n, const_iterator it, array< uint16_t, dim + 1 >& labels) : _n(__n), _it(it), _labels(labels){}; + + uint16_t* operator*() const { + if constexpr (colex){ + unrank_colex< false >(_it, _it+1, _n, dim+1, (uint16_t*) &_labels[0]); + } else { + unrank_lex< true, false >(_it, _it+1, _n, dim+1, (uint16_t*) &_labels[0]); + } + return &_labels[0]; + } + void operator++() { _it++; } + bool operator!=(RankLabelIterator o) const { return _it != o._it; } + + template< bool ranks = false, typename Lambda > + void boundary(Lambda&& f){ + if constexpr (!ranks){ + uint16_t* c_labels = this->operator*(); + combinatorial::for_each_combination(c_labels, c_labels + dim, c_labels + dim + 1, [&](auto b, [[maybe_unused]] auto e){ + f(b, e); + return false; + }); + } else if constexpr (ranks && colex){ + index_t idx_below = *_it; + index_t idx_above = 0; + index_t j = _n - 1, k = dim; + for (index_t ki = 0; ki <= dim; ++ki){ + j = combinatorial::get_max_vertex< false >(idx_below, k + 1, j); + index_t c = combinatorial::BinomialCoefficient< false >(j, k + 1); + index_t face_index = idx_above - c + idx_below; + idx_below -= c; + idx_above += combinatorial::BinomialCoefficient< false >(j, k); + --k; + f(face_index); + } + } else { + uint16_t* c_labels = this->operator*(); + const index_t N = combinatorial::BinomialCoefficient< false >(_n, dim); + for_each_combination(c_labels, c_labels + dim, c_labels + dim + 1, [&](auto b, [[maybe_unused]] auto e){ + f(combinatorial::rank_lex_k< false >(b, _n, dim, N)); + return false; + }); + // throw std::invalid_argument("Haven't implemented yet."); + } + } + }; + + auto begin() { + return RankLabelIterator(n, ranks.begin(), labels); + } + auto end() { + return RankLabelIterator(n, ranks.end(), labels); + } +}; + + +// Stores the simplex labels as-is, in the prescribed ordering +template< uint8_t dim, bool colex = true, typename index_t = uint16_t > +struct SimplexRange { + static constexpr bool colex_order = colex; + using iterator = typename vector< index_t >::iterator; + using const_iterator = typename vector< index_t >::const_iterator; + static_assert(std::is_integral< index_t >::value, "Must be integral"); + + // Fields + const vector< uint16_t > s_labels; // the actual labels + const size_t n; // alphabet size + + SimplexRange(const vector< uint16_t > S, const size_t _n) : s_labels(S), n(_n) { + // TODO: check labels explicitly and sort them for lex or colex + + combinatorial::BC.precompute(n, dim+1); + } + + ~SimplexRange(){ + if (!combinatorial::keep_table_alive){ + combinatorial::BC.BT.clear(); + combinatorial::BC.BT.shrink_to_fit(); + combinatorial::BC.pre_n = 0; + combinatorial::BC.pre_k = 0; + } + } + + struct SimplexLabelIterator { + using iterator_category = std::forward_iterator_tag; + // using iterator_category = std::random_access_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = uint16_t; + using pointer = uint16_t*; + using reference = uint16_t&; + + const index_t _n; + const index_t _N; + const_iterator _it; + SimplexLabelIterator(const_iterator it, const index_t __n) : _n(__n), _N(combinatorial::BinomialCoefficient< true >(_n, dim)), _it(it) {}; + SimplexLabelIterator& operator=(const SimplexLabelIterator& it){ + _it = it._it; + return *this; + } + uint16_t* operator*() noexcept { return (uint16_t*) &(*_it); } + constexpr void operator++() noexcept { _it += (dim+1); } + constexpr void operator--() noexcept { _it -= (dim+1); } + constexpr bool operator!=(SimplexLabelIterator o) const noexcept { return _it != o._it; } + // operator SimplexLabelIterator() const { return SimplexLabelIterator(_it, _n); } + + // Boilerplate + // constexpr SimplexLabelIterator& operator++() noexcept { _it += (dim+1); return *this; } + // constexpr SimplexLabelIterator& operator--() noexcept { _it -= (dim+1); return *this; } + // constexpr SimplexLabelIterator operator+(const difference_type& diff) const noexcept { return SimplexLabelIterator(_it + diff*(dim+1), _n); } + // constexpr SimplexLabelIterator operator-(const difference_type& diff) const noexcept { return SimplexLabelIterator(_it - diff*(dim+1), _n); } + // constexpr reference operator[] (const difference_type& offset) const noexcept { return *(_it + offset*(dim+1)); } + // constexpr difference_type operator-(const SimplexLabelIterator& it) const { return this->_it - it._it; } + + template< typename Lambda > + void boundary_labels(Lambda&& f){ + uint16_t* labels = this->operator*(); + combinatorial::for_each_combination(labels, labels + dim, labels + dim + 1, [&f](auto b, auto e){ + f(b,e); + return false; + }); + } + + // boundary ranks (lambda version) + template< typename Lambda > + void boundary_ranks(Lambda&& f){ + uint16_t* labels = this->operator*(); + if constexpr (dim == 1){ // labels size == 2 + f(labels[0]); + f(labels[1]); + } else if constexpr (dim == 2){ + if constexpr(colex){ + f(rank_colex_2(labels[0], labels[1])); + f(rank_colex_2(labels[0], labels[2])); + f(rank_colex_2(labels[1], labels[2])); + } else { + f(rank_lex_2(labels[0], labels[1], _n)); + f(rank_lex_2(labels[0], labels[2], _n)); + f(rank_lex_2(labels[1], labels[2], _n)); + } + } else { + for_each_combination(labels, labels + dim, labels + dim + 1, [&](auto b, [[maybe_unused]] auto e){ + f(rank_comb< colex, false >(b,_n,dim)); + return false; + }); + } + } // boundary_ranks + + // boundary ranks (tuple version) + auto boundary_ranks() -> tuple_of< uint64_t, dim+1 >{ + uint16_t* labels = this->operator*(); + if constexpr (dim == 1){ + return std::make_tuple(labels[0], labels[1]); + } else if constexpr(dim == 2){ + if constexpr(colex){ + return std::make_tuple(rank_colex_2(labels[0], labels[1]), rank_colex_2(labels[0], labels[2]), rank_colex_2(labels[1], labels[2])); + } else { + return std::make_tuple(rank_lex_2(labels[0], labels[1], _n), rank_lex_2(labels[0], labels[2], _n), rank_lex_2(labels[1], labels[2], _n)); + } + } else { + std::array< uint64_t, dim > a; + size_t i = 0; + for_each_combination(labels, labels + dim, labels + dim + 1, [&](auto b, [[maybe_unused]] auto e){ + a[i++] = rank_comb< colex, false >(b,_n,dim); + return false; + }); + return std::tuple_cat(a); + } + } + }; + + [[nodiscard]] + constexpr auto begin() const noexcept { + return SimplexLabelIterator(s_labels.begin(), n); + } + [[nodiscard]] + constexpr auto end() const noexcept { + return SimplexLabelIterator(s_labels.end(), n); + } +}; + +// template< uint8_t dim, typename index_t = uint16_t > +// struct SimplexBoundaryRange { +// static_assert(std::is_integral< index_t >::value, "Must be integral"); +// const vector< uint16_t > s_labels; +// using iterator = typename vector< index_t >::iterator; +// using const_iterator = typename vector< index_t >::const_iterator; + +// SimplexBoundaryRange(const vector< uint16_t > S) : s_labels(S) {} + +// struct SimplexBoundaryIterator { +// const_iterator _it; +// SimplexBoundaryIterator(const_iterator it) : _it(it) {}; +// constexpr uint16_t* operator*() const noexcept { +// return (uint16_t*) &(*_it); +// } +// constexpr void operator++() noexcept { _it += (dim+1); } +// constexpr bool operator!=(SimplexBoundaryIterator o) const noexcept { return _it != o._it; } +// }; + +// [[nodiscard]] +// constexpr auto begin() const noexcept { +// return SimplexBoundaryIterator(s_labels.begin()); +// } +// [[nodiscard]] +// constexpr auto end() const noexcept { +// return SimplexBoundaryIterator(s_labels.end()); +// } +// }; + +// struct binomial_coeff_table { +// std::vector> B; +// binomial_coeff_table(index_t n, index_t k) : B(k + 1, std::vector(n + 1, 0)) { +// for (index_t i = 0; i <= n; ++i) { +// B[0][i] = 1; +// for (index_t j = 1; j < std::min(i, k + 1); ++j) +// B[j][i] = B[j - 1][i - 1] + B[j][i - 1]; +// if (i <= k) B[i][i] = 1; +// //check_overflow(B[std::min(i >> 1, k)][i]); +// } +// } + +// index_t operator()(index_t n, index_t k) const { +// assert(k < B.size() && n < B[k].size() && n >= k - 1); +// return B[k][n]; +// } +// }; + +// template +// index_t get_max(index_t top, const index_t bottom, const Predicate pred) { +// if (!pred(top)) { +// index_t count = top - bottom; +// while (count > 0) { +// index_t step = count >> 1, mid = top - step; +// if (!pred(mid)) { +// top = mid - 1; +// count -= step + 1; +// } else { +// count = step; +// } +// } +// } +// return top; +// } + +// index_t get_max_vertex(const index_t idx, const index_t k, const index_t n, const binomial_coeff_table& B) { +// return get_max(n, k - 1, [&](index_t w) -> bool { return (B(w, k) <= idx); }); +// } + +// template +// OutputIterator get_simplex_vertices( +// index_t idx, +// const index_t dim, +// index_t n, +// OutputIterator out, +// const binomial_coeff_table& B +// ) { +// --n; +// for (index_t k = dim + 1; k > 1; --k) { +// n = get_max_vertex(idx, k, n, B); +// *out++ = n; +// idx -= B(n, k); +// } +// *out = idx; +// return out; +// } + +// struct simplex_boundary_enumerator { +// index_t idx_below, idx_above, j, k; +// index_t dim; +// const binomial_coeff_table& B; + +// public: +// simplex_boundary_enumerator(const index_t i, const index_t _dim, const index_t n, const binomial_coeff_table& _bt) +// : idx_below(i), idx_above(0), j(n - 1), k(_dim), B(_bt){} + +// simplex_boundary_enumerator(const index_t _dim, const index_t n, const binomial_coeff_table& _bt) +// : simplex_boundary_enumerator(-1, _dim, n, _bt) {} + +// void set_simplex(const index_t i, const index_t _dim, const index_t n) { +// idx_below = i; +// idx_above = 0; +// j = n - 1; +// k = _dim; +// dim = _dim; +// } + +// bool has_next() { return (k >= 0); } + +// index_t next() { +// j = get_max_vertex(idx_below, k + 1, j, B); +// index_t face_index = idx_above - B(j, k + 1) + idx_below; +// idx_below -= B(j, k + 1); +// idx_above += B(j, k); +// --k; +// return face_index; +// } +// }; \ No newline at end of file diff --git a/include/utility/combinations.h b/include/utility/combinations.h new file mode 100644 index 0000000..39508f6 --- /dev/null +++ b/include/utility/combinations.h @@ -0,0 +1,106 @@ +#ifndef COMBINATIONS_H +#define COMBINATIONS_H + +// (C) Copyright Howard Hinnant 2005-2011. +// Use, modification and distribution are subject to the Boost Software License, +// Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt). +// +// See http://www.boost.org/libs/type_traits for most recent version including documentation. + +// This code was adapted from Howard Hinnants excellent combinations header: +// https://github.com/HowardHinnant/combinations +// The original copyright is included above. + +#include +#include +#include +#include +#include + +namespace detail { + template + using it_diff_t = typename std::iterator_traits::difference_type; + + // Rotates two discontinuous ranges to put *first2 where *first1 is. + // Adapted from: https://github.com/HowardHinnant/combinations + template + void rotate_discontinuous( + It first1, It last1, it_diff_t< It > d1, + It first2, It last2, it_diff_t< It > d2) + { + using std::swap; + if (d1 <= d2){ std::rotate(first2, std::swap_ranges(first1, last1, first2), last2); } + else { + It i1 = last1; + while (first2 != last2){ swap(*--i1, *--last2); } + std::rotate(first1, i1, last1); + } + } + + // Call f() for each combination of the elements [first1, last1) + [first2, last2) + // swapped/rotated into the range [first1, last1). + template < typename It, typename Lambda > + bool combine_discontinuous( + It first1, It last1, it_diff_t< It > d1, + It first2, It last2, it_diff_t< It > d2, + Lambda&& f, it_diff_t< It > d = 0) + { + using D = it_diff_t< It >; + using std::swap; + if (d1 == 0 || d2 == 0){ return f(); } + if (d1 == 1) { + for (It i2 = first2; i2 != last2; ++i2) { + if (f()){ return true; } + swap(*first1, *i2); + } + } + else { + It f1p = std::next(first1), i2 = first2; + for (D d22 = d2; i2 != last2; ++i2, --d22){ + if (combine_discontinuous(f1p, last1, d1-1, i2, last2, d22, f, d+1)) + return true; + swap(*first1, *i2); + } + } + if (f()){ return true; } + if (d != 0){ rotate_discontinuous(first1, last1, d1, std::next(first2), last2, d2-1); } + else { rotate_discontinuous(first1, last1, d1, first2, last2, d2); } + return false; + } + + template < typename Lambda, class... Ts > + struct NullaryPredicate { + Lambda f_; + std::tuple< Ts... > params; + NullaryPredicate(Lambda& f, Ts ... args) : f_(f), params(std::make_tuple(std::move(args)...)) {}; + bool operator()() { return f_(params); } + }; +}; // namespace detail + +using namespace detail; + +template < class It, class Function > +Function for_each_combination(It first, It mid, It last, Function&& f) { + combine_discontinuous(first, mid, std::distance(first, mid), + mid, last, std::distance(mid, last), + [&f, &first, &mid]() -> bool { return f(first, mid); }); + return std::move(f); +} + +template < class I, class Function > +void for_each_combination_idx(I n, I k, Function&& f) { + static_assert(std::is_integral::value, "Must be integral type."); + using It = typename std::vector< I >::iterator; + std::vector< I > seq_n(n); + std::iota(begin(seq_n), end(seq_n), 0); + for_each_combination(begin(seq_n), begin(seq_n)+k, end(seq_n), [&f](It a, It b) -> bool { + std::vector< I > cc(a, b); + f(cc); + return false; + }); + return; +} + +#endif // COMBINATIONS_H + diff --git a/include/utility/delegate.hpp b/include/utility/delegate.hpp new file mode 100644 index 0000000..596a83f --- /dev/null +++ b/include/utility/delegate.hpp @@ -0,0 +1,357 @@ +// Delegate class implementation +// Taken from: https://codereview.stackexchange.com/questions/14730/impossibly-fast-delegate-in-c11 +#pragma once +#ifndef DELEGATE_HPP +#define DELEGATE_HPP + +#include +#include +#include +#include +#include + +template class delegate; + +template +class delegate { + using stub_ptr_type = R (*)(void*, A&&...); + + delegate(void* const o, stub_ptr_type const m) noexcept : + object_ptr_(o), + stub_ptr_(m) + { + } + +public: + delegate() = default; + delegate(delegate const&) = default; + delegate(delegate&&) = default; + delegate(::std::nullptr_t const) noexcept : delegate() { } + + template {}>::type> + explicit delegate(C const* const o) noexcept : + object_ptr_(const_cast(o)){} + + template {}>::type> + explicit delegate(C const& o) noexcept : + object_ptr_(const_cast(&o)){} + + template + delegate(C* const object_ptr, R (C::* const method_ptr)(A...)) + { + *this = from(object_ptr, method_ptr); + } + + template + delegate(C* const object_ptr, R (C::* const method_ptr)(A...) const) + { + *this = from(object_ptr, method_ptr); + } + + template + delegate(C& object, R (C::* const method_ptr)(A...)) + { + *this = from(object, method_ptr); + } + + template + delegate(C const& object, R (C::* const method_ptr)(A...) const) + { + *this = from(object, method_ptr); + } + + template < + typename T, + typename = typename ::std::enable_if< + !::std::is_same::type>{} + >::type + > + delegate(T&& f) : + store_(operator new(sizeof(typename ::std::decay::type)), + functor_deleter::type>), + store_size_(sizeof(typename ::std::decay::type)) + { + using functor_type = typename ::std::decay::type; + + new (store_.get()) functor_type(::std::forward(f)); + + object_ptr_ = store_.get(); + + stub_ptr_ = functor_stub; + + deleter_ = deleter_stub; + } + + delegate& operator=(delegate const&) = default; + + delegate& operator=(delegate&&) = default; + + template + delegate& operator=(R (C::* const rhs)(A...)) + { + return *this = from(static_cast(object_ptr_), rhs); + } + + template + delegate& operator=(R (C::* const rhs)(A...) const) + { + return *this = from(static_cast(object_ptr_), rhs); + } + + template < + typename T, + typename = typename ::std::enable_if< + !::std::is_same::type>{} + >::type + > + delegate& operator=(T&& f) + { + using functor_type = typename ::std::decay::type; + + if ((sizeof(functor_type) > store_size_) || !store_.unique()) + { + store_.reset(operator new(sizeof(functor_type)), + functor_deleter); + + store_size_ = sizeof(functor_type); + } + else + { + deleter_(store_.get()); + } + + new (store_.get()) functor_type(::std::forward(f)); + + object_ptr_ = store_.get(); + + stub_ptr_ = functor_stub; + + deleter_ = deleter_stub; + + return *this; + } + + template + static delegate from() noexcept + { + return { nullptr, function_stub }; + } + + template + static delegate from(C* const object_ptr) noexcept + { + return { object_ptr, method_stub }; + } + + template + static delegate from(C const* const object_ptr) noexcept + { + return { const_cast(object_ptr), const_method_stub }; + } + + template + static delegate from(C& object) noexcept + { + return { &object, method_stub }; + } + + template + static delegate from(C const& object) noexcept + { + return { const_cast(&object), const_method_stub }; + } + + template + static delegate from(T&& f) + { + return ::std::forward(f); + } + + static delegate from(R (* const function_ptr)(A...)) + { + return function_ptr; + } + + template + using member_pair = + ::std::pair; + + template + using const_member_pair = + ::std::pair; + + template + static delegate from(C* const object_ptr, + R (C::* const method_ptr)(A...)) + { + return member_pair(object_ptr, method_ptr); + } + + template + static delegate from(C const* const object_ptr, + R (C::* const method_ptr)(A...) const) + { + return const_member_pair(object_ptr, method_ptr); + } + + template + static delegate from(C& object, R (C::* const method_ptr)(A...)) + { + return member_pair(&object, method_ptr); + } + + template + static delegate from(C const& object, + R (C::* const method_ptr)(A...) const) + { + return const_member_pair(&object, method_ptr); + } + + void reset() { stub_ptr_ = nullptr; store_.reset(); } + + void reset_stub() noexcept { stub_ptr_ = nullptr; } + + void swap(delegate& other) noexcept { ::std::swap(*this, other); } + + bool operator==(delegate const& rhs) const noexcept + { + return (object_ptr_ == rhs.object_ptr_) && (stub_ptr_ == rhs.stub_ptr_); + } + + bool operator!=(delegate const& rhs) const noexcept + { + return !operator==(rhs); + } + + bool operator<(delegate const& rhs) const noexcept + { + return (object_ptr_ < rhs.object_ptr_) || + ((object_ptr_ == rhs.object_ptr_) && (stub_ptr_ < rhs.stub_ptr_)); + } + + bool operator==(::std::nullptr_t const) const noexcept + { + return !stub_ptr_; + } + + bool operator!=(::std::nullptr_t const) const noexcept + { + return stub_ptr_; + } + + explicit operator bool() const noexcept { return stub_ptr_; } + + R operator()(A... args) const + { +// assert(stub_ptr); + return stub_ptr_(object_ptr_, ::std::forward(args)...); + } + +private: + friend struct ::std::hash; + + using deleter_type = void (*)(void*); + + void* object_ptr_; + stub_ptr_type stub_ptr_{}; + + deleter_type deleter_; + + ::std::shared_ptr store_; + ::std::size_t store_size_; + + template + static void functor_deleter(void* const p) + { + static_cast(p)->~T(); + + operator delete(p); + } + + template + static void deleter_stub(void* const p) + { + static_cast(p)->~T(); + } + + template + static R function_stub(void* const, A&&... args) + { + return function_ptr(::std::forward(args)...); + } + + template + static R method_stub(void* const object_ptr, A&&... args) + { + return (static_cast(object_ptr)->*method_ptr)( + ::std::forward(args)...); + } + + template + static R const_method_stub(void* const object_ptr, A&&... args) + { + return (static_cast(object_ptr)->*method_ptr)( + ::std::forward(args)...); + } + + template + struct is_member_pair : std::false_type { }; + + template + struct is_member_pair< ::std::pair > : std::true_type + { + }; + + template + struct is_const_member_pair : std::false_type { }; + + template + struct is_const_member_pair< ::std::pair > : std::true_type + { + }; + + template + static typename ::std::enable_if< + !(is_member_pair{} || + is_const_member_pair{}), + R + >::type + functor_stub(void* const object_ptr, A&&... args) + { + return (*static_cast(object_ptr))(::std::forward(args)...); + } + + template + static typename ::std::enable_if< + is_member_pair{} || + is_const_member_pair{}, + R + >::type + functor_stub(void* const object_ptr, A&&... args) + { + return (static_cast(object_ptr)->first->* + static_cast(object_ptr)->second)(::std::forward(args)...); + } +}; + +namespace std +{ + template + struct hash<::delegate > + { + size_t operator()(::delegate const& d) const noexcept + { + auto const seed(hash()(d.object_ptr_)); + + return hash::stub_ptr_type>()( + d.stub_ptr_) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + }; +} + +#endif // DELEGATE_HPP + diff --git a/include/utility/delegate2.hpp b/include/utility/delegate2.hpp new file mode 100644 index 0000000..3e674a3 --- /dev/null +++ b/include/utility/delegate2.hpp @@ -0,0 +1,1039 @@ +/* + * delegate.h + * + * Created on: 8 aug. 2017 + * Author: Mikael Rosbacke + */ + +#ifndef DELEGATE_DELEGATE_HPP_ +#define DELEGATE_DELEGATE_HPP_ + +/** + * Storage of a callable object for functors, free and member functions. + * + * Design intent is to do part of what std::function do but + * without heap allocation, virtual function call and with minimal + * memory footprint. (2 pointers). + * Use case is embedded systems where previously raw function pointers + * where used, using a void* pointer to point out a particular structure/object. + * + * The price is less generality. The user must keep objects alive since the + * delegate only store a pointer to them. This is true for both + * member functions and functors. + * + * Free functions and member functions are supplied as compile time template + * arguments. It is required to generate the correct intermediate functions. + * + * Once constructed, the delegate should behave as any pointer: + * - Can be copied freely. + * - Can be compared to same type delegates and nullptr. + * - Can be reassigned. + * - Can be called. + * + * A default constructed delegate compare equal to nullptr. However it can be + * called with the default behavior being to do nothing and return a default + * constructed return object. + * + * Two overload sets are provided for construction: + * - set : Set an existing delegate with new pointer values. + * - make : Construct a new delegate. + * + * The following types of callables are supported: + * - Functors. + * - Member functions. + * - Free functions. (Do not use the stored void* value) + * - (Special) A free function with a void* extra first argument. That will + * be passed the void* value set at delegate construction, + * in addition to the arguments supplied to the call. + * + * Const correctness: + * The delegate models a pointer in const correctness. The constness of the + * delegate is different that the constness of the called objects. + * Calling a member function or operator() require them to be const to be able + * to call a const object. + * + * Both the member function and the object to call on must be set + * at the same time. This is required to maintain const correctness guarantees. + * The constness of the object or the member function is not part of the + * delegate type. + * + * There is a class MemFkn for storing pointers to member functions which + * will keep track of constness. It allow taking the address of a member + * function at one point and store it. At a later time the MemFkn object and an + * object can be set to a delegate for later call. + * + * The delegate do not allow storing pointers to r-value references + * (temporary objects) for member and functor construction. + */ + +#ifdef _MSVC_LANG +#define DELEGATE_CPP_VERSION _MSVC_LANG +#define DELEGATE_ALWAYS_INLINE __forceinline +#else +#define DELEGATE_CPP_VERSION __cplusplus +#define DELEGATE_ALWAYS_INLINE __attribute__((always_inline)) +#endif + +#if DELEGATE_CPP_VERSION < 201103L +#error "Require at least C++11 to compile delegate" +#endif + +#if DELEGATE_CPP_VERSION >= 201402L +#define DELEGATE_CXX14CONSTEXPR constexpr +#else +#define DELEGATE_CXX14CONSTEXPR +#endif + +namespace details +{ + +using nullptr_t = decltype(nullptr); + +template +T +nullReturnFunction() +{ + return T{}; +} +template <> +inline void +nullReturnFunction() +{ + return; +} + +template +class common; + +template +class common +{ + public: + using FknPtr = R (*)(Args...); + union DataPtr; + struct FknStore; + + using Trampoline = R (*)(DataPtr const&, Args...); + + union DataPtr { + constexpr DataPtr() = default; + constexpr DataPtr(void* p) noexcept : v_ptr(p){}; + + static constexpr bool equal(Trampoline fkn, const DataPtr& lhs, + const DataPtr& rhs) noexcept + { + // Ugly, but the m_ptr part should be optimized away on normal + // platforms. + return (fkn == doRuntimeFkn ? (lhs.fkn_ptr == rhs.fkn_ptr) + : (lhs.v_ptr == rhs.v_ptr)); + } + static constexpr bool less(Trampoline fkn, const DataPtr& lhs, + const DataPtr& rhs) noexcept + { + // Ugly, but the m_ptr part should be optimized away on normal + // platforms. + return fkn == doRuntimeFkn ? (lhs.fkn_ptr < rhs.fkn_ptr) + : (lhs.v_ptr < rhs.v_ptr); + } + constexpr void* ptr() const noexcept + { + return v_ptr; + } + + private: + // Whole reason for 'private' is to make sure fkn_ptr is only + // used when FknStore uses 'doRuntimeFkn' as adapter fkn. Will break + // Aliasing rules otherwise in equal/less comparisons. + friend struct FknStore; + constexpr DataPtr(FknPtr p) noexcept : fkn_ptr(p){}; + + static R doRuntimeFkn(DataPtr const& o_arg, Args... args) + { + FknPtr fkn = o_arg.fkn_ptr; + return fkn(args...); + } + + void* v_ptr = nullptr; + FknPtr fkn_ptr; + }; + + inline static constexpr R doNullCB(DataPtr const&, Args...) + { + return nullReturnFunction(); + } + + struct FknStore + { + constexpr FknStore() = default; + constexpr FknStore(Trampoline fkn, void* ptr) + : m_fkn(fkn), m_data(ptr){}; + constexpr FknStore(FknPtr ptr) + : m_fkn(ptr ? &DataPtr::doRuntimeFkn : &doNullCB), m_data(ptr) + { + } + + constexpr bool null() const noexcept + { + return m_fkn == &doNullCB; + } + + Trampoline m_fkn = &doNullCB; + DataPtr m_data; + }; + + // Adapter function for the member + object calling. + template + inline static constexpr R doMemberCB(DataPtr const& o, Args... args) + { + return (static_cast(o.ptr())->*memFkn)(args...); + } + + // Adapter function for the member + object calling. + template + inline static constexpr R doConstMemberCB(DataPtr const& o, Args... args) + { + return (static_cast(o.ptr())->*memFkn)(args...); + } + + static constexpr bool equal(const FknStore& lhs, + const FknStore& rhs) noexcept + { + return lhs.m_fkn == rhs.m_fkn && + DataPtr::equal(lhs.m_fkn, lhs.m_data, rhs.m_data); + } + + static constexpr bool less(const FknStore& lhs, + const FknStore& rhs) noexcept + { + return (lhs.null() && !rhs.null()) || lhs.m_fkn < rhs.m_fkn || + (lhs.m_fkn == rhs.m_fkn && + DataPtr::less(lhs.m_fkn, lhs.m_data, rhs.m_data)); + } + + // Helper struct to deduce member function types and const. + template + struct DeduceMemberType; + + template + struct DeduceMemberType + { + using ObjType = T; + static constexpr bool cnst = false; + static constexpr Trampoline trampoline = + &common::template doMemberCB; + static constexpr void* castPtr(ObjType* obj) noexcept + { + return static_cast(obj); + } + }; + template + struct DeduceMemberType + { + using ObjType = T; + static constexpr bool cnst = true; + static constexpr Trampoline trampoline = + &common::template doConstMemberCB; + static constexpr void* castPtr(ObjType const* obj) noexcept + { + return const_cast(static_cast(obj)); + }; + }; +}; + +} // namespace details + +template +class delegate; + +/** + * Simple adapter class for holding member functions and + * allow them to be called similar to std::invoke. + * + * @param T Object type this member can be called on. + * @param cnst Force stored member function to be const. + * @param Signature The signature required to call this member function. + * + * Notes on the cnst arguments. If false, we require the member to be called + * only on non const objects. Normally we match only on non const member + * functions. When true we require member functions to be const, so they can be + * called on both non-const/const objects. To assign a const function to a + * mem_fkn which is false, use make_from_const. A separate function name is + * needed to avoid ambiguous overloads w.r.t. constness. + * + * Current idea: Give both T and constness separately to be part of the type. + * - Allow us to disambiguate const/non-const member functions when taking the + * address. + * + * Discarded idea: T can be const/non-const. If const is not part of stored + * type, we can not later use constness to disambiguate taking the address. + * Still, constness of function is different from constness of object. + * + * -> + */ +template +class mem_fkn; + +template +class mem_fkn_base; + +template +class mem_fkn_base +{ + protected: + using common = details::common; + using DataPtr = typename common::DataPtr; + using Trampoline = typename common::Trampoline; + + constexpr mem_fkn_base(Trampoline fkn) : fknPtr(fkn){}; + constexpr mem_fkn_base() = default; + + Trampoline fknPtr = common::doNullCB; + + void setPtr(Trampoline t) + { + fknPtr = t; + } + + public: + constexpr Trampoline ptr() const + { + return fknPtr; + } + + constexpr bool null() const noexcept + { + return fknPtr == common::doNullCB; + } + // Return true if a function pointer is stored. + constexpr explicit operator bool() const noexcept + { + return !null(); + } + static constexpr bool equal(const mem_fkn_base& lhs, + const mem_fkn_base& rhs) + { + return lhs.fknPtr == rhs.fknPtr; + } + static constexpr bool less(const mem_fkn_base& lhs, const mem_fkn_base& rhs) + { + return (lhs.null() && !rhs.null()) || (lhs.fknPtr < rhs.fknPtr); + } +}; + +template +class mem_fkn : public mem_fkn_base +{ + using Base = mem_fkn_base; + static constexpr const bool cnst = false; + using common = details::common; + using DataPtr = typename common::DataPtr; + using Trampoline = typename common::Trampoline; + + constexpr mem_fkn(Trampoline fkn) : mem_fkn_base(fkn){}; + + public: + constexpr mem_fkn() = default; + + template + constexpr R invoke(U&& o, Args... args) const noexcept + { + return Base::fknPtr(DataPtr{static_cast(&o)}, args...); + } + + template + DELEGATE_CXX14CONSTEXPR mem_fkn& set() noexcept + { + Base::setPtr(&common::template doMemberCB); + return *this; + } + template + DELEGATE_CXX14CONSTEXPR mem_fkn& set_from_const() noexcept + { + Base::fknPtr = &common::template doConstMemberCB; + return *this; + } + template + static constexpr mem_fkn make() noexcept + { + return mem_fkn{&common::template doMemberCB}; + } + template + static constexpr mem_fkn make_from_const() noexcept + { + return mem_fkn{&common::template doConstMemberCB}; + } +}; + +template +class mem_fkn : public mem_fkn_base +{ + using Base = mem_fkn_base; + using common = details::common; + using Trampoline = typename common::Trampoline; + + constexpr mem_fkn(Trampoline fkn) : Base(fkn){}; + + public: + constexpr mem_fkn() = default; + + constexpr R invoke(T const& o, Args... args) const + { + return Base::fknPtr(const_cast(&o), args...); + } + + template + static constexpr mem_fkn make() noexcept + { + return mem_fkn{&common::template doConstMemberCB}; + } + + template + DELEGATE_CXX14CONSTEXPR mem_fkn& set() noexcept + { + Base::fknPtr = &common::template doConstMemberCB; + return *this; + } +}; + +template +bool +operator==(const mem_fkn& lhs, const mem_fkn& rhs) +{ + return lhs.equal(lhs, rhs); +} + +template +bool +operator!=(const mem_fkn& lhs, const mem_fkn& rhs) +{ + return !(lhs == rhs); +} + +template +bool +operator<(const mem_fkn& lhs, const mem_fkn& rhs) +{ + return lhs.less(lhs, rhs); +} + +template +bool +operator<=(const mem_fkn& lhs, const mem_fkn& rhs) +{ + return lhs.less(lhs, rhs) || lhs.equal(lhs, rhs); +} + +template +bool +operator>=(const mem_fkn& lhs, const mem_fkn& rhs) +{ + return lhs.less(rhs, lhs) || lhs.equal(lhs, rhs); +} + +template +bool +operator>(const mem_fkn& lhs, const mem_fkn& rhs) +{ + return rhs.less(rhs, lhs); +} + +/** + * Class for storing the callable object + * Stores a pointer to an adapter function and a void* pointer to the + * Object. + * + * @param R type of the return value from calling the callback. + * @param Args Argument list to the function when calling the callback. + */ +template +class delegate +{ + public: + using common = details::common; + using DataPtr = typename common::DataPtr; + using FknStore = typename common::FknStore; + + // Signature for call to the delegate. + using FknPtr = typename common::FknPtr; + + // Type of the function pointer for the trampoline functions. + using Trampoline = typename common::Trampoline; + + // Adaptor function for the case where void* is not forwarded + // to the caller. (Just a normal function pointer.) + template + inline static R doFreeCB(DataPtr const&, Args... args) + { + return freeFkn(args...); + } + + // Adapter function for when the stored object is a pointer to a + // callable object (stored elsewhere). Call it using operator(). + template + inline static R doFunctor(DataPtr const& o_arg, Args... args) + { + auto obj = static_cast(o_arg.ptr()); + return (*obj)(args...); + } + + template + inline static R doConstFunctor(DataPtr const& o_arg, Args... args) + { + const Functor* obj = static_cast(o_arg.ptr()); + return (*obj)(args...); + } + + // Adapter function for the free function with extra first arg + // in the called function, set at delegate construction. + template + inline static R dofreeFknWithObjectRef(DataPtr const& o, Args... args) + { + T* obj = static_cast(o.ptr()); + return freeFkn(*obj, args...); + } + + // Adapter function for the free function with extra first arg + // in the called function, set at delegate construction. + template + inline static R dofreeFknWithObjectConstRef(DataPtr const& o, Args... args) + { + T const* obj = static_cast(o.ptr()); + return freeFkn(*obj, args...); + } + + // Adapter function for the free function with extra first void*arg + // in the called function, set at delegate construction. + template + inline static R dofreeFknWithVoidPtr(DataPtr const& o, Args... args) + { + return freeFkn(o.ptr(), args...); + } + + // Adapter function for the free function with extra first arg + // in the called function, set at delegate construction. + template + inline static R dofreeFknWithVoidConstPtr(DataPtr const& o, Args... args) + { + return freeFkn(static_cast(o.ptr()), args...); + } + + public: + // Default construct with stored ptr == nullptr. + constexpr delegate() = default; + + // General simple function pointer handling. Will accept stateless lambdas. + // Do note it is less easy to optimize compared to static function setup. + // Handle nullptr / 0 arguments as well. + constexpr delegate(FknPtr fkn) noexcept : m_data(fkn) {} + + /** + * Allow writing extensions with separate make/trampoline function. + * We restrict to allowing 'void*Í„' as data pointer to make sure + * not tripping up on union aliasing issues in e.g. equal function. + * THe trampoline function can only access the v_ptr member of the passed + * union. + */ + constexpr delegate(Trampoline tFkn, void* datap) noexcept + : m_data(tFkn, datap) + { + } + + ~delegate() = default; + + DELEGATE_CXX14CONSTEXPR delegate& operator=(FknPtr fkn) noexcept + { + m_data = FknStore{fkn}; + return *this; + } + + // Call the stored function. Requires: bool(*this) == true; + // Will call trampoline fkn which will call the final fkn. + DELEGATE_ALWAYS_INLINE constexpr R operator()(Args... args) const + { + return m_data.m_fkn(m_data.m_data, args...); + } + + constexpr bool null() const noexcept + { + return m_data.null(); + } + + static constexpr bool equal(const delegate& lhs, + const delegate& rhs) noexcept + { + return common::equal(lhs.m_data, rhs.m_data); + } + + // Helper Functor for passing into std functions etc. + struct Equal + { + constexpr bool operator()(const delegate& lhs, + const delegate& rhs) const noexcept + { + return equal(lhs, rhs); + } + }; + + // Define a total order for purpose of sorting in maps etc. + // Do not define operators since this is not a natural total order. + // It will vary randomly depending on where symbols end up etc. + static constexpr bool less(const delegate& lhs, + const delegate& rhs) noexcept + { + return common::less(lhs.m_data, rhs.m_data); + // Ugly, but the m_ptr part should be optimized away on normal + // platforms. + } + + // Helper Functor for passing into std::map et.al. + struct Less + { + constexpr bool operator()(const delegate& lhs, + const delegate& rhs) const noexcept + { + return less(lhs, rhs); + } + }; + + // Return true if a function pointer is stored. + constexpr explicit operator bool() const noexcept + { + return !null(); + } + + DELEGATE_CXX14CONSTEXPR void clear() noexcept + { + m_data = FknStore{}; + } + + /** + * Create a callback to a free function with a specific type on + * the pointer. + */ + template + DELEGATE_CXX14CONSTEXPR delegate& set() noexcept + { + m_data = FknStore(&doFreeCB, nullptr); + return *this; + } + + /** + * Create a callback to a member function to a given object. + */ + template + DELEGATE_CXX14CONSTEXPR delegate& set(T& tr) noexcept + { + m_data = FknStore(&common::template doMemberCB, + static_cast(&tr)); + return *this; + } + + template + DELEGATE_CXX14CONSTEXPR delegate& set(T const& tr) noexcept + { + m_data = FknStore(&common::template doConstMemberCB, + const_cast(static_cast(&tr))); + return *this; + } + + // Delete r-values. Not interested in temporaries. + template + DELEGATE_CXX14CONSTEXPR delegate& set(T&&) = delete; + + template + DELEGATE_CXX14CONSTEXPR delegate& set(T&&) = delete; + + /** + * Create a callback to a Functor or a lambda. + * NOTE : Only a pointer to the functor is stored. The + * user must ensure the functor is still valid at call time. + * Hence, we do not accept functor r-values. + */ + template + DELEGATE_CXX14CONSTEXPR delegate& set(T& tr) noexcept + { + m_data = FknStore(&doFunctor, static_cast(&tr)); + return *this; + } + + template + DELEGATE_CXX14CONSTEXPR delegate& set(T const& tr) noexcept + { + m_data = FknStore(&doConstFunctor, + const_cast(static_cast(&tr))); + return *this; + } + + // Do not allow temporaries to be stored. + template + constexpr delegate& set(T&&) const = delete; + + DELEGATE_CXX14CONSTEXPR delegate& set(FknPtr fkn) noexcept + { + m_data = FknStore(fkn); + return *this; + } + + /** + * Combine a MemFkn with an object to set this delegate. + */ + template + DELEGATE_CXX14CONSTEXPR delegate& + set(mem_fkn const& f, T& o) noexcept + { + m_data = FknStore(f.ptr(), static_cast(&o)); + return *this; + } + template + DELEGATE_CXX14CONSTEXPR delegate& + set(mem_fkn const& f, T const& o) = delete; + + template + DELEGATE_CXX14CONSTEXPR delegate& set(mem_fkn const& f, + T const& o) noexcept + { + m_data = + FknStore(f.ptr(), const_cast(static_cast(&o))); + return *this; + } + template + DELEGATE_CXX14CONSTEXPR delegate& set(mem_fkn const& f, + T& o) noexcept + { + return set(f, static_cast(o)); + } + + template + DELEGATE_CXX14CONSTEXPR delegate& + set(mem_fkn const& f, T&& o) = delete; + template + DELEGATE_CXX14CONSTEXPR delegate& set(mem_fkn const& f, + T&& o) = delete; + + // C++17 allow template for non type template arguments. + // Use to avoid specifying object type. +#if __cplusplus >= 201703 + + template + constexpr delegate& + set(typename common::template DeduceMemberType::ObjType& obj) noexcept + { + using DM = + typename common::template DeduceMemberType; + m_data = FknStore(DM::trampoline, DM::castPtr(&obj)); + return *this; + } + template + constexpr delegate& set(typename common::template DeduceMemberType< + decltype(mFkn), mFkn>::ObjType const& obj) noexcept + { + using DM = + typename common::template DeduceMemberType; + m_data = FknStore(DM::trampoline, DM::castPtr(&obj)); + return *this; + } + template + constexpr delegate& set(typename common::template DeduceMemberType< + decltype(mFkn), mFkn>::ObjType&& obj) = delete; +#endif + + DELEGATE_CXX14CONSTEXPR delegate& set_fkn(FknPtr fkn) noexcept + { + return set(fkn); + } + + template + DELEGATE_CXX14CONSTEXPR delegate& set_free_with_void(void* ctx) noexcept + { + m_data = FknStore(&dofreeFknWithVoidPtr, ctx); + return *this; + } + + template + DELEGATE_CXX14CONSTEXPR delegate& set_free_with_void(decltype(nullptr)) noexcept + { + m_data = FknStore(&dofreeFknWithVoidPtr, nullptr); + return *this; + } + + template + DELEGATE_CXX14CONSTEXPR delegate& + set_free_with_void(void const* ctx) noexcept + { + m_data = + FknStore(&dofreeFknWithVoidConstPtr, const_cast(ctx)); + return *this; + } + + template + DELEGATE_CXX14CONSTEXPR delegate& set_free_with_object(T& o) noexcept + { + m_data = FknStore(&dofreeFknWithObjectRef, + const_cast(static_cast(&o))); + return *this; + } + + template + DELEGATE_CXX14CONSTEXPR delegate& set_free_with_object(T& o) noexcept + { + m_data = FknStore(&dofreeFknWithObjectConstRef, + static_cast(&o)); + return *this; + } + + template + DELEGATE_CXX14CONSTEXPR delegate& set_free_with_object(T const& o) noexcept + { + m_data = FknStore(&dofreeFknWithObjectConstRef, + const_cast(static_cast(&o))); + return *this; + } + + template + DELEGATE_CXX14CONSTEXPR delegate& set_free_with_object(T const&) = delete; + template + DELEGATE_CXX14CONSTEXPR delegate& set_free_with_object(T&&) = delete; + template + DELEGATE_CXX14CONSTEXPR delegate& set_free_with_object(T&&) = delete; + + /** + * Create a callback to a free function with a specific type on + * the pointer. + */ + template + static constexpr delegate make() noexcept + { + // Note: template arg can never be nullptr. + return delegate{&doFreeCB, static_cast(nullptr)}; + } + + /** + * Create a callback to a member function to a given object. + */ + template + static constexpr delegate make(T& o) noexcept + { + return delegate{&common::template doMemberCB, + static_cast(&o)}; + } + + template + static constexpr delegate make(const T& o) noexcept + { + return delegate{&common::template doConstMemberCB, + const_cast(static_cast(&o))}; + } + + /** + * Create a callback to a Functor or a lambda. + * NOTE : Only a pointer to the functor is stored. The + * user must ensure the functor is still valid at call time. + * Hence, we do not accept functor r-values. + */ + template + static constexpr delegate make(T& o) noexcept + { + return delegate{&doFunctor, static_cast(&o)}; + } + template + static constexpr delegate make(T const& o) noexcept + { + return delegate{&doConstFunctor, + const_cast(static_cast(&o))}; + } + template + static constexpr delegate make(T&& object) = delete; + + static constexpr delegate make(FknPtr fkn) noexcept + { + return delegate{fkn}; + } + + static constexpr delegate make_fkn(FknPtr fkn) noexcept + { + return delegate{fkn}; + } + + template + static constexpr delegate make_free_with_void(void* ctx) noexcept + { + return delegate{&dofreeFknWithVoidPtr, ctx}; + } + + template + static constexpr delegate make_free_with_void(void const* ctx) noexcept + { + return delegate{&dofreeFknWithVoidConstPtr, + const_cast(ctx)}; + } + + template + static constexpr delegate make_free_with_void(decltype(nullptr)) noexcept + { + return delegate{&dofreeFknWithVoidPtr, nullptr}; + } + + /** + * Create a delegate to a free function, where the first argument is + * assumed to be a reference to the object supplied as argument here. + * The return value and rest of the argument must match the signature + * of the delegate. + */ + template + static constexpr delegate make_free_with_object(T& o) noexcept + { + return delegate{&dofreeFknWithObjectRef, + static_cast(&o)}; + } + + template + static constexpr delegate make_free_with_object(T& o) noexcept + { + return delegate{&dofreeFknWithObjectConstRef, + static_cast(&o)}; + } + + template + static constexpr delegate make_free_with_object(T const& o) noexcept + { + return delegate{&dofreeFknWithObjectConstRef, + const_cast(static_cast(&o))}; + } + + template + static constexpr delegate make_free_with_object(T const&) = delete; + template + static constexpr delegate make_free_with_object(T&&) = delete; + template + static constexpr delegate make_free_with_object(T&&) = delete; + + template + static constexpr delegate make(mem_fkn f, + T& o) noexcept + { + return delegate{f.ptr(), static_cast(&o)}; + } + template + static constexpr delegate make(mem_fkn, + T const&) = delete; + + template + static constexpr delegate make(mem_fkn f, + T const& o) noexcept + { + return delegate{f.ptr(), + const_cast(static_cast(&o))}; + } + template + static constexpr delegate make(mem_fkn f, + T& o) noexcept + { + return delegate{f.ptr(), static_cast(&o)}; + } + + template + static constexpr delegate make(mem_fkn, T&&) = delete; + + // C++17 allow template for non type template arguments. + // Use to avoid specifying object type. (Getting a bit hairy here...) +#if __cplusplus >= 201703 + template + static constexpr delegate make(typename common::template DeduceMemberType< + decltype(mFkn), mFkn>::ObjType& obj) noexcept + { + using DM = + typename common::template DeduceMemberType; + return delegate{DM::trampoline, DM::castPtr(&obj)}; + } + template + static constexpr delegate + make(typename common::template DeduceMemberType< + decltype(mFkn), mFkn>::ObjType const& obj) noexcept + { + using DM = + typename common::template DeduceMemberType; + return delegate{DM::trampoline, DM::castPtr(&obj)}; + } + + // Temporaries not allowed. + template + static constexpr delegate + make(typename common::template DeduceMemberType::ObjType&&) = delete; + +#endif + + private: + FknStore m_data; +}; + +template +constexpr bool +operator==(const delegate& lhs, + const delegate& rhs) noexcept +{ + return delegate::equal(lhs, rhs); +} + +template +constexpr bool +operator!=(const delegate& lhs, + const delegate& rhs) noexcept +{ + return !(lhs == rhs); +} + +// Bite the bullet, this is how unique_ptr handle nullptr_t. +template +constexpr bool +operator==(details::nullptr_t, const delegate& rhs) noexcept +{ + return rhs.null(); +} + +template +constexpr bool +operator!=(details::nullptr_t lhs, const delegate& rhs) noexcept +{ + return !(lhs == rhs); +} + +template +constexpr bool +operator==(const delegate& lhs, details::nullptr_t) noexcept +{ + return lhs.null(); +} + +template +constexpr bool +operator!=(const delegate& lhs, details::nullptr_t rhs) noexcept +{ + return !(lhs == rhs); +} + +// No ordering operators ( operator< etc) defined. This delegate +// represent several classes of pointers and is not a naturally +// ordered type. Use members less, Less for explicit ordering. + +/** + * Helper macro to create a delegate for calling a member function. + * Example of use: + * + * auto cb = DELEGATE_MKMEM(void(), &SomeClass::memberFunction, obj); + * + * where 'obj' is of type 'SomeClass'. + * + * @param signature Template parameter for the delegate. + * @param memFknPtr address of member function pointer. C++ require + * full name path with addressof operator (&) + * @object object which the member function should be called on. + */ +#define DELEGATE_MKMEM(signature, memFknPtr, object) \ + (delegate::make, \ + memFkn>(object)) + +#undef DELEGATE_14CONSTEXPR + +#endif /* UTILITY_CALLBACK_H_ */ \ No newline at end of file diff --git a/include/utility/discrete.h b/include/utility/discrete.h new file mode 100644 index 0000000..99c97c6 --- /dev/null +++ b/include/utility/discrete.h @@ -0,0 +1,217 @@ +#ifndef DISCRETE_H_ +#define DISCRETE_H_ + +#include "short_alloc.h" // stack-based allocation helpers +#include // assertions +#include +#include // round, sqrt, etc. + +using std::floor; +using std::sqrt; +using std::round; + +template < class T, std::size_t BufSize = 32 > +using SmallVector = std::vector>; + +// Integral-type to use in the combinadics +#include +using cc_int_t = uint_fast64_t; + +// Szudziks pairing function. Takes as input two unsigned integral types (a, b), and uniquely +// maps the pair (a, b) to a distinct number c, where c is possibly a different integral type +template < typename T1 = uint_fast32_t, typename T2 = uint_fast64_t > +constexpr inline T2 szudzik_pair(T1 x, T1 y){ + static_assert(std::is_integral::value, "Integral-type required as a range storage type."); + static_assert(std::is_unsigned::value, "Integral-type required as a range storage type."); + return static_cast< T2 >(x >= y ? x * x + x + y : x + y * y); +} +template < typename T1 = uint_fast32_t, typename T2 = uint_fast64_t > +inline std::pair< T1, T1 > szudzik_unpair(T2 z) { + static_assert(std::is_integral::value, "Integral-type required as a range storage type."); + static_assert(std::is_unsigned::value, "Integral-type required as a range storage type."); + T2 sqrtz = std::floor(std::sqrt(z)); + T2 sqz = sqrtz * sqrtz; + return ((z - sqz) >= sqrtz) ? + std::make_pair(static_cast< T1 >(sqrtz), static_cast< T1 >(z - sqz - sqrtz)) : + std::make_pair(static_cast< T1 >(z - sqz), static_cast< T1 >(sqrtz)); +} + +// constexpr implicitly inlined +template < bool i_less_j = false > +constexpr auto to_natural_2(cc_int_t i, cc_int_t j, cc_int_t n) noexcept -> + typename std::enable_if< i_less_j == true, cc_int_t >::type { + return cc_int_t(n*i - i*(i+1)/2 + j - i - 1); +} + +template < bool i_less_j = false > +constexpr auto to_natural_2(cc_int_t i, cc_int_t j, cc_int_t n) noexcept -> + typename std::enable_if< i_less_j == false, cc_int_t >::type { + return i < j ? cc_int_t(n*i - i*(i+1)/2 + j - i - 1) : cc_int_t(n*j - j*(j+1)/2 + i - j - 1); +} + +// 0-based +inline std::array< cc_int_t, 2 > to_subscript_2(const cc_int_t x, const cc_int_t n) noexcept { + auto i = static_cast< cc_int_t >( (n - 2 - floor(sqrt(-8*x + 4*n*(n-1)-7)/2.0 - 0.5)) ); + auto j = static_cast< cc_int_t >( x + i + 1 - n*(n-1)/2 + (n-i)*((n-i)-1)/2 ); + return (std::array< cc_int_t, 2 >{ i, j }); +} + +// static constexpr size_t max_choose = 16; +// static constexpr std::array< size_t, 120 > BC = { 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,3,6,10,15,21,28,36,45,55,66,78,91,105,4,10,20,35,56,84,120,165,220,286,364,455,5,15,35,70,126,210,330,495,715,1001,1365,6,21,56,126,252,462,792,1287,2002,3003,7,28,84,210,462,924,1716,3003,5005,8,36,120,330,792,1716,3432,6435,9,45,165,495,1287,3003,6435,10,55,220,715,2002,5005,11,66,286,1001,3003,12,78,364,1365,13,91,455,14,105,15 }; + +template < size_t n, size_t k > +constexpr auto bc_recursive() noexcept -> typename std::enable_if< k == 0, size_t >::type { + return 1; +} +template < size_t n, size_t k > +constexpr auto bc_recursive() noexcept -> typename std::enable_if< (k > 0), size_t >::type { + return (n * bc_recursive< n - 1, k - 1>()) / k; +} + +// template< size_t n, size_t k > +// struct BinomialCoefficientTable { +// size_t combinations[n+1][k]; +// constexpr BinomialCoefficientTable() : combinations() { +// // auto n_dispatcher = make_index_dispatcher< n+1 >(); +// // auto k_dispatcher = make_index_dispatcher< k >(); +// // n_dispatcher([&](auto i) { +// // k_dispatcher([&](auto j){ +// // combinations[i][j] = bc_recursive< i, j>(); +// // }); +// // }); +// combinations[0][0] = bc_recursive< 0, 0 >(); +// combinations[0][1] = bc_recursive< 0, 1 >(); +// combinations[0][2] = bc_recursive< 0, 2 >(); +// combinations[1][0] = bc_recursive< 0, 0 >(); +// combinations[1][1] = bc_recursive< 0, 1 >(); +// combinations[1][2] = bc_recursive< 0, 2 >(); +// combinations[2][0] = bc_recursive< 0, 0 >(); +// combinations[2][1] = bc_recursive< 0, 1 >(); +// combinations[2][2] = bc_recursive< 0, 2 >(); +// } +// }; +// static constexpr auto BC = BinomialCoefficientTable< max_choose, max_choose >(); + +// Recursive binomial coefficient dispatcher; should be compiled to support up to max_choose - 1 at compile time + // return BinomialCoefficient(n-1, k-1) + BinomialCoefficient(n-1, std::min(n-1-k, k)); +inline size_t binomial_coeff_(const double n, size_t k) noexcept { + double bc = n; + for (size_t i = 2; i <= k; ++i){ bc *= (n+1-i)/i; } + return(static_cast< size_t >(std::round(bc))); +} + +static constexpr size_t max_comb = 16; +static constexpr std::array< size_t, 120 > BC = { 2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,3,6,10,15,21,28,36,45,55,66,78,91,105,120,4,10,20,35,56,84,120,165,220,286,364,455,560,5,15,35,70,126,210,330,495,715,1001,1365,1820,6,21,56,126,252,462,792,1287,2002,3003,4368,7,28,84,210,462,924,1716,3003,5005,8008,8,36,120,330,792,1716,3432,6435,11440,9,45,165,495,1287,3003,6435,12870,10,55,220,715,2002,5005,11440,11,66,286,1001,3003,8008,12,78,364,1365,4368,13,91,455,1820,14,105,560,15,120,16 }; + +inline size_t BinomialCoefficient(const size_t n, const size_t k){ + if (k == 0 || n == k){ return 1; } + if (n < k){ return 0; } + return n < max_comb ? BC[to_natural_2< true >(k-1,n-1,max_comb)] : binomial_coeff_(n,std::min(k,n-k)); +} + +// Given a binomial coefficient 'x' representing (n choose 2), finds 'n' +inline size_t inv_choose_2(const size_t x) noexcept { + const size_t a = floor(sqrt(2*x)); + const size_t b = ceil(sqrt(2*x)+2); + SmallVector< size_t >::allocator_type::arena_type arena; + SmallVector< size_t > rng{ arena }; + rng.resize((b - a) + 1); + std::iota(begin(rng), end(rng), a); + auto it = std::find_if(begin(rng), end(rng), [x](size_t n){ return(x == BinomialCoefficient(n, 2)); }); + return it == end(rng) ? 0 : *it; +} + +// converts to natural number +template < typename Iter > +inline size_t to_natural_k(Iter s, Iter e, const cc_int_t n, const cc_int_t k) { + if (n == k){ return(0); } + + // Apply the dual index mapping + const cc_int_t N = BinomialCoefficient(n,k); + cc_int_t i = k; + const cc_int_t index = std::accumulate(s, e, 0, [n, &i](cc_int_t val, cc_int_t num){ + return val + BinomialCoefficient((n-1) - num, i--); + }); + const cc_int_t combinadic = (N-1) - index; + return(combinadic); +} +// if (combinadic >= N || index > (N-1)){ +// // throw std::out_of_range ("Combinadic mapping failed."); +// // std::cout << "n: " << n << ", k: " << k << ", N: " << N << ", combinadic: " << combinadic << ", index: " << index << std::endl; +// // std::for_each(s,e, [](auto el){ std::cout << el << ","; }); +// // std::cout << std::endl; +// } + +// 0-based conversion of (n choose k) combinadic subscripts to natural number +template < bool check_inputs = true, typename Iter, typename Lambda > +inline void to_natural(Iter s, Iter e, const size_t n, const size_t k, Lambda&& f) { + // if constexpr (check_inputs){ + // if (k > n) { throw std::out_of_range ("Combinadic out of range (k > n)"); } + // if (std::distance(s,e) % k != 0){ throw std::out_of_range ("Invalid input; not aligned to k."); } + // bool any_overflow = std::any_of(s, e, [n](auto& i){ return(i < 0 || i >= n); }); + // if (any_overflow){ throw std::out_of_range ("Invalid combinadic input."); } + // } + if (n == k){ while (s != e){ f(0); s += k; } } + while (s != e){ + switch(k){ + case 2: + f(to_natural_2(*s, *std::next(s), n)); + break; + default: + f(to_natural_k(s, s+k, n, k)); + break; + } + s += k; + } +} + +// Converts each value between [s,e) to its corresponding combinadic +template< bool check_inputs = true, typename Iter, typename Lambda > +inline void to_subscript(Iter s, Iter e, const size_t n, const size_t k, Lambda&& f) { + + // If check inputs true (default), make sure the input is valid. + // if constexpr (check_inputs){ + // if (k > n) { throw std::out_of_range ("Combinadic out of range."); } + // const size_t N = BinomialCoefficient(n, k); + // bool any_overflow = std::any_of(s, e, [N](auto& i){ return(i < 0 || i >= N); }); + // if (any_overflow){ throw std::out_of_range ("Invalid combinadic input."); } + // } + + SmallVector< cc_int_t >::allocator_type::arena_type a; + SmallVector< cc_int_t > combination{ a }; + combination.resize(k); + switch(k){ + case 2:{ + using id_t = typename Iter::value_type; + std::array< cc_int_t, 2 > cc; + std::for_each(s, e, [n, &cc, &f, &combination](id_t i){ + cc = to_subscript_2(i, n); + std::move(cc.begin(), cc.end(), combination.begin()); + f(combination); + }); + break; + } + default: { + using id_t = typename Iter::value_type; + const size_t N = BinomialCoefficient(n, k); + std::for_each(s, e, [&](id_t m){ + m = ((N-1)-m); + auto guess = 0; + auto pc = n; + for (auto j = k; j > 0; --j){ + cc_int_t value = m + 1; + for (; value > m; pc = guess, --guess){ + guess = pc - 1; + value = BinomialCoefficient(guess, j); + } + m = m - value; + combination[k-j] = (n-1)-guess-1; + } + f(combination); + }); + break; + } + } +} + +#endif diff --git a/include/utility/set_utilities.h b/include/utility/set_utilities.h new file mode 100644 index 0000000..63c4b1b --- /dev/null +++ b/include/utility/set_utilities.h @@ -0,0 +1,244 @@ +// nerve_utility.cpp +// Contains utility functions related to the nerve construction + +#include +#include +#include +#include +#include + +using std::vector; +using std::begin; +using std::end; +using std::pair; +using std::unordered_map; +using std::unordered_set; + + +// Safely advance iterator to prevent passing the end +template +Iter safe_advance(const Iter& curr, const Iter& end, Incr n) { + size_t remaining(std::distance(curr, end)); + if (remaining < size_t(n)) { n = remaining; } + return std::next(curr, n); +} + +// Moves through two ordered sets, returning a boolean indicating if they are disjoint +// returns true on first element found in both sets, otherwise iterates through both. +template +bool disjoint_sorted(Iter it_a, const Iter a_end, Iter it_b, const Iter b_end) { + while (it_a != a_end && it_b != b_end) { + switch (*it_a == *it_b ? 0 : *it_a < *it_b ? -1 : 1) { + case 0: + return false; + case -1: + it_a = std::lower_bound(++it_a, a_end, *it_b); + break; + case 1: + it_b = std::lower_bound(++it_b, b_end, *it_a); + break; + } + } + return true; +} + +// Given a set of intervals as pairs, checks if any of them are disjoint from the others +template < typename T > +bool intervals_disjoint(vector< pair< T, T > > intervals){ + if (intervals.size() <= 1){ return(true); } + + auto interval_ids = vector< std::pair< T, T > >(); + T i = 0; + for (auto& interval: intervals){ + interval_ids.push_back( std::make_pair(i, interval.first) ); + interval_ids.push_back( std::make_pair(i, interval.second) ); + ++i; + } + + // Sort by the values of value types + using rng_t = std::pair< T, T >; + std::stable_sort(begin(interval_ids), end(interval_ids), [](const rng_t& p1, const rng_t& p2){ + return p1.second < p2.second; + }); + + // Check if any adjacent values are equal + auto adj_it = std::adjacent_find(begin(interval_ids), end(interval_ids), [](const rng_t& p1, const rng_t& p2){ + return(p1.second == p2.second); + }); + if (adj_it != end(interval_ids)){ return(false); } + + // Check if sequence is strictly increasing + auto sinc_it = std::adjacent_find(begin(interval_ids), end(interval_ids), [](const rng_t& p1, const rng_t& p2){ + return(p1.first > p2.first); + }); + if (sinc_it != end(interval_ids)){ return(false); } + + // Otherwise they are disjoint + return(true); +} + +// Given two random-access iterator ranges, (a1, a2), (b1, b2), return a boolean indicating +// whether or not they have a non-empty intersection. Does not assumes either is sorted. +template +bool intersects_nonempty(Iter a1, Iter a2, Iter b1, Iter b2){ + using it_cat = typename std::iterator_traits::iterator_category; + static_assert(std::is_same::value, "Iterator type must be random-access."); + using T = typename std::iterator_traits::value_type; + + // Either empty == they do not have an intersection + const size_t a_sz = std::distance(a1, a2); + const size_t b_sz = std::distance(b1, b2); + if (a_sz == 0 || b_sz == 0) { return false; } + + // a is much smaller than b => partial sort b, then do binary search on b for each element of a + if (a_sz * 100 < b_sz) { + vector b_sort(b_sz); + std::partial_sort_copy(b1, b2, begin(b_sort), end(b_sort)); // partial-sorted elements of y copied to y_sort + while (a1 != a2){ + if (std::binary_search(begin(b_sort), end(b_sort), T(*a1))) { return(true); } + ++a1; + } + return(false); + } else if (b_sz * 100 < a_sz) { // Opposite case + vector a_sort(a_sz); + std::partial_sort_copy(a1, a2, begin(a_sort), end(a_sort)); // partial-sorted elements of y copied to y_sort + while (b1 != b2){ + if (std::binary_search(begin(a_sort), end(a_sort), T(*b1))) { return(true); } + ++b1; + } + return(false); + } + + // Otherwise, sort both, then use lower_bound type approach to potentially skip massive sections. + vector a_sort(a_sz), b_sort(b_sz); + std::partial_sort_copy(a1, a2, begin(a_sort), end(a_sort)); // partial-sorted elements of y copied to y_sort + std::partial_sort_copy(b1, b2, begin(b_sort), end(b_sort)); // partial-sorted elements of y copied to y_sort + return !disjoint_sorted(a_sort, b_sort); +} + +// Checks if there are at least n elements common to all given sorted ranges +template < typename Iter > +bool n_intersects_sorted(vector< pair< Iter, Iter > > ranges, const size_t n){ + using T = typename Iter::value_type; + if (n == 0){ return(true); } + if (ranges.size() <= 1){ return(false); } + + // Sort by size, then fold a set_intersection + using rng_t = pair< Iter, Iter >; + std::sort(begin(ranges), end(ranges), [](rng_t& p1, rng_t& p2){ + return std::distance(p1.first, p1.second) < std::distance(p2.first, p2.second); + }); + + // Fold a sorted intersection + auto common = vector< T >(); + std::set_intersection(ranges[0].first, ranges[0].second, ranges[1].first, ranges[1].second, std::back_inserter(common)); + for (size_t i = 2; i < ranges.size(); ++i){ + auto aux = vector< T >(); + std::set_intersection(begin(common), end(common), ranges[i].first, ranges[i].second, std::back_inserter(aux)); + if (aux.size() < n){ + return(false); + } + common.resize(aux.size()); + std::move(begin(aux), end(aux), begin(common)); + } + return(common.size() >= n); +} + +// Checks if there are at least n elements common to all given unsorted ranges +template < typename Iter > +bool n_intersects_unsorted(vector< pair< Iter, Iter > > ranges, const size_t n, const size_t k=32){ + using T = typename Iter::value_type; + + // Use lambda to determine when finished + using rng_t = pair< Iter, Iter >; + const auto finished = [&ranges](){ + size_t n_finished = std::accumulate(begin(ranges), end(ranges), (size_t) 0, [](size_t val, rng_t& p){ return p.first == p.second; }); + return(n_finished >= 2); + }; + + auto mv = std::set< T >(); // the values mutally common to all the ranges + auto vc = std::map< T, size_t >(); // values counts; i.e. values -> the # times they've been seen + const auto m = ranges.size(); + while(!finished()){ + + // Partial sort to put minimum element at the beginning of each range + for (auto& rng: ranges){ + std::nth_element(rng.first, safe_advance(rng.first, rng.second, 1), rng.second); + } + + // Extract range with minimum first element + auto rng_it = std::min_element(begin(ranges), end(ranges), [](rng_t& p1, rng_t &p2){ + if (p1.first == p1.second){ return false; } + if (p2.first == p2.second){ return true; } + return(*p1.first < *p2.first); + }); + + // If the range with the minimum element is an empty range, no candidate ranges available + if (rng_it->first == rng_it->second){ return(false); } + const size_t min_idx = std::distance(begin(ranges), rng_it); + auto rng = ranges[min_idx]; + + // Partial sort up to the end + std::nth_element(rng.first, safe_advance(rng.first, rng.second, k), rng.second); + + // Go through the current range + const auto e = safe_advance(rng.first, rng.second, k); + for (auto b = rng.first; b != e; ++b){ + // auto value_it = std::lower_bound(begin(vc), end(vc), *b, [](auto& p, auto v){ return p.first < v; }); + auto value_it = vc.find(*b); + if (value_it != end(vc) && ++(value_it->second) == m){ + mv.insert(value_it->first); // Increase number of things found in mutual + } else { + vc.emplace_hint(value_it, *b, 1); + } + } + + // If the common intersection grows to threshold, we're done + if (mv.size() >= n){ return true; } + + // Always replace beginning range with it's next beginning + ranges[min_idx].first = e; + } + return(false); +} + +// Tests a set of ranges to see if they all have at least n elements in their intersection. +template < typename Iter > +bool n_intersects(const vector< pair< Iter, Iter > >& ranges, const size_t n){ + using T = typename Iter::value_type; + using rng_t = pair< Iter, Iter >; + + // Check if any of the ranges don't even have n elements + const bool too_small = std::any_of(begin(ranges), end(ranges), [n](const rng_t& rng){ return size_t(std::distance(rng.first, rng.second)) < size_t(n); }); + if (too_small){ return(false); } + + // Do linear O(n) scan to determine if everything is sorted + const bool is_sorted = std::all_of(begin(ranges), end(ranges), [](const rng_t& rng){ return std::is_sorted(rng.first, rng.second); }); + + // Collect (min,max) of each range + auto minmaxes = vector< std::pair< T, T > >(); + minmaxes.reserve(ranges.size()); + std::transform(begin(ranges), end(ranges), std::back_inserter(minmaxes), [is_sorted](const rng_t& rng){ + if (is_sorted){ + auto min = *rng.first; + auto max = std::distance(rng.first, rng.second) == 1 ? *rng.first : *std::prev(rng.second); + return std::make_pair(min,max); + } else { + auto mm = std::minmax_element(rng.first, rng.second); + return std::make_pair(*mm.first, *mm.second); + } + }); + + // Do initial check to see if they are all disjoint + if (intervals_disjoint(minmaxes)){ + return(false); + } + + // Use the appropriate intersection check + if (is_sorted && n == 1 && ranges.size() == 2){ + bool disjoint = disjoint_sorted(ranges[0].first, ranges[0].second, ranges[1].first, ranges[1].second); + return(!disjoint); + } else { + return is_sorted ? n_intersects_sorted(ranges, n) : n_intersects_unsorted(ranges, n); + } +} diff --git a/include/utility/short_alloc.h b/include/utility/short_alloc.h new file mode 100644 index 0000000..a8c72ad --- /dev/null +++ b/include/utility/short_alloc.h @@ -0,0 +1,172 @@ +#ifndef SHORT_ALLOC_H +#define SHORT_ALLOC_H + +// The MIT License (MIT) +// +// Copyright (c) 2015 Howard Hinnant +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// This code was adapted from Howard Hinnants excellent arena-based allocator header: +// https://howardhinnant.github.io/short_alloc.h +// The original copyright is included above. + +// To use alloca portably +#ifdef __GNUC__ +/* Includes GCC, clang and Intel compilers */ +# undef alloca +# define alloca(x) __builtin_alloca((x)) +#elif defined(__sun) || defined(_AIX) +/* this is necessary for Solaris 10 and AIX 6: */ +# include +#endif + + +#include +#include + +template +class arena +{ + alignas(alignment) char buf_[N]; + char* ptr_; + +public: + ~arena() {ptr_ = nullptr;} + arena() noexcept : ptr_(buf_) {} + arena(const arena&) = delete; + arena& operator=(const arena&) = delete; + + template char* allocate(std::size_t n); + void deallocate(char* p, std::size_t n) noexcept; + + static constexpr std::size_t size() noexcept {return N;} + std::size_t used() const noexcept {return static_cast(ptr_ - buf_);} + void reset() noexcept {ptr_ = buf_;} + +private: + static + std::size_t + align_up(std::size_t n) noexcept + {return (n + (alignment-1)) & ~(alignment-1);} + + bool + pointer_in_buffer(char* p) noexcept + {return buf_ <= p && p <= buf_ + N;} +}; + +template +template +char* +arena::allocate(std::size_t n) +{ + static_assert(ReqAlign <= alignment, "alignment is too small for this arena"); + assert(pointer_in_buffer(ptr_) && "short_alloc has outlived arena"); + auto const aligned_n = align_up(n); + if (static_cast(buf_ + N - ptr_) >= aligned_n) + { + char* r = ptr_; + ptr_ += aligned_n; + return r; + } + + static_assert(alignment <= alignof(std::max_align_t), "you've chosen an " + "alignment that is larger than alignof(std::max_align_t), and " + "cannot be guaranteed by normal operator new"); + return static_cast(::operator new(n)); +} + +template +void +arena::deallocate(char* p, std::size_t n) noexcept +{ + assert(pointer_in_buffer(ptr_) && "short_alloc has outlived arena"); + if (pointer_in_buffer(p)) + { + n = align_up(n); + if (p + n == ptr_) + ptr_ = p; + } + else + ::operator delete(p); +} + +template +class short_alloc +{ +public: + using value_type = T; + static auto constexpr alignment = Align; + static auto constexpr size = N; + using arena_type = arena; + +private: + arena_type& a_; + +public: + short_alloc(const short_alloc&) = default; + short_alloc& operator=(const short_alloc&) = delete; + + short_alloc(arena_type& a) noexcept : a_(a) + { + static_assert(size % alignment == 0, + "size N needs to be a multiple of alignment Align"); + } + template + short_alloc(const short_alloc& a) noexcept + : a_(a.a_) {} + + template struct rebind {using other = short_alloc<_Up, N, alignment>;}; + + T* allocate(std::size_t n) + { + return reinterpret_cast(a_.template allocate(n*sizeof(T))); + } + void deallocate(T* p, std::size_t n) noexcept + { + a_.deallocate(reinterpret_cast(p), n*sizeof(T)); + } + + template + friend + bool + operator==(const short_alloc& x, const short_alloc& y) noexcept; + + template friend class short_alloc; +}; + +template +inline +bool +operator==(const short_alloc& x, const short_alloc& y) noexcept +{ + return N == M && A1 == A2 && &x.a_ == &y.a_; +} + +template +inline +bool +operator!=(const short_alloc& x, const short_alloc& y) noexcept +{ + return !(x == y); +} + +#endif // SHORT_ALLOC_H + diff --git a/pyproject.toml b/pyproject.toml index 29bffbd..09ffde2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,23 +3,23 @@ requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" [project] -# See https://setuptools.pypa.io/en/latest/userguide/quickstart.html for more project configuration options. name = "simplextree" dynamic = ["version"] readme = "README.md" classifiers = [ - "Intended Audience :: Science/Research", - "Development Status :: 3 - Alpha", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3", - "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Intended Audience :: Science/Research", + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3" ] authors = [ - {name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org"} + { name = "Matt Piekenbrock", email = "matt.piekenbrock@gmail.com" } ] requires-python = ">=3.8" dependencies = [ - # Add your own dependencies here + "numpy", + "scipy", + "more_itertools" ] license = {file = "LICENSE"} @@ -31,34 +31,28 @@ Changelog = "https://github.com/peekxc/simplextree-py/blob/main/CHANGELOG.md" [project.optional-dependencies] dev = [ - "ruff", - "mypy>=1.0,<1.5", - "black>=23.0,<24.0", - "isort>=5.12,<5.13", - "pytest", - "pytest-sphinx", - "pytest-cov", - "twine>=1.11.0", - "build", - "setuptools", - "wheel", - "Sphinx>=4.3.0,<7.1.0", - "furo==2023.7.26", - "myst-parser>=1.0,<2.1", - "sphinx-copybutton==0.5.2", - "sphinx-autobuild==2021.3.14", - "sphinx-autodoc-typehints==1.23.3", - "packaging" + "ruff", + "mypy>=1.0,<1.5", + "black>=23.0,<24.0", + "isort>=5.12,<5.13", + "pytest", + "pytest-sphinx", + "pytest-cov", + "twine>=1.11.0", + "build", + "setuptools", + "wheel", + "packaging" ] [tool.setuptools.packages.find] exclude = [ - "*.tests", - "*.tests.*", - "tests.*", - "tests", - "docs*", - "scripts*" + "*.tests", + "*.tests.*", + "tests.*", + "tests", + "docs*", + "scripts*" ] [tool.setuptools] @@ -75,14 +69,14 @@ line-length = 100 include = '\.pyi?$' exclude = ''' ( - __pycache__ - | \.git - | \.mypy_cache - | \.pytest_cache - | \.vscode - | \.venv - | \bdist\b - | \bdoc\b + __pycache__ + | \.git + | \.mypy_cache + | \.pytest_cache + | \.vscode + | \.venv + | \bdist\b + | \bdoc\b ) ''' @@ -113,8 +107,8 @@ strict_optional = false [tool.pytest.ini_options] testpaths = "tests/" python_classes = [ - "Test*", - "*Test" + "Test*", + "*Test" ] log_format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" log_level = "DEBUG" diff --git a/simplextree/SimplexTree.py b/simplextree/SimplexTree.py new file mode 100644 index 0000000..75af7d5 --- /dev/null +++ b/simplextree/SimplexTree.py @@ -0,0 +1,313 @@ +from __future__ import annotations +from typing import * +from numbers import Integral +from numpy.typing import ArrayLike + +import numpy as np +from . import _simplextree as st_mod + +class SimplexTree(st_mod.SimplexTree): + """ + SimplexTree provides lightweight wrapper around a Simplex Tree data structure: an ordered, trie-like structure whose nodes are in bijection with the faces of the complex. + This class exposes a native extension module wrapping a simplex tree implemented with modern C++. + + The Simplex Tree was originally introduced in the paper + > Boissonnat, Jean-Daniel, and ClĂ©ment Maria. "The simplex tree: An efficient data structure for general simplicial complexes." Algorithmica 70.3 (2014): 406-427. + + Attributes: + n_simplices (ndarray): number of simplices + dimension (int): maximal dimension of the complex + id_policy (str): policy for generating new vertex ids + + Attributes: Properties: + vertices (ndarray): vertices of the complex + """ + def __init__(self, simplices: Iterable[Collection] = None) -> None: + st_mod.SimplexTree.__init__(self) + if simplices is not None: + self.insert(simplices) + return None + + def insert(self, simplices: Iterable[Collection]) -> None: + """ + Inserts simplices into the Simplex Tree. + + By definition, inserting a simplex also inserts all of its faces. If the simplex already exists in the complex, the tree is not modified. + + Parameters: + simplices: Iterable of simplices to insert (each of which are SimplexLike) + + ::: {.callout-note} + If the iterable is an 2-dim np.ndarray, then a p-simplex is inserted along each contiguous p+1 stride. + Otherwise, each element of the iterable to casted to a Simplex and then inserted into the tree. + ::: + """ + if isinstance(simplices, np.ndarray): + simplices = np.sort(simplices, axis=1).astype(np.uint16) + assert simplices.ndim in [1,2], "dimensions should be 1 or 2" + self._insert(simplices) + elif isinstance(simplices, Iterable): + self._insert_list(list(map(lambda x: np.asarray(x, dtype=np.uint16), simplices))) + else: + raise ValueError("Invalid type given") + + def remove(self, simplices: Iterable[Collection]): + """ + Removes simplices into the Simplex Tree. + + By definition, removing a face also removes all of its cofaces. If the simplex does not exist in the complex, the tree is not modified. + + Parameters: + simplices: + Iterable of simplices to insert (each of which are SimplexLike). + + ::: {.callout-note} + If the iterable is an 2-dim np.ndarray, then a p-simplex is removed along each contiguous p+1 stride. + Otherwise, each element of the iterable to casted to a Simplex and then removed from the tree. + ::: + + Examples: + st = SimplexTree([range(3)]) + print(st) + st.remove([[0,1]]) + print(st) + """ + if isinstance(simplices, np.ndarray): + simplices = np.sort(simplices, axis=1).astype(np.uint16) + assert simplices.ndim in [1,2], "dimensions should be 1 or 2" + self._remove(simplices) + elif isinstance(simplices, Iterable): + self._remove_list(list(map(lambda x: np.asarray(x, dtype=np.uint16), simplices))) + else: + raise ValueError("Invalid type given") + + def find(self, simplices: Iterable[Collection]): + """ + Finds whether simplices exist in Simplex Tree. + + Parameters: + simplices: Iterable of simplices to insert (each of which are SimplexLike) + + Returns: + found (ndarray) : boolean array indicating whether each simplex was found in the complex + + ::: {.callout-note} + If the iterable is an 2-dim np.ndarray, then the p-simplex to find is given by each contiguous p+1 stride. + Otherwise, each element of the iterable to casted to a Simplex and then searched for in the tree. + ::: + """ + if isinstance(simplices, np.ndarray): + simplices = np.array(simplices, dtype=np.int16) + assert simplices.ndim in [1,2], "dimensions should be 1 or 2" + return self._find(simplices) + elif isinstance(simplices, Iterable): + return self._find_list([tuple(s) for s in simplices]) + else: + raise ValueError("Invalid type given") + + def adjacent(self, simplices: Iterable[Collection]): + """Checks for adjacencies between simplices.""" + return self._adjacent(list(map(lambda x: np.asarray(x, dtype=np.uint16), simplices))) + + def collapse(self, tau: Collection, sigma: Collection) -> None: + """Performs an elementary collapse on two given simplices. + + Checks whether its possible to collapse $\\sigma$ through $\\tau$, and if so, both simplices are removed. + A simplex $\\sigma$ is said to be collapsible through one of its faces $\\tau$ if $\\sigma$ is the only coface of $\\tau$ (excluding $\\tau$ itself). + + Parameters: + sigma : maximal simplex to collapse + tau : face of sigma to collapse + + Returns: + bool: whether the pair was collapsed + + Examples: + + from splex import SimplexTree + st = SimplexTree([[0,1,2]]) + print(st) + + st.collapse([1,2], [0,1,2]) + + print(st) + """ + # assert tau in boundary(sigma), f"Simplex {tau} is not in the boundary of simplex {sigma}" + success = self._collapse(tau, sigma) + return success + + def vertex_collapse(self, u: int, v: int, w: int): + """Maps a pair of vertices into a single vertex. + + Parameters: + u (int): the first vertex in the free pair. + v (int): the second vertex in the free pair. + w (int): the target vertex to collapse to. + """ + u,v,w = int(u), int(v), int(w) + assert all([isinstance(e, Integral) for e in [u,v,w]]), f"Unknown vertex types given; must be integral" + self._vertex_collapse(u,v,w) + + def degree(self, vertices: Optional[ArrayLike] = None) -> Union[ArrayLike, int]: + """Computes the degree of select vertices in the trie. + + Parameters: + vertices (ArrayLike): Retrieves vertex degrees + If no vertices are specified, all degrees are computed. Non-existing vertices by default have degree 0. + + Returns: + list: degree of each vertex id given in 'vertices' + """ + if vertices is None: + return self._degree_default() + elif isinstance(vertices, Iterable): + vertices = np.fromiter(iter(vertices), dtype=np.int16) + assert vertices.ndim == 1, "Invalid shape given; Must be flattened array of vertex ids" + return self._degree(vertices) + else: + raise ValueError(f"Invalid type {type(vertices)} given") + + # PREORDER = 0, LEVEL_ORDER = 1, FACES = 2, COFACES = 3, COFACE_ROOTS = 4, + # K_SKELETON = 5, K_SIMPLICES = 6, MAXIMAL = 7, LINK = 8 + def traverse(self, order: str = "preorder", f: Callable = print, sigma: Collection = [], p: int = 0) -> None: + """Traverses the simplex tree in the specified order, calling 'f' on each simplex encountered. + + Supported traversals: + - breadth-first / level order ("bfs", "levelorder") + - depth-first / prefix ("dfs", "preorder") + - faces ("faces") + - cofaces ("cofaces") + - coface roots + - p-skeleton + - p-simplices + - maximal simplices + - link + To select one of these options, set order to one of ["bfs", "levelorder", "dfs", "preorder"] + + Parameters: + order : the type of traversal to do + f : a function to evaluate on every simplex in the traversal. Defaults to print. + sigma : simplex to start the traversal at, where applicable. Defaults to the root node (empty set) + p : dimension of simplices to restrict to, where applicable. + """ + # todo: handle kwargs + assert isinstance(order, str) + order = order.lower() + if order in ["dfs", "preorder"]: + order = 0 + elif order in ["bfs", "level_order", "levelorder"]: + order = 1 + elif order == "faces": + order = 2 + elif order == "cofaces": + order = 3 + elif order == "coface_roots": + order = 4 + elif order == "p-skeleton": + order = 5 + elif order == "p-simplices": + order = 6 + elif order == "maximal": + order = 7 + elif order == "link": + order = 8 + else: + raise ValueError(f"Unknown order '{order}' specified") + self._traverse(order, lambda s: f(s), sigma, p) # order, f, init, k + + def cofaces(self, sigma: Collection = []) -> list[Collection]: + """Returns the p-cofaces of a given simplex. + + Parameters: + p : coface dimension to restrict to + sigma : the simplex to obtain cofaces of + + Returns: + list: the p-cofaces of sigma + """ + if sigma == [] or len(sigma) == 0: + return self.simplices() + F = [] + self._traverse(3, lambda s: F.append(s), sigma, 0) # order, f, init, k + return F + + def coface_roots(self, sigma: Collection = []) -> Iterable[Collection]: + """Returns the simplex 'roots' of a given simplex whose subtrees generate its cofaces.""" + F = [] + self._traverse(4, lambda s: F.append(s), sigma, 0) # order, f, init, k + return F + + def skeleton(self, p: int = None, sigma: Collection = []) -> Iterable[Collection]: + """Returns the simplices in the p-skeleton of the complex.""" + F = [] + self._traverse(5, lambda s: F.append(s), sigma, p) + return F + + def simplices(self, p: int = None) -> Iterable[Collection]: + """Returns the p-simplices in the complex.""" + F = [] + if p is None: + self._traverse(1, lambda s: F.append(s), [], 0) # order, f, init, k + else: + assert isinstance(int(p), int), f"Invalid argument type '{type(p)}', must be integral" + self._traverse(6, lambda s: F.append(s), [], p) # order, f, init, k + return F + + def faces(self, p: int = None, **kwargs) -> Iterable[Collection]: + """Wrapper for simplices function.""" + return self.simplices(p) + + def maximal(self) -> Iterable[Collection]: + """Returns the maximal simplices in the complex.""" + F = [] + self._traverse(7, lambda s: F.append(s), [], 0) + return F + + def link(self, sigma: Collection = []) -> Iterable[Collection]: + """Returns all simplices in the link of a given simplex.""" + F = [] + self._traverse(8, lambda s: F.append(s), sigma, 0) + return F + + def expand(self, k: int) -> None: + """Performs a k-expansion of the complex. + + This function is particularly useful for expanding clique complexes beyond their 1-skeleton. + + Parameters: + k : maximum dimension to expand to. + + Examples: + + from splex import SimplexTree + from itertools import combinations + st = SimplexTree(combinations(range(8), 2)) + print(st) + + st.expand(k=2) + print(st) + """ + assert int(k) >= 0, f"Invalid expansion dimension k={k} given" + self._expand(int(k)) + + def __repr__(self) -> str: + if len(self.n_simplices) == 0: + return "< Empty simplex tree >" + return f"Simplex Tree with {tuple(self.n_simplices)} {tuple(range(0,self.dimension+1))}-simplices" + + def __iter__(self) -> Iterator[Collection]: + yield from self.simplices() + + def __contains__(self, s: Collection) -> bool: + return bool(self.find([s])[0]) + + def __len__(self) -> int: + return int(sum(self.n_simplices)) + + def card(self, p: int = None): + """Returns the cardinality of various skeleta of the complex.""" + if p is None: + return tuple(self.n_simplices) + else: + assert isinstance(p, int), "Invalid p" + return 0 if p < 0 or p >= len(self.n_simplices) else self.n_simplices[p] diff --git a/simplextree/UnionFind.cpp b/simplextree/UnionFind.cpp new file mode 100644 index 0000000..3f84ac9 --- /dev/null +++ b/simplextree/UnionFind.cpp @@ -0,0 +1,37 @@ +//---------------------------------------------------------------------- +// Disjoint-set data structure +// File: union_find.cpp +//---------------------------------------------------------------------- +// Copyright (c) 2018 Matt Piekenbrock. All Rights Reserved. +// +// Class definition based off of data-structure described here: +// https://en.wikipedia.org/wiki/Disjoint-set_data_structure +#include +#include +#include +#include +namespace py = pybind11; +using namespace pybind11::literals; + +#include "UnionFind.h" + +void printCC(UnionFind& uf){ + for (size_t j = 0; j < uf.size; ++j){ py::print(uf.Find(j), " "); } + py::print("\n"); +} + +PYBIND11_MODULE(_union_find, m) { + py::class_(m, "UnionFind") + .def(py::init< size_t >()) + .def_readonly( "size", &UnionFind::size) + .def_readonly( "parent", &UnionFind::parent) + .def_readonly( "rank", &UnionFind::rank) + .def("print", &printCC ) + .def("connected_components", &UnionFind::ConnectedComponents) + .def("find", &UnionFind::Find) + .def("find_all", &UnionFind::FindAll) + .def("union", &UnionFind::Union) + .def("union_all", &UnionFind::UnionAll) + .def("add_sets", &UnionFind::AddSets) + ; +} diff --git a/simplextree/UnionFind.py b/simplextree/UnionFind.py new file mode 100644 index 0000000..8d2c4b8 --- /dev/null +++ b/simplextree/UnionFind.py @@ -0,0 +1,8 @@ +from __future__ import annotations +from typing import * +from _union_find import UnionFind as UnionFindCpp + +class UnionFind(UnionFindCpp): + """ Union Find data structure """ + def __init__(n: int): + pass diff --git a/simplextree/combinatorial.cpp b/simplextree/combinatorial.cpp new file mode 100644 index 0000000..fa9a9fa --- /dev/null +++ b/simplextree/combinatorial.cpp @@ -0,0 +1,59 @@ +#include "combinatorial.h" +#include +#include +#include +#include +// #include +#include +namespace py = pybind11; + +#include // std::back_inserter +#include // std::vector +#include // std::copy +using std::vector; + +auto rank_combs(py::array_t< uint16_t > combs, const int n, const int k, bool colex = true) -> py::array_t< uint64_t > { + py::buffer_info buffer = combs.request(); + uint16_t* p = static_cast< uint16_t* >(buffer.ptr); + const size_t N = buffer.size; + vector< uint64_t > ranks; + ranks.reserve(static_cast< uint64_t >(N/k)); + auto out = std::back_inserter(ranks); + if (colex) { + combinatorial::rank_lex(p, p+N, size_t(n), size_t(k), out); + } else { + combinatorial::rank_colex(p, p+N, size_t(n), size_t(k), out); + } + return py::cast(ranks); +} + +// auto unrank_combs(py::array_t< int > ranks, const int n, const int k) -> py::array_t< int > { +// py::buffer_info buffer = ranks.request(); +// int* r = static_cast< int* >(buffer.ptr); +// const size_t N = buffer.size; +// vector< int > simplices; +// simplices.reserve(static_cast< int >(N*k)); +// auto out = std::back_inserter(simplices); +// combinatorial::unrank_lex(r, r+N, size_t(n), size_t(k), out); +// return py::cast(simplices); +// } + +// auto boundary_ranks(const int p_rank, const int n, const int k) -> py::array_t< int > { +// vector< int > face_ranks = vector< int >(); +// combinatorial::apply_boundary(p_rank, n, k, [&face_ranks](size_t r){ +// face_ranks.push_back(r); +// }); +// return py::cast(face_ranks); +// } + +// Package: pip install --no-deps --no-build-isolation --editable . +// Compile: clang -Wall -fPIC -c src/pbsig/combinatorial.cpp -std=c++20 -Iextern/pybind11/include -isystem /Users/mpiekenbrock/opt/miniconda3/envs/pbsig/include -I/Users/mpiekenbrock/opt/miniconda3/envs/pbsig/include/python3.9 +PYBIND11_MODULE(_combinatorial, m) { + m.doc() = "Combinatorial module"; + m.def("rank_combs", &rank_combs); + // m.def("unrank_combs", &unrank_combs); + // m.def("boundary_ranks", &boundary_ranks); + // m.def("interval_cost", &pairwise_cost); + // m.def("vectorized_func", py::vectorize(my_func));s + //m.def("call_go", &call_go); +} \ No newline at end of file diff --git a/simplextree/combinatorial.py b/simplextree/combinatorial.py new file mode 100644 index 0000000..6cd00e1 --- /dev/null +++ b/simplextree/combinatorial.py @@ -0,0 +1,129 @@ +import numpy as np +from typing import * +from itertools import * +from numbers import Integral +from math import comb, factorial +from .Simplex import Simplex +import _combinatorial as comb_mod + + +## Also: https://stackoverflow.com/questions/1942328/add-a-member-variable-method-to-a-python-generator +## See: https://stackoverflow.com/questions/48349929/numpy-convertible-class-that-correctly-converts-to-ndarray-from-inside-a-sequenc +class SimplexWrapper: + def __init__(self, g: Generator, d: int, dtype = None): + ## Precondition: g is a generator of SimplexConvertibles all of the same length + # head, self.simplices = spy(g) + self.simplices = g + # self.simplices = list(g) + # d = len(head[0]) + if d == 0: + self.s_dtype = np.uint16 if dtype is None else dtype + else: + self.s_dtype = (np.uint16, d+1) if dtype is None else (dtype, d+1) + + def __iter__(self) -> Iterator: + return map(Simplex, self.simplices) + + def __array__(self) -> np.ndarray: + return np.fromiter(iter(self), dtype=self.s_dtype) + +def rank_C2(i: int, j: int, n: int) -> int: + i, j = (j, i) if j < i else (i, j) + return(int(n*i - i*(i+1)/2 + j - i - 1)) + +def unrank_C2(x: int, n: int) -> tuple: + i = int(n - 2 - np.floor(np.sqrt(-8*x + 4*n*(n-1)-7)/2.0 - 0.5)) + j = int(x + i + 1 - n*(n-1)/2 + (n-i)*((n-i)-1)/2) + return(i,j) + +def unrank_lex(r: int, k: int, n: int): + result = [0]*k + x = 1 + for i in range(1, k+1): + while(r >= comb(n-x, k-i)): + r -= comb(n-x, k-i) + x += 1 + result[i-1] = (x - 1) + x += 1 + return tuple(result) + +def rank_lex(c: Iterable, n: int) -> int: + c = tuple(sorted(c)) + k = len(c) + index = sum([comb(int(n-ci-1),int(k-i)) for i,ci in enumerate(c)]) + #index = sum([comb((n-1)-cc, kk) for cc,kk in zip(c, reversed(range(1, len(c)+1)))]) + return int(comb(n, k) - index - 1) + +def rank_colex(c: Iterable) -> int: + c = tuple(sorted(c)) + k = len(c) + #return sum([comb(ci, i+1) for i,ci in zip(reversed(range(len(c))), reversed(c))]) + return sum([comb(ci,k-i) for i,ci in enumerate(reversed(c))]) + +def unrank_colex(r: int, k: int) -> np.ndarray: + """ + Unranks a k-combinations rank 'r' back into the original combination in colex order + + From: Unranking Small Combinations of a Large Set in Co-Lexicographic Order + """ + c = [0]*k + for i in reversed(range(1, k+1)): + m = i + while r >= comb(m,i): + m += 1 + c[i-1] = m-1 + r -= comb(m-1,i) + return tuple(c) + + +def rank_combs(C: Iterable[tuple], n: int = None, order: str = ["colex", "lex"]): + """ + Ranks k-combinations to integer ranks in either lexicographic or colexicographical order + + Parameters: + C : Iterable of combinations + n : cardinality of the set (lex order only) + order : the bijection to use + + Returns: + list : unsigned integers ranks in the chosen order. + """ + if (isinstance(order, list) and order == ["colex", "lex"]) or order == "colex": + return [rank_colex(c) for c in C] + else: + assert n is not None, "Cardinality of set must be supplied for lexicographical ranking" + return [rank_lex(c, n) for c in C] + +def unrank_combs(R: Iterable, k: Union[int, Iterable], n: int = None, order: str = ["colex", "lex"]): + """ + Unranks integer ranks to k-combinations in either lexicographic or colexicographical order + + Parameters: + R : Iterable of integer ranks + n : cardinality of the set (only required for lex order) + order : the bijection to use + + Returns: + list : k-combinations derived from R + """ + if (isinstance(order, list) and order == ["colex", "lex"]) or order == "colex": + if isinstance(k, Integral): + return SimplexWrapper((unrank_colex(r, k) for r in R), d=k-1) + else: + assert len(k) == len(R), "If 'k' is an iterable it must match the size of 'R'" + return [unrank_colex(r, l) for l, r in zip(k,R)] + else: + assert n is not None, "Cardinality of set must be supplied for lexicographical ranking" + if isinstance(k, Integral): + assert k > 0, f"Invalid cardinality {k}" + if k == 1: + return SimplexWrapper((r[0] for r in R), d=0) + if k == 2: + return SimplexWrapper((unrank_C2(r, n) for r in R), d=1) + # return [unrank_C2(r, n) for r in R] + else: + return SimplexWrapper((unrank_lex(r, k, n) for r in R), d=k-1) + # return [unrank_lex(r, k, n) for r in R] + else: + assert len(k) == len(R), "If 'k' is an iterable it must match the size of 'R'" + return [unrank_lex(r, l) for l, r in zip(k,R)] \ No newline at end of file diff --git a/simplextree/predicates.py b/simplextree/predicates.py new file mode 100644 index 0000000..c5390e9 --- /dev/null +++ b/simplextree/predicates.py @@ -0,0 +1,93 @@ +import numbers +import numpy as np +from typing import * +from numpy.typing import ArrayLike +from more_itertools import spy +from operator import itemgetter + +from math import comb, factorial +# from .combinatorial import * + +def inverse_choose(x: int, k: int): + """Inverse binomial coefficient (approximately). + + This function *attempts* to find the integer _n_ such that binom(n,k) = x, where _binom_ is the binomial coefficient: + + binom(n,k) := n!/(k! * (n-k)!) + + For k <= 2, an efficient iterative approach is used and the result is exact. For k > 2, the same approach is + used if x > 10e7; otherwise, an approximation is used based on the formula from this stack exchange post: + + https://math.stackexchange.com/questions/103377/how-to-reverse-the-n-choose-k-formula + """ + assert k >= 1, "k must be >= 1" + if k == 1: return(x) + if k == 2: + rng = np.array(list(range(int(np.floor(np.sqrt(2*x))), int(np.ceil(np.sqrt(2*x)+2) + 1)))) + final_n = rng[np.nonzero(np.array([comb(n, 2) for n in rng]) == x)[0].item()] + else: + # From: https://math.stackexchange.com/questions/103377/how-to-reverse-the-n-choose-k-formula + if x < 10**7: + lb = (factorial(k)*x)**(1/k) + potential_n = np.array(list(range(int(np.floor(lb)), int(np.ceil(lb+k)+1)))) + idx = np.nonzero(np.array([comb(n, k) for n in potential_n]) == x)[0].item() + final_n = potential_n[idx] + else: + lb = np.floor((4**k)/(2*k + 1)) + C, n = factorial(k)*x, 1 + while n**k < C: n = n*2 + m = (np.nonzero( np.array(list(range(1, n+1)))**k >= C )[0])[0].item() + potential_n = np.array(list(range(int(np.max([m, 2*k])), int(m+k+1)))) + if len(potential_n) == 0: + raise ValueError(f"Failed to invert C(n,{k}) = {x}") + ind = np.nonzero(np.array([comb(n, k) for n in potential_n]) == x)[0].item() + final_n = potential_n[ind] + return(final_n) + +def is_repeatable(x: Iterable) -> bool: + """Checks whether _x_ is Iterable and repeateable as an Iterable (generators fail this test).""" + return not(iter(x) is x) + +def is_simplex_like(x: Any) -> bool: + is_collection = isinstance(x, SimplexConvertible) # is a Collection supporting __contains__, __iter__, and __len__ + if is_collection: + return is_repeatable(x) and all([isinstance(v, Integral) for v in x]) + return False + +def is_complex_like(x: Any) -> bool: + if isinstance(x, ComplexLike): # is iterable + Sized + item, iterable = spy(x) + return is_simplex_like(item[0]) + return False + +def is_filtration_like(x: Any) -> bool: + is_collection = isinstance(x, FiltrationLike) # Collection + Sequence + .index + if is_collection: + # return is_complex_like(map(itemgetter(1), x)) + item, iterable = spy(x) + return len(item[0]) == 2 and is_simplex_like(item[0][1]) + return False + +def is_array_convertible(x: Any) -> bool: + return hasattr(x, "__array__") + +def is_distance_matrix(x: ArrayLike) -> bool: + """Checks whether _x_ is a distance matrix, i.e. is square, symmetric, and that the diagonal is all 0.""" + x = np.array(x, copy=False) + is_square = x.ndim == 2 and (x.shape[0] == x.shape[1]) + return(False if not(is_square) else np.all(np.diag(x) == 0)) + +def is_pairwise_distances(x: ArrayLike) -> bool: + """Checks whether 'x' is a 1-d array of pairwise distances.""" + x = np.array(x, copy=False) # don't use asanyarray here + if x.ndim > 1: return(False) + n = inverse_choose(len(x), 2) + return(x.ndim == 1 and n == int(n)) + +def is_point_cloud(x: ArrayLike) -> bool: + """Checks whether 'x' is a 2-d array of points""" + return(isinstance(x, np.ndarray) and x.ndim == 2) + +def is_dist_like(x: ArrayLike): + """Checks whether _x_ is any recognizable distance object.""" + return(is_distance_matrix(x) or is_pairwise_distances(x)) \ No newline at end of file diff --git a/simplextree/simplextree_module.cpp b/simplextree/simplextree_module.cpp new file mode 100644 index 0000000..b04fbc8 --- /dev/null +++ b/simplextree/simplextree_module.cpp @@ -0,0 +1,566 @@ +#include +#include +#include +#include +namespace py = pybind11; +using namespace pybind11::literals; + +#include "simplextree.h" +using simplex_t = SimplexTree::simplex_t; + +// Generic function to handle various vector types +template < typename Lambda > +void vector_handler(SimplexTree& st, const py::array_t< idx_t >& simplices, Lambda&& f){ + py::buffer_info s_buffer = simplices.request(); + if (s_buffer.ndim == 1){ + // py::print(s_buffer.shape[0]); + const size_t n = s_buffer.shape[0]; + idx_t* s = static_cast< idx_t* >(s_buffer.ptr); + for (size_t i = 0; i < n; ++i){ + f(s+i, s+i+1); + } + // st.insert_it< true >(s, s+s_buffer.shape[0], st.root.get(), 0); + } else if (s_buffer.ndim == 2) { + // const size_t d = static_cast< size_t >(s_buffer.shape[1]); + if (s_buffer.strides[0] <= 0){ return; } + const size_t d = static_cast< size_t >(s_buffer.shape[1]); + const size_t n = static_cast< size_t >(s_buffer.shape[0]); + idx_t* s = static_cast< idx_t* >(s_buffer.ptr); + // py::print("Strides: ", s_buffer.strides[0], s_buffer.strides[1], ", ", "size: ", s_buffer.size, ", shape: (", s_buffer.shape[0], s_buffer.shape[1], ")"); + for (size_t i = 0; i < n; ++i){ + f(s+(d*i), s+d*(i+1)); + // st.insert_it< true >(s+(d*i), s+(d*i)+1, st.root.get(), 0); + } + } +} + +// TODO: accept py::buffer? +void insert_(SimplexTree& st, const py::array_t< idx_t >& simplices){ + vector_handler(st, simplices, [&st](idx_t* b, idx_t* e){ + // py::print(py::cast(simplex_t(b,e))); + st.insert_it< true >(b, e, st.root.get(), 0); + }); +} +void insert_list(SimplexTree& st, std::list< simplex_t > L){ + for (auto s: L){ st.insert(simplex_t(s)); } +} + +void remove_(SimplexTree& st, const py::array_t< idx_t >& simplices){ + vector_handler(st, simplices, [&st](idx_t* b, idx_t* e){ + st.remove(st.find(simplex_t(b, e))); + }); +} +void remove_list(SimplexTree& st, std::list< simplex_t > L){ + for (auto s: L){ + st.remove(st.find(simplex_t(s))); + } +} + +// Vectorized find +[[nodiscard]] +auto find_(SimplexTree& st, const py::array_t< idx_t >& simplices) noexcept -> py::array_t< bool > { + std::vector< bool > v; + vector_handler(st, simplices, [&st, &v](idx_t* b, idx_t* e){ + node_ptr np = st.find(simplex_t(b, e)); + v.push_back(np != st.root.get() && np != nullptr); + }); + return(py::cast(v)); +} + +[[nodiscard]] +auto find_list(SimplexTree& st, std::list< simplex_t > L) -> py::array_t< bool > { + std::vector< bool > v; + for (auto s: L){ + v.push_back(st.find(simplex_t(s))); + } + return py::cast(v); +} + + +bool collapse_(SimplexTree& st, const vector< idx_t >& tau, const vector< idx_t >& sigma){ + return st.collapse(st.find(tau), st.find(sigma)); +} + +auto get_k_simplices(SimplexTree& st, const size_t k) -> py::array_t< idx_t > { + vector< idx_t > res; + if (st.n_simplexes.size() <= k){ return py::cast(res); } + const size_t ns = st.n_simplexes.at(k); + res.reserve(ns*(k+1)); + auto tr = st::k_simplices< true >(&st, st.root.get(), k); + traverse(tr, [&res](node_ptr cn, idx_t depth, simplex_t sigma){ + res.insert(res.end(), sigma.begin(), sigma.end()); + return true; + }); + py::array_t< idx_t > out = py::cast(res); + std::array< size_t, 2 > shape = { ns, k+1 }; + return(out.reshape(shape)); +} + +// Retrieve the vertices by their label +vector< idx_t > get_vertices(const SimplexTree& st) { + if (st.n_simplexes.size() == 0){ return vector< idx_t >(); } //IntegerVector(); } + vector< idx_t > v; + v.reserve(st.n_simplexes.at(0)); + for (auto& cn: st.root->children){ + v.push_back(cn->label); + } + return(v); +} +auto get_edges(SimplexTree& st) -> py::array_t< idx_t > { return get_k_simplices(st, 1); } +auto get_triangles(SimplexTree& st) -> py::array_t< idx_t > { return get_k_simplices(st, 2); } +auto get_quads(SimplexTree& st) -> py::array_t< idx_t > { return get_k_simplices(st, 3); } + + +// Exports the 1-skeleton as an adjacency matrix +// auto as_adjacency_matrix(SimplexTree& stp, int p = 0) -> py::array_t< idx_t > { +// const auto& vertices = st.root->children; +// const size_t n = vertices.size(); +// res = .. + +// // Fill in the adjacency matrix +// size_t i = 0; +// for (auto& vi: vertices){ +// for (auto& vj: vi->children){ +// auto it = std::lower_bound(begin(vertices), end(vertices), vj->label, [](const node_uptr& cn, const idx_t label){ +// return cn->label < label; +// }); +// const size_t j = std::distance(begin(vertices), it); +// res.at(i, j) = res.at(j, i) = 1; +// } +// ++i; +// } +// return(res); +// } + + +auto degree_(SimplexTree& st, vector< idx_t > ids) -> py::array_t< idx_t > { + vector< idx_t > res(ids.size()); + std::transform(begin(ids), end(ids), begin(res), [&st](int id){ + return st.degree(static_cast< idx_t >(id)); + }); + return py::cast(res); +} + +py::array_t< idx_t > degree_default(SimplexTree& st) { + return(degree_(st, get_vertices(st))); +} + +py::list adjacent_(const SimplexTree& st, vector< idx_t > ids = vector< idx_t >()){ + if (ids.size() == 0){ ids = get_vertices(st); } + vector< vector< idx_t > > res(ids.size()); + std::transform(begin(ids), end(ids), begin(res), [&st](int id){ + return st.adjacent_vertices(static_cast< idx_t >(id)); + }); + return py::cast(res); +} + + +void print_tree(const SimplexTree& st){ + py::scoped_ostream_redirect stream( + std::cout, // std::ostream& + py::module_::import("sys").attr("stdout") // Python output + ); + st.print_tree(std::cout); +} +void print_cousins(const SimplexTree& st){ + py::scoped_ostream_redirect stream( + std::cout, // std::ostream& + py::module_::import("sys").attr("stdout") // Python output + ); + st.print_cousins(std::cout); +} + +auto simplex_counts(const SimplexTree& st) -> py::array_t< size_t > { + auto zero_it = std::find(st.n_simplexes.begin(), st.n_simplexes.end(), 0); + auto ne = std::vector< size_t >(st.n_simplexes.begin(), zero_it); + return(py::cast(ne)); +} + + +// void make_flag_filtration(Filtration* st, const NumericVector& D){ +// if (st->n_simplexes.size() <= 1){ return; } +// const size_t ne = st->n_simplexes.at(1); +// const auto v = st->get_vertices(); +// const size_t N = BinomialCoefficient(v.size(), 2); +// if (size_t(ne) == size_t(D.size())){ +// vector< double > weights(D.begin(), D.end()); +// st->flag_filtration(weights, false); +// } else if (size_t(D.size()) == size_t(N)){ // full distance vector passed in +// auto edge_iter = st::k_simplices< true >(st, st->root.get(), 1); +// vector< double > weights; +// weights.reserve(ne); +// st::traverse(edge_iter, [&weights, &D, &v](node_ptr np, idx_t depth, simplex_t sigma){ +// auto v1 = sigma[0], v2 = sigma[1]; +// auto idx1 = std::distance(begin(v), std::lower_bound(begin(v), end(v), v1)); +// auto idx2 = std::distance(begin(v), std::lower_bound(begin(v), end(v), v2)); +// auto dist_idx = to_natural_2(idx1, idx2, v.size()); +// weights.push_back(D[dist_idx]); +// return true; +// }); +// st->flag_filtration(weights, false); +// } else { +// throw std::invalid_argument("Flag filtrations require a vector of distances for each edge or a 'dist' object"); +// } +// } +// +// void test_filtration(Filtration* st, const size_t i){ +// auto ind = st->simplex_idx(i); +// for (auto idx: ind){ Rcout << idx << ", "; } +// Rcout << std::endl; +// auto sigma = st->expand_simplex(begin(ind), end(ind)); +// for (auto& label: sigma){ Rcout << label << ", "; } +// Rcout << std::endl; +// } + + + +// PYBIND11_MODULE(filtration_module, m) { +// py::class_< SimplexTree >("SimplexTree") +// .def(py::init<>()); +// // .method( "as_XPtr", &as_XPtr) +// .property("n_simplices", &simplex_counts, "Gets simplex counts") +// // .field_readonly("n_simplices", &SimplexTree::n_simplexes) +// .property("dimension", &SimplexTree::dimension) +// .property("id_policy", &SimplexTree::get_id_policy, &SimplexTree::set_id_policy) +// // .property("vertices", &get_vertices, "Returns the vertex labels as an integer vector.") +// // .property("edges", &get_edges, "Returns the edges as an integer matrix.") +// // .property("triangles", &get_triangles, "Returns the 2-simplices as an integer matrix.") +// // .property("quads", &get_quads, "Returns the 3-simplices as an integer matrix.") +// .property("connected_components", &SimplexTree::connected_components) +// // .def( "print_tree", &print_tree ) +// // .def( "print_cousins", &print_cousins ) +// .def( "clear", &SimplexTree::clear) +// .def( "generate_ids", &SimplexTree::generate_ids) +// .def( "reindex", &SimplexTree::reindex) +// // .def( "adjacent", &adjacent_R) +// // .def( "degree", °ree_R) +// // .def( "insert", &insert_R) +// // .def( "insert_lex", &insert_lex) +// // .def( "remove", &remove_R) +// // .def( "find", &find_R) +// .def( "expand", &SimplexTree::expansion ) +// .def( "collapse", &collapse_R) +// // .def( "vertex_collapse", (bool (SimplexTree::*)(idx_t, idx_t, idx_t))(&SimplexTree::vertex_collapse)) +// .def( "contract", &SimplexTree::contract) +// .def( "is_tree", &SimplexTree::is_tree) +// // .def( "as_adjacency_matrix", &as_adjacency_matrix) +// // .def( "as_adjacency_list", &as_adjacency_list) +// // .def( "as_edge_list", &as_edge_list) +// ; +// // Rcpp::class_< Filtration >("Filtration") +// // .derives< SimplexTree >("SimplexTree") +// // .constructor() +// // .method("init_tree", &init_filtration) +// // .field("included", &Filtration::included) +// // .property("current_index", &Filtration::current_index) +// // .property("current_value", &Filtration::current_value) +// // .property("simplices", &get_simplices, "Returns the simplices in the filtration") +// // .property("weights", &Filtration::weights, "Returns the weights in the filtration") +// // .property("dimensions", &Filtration::dimensions, "Returns the dimensions of the simplices in the filtration") +// // .method("flag_filtration", &make_flag_filtration, "Constructs a flag filtration") +// // .method("threshold_value", &Filtration::threshold_value) +// // .method("threshold_index", &Filtration::threshold_index) +// // ; +// } + +// --- Begin functional exports + helpers --- + +// #include +// // [[Rcpp::export]] +// NumericVector profile(SEXP st){ +// Rcpp::XPtr< SimplexTree > st_ptr(st); +// Timer timer; +// timer.step("start"); +// st_ptr->expansion(2); +// timer.step("expansion"); + +// NumericVector res(timer); +// const size_t n = 1000; +// for (size_t i=0; i < size_t(res.size()); ++i) { res[i] = res[i] / n; } +// return res; +// } + + +// bool contains_arg(vector< std::string > v, std::string arg_name){ +// return(std::any_of(v.begin(), v.end(), [&arg_name](const std::string arg){ +// return arg == arg_name; +// })); +// }; + +// // From: https://stackoverflow.com/questions/56465550/how-to-concatenate-lists-in-rcpp +// List cLists(List x, List y) { +// int nsize = x.size(); +// int msize = y.size(); +// List out(nsize + msize); + +// CharacterVector xnames = x.names(); +// CharacterVector ynames = y.names(); +// CharacterVector outnames(nsize + msize); +// out.attr("names") = outnames; +// for(int i = 0; i < nsize; i++) { +// out[i] = x[i]; +// outnames[i] = xnames[i]; +// } +// for(int i = 0; i < msize; i++) { +// out[nsize+i] = y[i]; +// outnames[nsize+i] = ynames[i]; +// } +// return(out); +// } + +// The types of traversal supported +enum TRAVERSAL_TYPE { + PREORDER = 0, LEVEL_ORDER = 1, FACES = 2, COFACES = 3, COFACE_ROOTS = 4, K_SKELETON = 5, + K_SIMPLICES = 6, MAXIMAL = 7, LINK = 8 +}; +// const size_t N_TRAVERSALS = 9; + +// Exports a list with the parameters for a preorder traversal +// py::list parameterize_(SimplexTree& st, vector< idx_t > sigma, std::string type, Rcpp::Nullable args){ + +// if (type == "preorder" || type == "dfs") { param_res["traversal_type"] = int(PREORDER); } +// else if (type == "level_order" || type == "bfs") { param_res["traversal_type"] = int(LEVEL_ORDER); } +// else if (type == "cofaces" || type == "star") { param_res["traversal_type"] = int(COFACES); } +// else if (type == "coface_roots") { param_res["traversal_type"] = int(COFACE_ROOTS); } +// else if (type == "link"){ param_res["traversal_type"] = int(LINK); } +// else if (type == "k_skeleton" || type == "skeleton"){ param_res["traversal_type"] = int(K_SKELETON); } +// else if (type == "k_simplices" || type == "maximal-skeleton"){ param_res["traversal_type"] = int(K_SIMPLICES); } +// else if (type == "maximal"){ param_res["traversal_type"] = int(MAXIMAL); } +// else if(type == "faces"){ param_res["traversal_type"] = int(FACES); } +// else { stop("Iteration 'type' is invalid. Please use one of: preorder, level_order, faces, cofaces, star, link, skeleton, or maximal-skeleton"); } +// param_res.attr("class") = "st_traversal"; +// return(param_res); +// } + +using param_pack = typename std::tuple< SimplexTree*, node_ptr, TRAVERSAL_TYPE >; + +// Traverse some aspect of the simplex tree, given parameters +// template < class Lambda > +// void traverse_switch(SimplexTree& st, param_pack&& pp, List args, Lambda&& f){ +// auto args_str = as< vector< std::string > >(args.names()); +// SimplexTree* st = get< 0 >(pp); +// node_ptr init = get< 1 >(pp); +// TRAVERSAL_TYPE tt = get< 2 >(pp); +// switch(tt){ +// case PREORDER: { +// auto tr = st::preorder< true >(st, init); +// traverse(tr, f); +// break; +// } +// case LEVEL_ORDER: { +// auto tr = st::level_order< true >(st, init); +// traverse(tr, f); +// break; +// } +// case FACES: { +// auto tr = st::faces< true >(st, init); +// traverse(tr, f); +// break; +// } +// case COFACES: { +// auto tr = st::cofaces< true >(st, init); +// traverse(tr, f); +// break; +// } +// case COFACE_ROOTS: { +// auto tr = st::coface_roots< true >(st, init); +// traverse(tr, f); +// break; +// } +// case K_SKELETON: { +// if (!contains_arg(args_str, "k")){ stop("Expecting dimension 'k' to be passed."); } +// idx_t k = args["k"]; +// auto tr = st::k_skeleton< true >(st, init, k); +// traverse(tr, f); +// break; +// } +// case K_SIMPLICES: { +// if (!contains_arg(args_str, "k")){ stop("Expecting dimension 'k' to be passed."); } +// idx_t k = args["k"]; +// auto tr = st::k_simplices< true >(st, init, k); +// traverse(tr, f); +// break; +// } +// case MAXIMAL: { +// auto tr = st::maximal< true >(st, init); +// traverse(tr, f); +// break; +// } +// case LINK: { +// auto tr = st::link< true >(st, init); +// traverse(tr, f); +// break; +// } +// } +// } + +// // To validate the traversal parameters +// param_pack validate_params(List args){ +// // Extract parameters +// auto args_str = as< vector< std::string > >(args.names()); + +// // Extract tree +// if (!contains_arg(args_str, ".ptr")){ stop("Simplex tree pointer missing."); } +// SEXP xptr = args[".ptr"]; +// if (TYPEOF(xptr) != EXTPTRSXP || R_ExternalPtrAddr(xptr) == NULL){ +// stop("Invalid pointer to simplex tree."); +// } +// XPtr< SimplexTree > st(xptr); // Unwrap XPtr + +// // Extract initial simplex +// node_ptr init = nullptr; +// if (!contains_arg(args_str, "sigma")){ init = st->root.get(); } +// else { +// IntegerVector sigma = args["sigma"]; +// init = st->find(sigma); +// if (init == nullptr){ init = st->root.get(); } +// } +// if (init == nullptr){ stop("Invalid starting simplex"); } + +// // Extract traversal type +// size_t tt = (size_t) args["traversal_type"]; +// if (tt < 0 || tt >= N_TRAVERSALS){ stop("Unknown traversal type."); } + +// return(std::make_tuple(static_cast< SimplexTree* >(st), init, static_cast< TRAVERSAL_TYPE >(tt))); +// } + +void traverse_(SimplexTree& stree, const size_t order, py::function f, simplex_t init = simplex_t(), const size_t k = 0){ + node_ptr base = init.size() == 0 ? stree.root.get() : stree.find(init); + // py::print("Starting from root?", init.size() == 0 ? "Y" : "N", ", k=",k); + const auto apply_f = [&f](node_ptr cn, idx_t depth, simplex_t s){ + f(py::cast(s)); + return true; + }; + switch(order){ + case PREORDER: { + auto tr = st::preorder< true >(&stree, base); + traverse(tr, apply_f); + break; + } + case LEVEL_ORDER: { + auto tr = st::level_order< true >(&stree, base); + traverse(tr, apply_f); + break; + } + case FACES: { + auto tr = st::faces< true >(&stree, base); + traverse(tr, apply_f); + break; + } + case COFACES: { + auto tr = st::cofaces< true >(&stree, base); + traverse(tr, apply_f); + break; + } + case COFACE_ROOTS: { + auto tr = st::coface_roots< true >(&stree, base); + traverse(tr, apply_f); + break; + } + case K_SKELETON: { + auto tr = st::k_skeleton< true >(&stree, base, k); + traverse(tr, apply_f); + break; + } + case K_SIMPLICES: { + auto tr = st::k_simplices< true >(&stree, base, k); + traverse(tr, apply_f); + break; + } + case MAXIMAL: { + auto tr = st::maximal< true >(&stree, base); + traverse(tr, apply_f); + break; + } + case LINK: { + auto tr = st::link< true >(&stree, base); + traverse(tr, apply_f); + break; + } + } +} + +// // [[Rcpp::export]] +// List ltraverse_R(List args, Function f){ +// List res = List(); +// auto run_Rf = [&f, &res](node_ptr cn, idx_t d, simplex_t tau){ +// res.push_back(f(wrap(tau))); +// return(true); +// }; +// traverse_switch(validate_params(args), args, run_Rf); +// return(res); +// } + + +#include +#include + +// void expand_f_bernoulli(SimplexTree& stx, const size_t k, const double p){ +// SimplexTree& st = *(Rcpp::XPtr< SimplexTree >(stx)); + +// // Random number generator +// std::random_device random_device; +// std::mt19937 random_engine(random_device()); +// std::uniform_real_distribution< double > bernoulli(0.0, 1.0); + +// // Perform Bernoulli trials for given k, with success probability p +// st.expansion_f(k, [&](node_ptr parent, idx_t depth, idx_t label){ +// double q = bernoulli(random_engine); +// if (p == 1.0 | q < p){ // if successful trial +// std::array< idx_t, 1 > int_label = { label }; +// st.insert_it(begin(int_label), end(int_label), parent, depth); +// } +// }); +// } + + + +// Expands st conditionally based on f +// void expansion_f(SimplexTree& st, const size_t k, py::function f){ +// const auto do_expand = [&](node_ptr np, auto depth, auto label){ +// // get simplex +// return true; +// }; +// st.expansion_f(k, f); +// } + + +// pip install --no-deps --no-build-isolation --editable . +// clang -Wall -fPIC -c src/simplicial/simplextree_module.cpp -std=c++17 -I/Users/mpiekenbrock/pbsig/extern/pybind11/include -I/Users/mpiekenbrock/simplicial/src/simplicial/include -I/Users/mpiekenbrock/opt/miniconda3/envs/pbsig/include/python3.9 +PYBIND11_MODULE(_simplextree, m) { + + py::class_(m, "SimplexTree") + .def(py::init<>()) + .def_property_readonly("n_simplices", &simplex_counts) + .def_property_readonly("dimension", &SimplexTree::dimension) + .def_property("id_policy", &SimplexTree::get_id_policy, &SimplexTree::set_id_policy) + .def_property_readonly("vertices", &get_vertices) + .def_property_readonly("edges", &get_edges) + .def_property_readonly("triangles", &get_triangles) + .def_property_readonly("quads", &get_quads) + .def_property_readonly("connected_components", &SimplexTree::connected_components) + .def( "print_tree", &print_tree ) + .def( "print_cousins", &print_cousins ) + .def( "clear", &SimplexTree::clear) + .def( "_degree", °ree_) + .def( "_degree_default", °ree_default) + // .def( "degree", static_cast< py::array_t< idx_t >(SimplexTree::*)() >(°ree_default), "degree") + .def("_insert", &insert_) + .def("_insert_list", &insert_list) + // .def( "insert_lex", &insert_lex) + .def( "_remove", &remove_) + .def( "_remove_list", &remove_list) + .def( "_find", &find_) + .def( "_find_list", &find_list) + .def( "_adjacent", &adjacent_) + .def( "_collapse", &collapse_) + .def( "generate_ids", &SimplexTree::generate_ids) + .def( "_reindex", &SimplexTree::reindex) + .def( "_expand", &SimplexTree::expansion ) + .def( "_vertex_collapse", (bool (SimplexTree::*)(idx_t, idx_t, idx_t))(&SimplexTree::vertex_collapse)) + .def( "_contract", &SimplexTree::contract) + .def( "is_tree", &SimplexTree::is_tree) + .def( "_traverse", &traverse_) + // .def( "as_adjacency_matrix", &as_adjacency_matrix) + ; +} \ No newline at end of file diff --git a/simplextree/version.py b/simplextree/version.py index 7c1aef3..4d255a4 100644 --- a/simplextree/version.py +++ b/simplextree/version.py @@ -1,11 +1,7 @@ +## SemVer: https://semver.org/#is-v123-a-semantic-version _MAJOR = "0" _MINOR = "1" -# On main and in a nightly release the patch should be one ahead of the last -# released build. _PATCH = "0" -# This is mainly for nightly builds which have the suffix ".dev$DATE". See -# https://semver.org/#is-v123-a-semantic-version for the semantics. _SUFFIX = "" - VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)