Skip to content

Commit

Permalink
add fixed_shape::operator==
Browse files Browse the repository at this point in the history
  • Loading branch information
tailsu committed May 17, 2021
1 parent 41be223 commit be35a26
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 11 deletions.
14 changes: 14 additions & 0 deletions include/xtensor/xstorage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,19 @@ namespace xt
return sizeof...(X) == 0;
}

template <std::size_t... Y>
XTENSOR_FIXED_SHAPE_CONSTEXPR bool operator==(const fixed_shape<Y...>& b) const {
if (size() != b.size()) {
return false;
}
for (std::size_t i = 0; i < size(); ++i) {
if ((*this)[i] != b[i]) {
return false;
}
}
return true;
}

private:

XTENSOR_CONSTEXPR_ENHANCED_STATIC cast_type m_array = cast_type({X...});
Expand All @@ -1702,6 +1715,7 @@ namespace xt
constexpr typename fixed_shape<X...>::cast_type fixed_shape<X...>::m_array;
#endif


#undef XTENSOR_FIXED_SHAPE_CONSTEXPR

template <class E, std::ptrdiff_t Start, std::ptrdiff_t End = -1>
Expand Down
8 changes: 0 additions & 8 deletions test/test_xbuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,14 +380,6 @@ namespace xt
XT_EXPECT_ANY_THROW(xt::concatenate(xt::xtuple(fa, ta)));
}

template <std::size_t... I, std::size_t... J>
bool operator==(fixed_shape<I...>, fixed_shape<J...>)
{
std::array<std::size_t, sizeof...(I)> ix = {I...};
std::array<std::size_t, sizeof...(J)> jx = {J...};
return sizeof...(J) == sizeof...(I) && std::equal(ix.begin(), ix.end(), jx.begin());
}

#ifndef VS_SKIP_CONCATENATE_FIXED
// This test mimics the relevant parts of `TEST(xbuilder, concatenate)`
TEST(xbuilder, concatenate_fixed)
Expand Down
16 changes: 13 additions & 3 deletions test/test_xstorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ namespace xt
svector_type e(src);
EXPECT_EQ(size_t(10), d.size());
EXPECT_EQ(size_t(1), d[2]);

svector_type f = { 1, 2, 3, 4 };
EXPECT_EQ(size_t(4), f.size());
EXPECT_EQ(size_t(3), f[2]);
Expand All @@ -221,7 +221,7 @@ namespace xt
TEST(svector, assign)
{
svector_type a = { 1, 2, 3, 4 };

svector_type src1(10, 2);
a = src1;
EXPECT_EQ(size_t(10), a.size());
Expand Down Expand Up @@ -332,7 +332,7 @@ namespace xt
EXPECT_EQ(i, a[i]);
}
}

TEST(fixed_shape, fixed_shape)
{
fixed_shape<3, 4, 5> af;
Expand All @@ -344,4 +344,14 @@ namespace xt
EXPECT_EQ(a.front(), size_t(3));
EXPECT_EQ(a.size(), size_t(3));
}

TEST(fixed_shape, operator_eq)
{
fixed_shape<2, 3> a, b;
fixed_shape<5> c;
fixed_shape<2, 4> d;
EXPECT_TRUE(a == b);
EXPECT_FALSE(a == c);
EXPECT_FALSE(a == d);
}
}

0 comments on commit be35a26

Please sign in to comment.