Skip to content

Commit

Permalink
push up version that takes index_seq
Browse files Browse the repository at this point in the history
  • Loading branch information
artv3 committed Dec 23, 2024
1 parent 4d05ea5 commit b0a836a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 6 deletions.
30 changes: 29 additions & 1 deletion examples/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ int main(int RAJA_UNUSED_ARG(argc), char **RAJA_UNUSED_ARG(argv[]))
// Allocate memory for pointer
int *Rptr = memoryManager::allocate<int>(K * N * M);
int *Lptr = memoryManager::allocate<int>(K * N * M);
int *Cptr = memoryManager::allocate<int>(K * N * M);

//----------------------------------------------------------------------------//
//
Expand Down Expand Up @@ -90,13 +91,40 @@ int main(int RAJA_UNUSED_ARG(argc), char **RAJA_UNUSED_ARG(argv[]))
for(int k = 0; k < K; ++k) {

const int idx = k + K * (n + N * m);
Lview(k,m,n) = idx;
Lview(k,n,m) = idx;
}
}
}

checkResult(Lptr, K, N, M);


//----------------------------------------------------------------------------//
//
// Initialize memory using custom ordering, left most value of the index sequence
// is assumed to have unit stride, while the right most has the longest stride
//
//----------------------------------------------------------------------------//
std::cout << "\n\nInitialize array with custom indexing...\n";

using custom_seq = std::index_sequence<1U,0U,2U>;

auto Cview = RAJA::Reshape<custom_seq>::get(Cptr, K, N, M);

//Note the loop ordering has change from above
for(int k = 0; k < K; ++k) {
for(int m = 0; m < M; ++m) {
for(int n = 0; n < N; ++n) {

const int idx = n + N * (m + M * k);
Cview(k,n,m) = idx;
}
}
}

checkResult(Cptr, K, N, M);


//
// Clean up.
//
Expand Down
38 changes: 33 additions & 5 deletions include/RAJA/util/View.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,34 @@ struct layout_right{};
template<typename LAYOUT>
struct Reshape;

template<typename LAYOUT>
struct Reshape;

template<std::size_t Id, std::size_t...Ids>
constexpr auto get_first_index()
{
return Id;
}

template<std::size_t...Is>
struct Reshape<std::index_sequence<Is...>>
{
template<typename T, typename...Ts>
static auto get(T *ptr, Ts... s)
{
constexpr int N = sizeof...(Ts);
std::array<RAJA::idx_t, N> extent{s...};

auto custom_layout =
RAJA::make_permuted_layout(extent, std::array<RAJA::idx_t, N>{Is...});

constexpr auto unit_stride = get_first_index<Is...>();

return RAJA::View<T, RAJA::Layout<N, RAJA::Index_type, unit_stride>>
(ptr, custom_layout);
}
};

template<>
struct Reshape<layout_right>
{
Expand All @@ -327,15 +355,15 @@ struct Reshape<layout_left>
template<typename T, typename...Ts>
static auto get(T *ptr, Ts... s)
{
constexpr int N = sizeof...(Ts);
constexpr int N = sizeof...(Ts);

std::array<RAJA::idx_t, N> extent{s...};
std::array<RAJA::idx_t, N> extent{s...};

constexpr auto reverse_array = make_reverse_array(std::make_index_sequence<N>{});
constexpr auto reverse_array = make_reverse_array(std::make_index_sequence<N>{});

auto reverse_layout = RAJA::make_permuted_layout(extent, reverse_array);
auto reverse_layout = RAJA::make_permuted_layout(extent, reverse_array);

return RAJA::View<T, RAJA::Layout<N, RAJA::Index_type, 0>>(ptr, reverse_layout);
return RAJA::View<T, RAJA::Layout<N, RAJA::Index_type, 0>>(ptr, reverse_layout);
}

};
Expand Down

0 comments on commit b0a836a

Please sign in to comment.