diff --git a/include/xtensor/xsort.hpp b/include/xtensor/xsort.hpp index 1ed7fe990..2c0e6f0ea 100644 --- a/include/xtensor/xsort.hpp +++ b/include/xtensor/xsort.hpp @@ -279,15 +279,15 @@ namespace xt { n_iters = std::accumulate(data.shape().begin(), data.shape().end() - 1, std::size_t(1), std::multiplies<>()); - data_secondary_stride = data.strides()[data.dimension() - 2]; - inds_secondary_stride = inds.strides()[inds.dimension() - 2]; + data_secondary_stride = data.shape(data.dimension() - 1); + inds_secondary_stride = inds.shape(inds.dimension() - 1); } else { n_iters = std::accumulate(data.shape().begin() + 1, data.shape().end(), std::size_t(1), std::multiplies<>()); - data_secondary_stride = data.strides()[1]; - inds_secondary_stride = inds.strides()[1]; + data_secondary_stride = data.shape(0); + inds_secondary_stride = inds.shape(0); } auto ptr = data.data(); diff --git a/test/test_xsort.cpp b/test/test_xsort.cpp index fb887d9e5..a92efba67 100644 --- a/test/test_xsort.cpp +++ b/test/test_xsort.cpp @@ -71,6 +71,48 @@ namespace xt } } + TEST(xsort, argsort_zero_stride) + { + { + xt::xtensor A = {{1.4, 1.3, 1.2, 1.1}}; + xt::xtensor bsort = {{0, 0, 0, 0}}; + xt::xtensor fsort = {{3, 2, 1, 0}}; + EXPECT_EQ(bsort, xt::argsort(A, 0)); + EXPECT_EQ(fsort, xt::argsort(A, 1)); + EXPECT_EQ(fsort, xt::argsort(A)); + } + { + xt::xtensor A = {{{1.4, 1.3, 1.2, 1.1}}}; + xt::xtensor bsort = {{{0, 0, 0, 0}}}; + xt::xtensor fsort = {{{3, 2, 1, 0}}}; + EXPECT_EQ(bsort, xt::argsort(A, 0)); + EXPECT_EQ(bsort, xt::argsort(A, 1)); + EXPECT_EQ(fsort, xt::argsort(A, 2)); + EXPECT_EQ(fsort, xt::argsort(A)); + } + } + + TEST(xsort, argsort_zero_stride_column_major) + { + { + xt::xtensor A = {{1.4, 1.3, 1.2, 1.1}}; + xt::xtensor bsort = {{0, 0, 0, 0}}; + xt::xtensor fsort = {{3, 2, 1, 0}}; + EXPECT_EQ(bsort, xt::argsort(A, 0)); + EXPECT_EQ(fsort, xt::argsort(A, 1)); + EXPECT_EQ(fsort, xt::argsort(A)); + } + { + xt::xtensor A = {{{1.4, 1.3, 1.2, 1.1}}}; + xt::xtensor bsort = {{{0, 0, 0, 0}}}; + xt::xtensor fsort = {{{3, 2, 1, 0}}}; + EXPECT_EQ(bsort, xt::argsort(A, 0)); + EXPECT_EQ(bsort, xt::argsort(A, 1)); + EXPECT_EQ(fsort, xt::argsort(A, 2)); + EXPECT_EQ(fsort, xt::argsort(A)); + } + } + TEST(xsort, flatten_argsort) { { @@ -256,7 +298,7 @@ namespace xt auto d = unique(c); EXPECT_EQ(d, bx); - auto e = xt::unique(xt::where(xt::greater(b,2), 1, 0)); + auto e = xt::unique(xt::where(xt::greater(b,2), 1, 0)); xarray ex = {0, 1}; EXPECT_EQ(e, ex); }