Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Remove .template to fix build (#527)
Browse files Browse the repository at this point in the history
  • Loading branch information
s-Nick authored Jul 23, 2024
1 parent 3b833f7 commit c9d62cf
Show file tree
Hide file tree
Showing 7 changed files with 457 additions and 252 deletions.
2 changes: 1 addition & 1 deletion src/interface/blas1_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ typename sb_handle_t::event_t _iamax_iamin_impl(
static_cast<index_t>(localSize),
static_cast<index_t>(localMemSize), ret));
}
sb_handle.template release_temp_mem({*ret.rbegin()}, gpu_res);
sb_handle.release_temp_mem({*ret.rbegin()}, gpu_res);
}
return ret;
}
Expand Down
20 changes: 10 additions & 10 deletions src/interface/blas2_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ typename sb_handle_t::event_t _gemv_impl(
lastEvent = sb_handle.execute(assignOp, local_range, gemvEvent));
}

sb_handle.template release_temp_mem(lastEvent, dot_products_buffer);
sb_handle.release_temp_mem(lastEvent, dot_products_buffer);

} else // Local memory kernel
{
Expand Down Expand Up @@ -211,7 +211,7 @@ typename sb_handle_t::event_t _gemv_impl(
lastEvent = sb_handle.execute(assignOp, local_range, gemvEvent));
}

sb_handle.template release_temp_mem(lastEvent, dot_products_buffer);
sb_handle.release_temp_mem(lastEvent, dot_products_buffer);
}
return ret;
}
Expand Down Expand Up @@ -341,7 +341,7 @@ typename sb_handle_t::event_t _trmv_impl(
ret = concatenate_vectors(
ret, lastEvent = sb_handle.execute(assignOp, localSize, ret));

sb_handle.template release_temp_mem(lastEvent, valT1);
sb_handle.release_temp_mem(lastEvent, valT1);

return ret;
}
Expand Down Expand Up @@ -403,7 +403,7 @@ typename sb_handle_t::event_t _trsv_impl(
static_cast<index_t>(subgroup_size * (subgroup_size + 2 + sub_num)),
_dependencies);

sb_handle.template release_temp_mem(ret, sync_buffer);
sb_handle.release_temp_mem(ret, sync_buffer);

return ret;
#endif
Expand Down Expand Up @@ -525,8 +525,8 @@ typename sb_handle_t::event_t _symv_impl(
ret = concatenate_vectors(
ret, lastEvent = sb_handle.execute(assignOp, localSize, ret));

sb_handle.template release_temp_mem(lastEvent, valTR);
sb_handle.template release_temp_mem(lastEvent, valTC);
sb_handle.release_temp_mem(lastEvent, valTR);
sb_handle.release_temp_mem(lastEvent, valTC);

return ret;
}
Expand Down Expand Up @@ -680,7 +680,7 @@ typename sb_handle_t::event_t _tbmv_impl(
auto assignEvent = sb_handle.execute(assignOp, local_range, tbmvEvent);
auto ret = concatenate_vectors(tbmvEvent, assignEvent);

sb_handle.template release_temp_mem(assignEvent, res_buffer);
sb_handle.release_temp_mem(assignEvent, res_buffer);

return ret;
}
Expand Down Expand Up @@ -735,7 +735,7 @@ typename sb_handle_t::event_t _tpmv_impl(
auto ret = concatenate_vectors(
tpmvEvent, lastEvent = sb_handle.execute(assignOp, tpmvEvent));

sb_handle.template release_temp_mem(lastEvent, res_buffer);
sb_handle.release_temp_mem(lastEvent, res_buffer);

return ret;
}
Expand Down Expand Up @@ -798,7 +798,7 @@ typename sb_handle_t::event_t _tbsv_impl(
static_cast<index_t>(subgroup_size * (subgroup_size + 2 + sub_num)),
_dependencies);

sb_handle.template release_temp_mem(ret, sync_buffer);
sb_handle.release_temp_mem(ret, sync_buffer);

return ret;
#endif
Expand Down Expand Up @@ -863,7 +863,7 @@ typename sb_handle_t::event_t _tpsv_impl(
static_cast<index_t>(subgroup_size * (subgroup_size + 2 + sub_num)),
_dependencies);

sb_handle.template release_temp_mem(ret, sync_buffer);
sb_handle.release_temp_mem(ret, sync_buffer);

return ret;
#endif
Expand Down
207 changes: 139 additions & 68 deletions src/interface/blas3/backend/amd_gpu.hpp

Large diffs are not rendered by default.

122 changes: 82 additions & 40 deletions src/interface/blas3/backend/default.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,15 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 4,
static_cast<int>(gemm_batch_type_t::interleaved)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
static_cast<int>(
gemm_batch_type_t::interleaved)>::_select_gemm(sb_handle, _M, _N,
_K, _alpha, _a,
_lda, _stridea, _b,
_ldb, _strideb,
_beta, _c, _ldc,
_stridec,
batch_size,
_dependencies);
}
#if defined(NAIVE_GEMM)
return blas::Gemm_Launcher<
Expand All @@ -66,10 +71,14 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::naive),
static_cast<int>(gemm_vectorization_t::partial), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
batch_size, _dependencies);
static_cast<int>(
gemm_batch_type_t::strided)>::_select_gemm(sb_handle, _M, _N, _K,
_alpha, _a, _lda,
_stridea, _b, _ldb,
_strideb, _beta, _c,
_ldc, _stridec,
batch_size,
_dependencies);
#else
if (_M <= 128 && _N <= 128 && _K <= 256 && !s_a && !s_b) {
return blas::Gemm_Launcher<
Expand All @@ -78,43 +87,59 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
static_cast<int>(
gemm_batch_type_t::strided)>::_select_gemm(sb_handle, _M, _N, _K,
_alpha, _a, _lda,
_stridea, _b, _ldb,
_strideb, _beta, _c,
_ldc, _stridec,
batch_size,
_dependencies);
} else if ((_M * _N) >= 524288 && !s_a && !s_b) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 128, false, false, false,
64, Tile<4, 4, 4, 4>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::partial), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
static_cast<int>(
gemm_batch_type_t::strided)>::_select_gemm(sb_handle, _M, _N, _K,
_alpha, _a, _lda,
_stridea, _b, _ldb,
_strideb, _beta, _c,
_ldc, _stridec,
batch_size,
_dependencies);
} else if (!s_a && !s_b) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 128, false, false, false,
64, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
static_cast<int>(
gemm_batch_type_t::strided)>::_select_gemm(sb_handle, _M, _N, _K,
_alpha, _a, _lda,
_stridea, _b, _ldb,
_strideb, _beta, _c,
_ldc, _stridec,
batch_size,
_dependencies);
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
static_cast<int>(
gemm_batch_type_t::strided)>::_select_gemm(sb_handle, _M, _N, _K,
_alpha, _a, _lda,
_stridea, _b, _ldb,
_strideb, _beta, _c,
_ldc, _stridec,
batch_size,
_dependencies);
}

#endif
Expand Down Expand Up @@ -145,10 +170,15 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 4,
static_cast<int>(gemm_batch_type_t::interleaved)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
static_cast<int>(
gemm_batch_type_t::interleaved)>::_select_gemm(sb_handle, _M, _N,
_K, _alpha, _a,
_lda, _stridea, _b,
_ldb, _strideb,
_beta, _c, _ldc,
_stridec,
batch_size,
_dependencies);
}

return blas::Gemm_Launcher<
Expand All @@ -157,10 +187,14 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
batch_size, _dependencies);
static_cast<int>(
gemm_batch_type_t::strided)>::_select_gemm(sb_handle, _M, _N, _K,
_alpha, _a, _lda,
_stridea, _b, _ldb,
_strideb, _beta, _c,
_ldc, _stridec,
batch_size,
_dependencies);
}
}

Expand All @@ -184,21 +218,29 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
batch_size, _dependencies);
static_cast<int>(
gemm_batch_type_t::strided)>::_select_gemm(sb_handle, _M, _N, _K,
_alpha, _a, _lda,
_stridea, _b, _ldb,
_strideb, _beta, _c,
_ldc, _stridec,
batch_size,
_dependencies);
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<8, 8, 4, 4>, _t_a, _t_b, false, false,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::partial), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
batch_size, _dependencies);
static_cast<int>(
gemm_batch_type_t::strided)>::_select_gemm(sb_handle, _M, _N, _K,
_alpha, _a, _lda,
_stridea, _b, _ldb,
_strideb, _beta, _c,
_ldc, _stridec,
batch_size,
_dependencies);
}
}
#endif
Expand Down
Loading

0 comments on commit c9d62cf

Please sign in to comment.