Skip to content

Commit

Permalink
Moved reorder logic to assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
spectre-ns committed Dec 24, 2024
1 parent ae52796 commit 03774e5
Showing 1 changed file with 54 additions and 3 deletions.
57 changes: 54 additions & 3 deletions include/xtensor/xassign.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,16 +439,67 @@ namespace xt
using requested_value_type = detail::conditional_promote_to_complex_t<e1_value_type, e2_requested_value_type>;
};

/**********************************
* Expression Order Optimizations *
**********************************/

class optimize_expression
{
private:
template <class E1, class E2>
struct equal_rank
{
static constexpr bool value = get_rank<E1>::value == get_rank<E2>::value;
};

template <class E1, class... E>
struct all_equal_rank
{
static constexpr bool value = xtl::conjunction<equal_rank<E1, E>...>::value
&& (get_rank<E1>::value != SIZE_MAX);
};

template <class F, class... CT, class... S, size_t... I, size_t... J>
inline auto impl_reorder_function(const xfunction<F, CT...>& e, std::tuple<S...> slices, std::index_sequence<I...>, std::index_sequence<J...>)
{
return make_lambda_xfunction(F(), view(std::get<I>(e.arguments()), std::get<J>(slices)...)...);
}

public:
//when we have a view of a function where the closures of the functions are of equal rank (i.e no broadcasting)
//we can flip the order of the function and the view such that we have a function of views of containers which
//can be linearly assigned unlike the inverse.
template <class F, class... CT, class... S, class = std::enable_if_t<all_equal_rank<std::decay_t<CT>...>::value>>
inline auto reorder(const xview<xfunction<F, CT...>, S...>& e)
{
return impl_reorder_function(
e.expression(),
e.slices(),
std::make_index_sequence<sizeof...(CT)>(),
std::make_index_sequence<sizeof...(S)>()
);
}

//base case no applicable optimization
template<class E>
inline auto& reorder(E&& e)
{
return std::forward<E>(e);
}
};

template <class E1, class E2>
inline void xexpression_assigner_base<xtensor_expression_tag>::assign_data(
xexpression<E1>& e1,
const xexpression<E2>& e2,
bool trivial
)
{
E1& de1 = e1.derived_cast();
const E2& de2 = e2.derived_cast();
using traits = xassign_traits<E1, E2>;
auto& de1 = e1.derived_cast();
const auto& de2 = optimize_expression().reorder(e2.derived_cast());
using dst_type = typename std::decay_t<decltype(de1)>;
using src_type = typename std::decay_t<decltype(de2)>;
using traits = xassign_traits<dst_type, src_type>;

bool linear_assign = traits::linear_assign(de1, de2, trivial);
constexpr bool simd_assign = traits::simd_assign();
Expand Down

0 comments on commit 03774e5

Please sign in to comment.