diff --git a/examples/reshape.cpp b/examples/reshape.cpp index 84c78465cf..b756cf712c 100644 --- a/examples/reshape.cpp +++ b/examples/reshape.cpp @@ -54,6 +54,7 @@ int main(int RAJA_UNUSED_ARG(argc), char **RAJA_UNUSED_ARG(argv[])) // Allocate memory for pointer int *Rptr = memoryManager::allocate(K * N * M); int *Lptr = memoryManager::allocate(K * N * M); + int *Cptr = memoryManager::allocate(K * N * M); //----------------------------------------------------------------------------// // @@ -90,13 +91,40 @@ int main(int RAJA_UNUSED_ARG(argc), char **RAJA_UNUSED_ARG(argv[])) for(int k = 0; k < K; ++k) { const int idx = k + K * (n + N * m); - Lview(k,m,n) = idx; + Lview(k,n,m) = idx; } } } checkResult(Lptr, K, N, M); + +//----------------------------------------------------------------------------// +// +// Initialize memory using custom ordering, left most value of the index sequence +// is assumed to have unit stride, while the right most has the longest stride +// +//----------------------------------------------------------------------------// + std::cout << "\n\nInitialize array with custom indexing...\n"; + + using custom_seq = std::index_sequence<1U,0U,2U>; + + auto Cview = RAJA::Reshape::get(Cptr, K, N, M); + + //Note the loop ordering has change from above + for(int k = 0; k < K; ++k) { + for(int m = 0; m < M; ++m) { + for(int n = 0; n < N; ++n) { + + const int idx = n + N * (m + M * k); + Cview(k,n,m) = idx; + } + } + } + + checkResult(Cptr, K, N, M); + + // // Clean up. // diff --git a/include/RAJA/util/View.hpp b/include/RAJA/util/View.hpp index 6e16361616..c21c794513 100644 --- a/include/RAJA/util/View.hpp +++ b/include/RAJA/util/View.hpp @@ -303,6 +303,34 @@ struct layout_right{}; template struct Reshape; +template +struct Reshape; + +template +constexpr auto get_first_index() +{ + return Id; +} + +template +struct Reshape> +{ + template + static auto get(T *ptr, Ts... s) + { + constexpr int N = sizeof...(Ts); + std::array extent{s...}; + + auto custom_layout = + RAJA::make_permuted_layout(extent, std::array{Is...}); + + constexpr auto unit_stride = get_first_index(); + + return RAJA::View> + (ptr, custom_layout); + } +}; + template<> struct Reshape { @@ -327,15 +355,15 @@ struct Reshape template static auto get(T *ptr, Ts... s) { - constexpr int N = sizeof...(Ts); + constexpr int N = sizeof...(Ts); - std::array extent{s...}; + std::array extent{s...}; - constexpr auto reverse_array = make_reverse_array(std::make_index_sequence{}); + constexpr auto reverse_array = make_reverse_array(std::make_index_sequence{}); - auto reverse_layout = RAJA::make_permuted_layout(extent, reverse_array); + auto reverse_layout = RAJA::make_permuted_layout(extent, reverse_array); - return RAJA::View>(ptr, reverse_layout); + return RAJA::View>(ptr, reverse_layout); } };