diff --git a/include/xtensor/xassign.hpp b/include/xtensor/xassign.hpp index 2ec698ae6..0760fdd13 100644 --- a/include/xtensor/xassign.hpp +++ b/include/xtensor/xassign.hpp @@ -439,6 +439,58 @@ namespace xt using requested_value_type = detail::conditional_promote_to_complex_t; }; + /********************************** + * Expression Order Optimizations * + **********************************/ + + class optimize_expression + { + private: + + template + struct equal_rank + { + static constexpr bool value = get_rank::value == get_rank::value; + }; + + template + struct all_equal_rank + { + static constexpr bool value = xtl::conjunction...>::value + && (get_rank::value != SIZE_MAX); + }; + + template + inline auto + impl_reorder_function(const xfunction& e, const std::tuple& slices, std::index_sequence, std::index_sequence) + { + return make_lambda_xfunction(F(), view(std::get(e.arguments()), std::get(slices)...)...); + } + + public: + + // when we have a view of a function where the closures of the functions are of equal rank (i.e no + // broadcasting) we can flip the order of the function and the view such that we have a function of + // views of containers which can be linearly assigned unlike the inverse. + template ...>::value>> + inline auto reorder(const xview, S...>& e) + { + return impl_reorder_function( + e.expression(), + e.slices(), + std::make_index_sequence(), + std::make_index_sequence() + ); + } + + // base case no applicable optimization + template + inline auto& reorder(E&& e) + { + return std::forward(e); + } + }; + template inline void xexpression_assigner_base::assign_data( xexpression& e1, @@ -446,9 +498,11 @@ namespace xt bool trivial ) { - E1& de1 = e1.derived_cast(); - const E2& de2 = e2.derived_cast(); - using traits = xassign_traits; + auto& de1 = e1.derived_cast(); + const auto& de2 = optimize_expression().reorder(e2.derived_cast()); + using dst_type = typename std::decay_t; + using src_type = typename std::decay_t; + using traits = xassign_traits; bool linear_assign = traits::linear_assign(de1, de2, trivial); constexpr bool simd_assign = traits::simd_assign();