Skip to content

Commit

Permalink
Merge pull request #2238 from tdegeus/argsort
Browse files Browse the repository at this point in the history
argsort: catching zeros stride leading axis (bugfix)
  • Loading branch information
JohanMabille authored May 17, 2021
2 parents 1b36ad6 + fcf6ade commit a8f1a6a
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
8 changes: 4 additions & 4 deletions include/xtensor/xsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,15 @@ namespace xt
{
n_iters = std::accumulate(data.shape().begin(), data.shape().end() - 1,
std::size_t(1), std::multiplies<>());
data_secondary_stride = data.strides()[data.dimension() - 2];
inds_secondary_stride = inds.strides()[inds.dimension() - 2];
data_secondary_stride = data.shape(data.dimension() - 1);
inds_secondary_stride = inds.shape(inds.dimension() - 1);
}
else
{
n_iters = std::accumulate(data.shape().begin() + 1, data.shape().end(),
std::size_t(1), std::multiplies<>());
data_secondary_stride = data.strides()[1];
inds_secondary_stride = inds.strides()[1];
data_secondary_stride = data.shape(0);
inds_secondary_stride = inds.shape(0);
}

auto ptr = data.data();
Expand Down
44 changes: 43 additions & 1 deletion test/test_xsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,48 @@ namespace xt
}
}

TEST(xsort, argsort_zero_stride)
{
{
xt::xtensor<double, 2> A = {{1.4, 1.3, 1.2, 1.1}};
xt::xtensor<size_t, 2> bsort = {{0, 0, 0, 0}};
xt::xtensor<size_t, 2> fsort = {{3, 2, 1, 0}};
EXPECT_EQ(bsort, xt::argsort(A, 0));
EXPECT_EQ(fsort, xt::argsort(A, 1));
EXPECT_EQ(fsort, xt::argsort(A));
}
{
xt::xtensor<double, 3> A = {{{1.4, 1.3, 1.2, 1.1}}};
xt::xtensor<size_t, 3> bsort = {{{0, 0, 0, 0}}};
xt::xtensor<size_t, 3> fsort = {{{3, 2, 1, 0}}};
EXPECT_EQ(bsort, xt::argsort(A, 0));
EXPECT_EQ(bsort, xt::argsort(A, 1));
EXPECT_EQ(fsort, xt::argsort(A, 2));
EXPECT_EQ(fsort, xt::argsort(A));
}
}

TEST(xsort, argsort_zero_stride_column_major)
{
{
xt::xtensor<double, 2, xt::layout_type::column_major> A = {{1.4, 1.3, 1.2, 1.1}};
xt::xtensor<size_t, 2, xt::layout_type::column_major> bsort = {{0, 0, 0, 0}};
xt::xtensor<size_t, 2, xt::layout_type::column_major> fsort = {{3, 2, 1, 0}};
EXPECT_EQ(bsort, xt::argsort(A, 0));
EXPECT_EQ(fsort, xt::argsort(A, 1));
EXPECT_EQ(fsort, xt::argsort(A));
}
{
xt::xtensor<double, 3, xt::layout_type::column_major> A = {{{1.4, 1.3, 1.2, 1.1}}};
xt::xtensor<size_t, 3, xt::layout_type::column_major> bsort = {{{0, 0, 0, 0}}};
xt::xtensor<size_t, 3, xt::layout_type::column_major> fsort = {{{3, 2, 1, 0}}};
EXPECT_EQ(bsort, xt::argsort(A, 0));
EXPECT_EQ(bsort, xt::argsort(A, 1));
EXPECT_EQ(fsort, xt::argsort(A, 2));
EXPECT_EQ(fsort, xt::argsort(A));
}
}

TEST(xsort, flatten_argsort)
{
{
Expand Down Expand Up @@ -256,7 +298,7 @@ namespace xt
auto d = unique(c);
EXPECT_EQ(d, bx);

auto e = xt::unique(xt::where(xt::greater(b,2), 1, 0));
auto e = xt::unique(xt::where(xt::greater(b,2), 1, 0));
xarray<double> ex = {0, 1};
EXPECT_EQ(e, ex);
}
Expand Down

0 comments on commit a8f1a6a

Please sign in to comment.