From ac26cfe2e2d2a883a926b64c2c6f7dbe889a0a52 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 11 Sep 2023 13:16:40 +0100 Subject: [PATCH 01/18] Added complex type support to gemm kernels --- CMakeLists.txt | 1 + cmake/CmakeFunctionHelper.cmake | 119 ++++++++++++++- common/include/common/common_utils.hpp | 31 ++++ include/blas_meta.h | 28 ++++ include/operations/blas_constants.h | 10 ++ src/interface/blas3/backend/amd_gpu.hpp | 75 ++++++++- src/interface/blas3/backend/default_cpu.hpp | 70 ++++++++- src/interface/blas3/backend/intel_gpu.hpp | 84 +++++++++- src/interface/blas3/backend/nvidia_gpu.hpp | 46 +++++- src/interface/gemm_interface.hpp | 29 ++-- src/operations/blas1_trees.hpp | 11 ++ src/operations/blas3/gemm_common.hpp | 22 +++ src/operations/blas3/gemm_load_store.hpp | 144 ++++++++++++++++++ src/operations/blas3/gemm_local.hpp | 2 +- .../blas3/gemm_no_local_full_vec.hpp | 45 ++++-- .../blas3/gemm_no_local_partial_vec.hpp | 25 +-- src/operations/blas3/gemm_partial_local.hpp | 4 +- 17 files changed, 677 insertions(+), 69 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a6b85f570..1037b1098 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -220,6 +220,7 @@ option(BUILD_CUBLAS_BENCHMARKS "Whether to build cuBLAS benchmarks" OFF) option(BUILD_ROCBLAS_BENCHMARKS "Whether to build rocBLAS benchmarks" OFF) option(BUILD_ACL_BENCHMARKS "Whether to build ARM Compute Library benchmarks" OFF) option(BLAS_BUILD_SAMPLES "Whether to build portBLAS samples" ON) +option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for supported operators" ON) if (INSTALL_HEADER_ONLY AND BLAS_ENABLE_BENCHMARK) message(STATUS "Benchmarks are disabled when installing portBLAS in header only mode") set(BLAS_ENABLE_BENCHMARK OFF) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index 1be6ab8a8..8b411d5e6 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -36,10 +36,30 @@ function(cpp_type output data) if (${data} STREQUAL "half") set(${output} "cl::sycl::half" PARENT_SCOPE) return() + elseif(${data} STREQUAL "complex") + set(${output} "cl::sycl::ext::oneapi::experimental::complex" PARENT_SCOPE) + return() + elseif(${data} STREQUAL "complex") + set(${output} "cl::sycl::ext::oneapi::experimental::complex" PARENT_SCOPE) + return() endif() set(${output} "${data}" PARENT_SCOPE) endfunction() +function(set_complex_list output input append) + set(output_temp "") + if(${append} STREQUAL "true") + foreach(data ${input}) + list(APPEND output_temp "${data};complex<${data}>") + endforeach(data) + else() + foreach(data ${input}) + list(APPEND output_temp "complex<${data}>") + endforeach(data) + endif() + set(${output} ${output_temp} PARENT_SCOPE) +endfunction(set_complex_list) + ## represent the list of bolean options set(boolean_list "true" "false") @@ -56,6 +76,9 @@ function(sanitize_file_name output file_name) set(${output} "${file_name}" PARENT_SCOPE) endfunction() +#List of operators supporting Complex Data types +set(COMPLEX_OPS "gemm" "gemm_launcher" "scal") + function(set_target_compile_def in_target) #setting compiler flag for backend if(${TUNING_TARGET} STREQUAL "INTEL_GPU") @@ -84,16 +107,31 @@ function(set_target_compile_def in_target) message(STATUS "Gemm vectorization support enabled for target ${in_target}") target_compile_definitions(${in_target} PUBLIC GEMM_VECTORIZATION_SUPPORT=1) endif() - + #setting const data type support if(BLAS_ENABLE_CONST_INPUT) target_compile_definitions(${in_target} PUBLIC BLAS_ENABLE_CONST_INPUT=1) endif() + #setting complex support + if(${BLAS_ENABLE_COMPLEX}) + if("${in_target}" IN_LIST COMPLEX_OPS) + message(STATUS "Complex Data type support enabled for target ${in_target}") + target_compile_definitions(${in_target} PUBLIC BLAS_ENABLE_COMPLEX=1) + endif() + endif() endfunction() # blas unary function for generating source code function(generate_blas_objects blas_level func) set(LOCATION "${PORTBLAS_GENERATED_SRC}/${blas_level}/${func}/") - foreach(data ${data_list}) + set(data_list_c ${data_list}) + # Extend data_list to complex for each data in list + # if target function is in COMPLEX_OPS + if(BLAS_ENABLE_COMPLEX) + if("${func}" IN_LIST COMPLEX_OPS) + set_complex_list(data_list_c "${data_list}" "true") + endif() + endif() + foreach(data ${data_list_c}) cpp_type(cpp_data ${data}) foreach(index ${index_list}) foreach(increment ${index_list}) @@ -234,7 +272,11 @@ function(add_gemm_configuration batch_type use_joint_matrix ) - if(NOT ("${data}" IN_LIST data_list)) + set(data_list_c ${data_list}) + if(BLAS_ENABLE_COMPLEX) + set_complex_list(data_list_c "${data_list}" "true") + endif() + if(NOT ("${data}" IN_LIST data_list_c)) # Data type not enabled, skip configuration return() endif() @@ -380,12 +422,36 @@ if(${TUNING_TARGET} STREQUAL "INTEL_GPU") "${data}" 64 "false" "false" "false" 64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false") endforeach() + if(BLAS_ENABLE_COMPLEX) + # Extract list of complex for each data in supported_types + # list for complex specific gemm configurations + set(data_list_c) + set_complex_list(data_list_c "${supported_types}" "false") + foreach(data ${data_list_c}) + add_gemm_configuration( + "${data}" 64 "true" "false" "false" + 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 4 8 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false") + add_gemm_configuration( + "${data}" 32 "true" "true" "true" + 64 2 1 8 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + endforeach() + endif() # BLAS_ENABLE_COMPLEX elseif(${TUNING_TARGET} STREQUAL "POWER_VR" AND NOT IMGDNN_DIR) set(supported_types "float" "half" ) - foreach(data ${supported_types}) + set(data_list_c ${supported_types}) + if(BLAS_ENABLE_COMPLEX) + set_complex_list(data_list_c "${supported_types}" "false") + endif() + foreach(data ${data_list_c}) add_gemm_configuration( "${data}" 96 "true" "false" "false" 16 4 6 12 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") @@ -445,6 +511,23 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation "${data}" 64 "false" "false" "false" 64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false") endforeach() + if(BLAS_ENABLE_COMPLEX) + # Extract list of complex for each data in supported_types + # list for complex specific gemm configurations + set(data_list_c) + set_complex_list(data_list_c "${supported_types}" "false") + foreach(data ${data_list_c}) + add_gemm_configuration( + "${data}" 256 "true" "true" "true" + 64 1 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 1 1 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + endforeach() + endif() # BLAS_ENABLE_COMPLEX elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") set(supported_types "float" @@ -486,7 +569,18 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") add_gemm_configuration( "${data}" 256 "false" "true" "true" 128 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + endforeach() + if(BLAS_ENABLE_COMPLEX) + # Extract list of complex for each data in supported_types + # list for complex specific gemm configurations + set(data_list_c) + set_complex_list(data_list_c "${supported_types}" "false") + foreach(data ${data_list_c}) + add_gemm_configuration( + "${data}" 64 "false" "false" "true" + 64 8 8 8 8 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") endforeach() + endif() # BLAS_ENABLE_COMPLEX else() # default cpu backend set(supported_types "float" @@ -513,6 +607,23 @@ else() # default cpu backend "${data}" 64 "false" "false" "false" 64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false" "false") endforeach() + if(BLAS_ENABLE_COMPLEX) + # Extract list of complex for each data in supported_types + # list for complex specific gemm configurations + set(data_list_c) + set_complex_list(data_list_c "${supported_types}" "false") + foreach(data ${data_list_c}) + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false") + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false") + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false" "false") + endforeach() + endif() # BLAS_ENABLE_COMPLEX endif() add_library(${func} OBJECT ${gemm_sources}) set_target_compile_def(${func}) diff --git a/common/include/common/common_utils.hpp b/common/include/common/common_utils.hpp index a569ed2ff..df14ac062 100644 --- a/common/include/common/common_utils.hpp +++ b/common/include/common/common_utils.hpp @@ -1372,6 +1372,37 @@ static inline std::vector random_data(size_t size) { return v; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline complex_std random_scalar() { + scalar_t rl = 1e-3 * ((rand() % 2000) - 1000); + scalar_t im = 1e-3 * ((rand() % 2000) - 1000); + return complex_std({rl, im}); +} + +template +static inline complex_std random_scalar(scalar_t rangeMin, + scalar_t rangeMax) { + static std::random_device rd; + static std::default_random_engine gen(rd()); + std::uniform_real_distribution disRl(rangeMin, rangeMax); + std::uniform_real_distribution disIm(rangeMin, rangeMax); + + return complex_std({disRl(gen), disIm(gen)}); +} + +template +static inline std::vector> random_data(size_t size) { + std::vector> v = + std::vector>(size); + + for (scalar_t& e : v) { + e = random_scalar(scalar_t{-2}, scalar_t{5}); + } + return v; +} +#endif + /** * @breif Fills a lower or upper triangular matrix suitable for TRSM testing * @param A The matrix to fill. Size must be at least m * lda diff --git a/include/blas_meta.h b/include/blas_meta.h index 6bad4be98..a7634dbca 100644 --- a/include/blas_meta.h +++ b/include/blas_meta.h @@ -29,6 +29,11 @@ #include #include #include +#ifdef BLAS_ENABLE_COMPLEX +#define SYCL_EXT_ONEAPI_COMPLEX +#include +#include +#endif namespace blas { @@ -190,6 +195,29 @@ struct is_sycl_scalar : std::false_type {}; template <> struct is_sycl_scalar : std::false_type {}; +#ifdef BLAS_ENABLE_COMPLEX +// SYCL Complex type alias +template +using complex_sycl = typename cl::sycl::ext::oneapi::experimental::complex; + +template +struct is_complex_sycl + : std::integral_constant> || + std::is_same_v>> {}; + +// STD Complex type alias +template +using complex_std = typename std::complex; + +template +struct is_complex_std + : std::integral_constant> || + std::is_same_v>> {}; + +#endif + } // namespace blas #endif // BLAS_META_H diff --git a/include/operations/blas_constants.h b/include/operations/blas_constants.h index 103c78152..5fc4afb82 100644 --- a/include/operations/blas_constants.h +++ b/include/operations/blas_constants.h @@ -210,6 +210,16 @@ struct constant, Indicator> { } }; +#ifdef BLAS_ENABLE_COMPLEX +template +struct constant, Indicator> { + constexpr static PORTBLAS_INLINE complex_sycl value() { + return complex_sycl(constant::value(), + constant::value()); + } +}; +#endif + #ifdef BLAS_DATA_TYPE_HALF template <> struct constant diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index be864ae76..3aff8dd46 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -33,13 +33,18 @@ namespace backend { template -typename sb_handle_t::event_t _gemm( - sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, - container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, - gemm_batch_type_t batch_type, - const typename sb_handle_t::event_t& _dependencies) { +#ifdef BLAS_ENABLE_COMPLEX +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +#else +typename sb_handle_t::event_t +#endif +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { static constexpr int ClSize = 64; static constexpr int tileWgSize = ClSize / sizeof(element_t); if (batch_type == gemm_batch_type_t::interleaved) { @@ -142,6 +147,62 @@ typename sb_handle_t::event_t _gemm( batch_size, _dependencies); } } + +// Complex Configurations +#ifdef BLAS_ENABLE_COMPLEX +template +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + static constexpr int ClSize = 64; +/* Tall & Skinny matrices. */ +#ifdef GEMM_TALL_SKINNY_SUPPORT + if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8) && (!s_a && !s_b)) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, true, true, true, + ClSize, Tile<1, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +#endif + if (_M * _N <= 65536) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, false, + ClSize, Tile<1, 1, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, false, + ClSize, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +} +#endif + } // namespace backend } // namespace gemm } // namespace blas diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 17868991e..44a99d1fe 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -33,13 +33,18 @@ namespace backend { template -typename sb_handle_t::event_t _gemm( - sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, - container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, - gemm_batch_type_t batch_type, - const typename sb_handle_t::event_t& _dependencies) { +#ifdef BLAS_ENABLE_COMPLEX +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +#else +typename sb_handle_t::event_t +#endif +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { if (batch_type == gemm_batch_type_t::interleaved) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, @@ -101,6 +106,57 @@ typename sb_handle_t::event_t _gemm( #endif } + +// Complex Configurations +#ifdef BLAS_ENABLE_COMPLEX +template +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } else if (!s_a && !s_b) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +} +#endif + } // namespace backend } // namespace gemm } // namespace blas diff --git a/src/interface/blas3/backend/intel_gpu.hpp b/src/interface/blas3/backend/intel_gpu.hpp index 8fcb3e3a8..e22274008 100644 --- a/src/interface/blas3/backend/intel_gpu.hpp +++ b/src/interface/blas3/backend/intel_gpu.hpp @@ -32,13 +32,18 @@ namespace backend { template -typename sb_handle_t::event_t _gemm( - sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, - container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, - gemm_batch_type_t batch_type, - const typename sb_handle_t::event_t& _dependencies) { +#ifdef BLAS_ENABLE_COMPLEX +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +#else +typename sb_handle_t::event_t +#endif +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { if (batch_type == gemm_batch_type_t::interleaved) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, @@ -206,6 +211,71 @@ typename sb_handle_t::event_t _gemm( batch_size, _dependencies); } } + +// Complex Configurations +#ifdef BLAS_ENABLE_COMPLEX +template +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { +#ifdef GEMM_TALL_SKINNY_SUPPORT + if (!s_a && !s_b && batch_size == 1) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 32, true, true, true, 64, + Tile<2, 1, 8, 4>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +#endif + if (_M <= 128 && _N <= 128) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, true, false, false, 64, + Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } else if (_t_b && !_t_a && !s_a && !s_b) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<4, 8, 16, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +} +#endif + } // namespace backend } // namespace gemm } // namespace blas diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index aeb678704..f13a95d2e 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -33,13 +33,18 @@ namespace backend { template -typename sb_handle_t::event_t _gemm( - sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, - container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, - gemm_batch_type_t batch_type, - const typename sb_handle_t::event_t& _dependencies) { +#ifdef BLAS_ENABLE_COMPLEX +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +#else +typename sb_handle_t::event_t +#endif +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { if (batch_type == gemm_batch_type_t::interleaved) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, @@ -167,6 +172,33 @@ typename sb_handle_t::event_t _gemm( _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); } + +// Complex Configurations +#ifdef BLAS_ENABLE_COMPLEX +template +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, true, 64, + Tile<8, 8, 8, 8, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, + s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); +} +#endif + } // namespace backend } // namespace gemm } // namespace blas diff --git a/src/interface/gemm_interface.hpp b/src/interface/gemm_interface.hpp index a5c2c7bb3..f5b7383e6 100644 --- a/src/interface/gemm_interface.hpp +++ b/src/interface/gemm_interface.hpp @@ -48,6 +48,18 @@ namespace blas { */ namespace internal { +// Check whether value is zero (complex & float/double) +template +inline bool isZero(const T& value) { +#ifdef BLAS_ENABLE_COMPLEX + if constexpr (is_complex_sycl::value) { + using value_t = typename T::value_type; + return (value == T(value_t(0), value_t(0))); + } +#endif + return (value == static_cast(0)); +} + template @@ -73,15 +85,14 @@ typename sb_handle_t::event_t _gemm_is_beta_zero( container_2_t _C, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { - return ((_beta == static_cast(0)) - ? _gemm_platform_specific<_t_a, _t_b, s_a, s_b, true>( - sb_handle, _M, _N, _K, _alpha, a_, _lda, _stridea, b_, _ldb, - _strideb, _beta, _C, _ldc, _stridec, batch_size, batch_type, - _dependencies) - : _gemm_platform_specific<_t_a, _t_b, s_a, s_b, false>( - sb_handle, _M, _N, _K, _alpha, a_, _lda, _stridea, b_, _ldb, - _strideb, _beta, _C, _ldc, _stridec, batch_size, batch_type, - _dependencies)); + return isZero(_beta) ? _gemm_platform_specific<_t_a, _t_b, s_a, s_b, true>( + sb_handle, _M, _N, _K, _alpha, a_, _lda, _stridea, + b_, _ldb, _strideb, _beta, _C, _ldc, _stridec, + batch_size, batch_type, _dependencies) + : _gemm_platform_specific<_t_a, _t_b, s_a, s_b, false>( + sb_handle, _M, _N, _K, _alpha, a_, _lda, _stridea, + b_, _ldb, _strideb, _beta, _C, _ldc, _stridec, + batch_size, batch_type, _dependencies); } template > { static element_t get_scalar(element_t &scalar) { return scalar; } }; +#ifdef BLAS_ENABLE_COMPLEX +/*! DetectScalar (for sycl::complex) + * @brief See Detect Scalar. + */ +template +struct DetectScalar> { + using element_t = complex_sycl; + static element_t get_scalar(element_t &scalar) { return scalar; } +}; +#endif + /*! get_scalar. * @brief Template autodecuction function for DetectScalar. */ diff --git a/src/operations/blas3/gemm_common.hpp b/src/operations/blas3/gemm_common.hpp index 4966b9f13..6923f492b 100644 --- a/src/operations/blas3/gemm_common.hpp +++ b/src/operations/blas3/gemm_common.hpp @@ -33,6 +33,28 @@ namespace blas { +#ifdef BLAS_ENABLE_COMPLEX +template +static PORTBLAS_INLINE T +mul_add(T a, T b, T c, + typename std::enable_if::value>::type * = 0) { + return (a * b + c); +} + +template +static PORTBLAS_INLINE T +mul_add(T a, T b, T c, + typename std::enable_if::value>::type * = 0) { + return (sycl::mad(a, b, c)); +} +#else + +template +static PORTBLAS_INLINE T mul_add(T a, T b, T c) { + return (sycl::mad(a, b, c)); +} +#endif + template struct type_string { static const char *get_value() { return "unknown"; } diff --git a/src/operations/blas3/gemm_load_store.hpp b/src/operations/blas3/gemm_load_store.hpp index ef44cbfe6..7ae45ce5d 100644 --- a/src/operations/blas3/gemm_load_store.hpp +++ b/src/operations/blas3/gemm_load_store.hpp @@ -125,5 +125,149 @@ struct Packetize { } }; +#ifdef BLAS_ENABLE_COMPLEX +/*! @brief vec_complex is an intermediate wrapper of sycl::complex used in + * Packetize. It serves as a temporary workaround to the upcoming + * sycl::vec container + * github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_complex.asciidoc + * and only supports size = 1. + * @tparam DataT Complex type of the vector's data + * @tparam NumElements Elements count of the vector (only 1 is supported) + */ +template +class vec_complex { + static_assert(NumElements == 1, + "Vector wrapper arround sycl::complex of size>1 unsupported."); + using address_t = cl::sycl::access::address_space; + using decorated_t = cl::sycl::access::decorated; + using DataType = DataT; + static constexpr int getNumElements() { return NumElements; } + size_t size() const noexcept { return NumElements; } + + private: + DataType m_Data; + + public: + vec_complex() = default; + + constexpr vec_complex(const vec_complex &rhs) = default; + constexpr vec_complex(vec_complex &&rhs) = default; + constexpr vec_complex &operator=(const vec_complex &rhs) = default; + + vec_complex(const DataType &rhs_data) : m_Data{rhs_data} {} + + // Conversion operator (valid with NumElements==1) + operator DataT() const { return m_Data; } + + // Subscript operators + DataT &operator[](int i) { + assert(i < NumElements); + return (m_Data); + } + const DataT &operator[](int i) const { + assert(i < NumElements); + return (m_Data); + } + + // Binary Ops + // Multiply + vec_complex operator*(const vec_complex &rhs) { + return (vec_complex{m_Data * static_cast(rhs)}); + } + + vec_complex operator*(const DataType &rhs) { + return (vec_complex{m_Data * rhs}); + } + + // Compound Multiply + vec_complex &operator*=(const DataType &rhs) { + this->m_Data = this->m_Data * rhs; + return (*this); + } + + vec_complex &operator*=(const vec_complex &rhs) { + this->m_Data = this->m_Data * static_cast(rhs); + return (*this); + } + + // Add + vec_complex operator+(const vec_complex &rhs) { + return (vec_complex{m_Data + static_cast(rhs)}); + } + + vec_complex operator+(const DataType &rhs) { + return (vec_complex{m_Data + rhs}); + } + + // Compound Add + vec_complex &operator+=(const DataType &rhs) { + this->m_Data = this->m_Data * rhs; + return (*this); + } + + vec_complex &operator+=(const vec_complex &rhs) { + this->m_Data = this->m_Data + static_cast(rhs); + return (*this); + } + + // Load + template + void load(size_t Offset, + cl::sycl::multi_ptr Ptr) { + m_Data = *(Ptr + Offset * NumElements); + } + + // Store + template + void store(size_t Offset, + cl::sycl::multi_ptr Ptr) const { + *(Ptr + Offset * NumElements) = m_Data; + } +}; + +/*! @brief Partial specialization of the Packetize class dedicated to +sycl::complex types. It contains static methods for loading and storing size=1 +complex packets from/to memory. +* @tparam vector_size The desired vector size to be used. Only size = 1 is +supported so far. +* @tparam value_t The complex type of the matrix data. +*/ +template +struct Packetize, index_t> { + // Vectorization is not enabled for complex, always set to 1 + using value_t = complex_sycl; + using PacketType = vec_complex; + static constexpr int packet_size = 1; + template + static PORTBLAS_INLINE constexpr bool check_size() { + return true; + } + + /*! @brief Performs a non-vectorised load of sycl::complex data element while + * whether block is internal or not since vectorization is not enabled for + * complex types yet. + * @tparam trans Whether the source matrix is transposed or not. + * @tparam internal True if the current block is internal and no bounds + * checking is required. + * @tparam ld The leading dimension of the destination memory. */ + template + static PORTBLAS_INLINE void load(const bool in_range, SrcPointerType src, + DestPointerType dest, + EdgePredicate edge_in_range) { + *(dest) = in_range ? *(src) : value_t{(T)0, (T)0}; + } + + /*! @brief Store a size = 1 vector packet of sycl::complex data into local + * memory (whether source is transposed or not since it's only 1 element). + * @tparam trans Whether the source matrix is transposed or not. + * @tparam ld The leading dimension of the destination memory.*/ + template + static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) { + *dest = packet[0]; + } +}; +#endif + } // namespace blas #endif // PORTBLAS_BLAS3_GEMM_LOAD_STORE_HPP diff --git a/src/operations/blas3/gemm_local.hpp b/src/operations/blas3/gemm_local.hpp index 9b1c1c98b..db0fe6f14 100644 --- a/src/operations/blas3/gemm_local.hpp +++ b/src/operations/blas3/gemm_local.hpp @@ -754,7 +754,7 @@ class Gemm(reg_a[l], reg_b, reg_res[j * item_rows + l]); } } A = A + ldsa; diff --git a/src/operations/blas3/gemm_no_local_full_vec.hpp b/src/operations/blas3/gemm_no_local_full_vec.hpp index a5dc683f3..732cc9568 100644 --- a/src/operations/blas3/gemm_no_local_full_vec.hpp +++ b/src/operations/blas3/gemm_no_local_full_vec.hpp @@ -69,6 +69,7 @@ class Gemm::type; using address_t = cl::sycl::access::address_space; using packetize_t = Packetize; + using vector_t = typename packetize_t::PacketType; static constexpr int local_memory_size = 0; /*! @brief The number of rows processed by each work item */ static constexpr index_t item_rows = tile_type::item_rows; @@ -114,8 +115,8 @@ class Gemm(check_boundary( dim_m_c_start + j * wg_rows, dim_n_c_start + i * wg_cols))) { - cl::sycl::vec out_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t out_vec{}; out_vec.template load( 0, cl::sycl::multi_ptr( @@ -552,7 +555,9 @@ class Gemm(is_valid_row(j * ptr_next + work_per_load - 1)); - cl::sycl::vec in_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{}; if (in_range) { // if in range perform a vectorised load in_vec.template load( @@ -630,7 +635,9 @@ class Gemm(is_valid_col(work_per_load - 1)); - cl::sycl::vec in_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{}; if (in_range) { // if in range perform a vectorised load in_vec.template load( @@ -705,7 +712,9 @@ class Gemm(is_valid_row(work_per_load - 1)) && do_check(is_valid_col(col_ofs)); - cl::sycl::vec in_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{}; if (in_range) { // If in range perform a vectorised load. in_vec.template load( @@ -768,7 +777,9 @@ class Gemm(is_valid_row(row_ofs)) && do_check(is_valid_col(work_per_load - 1)); - cl::sycl::vec in_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{}; if (in_range) { // If in range perform a vectorised load. in_vec.template load( @@ -808,7 +819,7 @@ class Gemm(reg_a[j], reg_b[i], reg_res[i * item_rows + j]); } } } @@ -860,7 +871,7 @@ class Gemm(reg_a[j], *reg_b, reg_res[j]); } } @@ -887,11 +898,11 @@ class Gemm PORTBLAS_INLINE void store(PointerType C, element_t *reg_res, - const index_t &dim_m_c_start, - const index_t &dim_n_c_start, - const check_boundary &chk_boundary, - const bool out_of_range, - const index_t &ldc) noexcept { + const index_t &dim_m_c_start, + const index_t &dim_n_c_start, + const check_boundary &chk_boundary, + const bool out_of_range, + const index_t &ldc) noexcept { if (out_of_range) { return; } @@ -901,7 +912,9 @@ class Gemm(chk_boundary(dim_m_c_start + j * wg_rows, dim_n_c_start + i * wg_cols))) { - cl::sycl::vec out_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t out_vec{}; out_vec.template load( 0, cl::sycl::multi_ptr( diff --git a/src/operations/blas3/gemm_no_local_partial_vec.hpp b/src/operations/blas3/gemm_no_local_partial_vec.hpp index eb3d19473..189de963b 100644 --- a/src/operations/blas3/gemm_no_local_partial_vec.hpp +++ b/src/operations/blas3/gemm_no_local_partial_vec.hpp @@ -69,6 +69,7 @@ class Gemm::type; using address_t = cl::sycl::access::address_space; using packetize_t = Packetize; + using vector_t = typename packetize_t::PacketType; static constexpr int local_memory_size = 0; /*! @brief The number of rows processed by each work item */ static constexpr index_t item_rows = tile_type::item_rows; @@ -458,7 +459,9 @@ class Gemm(chk_boundary(index + (work_per_load - 1))); - cl::sycl::vec in_vec{0}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{0}; if (in_range) { in_vec.template load( 0, @@ -488,7 +491,7 @@ class Gemm(reg_a[j], reg_b[i], reg_res[i * item_rows + j]); } } } @@ -502,7 +505,9 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type store_packet( element_t *reg, OutputPointerType out_ptr) { - cl::sycl::vec out_vec{0}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t out_vec{0}; out_vec.template load( 0, cl::sycl::multi_ptr(reg)); @@ -531,11 +536,11 @@ class Gemm PORTBLAS_INLINE void store(PointerType C, element_t *reg_res, - const index_t &dim_m_c_start, - const index_t &dim_n_c_start, - const check_boundary &chk_boundary, - const bool out_of_range, - const index_t &ldc) noexcept { + const index_t &dim_m_c_start, + const index_t &dim_n_c_start, + const check_boundary &chk_boundary, + const bool out_of_range, + const index_t &ldc) noexcept { if (out_of_range) { return; } @@ -545,7 +550,9 @@ class Gemm(chk_boundary(dim_m_c_start + j * wg_rows, dim_n_c_start + i * wg_cols))) { - cl::sycl::vec out_vec{0}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t out_vec{0}; out_vec.template load( 0, cl::sycl::multi_ptr( diff --git a/src/operations/blas3/gemm_partial_local.hpp b/src/operations/blas3/gemm_partial_local.hpp index a9de19fb8..a6f8bf30a 100644 --- a/src/operations/blas3/gemm_partial_local.hpp +++ b/src/operations/blas3/gemm_partial_local.hpp @@ -309,8 +309,8 @@ class GemmPartial( + privateLhs, privateRhs, private_res[wLPTM + idx]); lhs_index += tile_type::wg_rows; } From a4d6b8fac9d6543088ec318ad41cdfe8b06e6b38 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 11 Sep 2023 13:17:45 +0100 Subject: [PATCH 02/18] Added unit tests for complex type gemm operators --- common/include/common/float_comparison.hpp | 98 ++++- .../include/common/system_reference_blas.hpp | 15 + test/blas_test.hpp | 48 ++- test/blas_test_macros.hpp | 42 +++ test/unittest/CMakeLists.txt | 5 + .../blas3/blas3_gemm_batched_test.cpp | 64 ++++ test/unittest/blas3/blas3_gemm_common.hpp | 338 +++++++++++++++++- .../blas3/blas3_gemm_tall_skinny_test.cpp | 78 ++++ test/unittest/blas3/blas3_gemm_test.cpp | 118 ++++++ 9 files changed, 803 insertions(+), 3 deletions(-) diff --git a/common/include/common/float_comparison.hpp b/common/include/common/float_comparison.hpp index 43f8f578b..e244f0d5a 100644 --- a/common/include/common/float_comparison.hpp +++ b/common/include/common/float_comparison.hpp @@ -28,6 +28,9 @@ #include #include +#ifdef BLAS_ENABLE_COMPLEX +#include +#endif #ifdef BLAS_DATA_TYPE_HALF #if SYCL_LANGUAGE_VERSION < 202000 @@ -65,6 +68,23 @@ scalar_t abs(scalar_t value) noexcept { return std::abs(value); } +#ifdef BLAS_ENABLE_COMPLEX +template +bool isnan(std::complex value) noexcept { + return (isnan(value.imag()) || isnan(value.imag())); +} + +template +bool isinf(std::complex value) noexcept { + return (isinf(value.imag()) || isinf(value.imag())); +} + +template +scalar_t abs(std::complex value) noexcept { + return std::abs(value); +} +#endif + #ifdef BLAS_DATA_TYPE_HALF template <> inline bool isnan(cl::sycl::half value) noexcept { @@ -172,7 +192,7 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) { return true; } - const scalar_t absolute_diff = utils::abs(scalar1 - scalar2); + const auto absolute_diff = utils::abs(scalar1 - scalar2); // Close to zero, the relative error doesn't work, use absolute error if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} || @@ -212,6 +232,37 @@ inline bool compare_vectors(std::vector const& vec, return true; } +#ifdef BLAS_ENABLE_COMPLEX +/** + * Compare two vectors of complex data and returns false if the difference is + * not acceptable. The second vector is considered the reference. + * @tparam scalar_t the type of complex underying data present in the input + * vectors + * @tparam epilon_t the type used as tolerance. + */ +template +inline bool compare_vectors(std::vector> const& vec, + std::vector> const& ref, + std::ostream& err_stream = std::cerr, + std::string end_line = "\n") { + if (vec.size() != ref.size()) { + err_stream << "Error: tried to compare vectors of different sizes" + << std::endl; + return false; + } + + for (int i = 0; i < vec.size(); ++i) { + if (!almost_equal, epsilon_t>(vec[i], ref[i])) { + err_stream << "Value mismatch at index " << i << ": (" << vec[i].real() + << "," << vec[i].imag() << "); expected (" << ref[i].real() + << "," << ref[i].imag() << ")" << end_line; + return false; + } + } + return true; +} +#endif + /** * Compare two vectors at a given stride and window (unit_vec_size) and returns * false if the difference is not acceptable. The second vector is considered @@ -253,6 +304,51 @@ inline bool compare_vectors_strided(std::vector const& vec, return true; } +#ifdef BLAS_ENABLE_COMPLEX +/** + * Compare two vectors of complex data at a given stride and window and returns + * false if the difference is not acceptable. The second vector is considered + * the reference. + * @tparam scalar_t the type of the complex underying data present in the input + * vectors + * @tparam epsilon_t the type used as tolerance. + * @param stride is the stride between two consecutive 'windows' + * @param window is the size of a comparison window + */ +template +inline bool compare_vectors_strided( + std::vector> const& vec, + std::vector> const& ref, int stride, int window, + std::ostream& err_stream = std::cerr, std::string end_line = "\n") { + if (vec.size() != ref.size()) { + err_stream << "Error: tried to compare vectors of different sizes" + << std::endl; + return false; + } + + int k = 0; + + // Loop over windows + while (window + (k + 1) * stride < vec.size()) { + // Loop within a window + for (int i = 0; i < window; ++i) { + auto index = i + k * stride; + if (!almost_equal, epsilon_t>(vec[index], + ref[index])) { + err_stream << "Value mismatch at index " << index << ": (" + << vec[index].real() << "," << vec[index].imag() + << "); expected (" << ref[index].real() << "," + << ref[index].imag() << ")" << end_line; + return false; + } + } + k += 1; + } + + return true; +} +#endif + } // namespace utils #endif // UTILS_FLOAT_COMPARISON_H_ diff --git a/common/include/common/system_reference_blas.hpp b/common/include/common/system_reference_blas.hpp index afcb4f5e4..cd07e27cf 100644 --- a/common/include/common/system_reference_blas.hpp +++ b/common/include/common/system_reference_blas.hpp @@ -133,6 +133,12 @@ auto blas_system_function(floatfn_t ffn, doublefn_t dfn) return BlasSystemFunction::get(ffn, dfn); } +template +auto blas_cplx_system_function(floatfn_t ffn, doublefn_t dfn) + -> decltype(BlasSystemFunction::get(ffn, dfn)) { + return BlasSystemFunction::get(ffn, dfn); +} + // ======= // Level 1 // ======= @@ -378,6 +384,15 @@ void gemm(const char *transA, const char *transB, int m, int n, int k, lda, b, ldb, beta, c, ldc); } +template +void cgemm(const char *transA, const char *transB, int m, int n, int k, + const void *alpha, const void *a, int lda, const void *b, int ldb, + const void *beta, void *c, int ldc) { + auto func = blas_cplx_system_function(&cblas_cgemm, &cblas_zgemm); + func(CblasColMajor, c_trans(*transA), c_trans(*transB), m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc); +} + template void trsm(const char *side, const char *uplo, const char *trans, const char *diag, int m, int n, scalar_t alpha, const scalar_t A[], diff --git a/test/blas_test.hpp b/test/blas_test.hpp index 1d0f39de3..70a32d61c 100644 --- a/test/blas_test.hpp +++ b/test/blas_test.hpp @@ -149,6 +149,34 @@ static inline void fill_random(std::vector &vec) { fill_random_with_range(vec, scalar_t{-2}, scalar_t{5}); } +#ifdef BLAS_ENABLE_COMPLEX +/** + * @brief Generates a random vector of std::complex values, using a + * uniform distribution. + * @param vec Input vector to fill + * @param rangeMin Minimum value for the uniform distribution (real & imag) + * @param rangeMax Maximum value for the uniform distribution (real & imag) + */ +template +static inline void fill_random_with_range( + std::vector> &vec, scalar_t rangeMin, + scalar_t rangeMax) { + for (complex_std &e : vec) { + e = complex_std{random_scalar(rangeMin, rangeMax), + random_scalar(rangeMin, rangeMax)}; + } +} + +/** + * @brief Generates a random vector of std::complex values, using a + * uniform distribution. + */ +template +static inline void fill_random(std::vector> &vec) { + fill_random_with_range(vec, scalar_t{-2}, scalar_t{5}); +} +#endif + /** * @brief Fills a lower or upper triangular matrix suitable for TRSM testing * @param A The matrix to fill. Size must be at least m * lda @@ -165,7 +193,7 @@ static inline void fill_random(std::vector &vec) { * @param unused Value to put in the unused parts of the matrix */ template -static inline void fill_trsm_matrix(std::vector& A, size_t k, +static inline void fill_trsm_matrix(std::vector &A, size_t k, size_t lda, char uplo, char unit_diag, scalar_t diag = scalar_t{1}, scalar_t unused = scalar_t{0}) { @@ -262,6 +290,24 @@ struct dump_arg_helper { } }; +#ifdef BLAS_ENABLE_COMPLEX +/** Specialization of dump_arg_helper for std::complex types. + * This is required to split the real & imag parts properly and avoid + * by-default parentheses format. + **/ +template +struct dump_arg_helper< + T, typename std::enable_if::value>::type> { + inline void operator()(std::ostream &ss, T f) { + using scalar_t = typename T::value_type; + dump_arg_helper{}(ss, f.real()); + ss << "r"; + dump_arg_helper{}(ss, f.imag()); + ss << "i"; + } +}; +#endif + /** * Type of the tested api */ diff --git a/test/blas_test_macros.hpp b/test/blas_test_macros.hpp index 5b4cf979c..89e733e60 100644 --- a/test/blas_test_macros.hpp +++ b/test/blas_test_macros.hpp @@ -93,6 +93,36 @@ combination, name_generator) #endif // BLAS_DATA_TYPE_HALF +#ifdef BLAS_ENABLE_COMPLEX +#define BLAS_REGISTER_TEST_CPLX_S_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) \ + class class_name##CplxFloat \ + : public ::testing::TestWithParam> {}; \ + TEST_P(class_name##CplxFloat, test) { test_function(GetParam()); }; \ + INSTANTIATE_TEST_SUITE_P(test_suite, class_name##CplxFloat, \ + combination, name_generator); +#else +#define BLAS_REGISTER_TEST_CPLX_S_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) +#endif // BLAS_ENABLE_COMPLEX + +#if defined(BLAS_DATA_TYPE_DOUBLE) & defined(BLAS_ENABLE_COMPLEX) +#define BLAS_REGISTER_TEST_CPLX_D_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) \ + class class_name##CplxDouble \ + : public ::testing::TestWithParam> {}; \ + TEST_P(class_name##CplxDouble, test) { test_function(GetParam()); }; \ + INSTANTIATE_TEST_SUITE_P(test_suite, class_name##CplxDouble, \ + combination, name_generator); +#else +#define BLAS_REGISTER_TEST_CPLX_D_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) +#endif // BLAS_ENABLE_COMPLEX & BLAS_ENABLE_COMPLEX + /** Registers test for all supported data types * @param test_suite Name of the test suite * @param class_name Base name of the test class @@ -115,6 +145,18 @@ combination_t, combination, \ name_generator); +#ifdef BLAS_ENABLE_COMPLEX +#define BLAS_REGISTER_CPLX_TEST_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) \ + BLAS_REGISTER_TEST_CPLX_S_CUSTOM_NAME(test_suite, class_name, test_function, \ + combination_t, combination, \ + name_generator); \ + BLAS_REGISTER_TEST_CPLX_D_CUSTOM_NAME(test_suite, class_name, test_function, \ + combination_t, combination, \ + name_generator); +#endif // BLAS_ENABLE_COMPLEX + /** Registers test for all supported data types * @see BLAS_REGISTER_TEST_CUSTOM_NAME */ diff --git a/test/unittest/CMakeLists.txt b/test/unittest/CMakeLists.txt index 4f824238d..b4d2b0a3b 100644 --- a/test/unittest/CMakeLists.txt +++ b/test/unittest/CMakeLists.txt @@ -116,6 +116,11 @@ foreach(blas_test ${SYCL_UNITTEST_SRCS}) if(STRESS_TESTING) target_compile_definitions(${test_exec} PRIVATE STRESS_TESTING) endif() + if(${BLAS_ENABLE_COMPLEX}) + if(${test_exec} MATCHES "gemm") + target_compile_definitions(${test_exec} PRIVATE BLAS_ENABLE_COMPLEX=1) + endif() + endif() target_compile_definitions(${test_exec} PRIVATE -DBLAS_INDEX_T=${BLAS_TEST_INDEX_TYPE}) target_link_libraries(${test_exec} PRIVATE gtest_main Clara::Clara blas::blas portblas) target_include_directories(${test_exec} PRIVATE ${CBLAS_INCLUDE} ${PORTBLAS_COMMON_INCLUDE_DIR}) diff --git a/test/unittest/blas3/blas3_gemm_batched_test.cpp b/test/unittest/blas3/blas3_gemm_batched_test.cpp index 1ce9413bd..6794ff56c 100644 --- a/test/unittest/blas3/blas3_gemm_batched_test.cpp +++ b/test/unittest/blas3/blas3_gemm_batched_test.cpp @@ -145,3 +145,67 @@ const auto AllStridedBatched = ::testing::Values(1, 2, 3) // stride_c_mul ); GENERATE_GEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, AllStridedBatched); + +#ifdef BLAS_ENABLE_COMPLEX +template +const auto CplxBetaNonZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(5), // batch + ::testing::Values(63, 128), // m + ::testing::Values(63, 128), // n + ::testing::Values(63, 128), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(BatchGemm, CplxBetaNonZeroLDMatch); + +template +const auto CplxDefaultGemmAndGemmBatched = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1, 5), // batch + ::testing::Values(63, 128), // m + ::testing::Values(63, 128), // n + ::testing::Values(63, 128), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({2.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(1), // stride_a_mul + ::testing::Values(1), // stride_b_mul + ::testing::Values(1) // stride_c_mul +); +GENERATE_CPLXGEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, + CplxDefaultGemmAndGemmBatched); + +template +const auto CplxAllStridedBatched = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(5), // batch + ::testing::Values(128), // m + ::testing::Values(128), // n + ::testing::Values(128), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({2.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(2), // lda_mul + ::testing::Values(3), // ldb_mul + ::testing::Values(4), // ldc_mul + ::testing::Values(0, 1, 2), // stride_a_mul + ::testing::Values(0, 1, 2), // stride_b_mul + ::testing::Values(1, 2, 3) // stride_c_mul +); +GENERATE_CPLXGEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, CplxAllStridedBatched); +#endif diff --git a/test/unittest/blas3/blas3_gemm_common.hpp b/test/unittest/blas3/blas3_gemm_common.hpp index 48bd28128..d28baa99a 100644 --- a/test/unittest/blas3/blas3_gemm_common.hpp +++ b/test/unittest/blas3/blas3_gemm_common.hpp @@ -37,6 +37,19 @@ using gemm_batched_strided_arguments_t = std::tuple; +#ifdef BLAS_ENABLE_COMPLEX +template +using gemm_cplx_arguments_t = + std::tuple, std::complex, int, int, int, + gemm_batch_type_t>; + +template +using gemm_cplx_batched_strided_arguments_t = + std::tuple, std::complex, int, int, int, int, int, int>; +#endif + // Convert batch_type=strided to interleaved on the host template inline std::vector strided_to_interleaved( @@ -383,4 +396,327 @@ static std::string generate_batched_strided_name( BLAS_REGISTER_TEST_CUSTOM_NAME(test_suite, test_suite##combination, \ verify_gemm, \ gemm_batched_strided_arguments_t, \ - combination, generate_batched_strided_name); \ No newline at end of file + combination, generate_batched_strided_name); + +#ifdef BLAS_ENABLE_COMPLEX + +template +inline void verify_gemm(const gemm_cplx_arguments_t arguments) { + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + complex_std alpha; + complex_std beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + gemm_batch_type_t batch_type; + std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, + ldb_mul, ldc_mul, batch_type) = arguments; + + const char ta_str[2] = {transa, '\0'}; + const char tb_str[2] = {transb, '\0'}; + + auto q = make_queue(); + blas::SB_Handle sb_handle(q); + + const index_t lda = ((transa != 'n') ? k : m) * lda_mul; + const index_t ldb = ((transb != 'n') ? n : k) * ldb_mul; + const index_t ldc = m * ldc_mul; + + const index_t size_a = m * k * lda_mul; + const index_t size_b = k * n * ldb_mul; + const index_t size_c = m * n * ldc_mul; + + const index_t buffer_size_a = batch * size_a + offset; + const index_t buffer_size_b = batch * size_b + offset; + const index_t buffer_size_c = batch * size_c + offset; + + std::vector> a_m(buffer_size_a); + std::vector> b_m(buffer_size_b); + std::vector> c_m_gpu(buffer_size_c); + + fill_random(a_m); + fill_random(b_m); + fill_random(c_m_gpu); + std::vector> c_m_cpu = c_m_gpu; + + // Use system blas to create a reference output + for (int i = 0; i < batch; ++i) { + reference_blas::cgemm( + ta_str, tb_str, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a_m.data() + i * size_a + offset), lda, + reinterpret_cast(b_m.data() + i * size_b + offset), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_m_cpu.data() + i * size_c + offset), ldc); + } + + if (batch > 1 && batch_type == gemm_batch_type_t::interleaved) { + // Interleaved batched gemm unsupported + GTEST_SKIP(); + } + + auto m_a_gpu = blas::helper::allocate>( + buffer_size_a, q); + auto m_b_gpu = blas::helper::allocate>( + buffer_size_b, q); + auto m_c_gpu = blas::helper::allocate>( + buffer_size_c, q); + + auto copy_a = blas::helper::copy_to_device( + q, reinterpret_cast*>(a_m.data()), m_a_gpu, + buffer_size_a); + auto copy_b = blas::helper::copy_to_device( + q, reinterpret_cast*>(b_m.data()), m_b_gpu, + buffer_size_b); + auto copy_c = blas::helper::copy_to_device( + q, reinterpret_cast*>(c_m_gpu.data()), m_c_gpu, + buffer_size_c); + + complex_sycl alpha_sycl(alpha); + complex_sycl beta_sycl(beta); + + // portBLAS GEMM implementation + typename blas::SB_Handle::event_t gemm_event; + if (batch == index_t(1)) { + gemm_event = _gemm(sb_handle, transa, transb, m, n, k, alpha_sycl, + m_a_gpu + offset, lda, m_b_gpu + offset, ldb, beta_sycl, + m_c_gpu + offset, ldc, {copy_a, copy_b, copy_c}); + } else { + return; + _gemm_batched(sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, + lda, m_b_gpu + offset, ldb, beta, m_c_gpu + offset, ldc, + batch, batch_type, {copy_a, copy_b, copy_c}); + } + sb_handle.wait(gemm_event); + + auto event = blas::helper::copy_to_host( + q, m_c_gpu, reinterpret_cast*>(c_m_gpu.data()), + buffer_size_c); + sb_handle.wait(event); + + const bool isAlmostEqual = utils::compare_vectors(c_m_gpu, c_m_cpu); + ASSERT_TRUE(isAlmostEqual); + + helper::deallocate(m_a_gpu, q); + helper::deallocate(m_b_gpu, q); + helper::deallocate(m_c_gpu, q); +} + +template +inline void verify_gemm(const gemm_cplx_arguments_t arguments) { + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + complex_std alpha; + complex_std beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + gemm_batch_type_t batch_type; + std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, + ldb_mul, ldc_mul, batch_type) = arguments; + + if (alloc == "usm") { +#ifdef SB_ENABLE_USM + verify_gemm(arguments); +#else + GTEST_SKIP(); +#endif + } else { + verify_gemm(arguments); + } +} + +template +static std::string generate_cplx_name( + const ::testing::TestParamInfo>& info) { + std::string alloc; + int offset, batch, m, n, k, ldaMul, ldbMul, ldcMul; + char transa, transb; + complex_std alpha, beta; + gemm_batch_type_t batchType; + BLAS_GENERATE_NAME(info.param, alloc, offset, batch, m, n, k, transa, transb, + alpha, beta, ldaMul, ldbMul, ldcMul, batchType); +} + +template +inline void verify_gemm( + const gemm_cplx_batched_strided_arguments_t arguments) { + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + complex_std alpha; + complex_std beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + index_t stride_a_mul; + index_t stride_b_mul; + index_t stride_c_mul; + std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, + ldb_mul, ldc_mul, stride_a_mul, stride_b_mul, stride_c_mul) = + arguments; + + const char ta_str[2] = {transa, '\0'}; + const char tb_str[2] = {transb, '\0'}; + + auto q = make_queue(); + blas::SB_Handle sb_handle(q); + + const index_t lda = ((transa != 'n') ? k : m) * lda_mul; + const index_t ldb = ((transb != 'n') ? n : k) * ldb_mul; + const index_t ldc = m * ldc_mul; + + const index_t size_a = m * k * lda_mul; + const index_t size_b = k * n * ldb_mul; + const index_t size_c = m * n * ldc_mul; + + const index_t stride_a = stride_a_mul * size_a; + const index_t stride_b = stride_b_mul * size_b; + const index_t stride_c = stride_c_mul * size_c; + + const index_t buffer_size_a = size_a + (batch - 1) * stride_a + offset; + const index_t buffer_size_b = size_b + (batch - 1) * stride_b + offset; + const index_t buffer_size_c = size_c + (batch - 1) * stride_c + offset; + + std::vector> a_m(buffer_size_a); + std::vector> b_m(buffer_size_b); + std::vector> c_m_gpu(buffer_size_c); + + fill_random(a_m); + fill_random(b_m); + fill_random(c_m_gpu); + std::vector> c_m_cpu = c_m_gpu; + + // Use system blas to create a reference output + for (int i = 0; i < batch; ++i) { + reference_blas::cgemm( + ta_str, tb_str, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a_m.data() + i * stride_a + offset), lda, + reinterpret_cast(b_m.data() + i * stride_b + offset), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_m_cpu.data() + i * stride_c + offset), ldc); + } + + auto m_a_gpu = blas::helper::allocate>( + buffer_size_a, q); + auto m_b_gpu = blas::helper::allocate>( + buffer_size_b, q); + auto m_c_gpu = blas::helper::allocate>( + buffer_size_c, q); + + auto copy_a = blas::helper::copy_to_device( + q, reinterpret_cast*>(a_m.data()), m_a_gpu, + buffer_size_a); + auto copy_b = blas::helper::copy_to_device( + q, reinterpret_cast*>(b_m.data()), m_b_gpu, + buffer_size_b); + auto copy_c = blas::helper::copy_to_device( + q, reinterpret_cast*>(c_m_gpu.data()), m_c_gpu, + buffer_size_c); + + complex_sycl alpha_sycl(alpha); + complex_sycl beta_sycl(beta); + + // portBLAS GEMM STRIDED BATCHED implementation + auto gemm_batched_event = _gemm_strided_batched( + sb_handle, transa, transb, m, n, k, alpha_sycl, m_a_gpu + offset, lda, + stride_a, m_b_gpu + offset, ldb, stride_b, beta_sycl, m_c_gpu + offset, + ldc, stride_c, batch, {copy_a, copy_b, copy_c}); + + sb_handle.wait({gemm_batched_event}); + auto event = blas::helper::copy_to_host( + q, m_c_gpu, reinterpret_cast*>(c_m_gpu.data()), + buffer_size_c); + sb_handle.wait(event); + + const bool isAlmostEqual = + (stride_c_mul == 1) + ? utils::compare_vectors(c_m_gpu, c_m_cpu) + : utils::compare_vectors_strided(c_m_gpu, c_m_cpu, stride_c, size_c); + ASSERT_TRUE(isAlmostEqual); + + helper::deallocate(m_a_gpu, q); + helper::deallocate(m_b_gpu, q); + helper::deallocate(m_c_gpu, q); +} + +template +inline void verify_gemm( + const gemm_cplx_batched_strided_arguments_t arguments) { + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + complex_std alpha; + complex_std beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + index_t stride_a_mul; + index_t stride_b_mul; + index_t stride_c_mul; + std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, + ldb_mul, ldc_mul, stride_a_mul, stride_b_mul, stride_c_mul) = + arguments; + + if (alloc == "usm") { +#ifdef SB_ENABLE_USM + verify_gemm(arguments); +#endif + } else { + verify_gemm(arguments); + } +} + +template +static std::string generate_cplx_batched_strided_name( + const ::testing::TestParamInfo>& + info) { + std::string alloc; + int offset, batch, m, n, k, ldaMul, ldbMul, ldcMul, stride_a_mul, + stride_b_mul, stride_c_mul; + char transa, transb; + complex_std alpha, beta; + BLAS_GENERATE_NAME(info.param, alloc, offset, batch, m, n, k, transa, transb, + alpha, beta, ldaMul, ldbMul, ldcMul, stride_a_mul, + stride_b_mul, stride_c_mul); +} + +/** Registers GEMM test for all supported complex data types + * @param test_suite Name of the test suite + * @param combination Combinations object + * @see BLAS_REGISTER_TEST_CUSTOM_NAME + */ +#define GENERATE_CPLX_GEMM_TEST(test_suite, combination) \ + BLAS_REGISTER_CPLX_TEST_CUSTOM_NAME(test_suite, test_suite##combination, \ + verify_gemm, gemm_cplx_arguments_t, \ + combination, generate_cplx_name); + +#define GENERATE_CPLXGEMM_STRIDED_BATCHED_TEST(test_suite, combination) \ + BLAS_REGISTER_CPLX_TEST_CUSTOM_NAME( \ + test_suite, test_suite##combination, verify_gemm, \ + gemm_cplx_batched_strided_arguments_t, combination, \ + generate_cplx_batched_strided_name); + +#endif diff --git a/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp b/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp index 5e156b7c5..4eeee3cde 100644 --- a/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp +++ b/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp @@ -101,3 +101,81 @@ const auto OffsetNonZero = ::testing::Combine( ::testing::Values(gemm_batch_type_t::strided) // batch_type ); GENERATE_GEMM_TEST(TallSkinnyGemm, OffsetNonZero); + +#ifdef BLAS_ENABLE_COMPLEX +template +const auto CplxBetaNonZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(7, 65), // m + ::testing::Values(9, 126), // n + ::testing::Values(2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 1.5}), // alpha + ::testing::Values>({0.5, 0.5}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(TallSkinnyGemm, CplxBetaNonZeroLDMatch); + +template +const auto CplxBetaNonZeroLDMultiplied = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(7, 65), // m + ::testing::Values(9, 126), // n + ::testing::Values(2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 0.5}), // alpha + ::testing::Values>({0.5, 1.5}), // beta + ::testing::Values(2), // lda_mul + ::testing::Values(3), // ldb_mul + ::testing::Values(4), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(TallSkinnyGemm, CplxBetaNonZeroLDMultiplied); + +template +const auto CplxBetaZero = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(7), // m + ::testing::Values(9), // n + ::testing::Values(1026), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 2.0}), // alpha + ::testing::Values>({0.0, 0.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(TallSkinnyGemm, CplxBetaZero); + +template +const auto CplxOffsetNonZero = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(10), // offset + ::testing::Values(1), // batch + ::testing::Values(7), // m + ::testing::Values(9), // n + ::testing::Values(1026), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 2.5}), // alpha + ::testing::Values>({0.5, 1.5}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(TallSkinnyGemm, CplxOffsetNonZero); +#endif diff --git a/test/unittest/blas3/blas3_gemm_test.cpp b/test/unittest/blas3/blas3_gemm_test.cpp index e5d4a4122..acf4c85d8 100644 --- a/test/unittest/blas3/blas3_gemm_test.cpp +++ b/test/unittest/blas3/blas3_gemm_test.cpp @@ -139,3 +139,121 @@ const auto LargeBetaNonZeroLDMatch = ::testing::Combine( ::testing::Values(gemm_batch_type_t::strided) // batch_type ); GENERATE_GEMM_TEST(Gemm, LargeBetaNonZeroLDMatch); + +#ifdef BLAS_ENABLE_COMPLEX +template +const auto CplxSmallBetaNonZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32), // m + ::testing::Values(11, 16, 32), // n + ::testing::Values(16, 17), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxSmallBetaNonZeroLDMatch); + +template +const auto CplxSmallBetaZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 32), // m + ::testing::Values(11, 32), // n + ::testing::Values(17), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxSmallBetaZeroLDMatch); + +template +const auto CplxSmallBetaZeroLDMultiplied = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 32), // m + ::testing::Values(11, 32), // n + ::testing::Values(17), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 3.0}), // alpha + ::testing::Values>({0.0, 0.0}), // beta + ::testing::Values(2), // lda_mul + ::testing::Values(3), // ldb_mul + ::testing::Values(4), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxSmallBetaZeroLDMultiplied); + +template +const auto CplxAlphaZero = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 10), // offset + ::testing::Values(1), // batch + ::testing::Values(16), // m + ::testing::Values(16), // n + ::testing::Values(17), // k + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values>({0.0, 0.0}), // alpha + ::testing::Values(std::complex{0.0, 0.0}, + std::complex{1.0, 0.0}), // beta + ::testing::Values(1, 2), // lda_mul + ::testing::Values(1, 2), // ldb_mul + ::testing::Values(1, 2), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxAlphaZero); + +template +const auto CplxOffsetNonZero = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(1, 10), // offset + ::testing::Values(1), // batch + ::testing::Values(16, 63), // m + ::testing::Values(16, 63), // n + ::testing::Values(17, 63), // k + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values>({1.0, 1.0}), // alpha + ::testing::Values>({1.0, 1.0}), // beta + ::testing::Values(1, 2), // lda_mul + ::testing::Values(1, 2), // ldb_mul + ::testing::Values(1, 2), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxOffsetNonZero); + +template +const auto CplxLargeBetaNonZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(253, 511), // m + ::testing::Values(257, 511), // n + ::testing::Values(253, 511), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.0, 1.0}), // alpha + ::testing::Values>({1.0, 1.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxLargeBetaNonZeroLDMatch); + +#endif From 58b8c350a5853244b18291e341724602096eed23 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 11 Sep 2023 14:20:43 +0100 Subject: [PATCH 03/18] Minor fixes --- cmake/CmakeFunctionHelper.cmake | 6 +----- src/operations/blas3/gemm_common.hpp | 7 ++++--- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index 8b411d5e6..fff0f923f 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -447,11 +447,7 @@ elseif(${TUNING_TARGET} STREQUAL "POWER_VR" AND NOT IMGDNN_DIR) "float" "half" ) - set(data_list_c ${supported_types}) - if(BLAS_ENABLE_COMPLEX) - set_complex_list(data_list_c "${supported_types}" "false") - endif() - foreach(data ${data_list_c}) + foreach(data ${supported_types}) add_gemm_configuration( "${data}" 96 "true" "false" "false" 16 4 6 12 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") diff --git a/src/operations/blas3/gemm_common.hpp b/src/operations/blas3/gemm_common.hpp index 6923f492b..f1817e70a 100644 --- a/src/operations/blas3/gemm_common.hpp +++ b/src/operations/blas3/gemm_common.hpp @@ -45,13 +45,13 @@ template static PORTBLAS_INLINE T mul_add(T a, T b, T c, typename std::enable_if::value>::type * = 0) { - return (sycl::mad(a, b, c)); + return (cl::sycl::mad(a, b, c)); } #else template static PORTBLAS_INLINE T mul_add(T a, T b, T c) { - return (sycl::mad(a, b, c)); + return (cl::sycl::mad(a, b, c)); } #endif @@ -84,7 +84,8 @@ template PORTBLAS_INLINE std::string Tile::get_type_string() noexcept { + ItemBatchs, WgBatchs, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>::get_type_string() noexcept { std::ostringstream str{}; str << "Tile<" << item_rows << ", " << item_cols << ", " << wg_rows << ", " << wg_cols << ", " << sg_rows << ", " << sg_cols << ", " << tl_rows From bfc56ba9393166519bc98f9054c487ce940e8d1e Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 11 Sep 2023 16:38:05 +0100 Subject: [PATCH 04/18] Typo fix --- common/include/common/float_comparison.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/include/common/float_comparison.hpp b/common/include/common/float_comparison.hpp index e244f0d5a..1222ccc41 100644 --- a/common/include/common/float_comparison.hpp +++ b/common/include/common/float_comparison.hpp @@ -71,12 +71,12 @@ scalar_t abs(scalar_t value) noexcept { #ifdef BLAS_ENABLE_COMPLEX template bool isnan(std::complex value) noexcept { - return (isnan(value.imag()) || isnan(value.imag())); + return (isnan(value.real()) || isnan(value.imag())); } template bool isinf(std::complex value) noexcept { - return (isinf(value.imag()) || isinf(value.imag())); + return (isinf(value.real()) || isinf(value.imag())); } template From 3f80316c9cd53af9ab777d55eea4736a565cd6f1 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 13 Sep 2023 22:46:47 +0100 Subject: [PATCH 05/18] amd gpu config --- cmake/CmakeFunctionHelper.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index fff0f923f..c50688a6d 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -515,7 +515,7 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation foreach(data ${data_list_c}) add_gemm_configuration( "${data}" 256 "true" "true" "true" - 64 1 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false") + 64 1 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") add_gemm_configuration( "${data}" 256 "false" "false" "false" 64 1 1 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") From 2eeb03f05d607bf4582c19eab9e3d64d534fe65a Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 20 Sep 2023 10:53:10 +0100 Subject: [PATCH 06/18] De-coupling complex & scalar enable_if statements --- cmake/CmakeFunctionHelper.cmake | 12 +++++++++--- src/interface/blas3/backend/amd_gpu.hpp | 6 +----- src/interface/blas3/backend/default_cpu.hpp | 6 +----- src/interface/blas3/backend/intel_gpu.hpp | 11 ++++------- src/interface/blas3/backend/nvidia_gpu.hpp | 6 +----- src/interface/gemm_interface.hpp | 18 +++++++++++------- src/operations/blas3/gemm_common.hpp | 10 ++-------- 7 files changed, 29 insertions(+), 40 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index c50688a6d..0c30a48d5 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -437,9 +437,15 @@ if(${TUNING_TARGET} STREQUAL "INTEL_GPU") add_gemm_configuration( "${data}" 64 "false" "false" "false" 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false") - add_gemm_configuration( - "${data}" 32 "true" "true" "true" - 64 2 1 8 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + if (${data} STREQUAL "complex") + add_gemm_configuration( + "${data}" 64 "true" "true" "true" + 64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + else() + add_gemm_configuration( + "${data}" 64 "true" "true" "true" + 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + endif() endforeach() endif() # BLAS_ENABLE_COMPLEX elseif(${TUNING_TARGET} STREQUAL "POWER_VR" AND NOT IMGDNN_DIR) diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index 3aff8dd46..3ec103620 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -33,12 +33,8 @@ namespace backend { template -#ifdef BLAS_ENABLE_COMPLEX -typename std::enable_if::value, +typename std::enable_if::value, typename sb_handle_t::event_t>::type -#else -typename sb_handle_t::event_t -#endif _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 44a99d1fe..1b7dfd680 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -33,12 +33,8 @@ namespace backend { template -#ifdef BLAS_ENABLE_COMPLEX -typename std::enable_if::value, +typename std::enable_if::value, typename sb_handle_t::event_t>::type -#else -typename sb_handle_t::event_t -#endif _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, diff --git a/src/interface/blas3/backend/intel_gpu.hpp b/src/interface/blas3/backend/intel_gpu.hpp index e22274008..a0ce6f52a 100644 --- a/src/interface/blas3/backend/intel_gpu.hpp +++ b/src/interface/blas3/backend/intel_gpu.hpp @@ -32,12 +32,8 @@ namespace backend { template -#ifdef BLAS_ENABLE_COMPLEX -typename std::enable_if::value, +typename std::enable_if::value, typename sb_handle_t::event_t>::type -#else -typename sb_handle_t::event_t -#endif _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, @@ -227,9 +223,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, const typename sb_handle_t::event_t& _dependencies) { #ifdef GEMM_TALL_SKINNY_SUPPORT if (!s_a && !s_b && batch_size == 1) { + constexpr int wg_size = sizeof(element_t) == 16 ? 4 : 8; return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 32, true, true, true, 64, - Tile<2, 1, 8, 4>, _t_a, _t_b, s_a, s_b, + container_0_t, container_1_t, container_2_t, 64, true, true, true, 64, + Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::tall_skinny), static_cast(gemm_vectorization_t::none), is_beta_zero, 1, diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index f13a95d2e..7d555d902 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -33,12 +33,8 @@ namespace backend { template -#ifdef BLAS_ENABLE_COMPLEX -typename std::enable_if::value, +typename std::enable_if::value, typename sb_handle_t::event_t>::type -#else -typename sb_handle_t::event_t -#endif _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, diff --git a/src/interface/gemm_interface.hpp b/src/interface/gemm_interface.hpp index f5b7383e6..8e90a4b82 100644 --- a/src/interface/gemm_interface.hpp +++ b/src/interface/gemm_interface.hpp @@ -50,16 +50,20 @@ namespace internal { // Check whether value is zero (complex & float/double) template -inline bool isZero(const T& value) { -#ifdef BLAS_ENABLE_COMPLEX - if constexpr (is_complex_sycl::value) { - using value_t = typename T::value_type; - return (value == T(value_t(0), value_t(0))); - } -#endif +inline typename std::enable_if::value, bool>::type isZero( + const T& value) { return (value == static_cast(0)); } +#ifdef BLAS_ENABLE_COMPLEX +template +inline typename std::enable_if::value, bool>::type isZero( + const T& value) { + using value_t = typename T::value_type; + return (value == T(value_t(0), value_t(0))); +} +#endif + template diff --git a/src/operations/blas3/gemm_common.hpp b/src/operations/blas3/gemm_common.hpp index f1817e70a..670dc340d 100644 --- a/src/operations/blas3/gemm_common.hpp +++ b/src/operations/blas3/gemm_common.hpp @@ -40,20 +40,14 @@ mul_add(T a, T b, T c, typename std::enable_if::value>::type * = 0) { return (a * b + c); } +#endif template static PORTBLAS_INLINE T mul_add(T a, T b, T c, - typename std::enable_if::value>::type * = 0) { - return (cl::sycl::mad(a, b, c)); -} -#else - -template -static PORTBLAS_INLINE T mul_add(T a, T b, T c) { + typename std::enable_if::value>::type * = 0) { return (cl::sycl::mad(a, b, c)); } -#endif template struct type_string { From 06705b31a176409abb00952f68b6583745b4c09a Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 20 Sep 2023 18:00:18 +0100 Subject: [PATCH 07/18] Added static asserts on vector size when using cplx data --- src/operations/blas3/gemm_interleaved.hpp | 20 ++++++++------ src/operations/blas3/gemm_local.hpp | 26 ++++++++++++------- .../blas3/gemm_no_local_full_vec.hpp | 6 +++++ .../blas3/gemm_no_local_partial_vec.hpp | 14 +++++++--- 4 files changed, 44 insertions(+), 22 deletions(-) diff --git a/src/operations/blas3/gemm_interleaved.hpp b/src/operations/blas3/gemm_interleaved.hpp index 551bb465a..66629033e 100644 --- a/src/operations/blas3/gemm_interleaved.hpp +++ b/src/operations/blas3/gemm_interleaved.hpp @@ -146,6 +146,11 @@ class Gemm::value, + "Interleaved GEMM is not supported for Complex Data types"); +#endif + input_t a_; input_t b_; output_t c_; @@ -159,10 +164,9 @@ class Gemm PORTBLAS_INLINE void compute_panel(check_t boundary_check, index_t m_stride, - index_t n_stride, index_t mb_start, - index_t m_start, index_t n_start, - in_ptr_t A, in_ptr_t B, out_ptr_t C) { + index_t n_stride, index_t mb_start, + index_t m_start, index_t n_start, + in_ptr_t A, in_ptr_t B, out_ptr_t C) { packet_type reg_a[item_rows * item_batchs / VectorSize]; packet_type reg_b[item_cols * item_batchs / VectorSize]; packet_type reg_res[item_rows * item_cols * item_batchs / VectorSize]; @@ -482,7 +486,7 @@ class Gemm::value) || + is_sycl_scalar::value, + "Vector size should be equal to 1 for Complex Data types"); +#endif + //! @brief leading dimension of block of A in local static constexpr index_t ldsa = block_rows + nbc_a; //! @brief leading dimension of block of B in local @@ -162,8 +168,8 @@ class Gemm PORTBLAS_INLINE void eval(local_memory_t scratch_acc, - const cl::sycl::nd_item<1> &id) noexcept { + const cl::sycl::nd_item<1> &id) noexcept { index_t m = a_.get_size_row(); index_t n = b_.get_size_col(); const index_t k = a_.get_size_col(); @@ -546,9 +552,9 @@ class Gemm PORTBLAS_INLINE void store_output_block(index_t, index_t mc, index_t nc, - OutputPointerType C, index_t ldc, - element_t *reg_res, - const bool out_of_range) noexcept { + OutputPointerType C, index_t ldc, + element_t *reg_res, + const bool out_of_range) noexcept { if (out_of_range) { return; } @@ -726,9 +732,9 @@ class Gemm PORTBLAS_INLINE void compute_block_gemm(index_t, InputPointerType B, - InputPointerType A, element_t *reg_a, - element_t ®_b, - element_t *reg_res) noexcept { + InputPointerType A, element_t *reg_a, + element_t ®_b, + element_t *reg_res) noexcept { // NOTE: Adding "#pragma unroll" here reduces performance on AMD R9 // Nano. // Seems that the small reduction of arithmetic operations does @@ -781,7 +787,7 @@ class Gemm static PORTBLAS_INLINE typename std::enable_if::type sync_smem( const cl::sycl::nd_item<1> &id, index_t &ofs_sign, P &s, - Ps &... ss) noexcept { + Ps &...ss) noexcept { s += ofs_sign * o; sync_smem(id, ofs_sign, ss...); } diff --git a/src/operations/blas3/gemm_no_local_full_vec.hpp b/src/operations/blas3/gemm_no_local_full_vec.hpp index 732cc9568..df1ce6bd7 100644 --- a/src/operations/blas3/gemm_no_local_full_vec.hpp +++ b/src/operations/blas3/gemm_no_local_full_vec.hpp @@ -104,6 +104,12 @@ class Gemm(), "If vectorization is enabled item_cols must equal the packet_size"); +#ifdef BLAS_ENABLE_COMPLEX + static_assert((VectorSize == 1 && is_complex_sycl::value) || + is_sycl_scalar::value, + "Vector size should be equal to 1 for Complex Data types"); +#endif + input_t a_; input_t b_; output_t c_; diff --git a/src/operations/blas3/gemm_no_local_partial_vec.hpp b/src/operations/blas3/gemm_no_local_partial_vec.hpp index 189de963b..02a42e938 100644 --- a/src/operations/blas3/gemm_no_local_partial_vec.hpp +++ b/src/operations/blas3/gemm_no_local_partial_vec.hpp @@ -100,6 +100,12 @@ class Gemm::value) || + is_sycl_scalar::value, + "Vector size should be equal to 1 for Complex Data types"); +#endif + input_t a_; input_t b_; output_t c_; @@ -111,8 +117,8 @@ class Gemm PORTBLAS_INLINE void load(PointerType ptr, element_t *reg, const index_t &ld, - index_t index, const check_boundary &chk_boundary, - const bool out_of_range) noexcept { + index_t index, const check_boundary &chk_boundary, + const bool out_of_range) noexcept { if (out_of_range) { return; } From f7179c51a5e6edf93b623f4c70f07dc44eeb5874 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Thu, 21 Sep 2023 16:33:12 +0100 Subject: [PATCH 08/18] Fixes to amd gpu configs --- cmake/CmakeFunctionHelper.cmake | 30 +++++++++----- src/interface/blas3/backend/amd_gpu.hpp | 53 +++++++++++++------------ 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index 0c30a48d5..8dedc9857 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -519,15 +519,27 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation set(data_list_c) set_complex_list(data_list_c "${supported_types}" "false") foreach(data ${data_list_c}) - add_gemm_configuration( - "${data}" 256 "true" "true" "true" - 64 1 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") - add_gemm_configuration( - "${data}" 256 "false" "false" "false" - 64 1 1 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") - add_gemm_configuration( - "${data}" 256 "false" "false" "false" - 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + if (${data} STREQUAL "complex") + add_gemm_configuration( + "${data}" 256 "true" "true" "true" + 64 1 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 1 1 4 4 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + else() + add_gemm_configuration( + "${data}" 256 "true" "true" "true" + 64 1 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 1 1 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + endif() endforeach() endif() # BLAS_ENABLE_COMPLEX elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index 3ec103620..a425b2f2a 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -119,29 +119,29 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } } else #endif // GEMM_TALL_SKINNY_SUPPORT - if (_M * _N <= 65536) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 1, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); - } else { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 2, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); - } + if (_M * _N <= 65536) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, false, + ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, false, + ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 2, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } } // Complex Configurations @@ -158,12 +158,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { static constexpr int ClSize = 64; + static constexpr int tileWgSize = ClSize / sizeof(element_t); /* Tall & Skinny matrices. */ #ifdef GEMM_TALL_SKINNY_SUPPORT if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8) && (!s_a && !s_b)) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, true, true, true, - ClSize, Tile<1, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::tall_skinny), static_cast(gemm_vectorization_t::none), is_beta_zero, 1, @@ -176,7 +177,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if (_M * _N <= 65536) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<1, 1, 8, 8>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -187,7 +188,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, From 9ba82fe8190f06ea443c8fe7dbd16abba4e2ec85 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 25 Sep 2023 10:07:42 +0100 Subject: [PATCH 09/18] Addressed PR comments --- common/include/common/common_utils.hpp | 16 +++++----- include/blas_meta.h | 10 ++----- test/blas_test.hpp | 8 ++--- test/unittest/blas3/blas3_gemm_common.hpp | 36 +++++++++++------------ 4 files changed, 33 insertions(+), 37 deletions(-) diff --git a/common/include/common/common_utils.hpp b/common/include/common/common_utils.hpp index df14ac062..26916483b 100644 --- a/common/include/common/common_utils.hpp +++ b/common/include/common/common_utils.hpp @@ -1374,27 +1374,27 @@ static inline std::vector random_data(size_t size) { #ifdef BLAS_ENABLE_COMPLEX template -static inline complex_std random_scalar() { +static inline std::complex random_scalar() { scalar_t rl = 1e-3 * ((rand() % 2000) - 1000); scalar_t im = 1e-3 * ((rand() % 2000) - 1000); - return complex_std({rl, im}); + return std::complex(rl, im); } template -static inline complex_std random_scalar(scalar_t rangeMin, - scalar_t rangeMax) { +static inline std::complex random_scalar(scalar_t rangeMin, + scalar_t rangeMax) { static std::random_device rd; static std::default_random_engine gen(rd()); std::uniform_real_distribution disRl(rangeMin, rangeMax); std::uniform_real_distribution disIm(rangeMin, rangeMax); - return complex_std({disRl(gen), disIm(gen)}); + return std::complex(disRl(gen), disIm(gen)); } template -static inline std::vector> random_data(size_t size) { - std::vector> v = - std::vector>(size); +static inline std::vector> random_data(size_t size) { + std::vector> v = + std::vector>(size); for (scalar_t& e : v) { e = random_scalar(scalar_t{-2}, scalar_t{5}); diff --git a/include/blas_meta.h b/include/blas_meta.h index a7634dbca..d39a395f5 100644 --- a/include/blas_meta.h +++ b/include/blas_meta.h @@ -167,7 +167,7 @@ int append_vector(vector_t &lhs_vector, vector_t const &rhs_vector) { template first_vector_t concatenate_vectors(first_vector_t first_vector, - other_vector_t &&... other_vectors) { + other_vector_t &&...other_vectors) { int first_Vector_size = static_cast(first_vector.size()); int s[] = {vec_total_size(first_Vector_size, other_vectors)..., 0}; first_vector.reserve(first_Vector_size); @@ -206,15 +206,11 @@ struct is_complex_sycl std::is_same_v> || std::is_same_v>> {}; -// STD Complex type alias -template -using complex_std = typename std::complex; - template struct is_complex_std : std::integral_constant> || - std::is_same_v>> {}; + std::is_same_v> || + std::is_same_v>> {}; #endif diff --git a/test/blas_test.hpp b/test/blas_test.hpp index 70a32d61c..d159109db 100644 --- a/test/blas_test.hpp +++ b/test/blas_test.hpp @@ -161,9 +161,9 @@ template static inline void fill_random_with_range( std::vector> &vec, scalar_t rangeMin, scalar_t rangeMax) { - for (complex_std &e : vec) { - e = complex_std{random_scalar(rangeMin, rangeMax), - random_scalar(rangeMin, rangeMax)}; + for (std::complex &e : vec) { + e = std::complex{random_scalar(rangeMin, rangeMax), + random_scalar(rangeMin, rangeMax)}; } } @@ -172,7 +172,7 @@ static inline void fill_random_with_range( * uniform distribution. */ template -static inline void fill_random(std::vector> &vec) { +static inline void fill_random(std::vector> &vec) { fill_random_with_range(vec, scalar_t{-2}, scalar_t{5}); } #endif diff --git a/test/unittest/blas3/blas3_gemm_common.hpp b/test/unittest/blas3/blas3_gemm_common.hpp index d28baa99a..3aacf4244 100644 --- a/test/unittest/blas3/blas3_gemm_common.hpp +++ b/test/unittest/blas3/blas3_gemm_common.hpp @@ -410,8 +410,8 @@ inline void verify_gemm(const gemm_cplx_arguments_t arguments) { index_t k; char transa; char transb; - complex_std alpha; - complex_std beta; + std::complex alpha; + std::complex beta; index_t lda_mul; index_t ldb_mul; index_t ldc_mul; @@ -437,14 +437,14 @@ inline void verify_gemm(const gemm_cplx_arguments_t arguments) { const index_t buffer_size_b = batch * size_b + offset; const index_t buffer_size_c = batch * size_c + offset; - std::vector> a_m(buffer_size_a); - std::vector> b_m(buffer_size_b); - std::vector> c_m_gpu(buffer_size_c); + std::vector> a_m(buffer_size_a); + std::vector> b_m(buffer_size_b); + std::vector> c_m_gpu(buffer_size_c); fill_random(a_m); fill_random(b_m); fill_random(c_m_gpu); - std::vector> c_m_cpu = c_m_gpu; + std::vector> c_m_cpu = c_m_gpu; // Use system blas to create a reference output for (int i = 0; i < batch; ++i) { @@ -518,8 +518,8 @@ inline void verify_gemm(const gemm_cplx_arguments_t arguments) { index_t k; char transa; char transb; - complex_std alpha; - complex_std beta; + std::complex alpha; + std::complex beta; index_t lda_mul; index_t ldb_mul; index_t ldc_mul; @@ -544,7 +544,7 @@ static std::string generate_cplx_name( std::string alloc; int offset, batch, m, n, k, ldaMul, ldbMul, ldcMul; char transa, transb; - complex_std alpha, beta; + std::complex alpha, beta; gemm_batch_type_t batchType; BLAS_GENERATE_NAME(info.param, alloc, offset, batch, m, n, k, transa, transb, alpha, beta, ldaMul, ldbMul, ldcMul, batchType); @@ -561,8 +561,8 @@ inline void verify_gemm( index_t k; char transa; char transb; - complex_std alpha; - complex_std beta; + std::complex alpha; + std::complex beta; index_t lda_mul; index_t ldb_mul; index_t ldc_mul; @@ -595,14 +595,14 @@ inline void verify_gemm( const index_t buffer_size_b = size_b + (batch - 1) * stride_b + offset; const index_t buffer_size_c = size_c + (batch - 1) * stride_c + offset; - std::vector> a_m(buffer_size_a); - std::vector> b_m(buffer_size_b); - std::vector> c_m_gpu(buffer_size_c); + std::vector> a_m(buffer_size_a); + std::vector> b_m(buffer_size_b); + std::vector> c_m_gpu(buffer_size_c); fill_random(a_m); fill_random(b_m); fill_random(c_m_gpu); - std::vector> c_m_cpu = c_m_gpu; + std::vector> c_m_cpu = c_m_gpu; // Use system blas to create a reference output for (int i = 0; i < batch; ++i) { @@ -668,8 +668,8 @@ inline void verify_gemm( index_t k; char transa; char transb; - complex_std alpha; - complex_std beta; + std::complex alpha; + std::complex beta; index_t lda_mul; index_t ldb_mul; index_t ldc_mul; @@ -697,7 +697,7 @@ static std::string generate_cplx_batched_strided_name( int offset, batch, m, n, k, ldaMul, ldbMul, ldcMul, stride_a_mul, stride_b_mul, stride_c_mul; char transa, transb; - complex_std alpha, beta; + std::complex alpha, beta; BLAS_GENERATE_NAME(info.param, alloc, offset, batch, m, n, k, transa, transb, alpha, beta, ldaMul, ldbMul, ldcMul, stride_a_mul, stride_b_mul, stride_c_mul); From 007727c24cf6d88a64d0f79f429652aca3027984 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 25 Sep 2023 20:24:12 +0100 Subject: [PATCH 10/18] minor fixes --- common/include/common/common_utils.hpp | 31 -------------------------- 1 file changed, 31 deletions(-) diff --git a/common/include/common/common_utils.hpp b/common/include/common/common_utils.hpp index 26916483b..a569ed2ff 100644 --- a/common/include/common/common_utils.hpp +++ b/common/include/common/common_utils.hpp @@ -1372,37 +1372,6 @@ static inline std::vector random_data(size_t size) { return v; } -#ifdef BLAS_ENABLE_COMPLEX -template -static inline std::complex random_scalar() { - scalar_t rl = 1e-3 * ((rand() % 2000) - 1000); - scalar_t im = 1e-3 * ((rand() % 2000) - 1000); - return std::complex(rl, im); -} - -template -static inline std::complex random_scalar(scalar_t rangeMin, - scalar_t rangeMax) { - static std::random_device rd; - static std::default_random_engine gen(rd()); - std::uniform_real_distribution disRl(rangeMin, rangeMax); - std::uniform_real_distribution disIm(rangeMin, rangeMax); - - return std::complex(disRl(gen), disIm(gen)); -} - -template -static inline std::vector> random_data(size_t size) { - std::vector> v = - std::vector>(size); - - for (scalar_t& e : v) { - e = random_scalar(scalar_t{-2}, scalar_t{5}); - } - return v; -} -#endif - /** * @breif Fills a lower or upper triangular matrix suitable for TRSM testing * @param A The matrix to fill. Size must be at least m * lda From 7f76dfd2ad3b081512758b9f2fccc52496efe9c5 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 27 Sep 2023 22:28:22 +0100 Subject: [PATCH 11/18] fixed bug in cmake & added readme description to complex --- CMakeLists.txt | 3 ++- README.md | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1037b1098..09785078f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,6 +106,7 @@ if(IMGDNN_DIR) endif() option(BLAS_ENABLE_EXTENSIONS "Whether to enable portBLAS extensions" ON) +option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for supported operators" ON) # CmakeFunctionHelper has to be included after any options that it depends on are declared. # These include: @@ -115,6 +116,7 @@ option(BLAS_ENABLE_EXTENSIONS "Whether to enable portBLAS extensions" ON) # * BLAS_DATA_TYPES # * BLAS_INDEX_TYPES # * NAIVE_GEMM +# * BLAS_ENABLE_COMPLEX include(CmakeFunctionHelper) if (INSTALL_HEADER_ONLY) @@ -220,7 +222,6 @@ option(BUILD_CUBLAS_BENCHMARKS "Whether to build cuBLAS benchmarks" OFF) option(BUILD_ROCBLAS_BENCHMARKS "Whether to build rocBLAS benchmarks" OFF) option(BUILD_ACL_BENCHMARKS "Whether to build ARM Compute Library benchmarks" OFF) option(BLAS_BUILD_SAMPLES "Whether to build portBLAS samples" ON) -option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for supported operators" ON) if (INSTALL_HEADER_ONLY AND BLAS_ENABLE_BENCHMARK) message(STATUS "Benchmarks are disabled when installing portBLAS in header only mode") set(BLAS_ENABLE_BENCHMARK OFF) diff --git a/README.md b/README.md index 5720ae145..c5383b73f 100644 --- a/README.md +++ b/README.md @@ -463,7 +463,7 @@ Some of the supported options are: | `BLAS_ENABLE_EXTENSIONS` | `ON`/`OFF` | Determines whether to enable portBLAS extensions (`ON` by default) | | `BLAS_DATA_TYPES` | `half;float;double` | Determines the floating-point types to instantiate BLAS operations for. Default is `float` | | `BLAS_INDEX_TYPES` | `int32_t;int64_t` | Determines the type(s) to use for `index_t` and `increment_t`. Default is `int` | - +| `BLAS_ENABLE_COMPLEX` | `ON`/`OFF` | Determines whether to enable Complex data type support *(GEMM Kernels only)* (`ON` by default) | ### Cross-Compile (ComputeCpp Only) From 148a2eaa9ddbfc4e6fc86149396fa12998b058ed Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 2 Oct 2023 13:42:59 +0100 Subject: [PATCH 12/18] Reduced complex gemm tests cases sizes --- .../blas3/blas3_gemm_batched_test.cpp | 6 ++--- .../blas3/blas3_gemm_tall_skinny_test.cpp | 6 ++--- test/unittest/blas3/blas3_gemm_test.cpp | 24 +++++++++---------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/test/unittest/blas3/blas3_gemm_batched_test.cpp b/test/unittest/blas3/blas3_gemm_batched_test.cpp index 6794ff56c..824bf656b 100644 --- a/test/unittest/blas3/blas3_gemm_batched_test.cpp +++ b/test/unittest/blas3/blas3_gemm_batched_test.cpp @@ -151,7 +151,7 @@ template const auto CplxBetaNonZeroLDMatch = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset - ::testing::Values(5), // batch + ::testing::Values(3), // batch ::testing::Values(63, 128), // m ::testing::Values(63, 128), // n ::testing::Values(63, 128), // k @@ -170,7 +170,7 @@ template const auto CplxDefaultGemmAndGemmBatched = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset - ::testing::Values(1, 5), // batch + ::testing::Values(1, 4), // batch ::testing::Values(63, 128), // m ::testing::Values(63, 128), // n ::testing::Values(63, 128), // k @@ -192,7 +192,7 @@ template const auto CplxAllStridedBatched = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset - ::testing::Values(5), // batch + ::testing::Values(3), // batch ::testing::Values(128), // m ::testing::Values(128), // n ::testing::Values(128), // k diff --git a/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp b/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp index 4eeee3cde..95abb271a 100644 --- a/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp +++ b/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp @@ -127,9 +127,9 @@ const auto CplxBetaNonZeroLDMultiplied = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset ::testing::Values(1), // batch - ::testing::Values(7, 65), // m - ::testing::Values(9, 126), // n - ::testing::Values(2049), // k + ::testing::Values(7, 33), // m + ::testing::Values(9, 63), // n + ::testing::Values(1026), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb ::testing::Values>({1.5, 0.5}), // alpha diff --git a/test/unittest/blas3/blas3_gemm_test.cpp b/test/unittest/blas3/blas3_gemm_test.cpp index acf4c85d8..f7cae4630 100644 --- a/test/unittest/blas3/blas3_gemm_test.cpp +++ b/test/unittest/blas3/blas3_gemm_test.cpp @@ -146,8 +146,8 @@ const auto CplxSmallBetaNonZeroLDMatch = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset ::testing::Values(1), // batch - ::testing::Values(11, 16, 32), // m - ::testing::Values(11, 16, 32), // n + ::testing::Values(11, 33), // m + ::testing::Values(11, 33), // n ::testing::Values(16, 17), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb @@ -171,7 +171,7 @@ const auto CplxSmallBetaZeroLDMatch = ::testing::Combine( ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb ::testing::Values>({1.5, 1.0}), // alpha - ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values>({0.0, 0.0}), // beta ::testing::Values(1), // lda_mul ::testing::Values(1), // ldb_mul ::testing::Values(1), // ldc_mul @@ -184,16 +184,16 @@ const auto CplxSmallBetaZeroLDMultiplied = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset ::testing::Values(1), // batch - ::testing::Values(11, 32), // m - ::testing::Values(11, 32), // n + ::testing::Values(11, 33), // m + ::testing::Values(11, 33), // n ::testing::Values(17), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb ::testing::Values>({1.5, 3.0}), // alpha ::testing::Values>({0.0, 0.0}), // beta ::testing::Values(2), // lda_mul - ::testing::Values(3), // ldb_mul - ::testing::Values(4), // ldc_mul + ::testing::Values(2), // ldb_mul + ::testing::Values(3), // ldc_mul ::testing::Values(gemm_batch_type_t::strided) // batch_type ); GENERATE_CPLX_GEMM_TEST(Gemm, CplxSmallBetaZeroLDMultiplied); @@ -242,13 +242,13 @@ const auto CplxLargeBetaNonZeroLDMatch = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset ::testing::Values(1), // batch - ::testing::Values(253, 511), // m - ::testing::Values(257, 511), // n - ::testing::Values(253, 511), // k + ::testing::Values(63, 253), // m + ::testing::Values(63, 253), // n + ::testing::Values(63, 253), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb - ::testing::Values>({1.0, 1.0}), // alpha - ::testing::Values>({1.0, 1.0}), // beta + ::testing::Values>({1.0, 1.5}), // alpha + ::testing::Values>({1.5, 1.0}), // beta ::testing::Values(1), // lda_mul ::testing::Values(1), // ldb_mul ::testing::Values(1), // ldc_mul From 49e0e01594d510755af1171a46137e5562c707d7 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 4 Oct 2023 11:56:25 +0100 Subject: [PATCH 13/18] removed unused legacy complex data utils --- include/operations/blas_constants.h | 8 -------- src/operations/blas1_trees.hpp | 18 ------------------ 2 files changed, 26 deletions(-) diff --git a/include/operations/blas_constants.h b/include/operations/blas_constants.h index 5fc4afb82..637f23f95 100644 --- a/include/operations/blas_constants.h +++ b/include/operations/blas_constants.h @@ -202,14 +202,6 @@ struct constant, const_val::collapse> { } }; -template -struct constant, Indicator> { - constexpr static PORTBLAS_INLINE std::complex value() { - return std::complex(constant::value(), - constant::value()); - } -}; - #ifdef BLAS_ENABLE_COMPLEX template struct constant, Indicator> { diff --git a/src/operations/blas1_trees.hpp b/src/operations/blas1_trees.hpp index ff51a7915..1b079c98b 100644 --- a/src/operations/blas1_trees.hpp +++ b/src/operations/blas1_trees.hpp @@ -90,24 +90,6 @@ struct DetectScalar { }; #endif // BLAS_DATA_TYPE_HALF -/*! DetectScalar. - * @brief See Detect Scalar. - */ -template <> -struct DetectScalar> { - using element_t = std::complex; - static element_t get_scalar(element_t &scalar) { return scalar; } -}; - -/*! DetectScalar. - * @brief See Detect Scalar. - */ -template <> -struct DetectScalar> { - using element_t = std::complex; - static element_t get_scalar(element_t &scalar) { return scalar; } -}; - #ifdef BLAS_ENABLE_COMPLEX /*! DetectScalar (for sycl::complex) * @brief See Detect Scalar. From d86cfc8598101f657890f6798f69637fef60be9c Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 4 Oct 2023 13:20:05 +0100 Subject: [PATCH 14/18] Tuned gemm complex for cpu --- cmake/CmakeFunctionHelper.cmake | 7 ++----- src/interface/blas3/backend/default_cpu.hpp | 19 ++++--------------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index 8dedc9857..ef7fc22d6 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -629,13 +629,10 @@ else() # default cpu backend foreach(data ${data_list_c}) add_gemm_configuration( "${data}" 64 "false" "false" "false" - 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false") + 64 2 2 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false") add_gemm_configuration( "${data}" 64 "false" "false" "false" - 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false") - add_gemm_configuration( - "${data}" 64 "false" "false" "false" - 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false" "false") + 64 8 8 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false") endforeach() endif() # BLAS_ENABLE_COMPLEX endif() diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 1b7dfd680..14c0cd337 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -116,10 +116,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { - if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) { + if (_M <= 256 && _N <= 256 && _K <= 256) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, + 64, Tile<2, 2, 4, 4>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -127,10 +127,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else if (!s_a && !s_b) { + } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + 64, Tile<8, 8, 4, 4>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, @@ -138,17 +138,6 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 1, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); } } #endif From 2dc363db653413af39acd416f09d57c101de3004 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Fri, 13 Oct 2023 17:06:21 +0100 Subject: [PATCH 15/18] Separated complex gemm load store & addressed PR comments --- doc/Gemm.md | 2 +- src/operations/blas3/gemm_load_store.hpp | 144 --------------- .../blas3/gemm_load_store_complex.hpp | 174 ++++++++++++++++++ src/operations/blas3/gemm_local.hpp | 3 + .../blas3/gemm_no_local_full_vec.hpp | 3 + .../blas3/gemm_no_local_partial_vec.hpp | 3 + test/unittest/blas3/blas3_gemm_common.hpp | 10 +- 7 files changed, 189 insertions(+), 150 deletions(-) create mode 100644 src/operations/blas3/gemm_load_store_complex.hpp diff --git a/doc/Gemm.md b/doc/Gemm.md index 0264e3d4c..653549212 100644 --- a/doc/Gemm.md +++ b/doc/Gemm.md @@ -100,7 +100,7 @@ The core of the `GEMM` computation is as follows: ## Vectorized Loading/Storing -Many of the `GEMM` kernels support vectorized loads/stores using functions located in `gemm_load_store.hpp` in `src/operations/blas3/` . +Many of the `GEMM` kernels support vectorized loads/stores using functions located in `gemm_load_store.hpp` in `src/operations/blas3/`*(this feature is limited to non-complex data types)*. These functions are pretty simple but there are some special considerations for how they are used, particularly around whether the matrices are transposed or not. If a matrix is transposed this changes the data layout such that elements are no longer contiguous in memory. diff --git a/src/operations/blas3/gemm_load_store.hpp b/src/operations/blas3/gemm_load_store.hpp index 7ae45ce5d..ef44cbfe6 100644 --- a/src/operations/blas3/gemm_load_store.hpp +++ b/src/operations/blas3/gemm_load_store.hpp @@ -125,149 +125,5 @@ struct Packetize { } }; -#ifdef BLAS_ENABLE_COMPLEX -/*! @brief vec_complex is an intermediate wrapper of sycl::complex used in - * Packetize. It serves as a temporary workaround to the upcoming - * sycl::vec container - * github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_complex.asciidoc - * and only supports size = 1. - * @tparam DataT Complex type of the vector's data - * @tparam NumElements Elements count of the vector (only 1 is supported) - */ -template -class vec_complex { - static_assert(NumElements == 1, - "Vector wrapper arround sycl::complex of size>1 unsupported."); - using address_t = cl::sycl::access::address_space; - using decorated_t = cl::sycl::access::decorated; - using DataType = DataT; - static constexpr int getNumElements() { return NumElements; } - size_t size() const noexcept { return NumElements; } - - private: - DataType m_Data; - - public: - vec_complex() = default; - - constexpr vec_complex(const vec_complex &rhs) = default; - constexpr vec_complex(vec_complex &&rhs) = default; - constexpr vec_complex &operator=(const vec_complex &rhs) = default; - - vec_complex(const DataType &rhs_data) : m_Data{rhs_data} {} - - // Conversion operator (valid with NumElements==1) - operator DataT() const { return m_Data; } - - // Subscript operators - DataT &operator[](int i) { - assert(i < NumElements); - return (m_Data); - } - const DataT &operator[](int i) const { - assert(i < NumElements); - return (m_Data); - } - - // Binary Ops - // Multiply - vec_complex operator*(const vec_complex &rhs) { - return (vec_complex{m_Data * static_cast(rhs)}); - } - - vec_complex operator*(const DataType &rhs) { - return (vec_complex{m_Data * rhs}); - } - - // Compound Multiply - vec_complex &operator*=(const DataType &rhs) { - this->m_Data = this->m_Data * rhs; - return (*this); - } - - vec_complex &operator*=(const vec_complex &rhs) { - this->m_Data = this->m_Data * static_cast(rhs); - return (*this); - } - - // Add - vec_complex operator+(const vec_complex &rhs) { - return (vec_complex{m_Data + static_cast(rhs)}); - } - - vec_complex operator+(const DataType &rhs) { - return (vec_complex{m_Data + rhs}); - } - - // Compound Add - vec_complex &operator+=(const DataType &rhs) { - this->m_Data = this->m_Data * rhs; - return (*this); - } - - vec_complex &operator+=(const vec_complex &rhs) { - this->m_Data = this->m_Data + static_cast(rhs); - return (*this); - } - - // Load - template - void load(size_t Offset, - cl::sycl::multi_ptr Ptr) { - m_Data = *(Ptr + Offset * NumElements); - } - - // Store - template - void store(size_t Offset, - cl::sycl::multi_ptr Ptr) const { - *(Ptr + Offset * NumElements) = m_Data; - } -}; - -/*! @brief Partial specialization of the Packetize class dedicated to -sycl::complex types. It contains static methods for loading and storing size=1 -complex packets from/to memory. -* @tparam vector_size The desired vector size to be used. Only size = 1 is -supported so far. -* @tparam value_t The complex type of the matrix data. -*/ -template -struct Packetize, index_t> { - // Vectorization is not enabled for complex, always set to 1 - using value_t = complex_sycl; - using PacketType = vec_complex; - static constexpr int packet_size = 1; - template - static PORTBLAS_INLINE constexpr bool check_size() { - return true; - } - - /*! @brief Performs a non-vectorised load of sycl::complex data element while - * whether block is internal or not since vectorization is not enabled for - * complex types yet. - * @tparam trans Whether the source matrix is transposed or not. - * @tparam internal True if the current block is internal and no bounds - * checking is required. - * @tparam ld The leading dimension of the destination memory. */ - template - static PORTBLAS_INLINE void load(const bool in_range, SrcPointerType src, - DestPointerType dest, - EdgePredicate edge_in_range) { - *(dest) = in_range ? *(src) : value_t{(T)0, (T)0}; - } - - /*! @brief Store a size = 1 vector packet of sycl::complex data into local - * memory (whether source is transposed or not since it's only 1 element). - * @tparam trans Whether the source matrix is transposed or not. - * @tparam ld The leading dimension of the destination memory.*/ - template - static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) { - *dest = packet[0]; - } -}; -#endif - } // namespace blas #endif // PORTBLAS_BLAS3_GEMM_LOAD_STORE_HPP diff --git a/src/operations/blas3/gemm_load_store_complex.hpp b/src/operations/blas3/gemm_load_store_complex.hpp new file mode 100644 index 000000000..7b1eb769b --- /dev/null +++ b/src/operations/blas3/gemm_load_store_complex.hpp @@ -0,0 +1,174 @@ +/*************************************************************************** + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename gemm_load_store_complex.hpp + * + **************************************************************************/ + +#ifndef PORTBLAS_BLAS3_GEMM_LOAD_STORE_CPLX_HPP +#define PORTBLAS_BLAS3_GEMM_LOAD_STORE_CPLX_HPP + +namespace blas { +#ifdef BLAS_ENABLE_COMPLEX +/*! @brief vec_complex is an intermediate wrapper of sycl::complex used in + * Packetize. It serves as a temporary workaround to the upcoming + * sycl::vec container + * github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_complex.asciidoc + * and only supports size = 1. + * @tparam DataT Complex type of the vector's data + * @tparam NumElements Elements count of the vector (only 1 is supported) + */ +template +class vec_complex { + static_assert(NumElements == 1, + "Vector wrapper arround sycl::complex of size>1 unsupported."); + using address_t = cl::sycl::access::address_space; + using decorated_t = cl::sycl::access::decorated; + using DataType = DataT; + static constexpr int getNumElements() { return NumElements; } + size_t size() const noexcept { return NumElements; } + + private: + DataType m_Data; + + public: + vec_complex() = default; + + constexpr vec_complex(const vec_complex &rhs) = default; + constexpr vec_complex(vec_complex &&rhs) = default; + constexpr vec_complex &operator=(const vec_complex &rhs) = default; + + vec_complex(const DataType &rhs_data) : m_Data{rhs_data} {} + + // Conversion operator (valid with NumElements==1) + operator DataT() const { return m_Data; } + + // Subscript operators + DataT &operator[](int i) { + assert(i < NumElements); + return (m_Data); + } + const DataT &operator[](int i) const { + assert(i < NumElements); + return (m_Data); + } + + // Binary Ops + // Multiply + vec_complex operator*(const vec_complex &rhs) { + return (vec_complex{m_Data * static_cast(rhs)}); + } + + vec_complex operator*(const DataType &rhs) { + return (vec_complex{m_Data * rhs}); + } + + // Compound Multiply + vec_complex &operator*=(const DataType &rhs) { + this->m_Data = this->m_Data * rhs; + return (*this); + } + + vec_complex &operator*=(const vec_complex &rhs) { + this->m_Data = this->m_Data * static_cast(rhs); + return (*this); + } + + // Add + vec_complex operator+(const vec_complex &rhs) { + return (vec_complex{m_Data + static_cast(rhs)}); + } + + vec_complex operator+(const DataType &rhs) { + return (vec_complex{m_Data + rhs}); + } + + // Compound Add + vec_complex &operator+=(const DataType &rhs) { + this->m_Data = this->m_Data * rhs; + return (*this); + } + + vec_complex &operator+=(const vec_complex &rhs) { + this->m_Data = this->m_Data + static_cast(rhs); + return (*this); + } + + // Load + template + void load(size_t Offset, + cl::sycl::multi_ptr Ptr) { + m_Data = *(Ptr + Offset * NumElements); + } + + // Store + template + void store(size_t Offset, + cl::sycl::multi_ptr Ptr) const { + *(Ptr + Offset * NumElements) = m_Data; + } +}; + +/*! @brief Partial specialization of the Packetize class dedicated to +sycl::complex types. It contains static methods for loading and storing size=1 +complex packets from/to memory. +* @tparam vector_size The desired vector size to be used. Only size = 1 is +supported so far. +* @tparam value_t The complex type of the matrix data. +*/ +template +struct Packetize, index_t> { + // Vectorization is not enabled for complex, always set to 1 + using value_t = complex_sycl; + using PacketType = vec_complex; + static constexpr int packet_size = 1; + template + static PORTBLAS_INLINE constexpr bool check_size() { + return true; + } + + /*! @brief Performs a non-vectorised load of sycl::complex data element while + * whether block is internal or not since vectorization is not enabled for + * complex types yet. + * @tparam trans Whether the source matrix is transposed or not. + * @tparam internal True if the current block is internal and no bounds + * checking is required. + * @tparam ld The leading dimension of the destination memory. */ + template + static PORTBLAS_INLINE void load(const bool in_range, SrcPointerType src, + DestPointerType dest, + EdgePredicate edge_in_range) { + *(dest) = in_range ? *(src) : value_t{(T)0, (T)0}; + } + + /*! @brief Store a size = 1 vector packet of sycl::complex data into local + * memory (whether source is transposed or not since it's only 1 element). + * @tparam trans Whether the source matrix is transposed or not. + * @tparam ld The leading dimension of the destination memory.*/ + template + static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) { + *dest = packet[0]; + } +}; +#endif +} // namespace blas + +#endif // PORTBLAS_BLAS3_GEMM_LOAD_STORE_CPLX_HPP diff --git a/src/operations/blas3/gemm_local.hpp b/src/operations/blas3/gemm_local.hpp index 870349c48..0ca182918 100644 --- a/src/operations/blas3/gemm_local.hpp +++ b/src/operations/blas3/gemm_local.hpp @@ -27,6 +27,9 @@ #include "gemm_common.hpp" #include "gemm_load_store.hpp" +#ifdef BLAS_ENABLE_COMPLEX +#include "gemm_load_store_complex.hpp" +#endif namespace blas { diff --git a/src/operations/blas3/gemm_no_local_full_vec.hpp b/src/operations/blas3/gemm_no_local_full_vec.hpp index df1ce6bd7..77cbafbbf 100644 --- a/src/operations/blas3/gemm_no_local_full_vec.hpp +++ b/src/operations/blas3/gemm_no_local_full_vec.hpp @@ -27,6 +27,9 @@ #include "gemm_common.hpp" #include "gemm_load_store.hpp" +#ifdef BLAS_ENABLE_COMPLEX +#include "gemm_load_store_complex.hpp" +#endif namespace blas { diff --git a/src/operations/blas3/gemm_no_local_partial_vec.hpp b/src/operations/blas3/gemm_no_local_partial_vec.hpp index 02a42e938..ba26ef67f 100644 --- a/src/operations/blas3/gemm_no_local_partial_vec.hpp +++ b/src/operations/blas3/gemm_no_local_partial_vec.hpp @@ -27,6 +27,9 @@ #include "gemm_common.hpp" #include "gemm_load_store.hpp" +#ifdef BLAS_ENABLE_COMPLEX +#include "gemm_load_store_complex.hpp" +#endif namespace blas { diff --git a/test/unittest/blas3/blas3_gemm_common.hpp b/test/unittest/blas3/blas3_gemm_common.hpp index 3aacf4244..b9bd04e04 100644 --- a/test/unittest/blas3/blas3_gemm_common.hpp +++ b/test/unittest/blas3/blas3_gemm_common.hpp @@ -419,6 +419,11 @@ inline void verify_gemm(const gemm_cplx_arguments_t arguments) { std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, ldb_mul, ldc_mul, batch_type) = arguments; + if (batch > 1 && batch_type == gemm_batch_type_t::interleaved) { + // Interleaved batched gemm unsupported with complex data types + GTEST_SKIP(); + } + const char ta_str[2] = {transa, '\0'}; const char tb_str[2] = {transb, '\0'}; @@ -456,11 +461,6 @@ inline void verify_gemm(const gemm_cplx_arguments_t arguments) { reinterpret_cast(c_m_cpu.data() + i * size_c + offset), ldc); } - if (batch > 1 && batch_type == gemm_batch_type_t::interleaved) { - // Interleaved batched gemm unsupported - GTEST_SKIP(); - } - auto m_a_gpu = blas::helper::allocate>( buffer_size_a, q); auto m_b_gpu = blas::helper::allocate>( From 6a0e010b7011bcd810a038c15ed3712ab76b4101 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Fri, 13 Oct 2023 17:16:24 +0100 Subject: [PATCH 16/18] Removed symm kernels generation from complex data types --- cmake/CmakeFunctionHelper.cmake | 7 +++++-- src/interface/blas3/backend/amd_gpu.hpp | 8 ++++---- src/interface/blas3/backend/default_cpu.hpp | 4 ++-- src/interface/blas3/backend/intel_gpu.hpp | 12 ++++++------ src/interface/blas3/backend/nvidia_gpu.hpp | 6 +++--- 5 files changed, 20 insertions(+), 17 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index ef7fc22d6..2ae71bc5e 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -291,6 +291,9 @@ function(add_gemm_configuration cpp_type(cpp_data ${data}) foreach(symm_a ${boolean_list}) foreach(symm_b ${boolean_list}) + if ((${data} MATCHES "complex") AND (symm_a OR symm_b)) + continue() + endif() foreach(trans_a ${boolean_list}) foreach(trans_b ${boolean_list}) foreach(is_beta_zero ${boolean_list}) @@ -591,8 +594,8 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") set_complex_list(data_list_c "${supported_types}" "false") foreach(data ${data_list_c}) add_gemm_configuration( - "${data}" 64 "false" "false" "true" - 64 8 8 8 8 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + "${data}" 256 "false" "false" "true" + 64 2 2 16 16 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") endforeach() endif() # BLAS_ENABLE_COMPLEX else() # default cpu backend diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index a425b2f2a..f494f25b9 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -161,10 +161,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, static constexpr int tileWgSize = ClSize / sizeof(element_t); /* Tall & Skinny matrices. */ #ifdef GEMM_TALL_SKINNY_SUPPORT - if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8) && (!s_a && !s_b)) { + if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8)) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, true, true, true, - ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::tall_skinny), static_cast(gemm_vectorization_t::none), is_beta_zero, 1, @@ -177,7 +177,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if (_M * _N <= 65536) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -188,7 +188,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 14c0cd337..e62348363 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -119,7 +119,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if (_M <= 256 && _N <= 256 && _K <= 256) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 4, 4>, _t_a, _t_b, s_a, s_b, + 64, Tile<2, 2, 4, 4>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -130,7 +130,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<8, 8, 4, 4>, _t_a, _t_b, s_a, s_b, + 64, Tile<8, 8, 4, 4>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, diff --git a/src/interface/blas3/backend/intel_gpu.hpp b/src/interface/blas3/backend/intel_gpu.hpp index a0ce6f52a..8d788c9b5 100644 --- a/src/interface/blas3/backend/intel_gpu.hpp +++ b/src/interface/blas3/backend/intel_gpu.hpp @@ -222,11 +222,11 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { #ifdef GEMM_TALL_SKINNY_SUPPORT - if (!s_a && !s_b && batch_size == 1) { + if (batch_size == 1) { constexpr int wg_size = sizeof(element_t) == 16 ? 4 : 8; return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, true, true, true, 64, - Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, s_a, s_b, + Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::tall_skinny), static_cast(gemm_vectorization_t::none), is_beta_zero, 1, @@ -239,7 +239,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if (_M <= 128 && _N <= 128) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, true, false, false, 64, - Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + Tile<4, 4, 8, 8>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -247,10 +247,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else if (_t_b && !_t_a && !s_a && !s_b) { + } else if (_t_b && !_t_a) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + 64, Tile<8, 8, 8, 8>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, @@ -261,7 +261,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<4, 8, 16, 8>, _t_a, _t_b, s_a, s_b, + 64, Tile<4, 8, 16, 8>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index 7d555d902..13966172e 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -183,9 +183,9 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, true, 64, - Tile<8, 8, 8, 8, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, - s_a, s_b, static_cast(gemm_memory_t::local), + container_0_t, container_1_t, container_2_t, 256, false, false, true, 64, + Tile<2, 2, 16, 16, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, + false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, static_cast(gemm_batch_type_t::strided), From 3e3d4dc39ce2bc0d723e01e269cf7b4845dcfd4c Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 23 Oct 2023 14:40:34 +0100 Subject: [PATCH 17/18] Added & enabled portblas, cublas & rocblas benchmarks for complex GEMM --- benchmark/cublas/CMakeLists.txt | 8 + benchmark/cublas/blas3/gemm.cpp | 168 +++++++++++ benchmark/cublas/blas3/gemm_batched.cpp | 205 ++++++++++++- .../cublas/blas3/gemm_batched_strided.cpp | 200 +++++++++++++ benchmark/cublas/utils.hpp | 10 + benchmark/portblas/CMakeLists.txt | 8 + benchmark/portblas/blas3/gemm.cpp | 185 ++++++++++++ benchmark/portblas/blas3/gemm_batched.cpp | 222 +++++++++++++- .../portblas/blas3/gemm_batched_strided.cpp | 230 ++++++++++++++- benchmark/rocblas/CMakeLists.txt | 9 +- benchmark/rocblas/blas3/gemm.cpp | 183 ++++++++++++ benchmark/rocblas/blas3/gemm_batched.cpp | 200 +++++++++++++ .../rocblas/blas3/gemm_batched_strided.cpp | 217 ++++++++++++++ .../include/common/blas3_state_counters.hpp | 60 ++++ common/include/common/common_utils.hpp | 273 +++++++++++++++++- common/include/common/set_benchmark_label.hpp | 18 ++ 16 files changed, 2182 insertions(+), 14 deletions(-) diff --git a/benchmark/cublas/CMakeLists.txt b/benchmark/cublas/CMakeLists.txt index 250278fac..ad3b4ed05 100644 --- a/benchmark/cublas/CMakeLists.txt +++ b/benchmark/cublas/CMakeLists.txt @@ -74,12 +74,20 @@ set(sources extension/omatadd.cpp ) +# Operators supporting COMPLEX types benchmarking +set(CPLX_OPS "gemm" "gemm_batched" "gemm_batched_strided") + # Add individual benchmarks for each method foreach(cublas_bench ${sources}) get_filename_component(bench_cublas_exec ${cublas_bench} NAME_WE) add_executable(bench_cublas_${bench_cublas_exec} ${cublas_bench} main.cpp) target_link_libraries(bench_cublas_${bench_cublas_exec} PRIVATE benchmark CUDA::toolkit CUDA::cublas CUDA::cudart portblas Clara::Clara bench_info) target_compile_definitions(bench_cublas_${bench_cublas_exec} PRIVATE -DBLAS_INDEX_T=${BLAS_BENCHMARK_INDEX_TYPE}) + if(${BLAS_ENABLE_COMPLEX}) + if("${bench_cublas_exec}" IN_LIST CPLX_OPS) + target_compile_definitions(bench_cublas_${bench_cublas_exec} PRIVATE BLAS_ENABLE_COMPLEX=1) + endif() + endif() add_sycl_to_target( TARGET bench_cublas_${bench_cublas_exec} SOURCES ${cublas_bench} diff --git a/benchmark/cublas/blas3/gemm.cpp b/benchmark/cublas/blas3/gemm.cpp index 5a103d032..c74c9e98e 100644 --- a/benchmark/cublas/blas3/gemm.cpp +++ b/benchmark/cublas/blas3/gemm.cpp @@ -38,6 +38,18 @@ static inline void cublas_routine(args_t&&... args) { return; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline void cublas_cplx_routine(args_t&&... args) { + if constexpr (std::is_same_v) { + CUBLAS_CHECK(cublasCgemm(std::forward(args)...)); + } else if constexpr (std::is_same_v) { + CUBLAS_CHECK(cublasZgemm(std::forward(args)...)); + } + return; +} +#endif + template void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1, int t2, index_t m, index_t k, index_t n, scalar_t alpha, scalar_t beta, @@ -168,6 +180,162 @@ void register_benchmark(blas_benchmark::Args& args, } } +#ifdef BLAS_ENABLE_COMPLEX +template +using cudaComplex = typename std::conditional::type; + +template +void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1, + int t2, index_t m, index_t k, index_t n, std::complex alpha, + std::complex beta, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>(state); + + // Standard test setup. + std::string t1s = blas_benchmark::utils::from_transpose_enum( + static_cast(t1)); + std::string t2s = blas_benchmark::utils::from_transpose_enum( + static_cast(t2)); + const char* t_a = t1s.c_str(); + const char* t_b = t2s.c_str(); + + index_t lda = t_a[0] == 'n' ? m : k; + index_t ldb = t_b[0] == 'n' ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm, scalar_t>(state, beta, m, n, k, + static_cast(1)); + + cublasHandle_t& cuda_handle = *cuda_handle_ptr; + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(m * k); + std::vector> b = + blas_benchmark::utils::random_cplx_data(k * n); + std::vector> c = + blas_benchmark::utils::const_cplx_data(m * n, 0); + + blas_benchmark::utils::CUDAVector> a_gpu( + m * k, reinterpret_cast*>(a.data())); + blas_benchmark::utils::CUDAVector> b_gpu( + k * n, reinterpret_cast*>(b.data())); + blas_benchmark::utils::CUDAVector> c_gpu( + n * m, reinterpret_cast*>(c.data())); + + cublasOperation_t c_t_a = (*t_a == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t c_t_b = (*t_b == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T; + + cudaComplex cuBeta{beta.real(), beta.imag()}; + cudaComplex cuAlpha{alpha.real(), alpha.imag()}; + +#ifdef BLAS_VERIFY_BENCHMARK + // Run a first time with a verification of the results + std::vector> c_ref = c; + + reference_blas::cgemm(t_a, t_b, m, n, k, + reinterpret_cast(&alpha), + reinterpret_cast(a.data()), lda, + reinterpret_cast(b.data()), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data()), ldc); + std::vector> c_temp = c; + { + blas_benchmark::utils::CUDAVector, true> c_temp_gpu( + m * n, reinterpret_cast*>(c_temp.data())); + cublas_cplx_routine(cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + a_gpu, lda, b_gpu, ldb, &cuBeta, c_temp_gpu, + ldc); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors(c_temp, c_ref, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + auto blas_warmup = [&]() -> void { + cublas_cplx_routine(cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + a_gpu, lda, b_gpu, ldb, &cuBeta, c_gpu, ldc); + return; + }; + + cudaEvent_t start; + cudaEvent_t stop; + CUDA_CHECK(cudaEventCreate(&start)); + CUDA_CHECK(cudaEventCreate(&stop)); + + auto blas_method_def = [&]() -> std::vector { + CUDA_CHECK(cudaEventRecord(start)); + cublas_cplx_routine(cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + a_gpu, lda, b_gpu, ldb, &cuBeta, c_gpu, ldc); + CUDA_CHECK(cudaEventRecord(stop)); + CUDA_CHECK(cudaEventSynchronize(stop)); + return std::vector{start, stop}; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_warmup); + CUDA_CHECK(cudaStreamSynchronize(NULL)); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef_cuda(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + + blas_benchmark::utils::calc_avg_counters(state); + + CUDA_CHECK(cudaEventDestroy(start)); + CUDA_CHECK(cudaEventDestroy(stop)); +}; + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + cublasHandle_t* cuda_handle_ptr, bool* success) { + auto gemm_params = + blas_benchmark::utils::get_blas3_cplx_params(args); + for (auto p : gemm_params) { + std::string t1s, t2s; + index_t m, n, k; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + + std::tie(t1s, t2s, m, k, n, alpha_r, alpha_i, beta_r, beta_i) = p; + int t1 = static_cast(blas_benchmark::utils::to_transpose_enum(t1s)); + int t2 = static_cast(blas_benchmark::utils::to_transpose_enum(t2s)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, cublasHandle_t* cuda_handle_ptr, + int t1, int t2, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, bool* success) { + run(st, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t1s, t2s, m, k, n, blas_benchmark::utils::MEM_TYPE_USM) + .c_str(), + BM_lambda, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, success) + ->UseRealTime(); + } +} + +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, cublasHandle_t* cuda_handle_ptr, bool* success) { diff --git a/benchmark/cublas/blas3/gemm_batched.cpp b/benchmark/cublas/blas3/gemm_batched.cpp index 4cce28ff5..c0c50631f 100644 --- a/benchmark/cublas/blas3/gemm_batched.cpp +++ b/benchmark/cublas/blas3/gemm_batched.cpp @@ -38,6 +38,18 @@ static inline void cublas_routine(args_t&&... args) { return; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline void cublas_cplx_routine(args_t&&... args) { + if constexpr (std::is_same_v) { + CUBLAS_CHECK(cublasCgemmBatched(std::forward(args)...)); + } else if constexpr (std::is_same_v) { + CUBLAS_CHECK(cublasZgemmBatched(std::forward(args)...)); + } + return; +} +#endif + template void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, index_t t1, index_t t2, index_t m, index_t k, index_t n, scalar_t alpha, @@ -164,7 +176,7 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, index_t t1, state.counters["bytes_processed"]); blas_benchmark::utils::calc_avg_counters(state); - + CUDA_CHECK(cudaEventDestroy(start)); CUDA_CHECK(cudaEventDestroy(stop)); }; @@ -209,6 +221,197 @@ void register_benchmark(blas_benchmark::Args& args, } } +#ifdef BLAS_ENABLE_COMPLEX +template +using cudaComplex = typename std::conditional::type; +template +void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, index_t t1, + index_t t2, index_t m, index_t k, index_t n, + std::complex alpha, std::complex beta, + index_t batch_count, int batch_type_i, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>(state); + + // Standard setup + std::string t1s = blas_benchmark::utils::from_transpose_enum( + static_cast(t1)); + std::string t2s = blas_benchmark::utils::from_transpose_enum( + static_cast(t2)); + const char* t_a = t1s.c_str(); + const char* t_b = t2s.c_str(); + auto batch_type = static_cast(batch_type_i); + + index_t lda = t_a[0] == 'n' ? m : k; + index_t ldb = t_b[0] == 'n' ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm_batched, scalar_t>( + state, beta, m, n, k, batch_count); + + cublasHandle_t& cuda_handle = *cuda_handle_ptr; + + const index_t size_a = m * k; + const index_t size_b = k * n; + const index_t size_c = m * n; + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(size_a * batch_count); + std::vector> b = + blas_benchmark::utils::random_cplx_data(size_b * batch_count); + std::vector> c = + blas_benchmark::utils::const_cplx_data(size_c * batch_count, 0); + + blas_benchmark::utils::CUDAVectorBatched> d_A_array( + size_a, batch_count, reinterpret_cast*>(a.data())); + blas_benchmark::utils::CUDAVectorBatched> d_B_array( + size_b, batch_count, reinterpret_cast*>(b.data())); + blas_benchmark::utils::CUDAVectorBatched> d_C_array( + size_c, batch_count); + + cublasOperation_t c_t_a = (*t_a == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t c_t_b = (*t_b == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T; + + cudaComplex cuBeta{beta.real(), beta.imag()}; + cudaComplex cuAlpha{alpha.real(), alpha.imag()}; + +#ifdef BLAS_VERIFY_BENCHMARK + // Run a first time with a verification of the results + { + std::vector> c_ref = c; + auto _base = [=](index_t dim0, index_t dim1, index_t idx) { + return dim0 * dim1 * idx; + }; + for (int batch_idx = 0; batch_idx < batch_count; batch_idx++) { + reference_blas::cgemm( + t_a, t_b, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data() + _base(m, k, batch_idx)), lda, + reinterpret_cast(b.data() + _base(k, n, batch_idx)), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data() + _base(m, n, batch_idx)), ldc); + } + + std::vector> c_temp(size_c * batch_count); + + { + blas_benchmark::utils::CUDAVectorBatched, true> + c_temp_gpu(n * m, batch_count, + reinterpret_cast*>(c_temp.data())); + cublas_cplx_routine( + cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + d_A_array.get_batch_array(), lda, d_B_array.get_batch_array(), ldb, + &cuBeta, c_temp_gpu.get_batch_array(), ldc, batch_count); + } + + std::ostringstream err_stream; + for (int i = 0; i < batch_count; ++i) { + if (!utils::compare_vectors(c_temp, c_ref, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; + } + + } // close scope for verify benchmark +#endif + + auto blas_warmup = [&]() -> void { + cublas_cplx_routine( + cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + d_A_array.get_batch_array(), lda, d_B_array.get_batch_array(), ldb, + &cuBeta, d_C_array.get_batch_array(), ldc, batch_count); + return; + }; + + cudaEvent_t start, stop; + CUDA_CHECK(cudaEventCreate(&start)); + CUDA_CHECK(cudaEventCreate(&stop)); + + auto blas_method_def = [&]() -> std::vector { + CUDA_CHECK(cudaEventRecord(start)); + cublas_cplx_routine( + cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + d_A_array.get_batch_array(), lda, d_B_array.get_batch_array(), ldb, + &cuBeta, d_C_array.get_batch_array(), ldc, batch_count); + CUDA_CHECK(cudaEventRecord(stop)); + CUDA_CHECK(cudaEventSynchronize(stop)); + return std::vector{start, stop}; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_method_def); + CUDA_CHECK(cudaStreamSynchronize(NULL)); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef_cuda(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + + blas_benchmark::utils::calc_avg_counters(state); + + CUDA_CHECK(cudaEventDestroy(start)); + CUDA_CHECK(cudaEventDestroy(stop)); +}; + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + cublasHandle_t* cuda_handle_ptr, bool* success) { + auto gemm_batched_params = + blas_benchmark::utils::get_gemm_cplx_batched_params(args); + + for (auto p : gemm_batched_params) { + std::string t1s, t2s; + index_t m, n, k, batch_count; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + int batch_type; + + std::tie(t1s, t2s, m, k, n, alpha_r, alpha_i, beta_r, beta_i, batch_count, + batch_type) = p; + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + if (batch_type == 1) { + std::cerr << "interleaved memory for gemm_batched operator is not " + "supported by cuBLAS\n"; + continue; + } + + int t1 = static_cast(blas_benchmark::utils::to_transpose_enum(t1s)); + int t2 = static_cast(blas_benchmark::utils::to_transpose_enum(t2s)); + + auto BM_lambda = [&](benchmark::State& st, cublasHandle_t* cuda_handle_ptr, + int t1, int t2, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, index_t batch_count, + int batch_type, bool* success) { + run(st, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, + batch_count, batch_type, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t1s, t2s, m, k, n, batch_count, batch_type, + blas_benchmark::utils::MEM_TYPE_USM) + .c_str(), + BM_lambda, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_count, + batch_type, success) + ->UseRealTime(); + } +} +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, cublasHandle_t* cuda_handle_ptr, bool* success) { diff --git a/benchmark/cublas/blas3/gemm_batched_strided.cpp b/benchmark/cublas/blas3/gemm_batched_strided.cpp index d96b7adfe..beb81fb4c 100644 --- a/benchmark/cublas/blas3/gemm_batched_strided.cpp +++ b/benchmark/cublas/blas3/gemm_batched_strided.cpp @@ -38,6 +38,18 @@ static inline void cublas_routine(args_t&&... args) { return; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline void cublas_cplx_routine(args_t&&... args) { + if constexpr (std::is_same_v) { + CUBLAS_CHECK(cublasCgemmStridedBatched(std::forward(args)...)); + } else if constexpr (std::is_same_v) { + CUBLAS_CHECK(cublasZgemmStridedBatched(std::forward(args)...)); + } + return; +} +#endif + template void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1, int t2, index_t m, index_t k, index_t n, scalar_t alpha, scalar_t beta, @@ -208,6 +220,194 @@ void register_benchmark(blas_benchmark::Args& args, } } +#ifdef BLAS_ENABLE_COMPLEX +template +using cudaComplex = typename std::conditional::type; + +template +void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1, + int t2, index_t m, index_t k, index_t n, std::complex alpha, + std::complex beta, index_t batch_size, index_t stride_a_mul, + index_t stride_b_mul, index_t stride_c_mul, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>(state); + + // Standard test setup. + std::string t1s = blas_benchmark::utils::from_transpose_enum( + static_cast(t1)); + std::string t2s = blas_benchmark::utils::from_transpose_enum( + static_cast(t2)); + const char* t_a = t1s.c_str(); + const char* t_b = t2s.c_str(); + + const bool trA = t_a[0] == 'n'; + const bool trB = t_b[0] == 'n'; + + index_t lda = trA ? m : k; + index_t ldb = trB ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm_batched_strided, scalar_t>( + state, beta, m, n, k, batch_size, stride_a_mul, stride_b_mul, + stride_c_mul); + + cublasHandle_t& cuda_handle = *cuda_handle_ptr; + + // Data sizes + // Elementary matrices + const index_t a_size = m * k; + const index_t b_size = k * n; + const index_t c_size = m * n; + // Strides + const index_t stride_a = stride_a_mul * a_size; + const index_t stride_b = stride_b_mul * b_size; + const index_t stride_c = stride_c_mul * c_size; + // Batched matrices + const int size_a_batch = a_size + (batch_size - 1) * stride_a; + const int size_b_batch = b_size + (batch_size - 1) * stride_b; + const int size_c_batch = c_size + (batch_size - 1) * stride_c; + + // Matrices (Total size is equal to matrix size x batch_size since we're using + // default striding values) + std::vector> a = + blas_benchmark::utils::random_cplx_data(size_a_batch); + std::vector> b = + blas_benchmark::utils::random_cplx_data(size_b_batch); + std::vector> c = + blas_benchmark::utils::const_cplx_data(size_c_batch, 0); + + blas_benchmark::utils::CUDAVector> a_gpu( + size_a_batch, reinterpret_cast*>(a.data())); + blas_benchmark::utils::CUDAVector> b_gpu( + size_b_batch, reinterpret_cast*>(b.data())); + blas_benchmark::utils::CUDAVector> c_gpu( + size_c_batch, reinterpret_cast*>(c.data())); + + cublasOperation_t c_t_a = trA ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t c_t_b = trB ? CUBLAS_OP_N : CUBLAS_OP_T; + + cudaComplex cuBeta{beta.real(), beta.imag()}; + cudaComplex cuAlpha{alpha.real(), alpha.imag()}; + +#ifdef BLAS_VERIFY_BENCHMARK + // Run a first time with a verification of the results + std::vector> c_ref = c; + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + reference_blas::cgemm( + t_a, t_b, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data() + batch_idx * stride_a), lda, + reinterpret_cast(b.data() + batch_idx * stride_b), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data() + batch_idx * stride_c), ldc); + } + + std::vector> c_temp = c; + { + blas_benchmark::utils::CUDAVector, true> c_temp_gpu( + size_c_batch, reinterpret_cast*>(c_temp.data())); + cublas_cplx_routine( + cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, a_gpu, lda, stride_a, + b_gpu, ldb, stride_b, &cuBeta, c_temp_gpu, ldc, stride_c, batch_size); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors_strided(c_temp, c_ref, stride_c, c_size, + err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + + auto blas_warmup = [&]() -> void { + cublas_cplx_routine(cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + a_gpu, lda, stride_a, b_gpu, ldb, stride_b, + &cuBeta, c_gpu, ldc, stride_c, batch_size); + return; + }; + + cudaEvent_t start, stop; + CUDA_CHECK(cudaEventCreate(&start)); + CUDA_CHECK(cudaEventCreate(&stop)); + + auto blas_method_def = [&]() -> std::vector { + CUDA_CHECK(cudaEventRecord(start)); + cublas_cplx_routine(cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + a_gpu, lda, stride_a, b_gpu, ldb, stride_b, + &cuBeta, c_gpu, ldc, stride_c, batch_size); + CUDA_CHECK(cudaEventRecord(stop)); + CUDA_CHECK(cudaEventSynchronize(stop)); + return std::vector{start, stop}; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_warmup); + CUDA_CHECK(cudaStreamSynchronize(NULL)); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef_cuda(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + + blas_benchmark::utils::calc_avg_counters(state); + + CUDA_CHECK(cudaEventDestroy(start)); + CUDA_CHECK(cudaEventDestroy(stop)); +}; + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + cublasHandle_t* cuda_handle_ptr, bool* success) { + auto gemm_batched_strided_params = + blas_benchmark::utils::get_gemm_batched_strided_cplx_params( + args); + + for (auto p : gemm_batched_strided_params) { + std::string t1s, t2s; + index_t m, n, k, batch_size, stride_a_mul, stride_b_mul, stride_c_mul; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + std::tie(t1s, t2s, m, k, n, alpha_r, alpha_i, beta_r, beta_i, batch_size, + stride_a_mul, stride_b_mul, stride_c_mul) = p; + int t1 = static_cast(blas_benchmark::utils::to_transpose_enum(t1s)); + int t2 = static_cast(blas_benchmark::utils::to_transpose_enum(t2s)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, cublasHandle_t* cuda_handle_ptr, + int t1, int t2, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, index_t batch_size, + index_t strd_a_mul, index_t strd_b_mul, + index_t strd_c_mul, bool* success) { + run(st, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, + batch_size, strd_a_mul, strd_b_mul, strd_c_mul, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t1s, t2s, m, k, n, batch_size, stride_a_mul, stride_b_mul, + stride_c_mul, blas_benchmark::utils::MEM_TYPE_USM) + .c_str(), + BM_lambda, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size, + stride_a_mul, stride_b_mul, stride_c_mul, success) + ->UseRealTime(); + } +} + +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, cublasHandle_t* cuda_handle_ptr, bool* success) { diff --git a/benchmark/cublas/utils.hpp b/benchmark/cublas/utils.hpp index eeaee7371..362fdce51 100644 --- a/benchmark/cublas/utils.hpp +++ b/benchmark/cublas/utils.hpp @@ -33,6 +33,7 @@ #include "portblas.h" #include +#include #include #include #include @@ -179,6 +180,15 @@ class CUDAVectorBatched : private CUDADeviceMemory { } } + CUDAVectorBatched(size_t matrix_size, size_t batch_count, T* h_v) + : CUDAVectorBatched(matrix_size, batch_count) { + if constexpr (CopyToHost) h_data = h_v; + for (int i = 0; i < batch_count; ++i) { + CUDA_CHECK(cudaMemcpy(d_data[i], &h_v[matrix_size * i], + sizeof(T) * c_matrix_size, cudaMemcpyHostToDevice)); + } + } + ~CUDAVectorBatched() { if constexpr (CopyToHost) { for (int i = 0; i < c_batch_count; ++i) { diff --git a/benchmark/portblas/CMakeLists.txt b/benchmark/portblas/CMakeLists.txt index 785996422..4ac3fdeaa 100644 --- a/benchmark/portblas/CMakeLists.txt +++ b/benchmark/portblas/CMakeLists.txt @@ -75,12 +75,20 @@ if(${BLAS_ENABLE_EXTENSIONS}) list(APPEND sources extension/reduction.cpp) endif() +# Operators supporting COMPLEX types benchmarking +set(CPLX_OPS "gemm" "gemm_batched" "gemm_batched_strided") + # Add individual benchmarks for each method foreach(portblas_bench ${sources}) get_filename_component(bench_exec ${portblas_bench} NAME_WE) add_executable(bench_${bench_exec} ${portblas_bench} main.cpp) target_link_libraries(bench_${bench_exec} PRIVATE benchmark Clara::Clara portblas bench_info) target_compile_definitions(bench_${bench_exec} PRIVATE -DBLAS_INDEX_T=${BLAS_BENCHMARK_INDEX_TYPE}) + if(${BLAS_ENABLE_COMPLEX}) + if("${bench_exec}" IN_LIST CPLX_OPS) + target_compile_definitions(bench_${bench_exec} PRIVATE BLAS_ENABLE_COMPLEX=1) + endif() + endif() add_sycl_to_target( TARGET bench_${bench_exec} SOURCES ${portblas_bench} diff --git a/benchmark/portblas/blas3/gemm.cpp b/benchmark/portblas/blas3/gemm.cpp index 51d4869a8..27bb90650 100644 --- a/benchmark/portblas/blas3/gemm.cpp +++ b/benchmark/portblas/blas3/gemm.cpp @@ -177,6 +177,191 @@ void register_benchmark(blas_benchmark::Args& args, #endif } +#ifdef BLAS_ENABLE_COMPLEX +template +void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, int t1, + int t2, index_t m, index_t k, index_t n, std::complex alpha, + std::complex beta, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>( + state, sb_handle_ptr->get_queue()); + + // Standard test setup. + std::string t1s = blas_benchmark::utils::from_transpose_enum( + static_cast(t1)); + std::string t2s = blas_benchmark::utils::from_transpose_enum( + static_cast(t2)); + const char* t_a = t1s.c_str(); + const char* t_b = t2s.c_str(); + + index_t lda = t_a[0] == 'n' ? m : k; + index_t ldb = t_b[0] == 'n' ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm, scalar_t>(state, beta, m, n, k, + static_cast(1)); + + blas::SB_Handle& sb_handle = *sb_handle_ptr; + auto q = sb_handle.get_queue(); + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(m * k); + std::vector> b = + blas_benchmark::utils::random_cplx_data(k * n); + std::vector> c = + blas_benchmark::utils::const_cplx_data(m * n, 0); + + auto a_gpu = + blas::helper::allocate>(m * k, q); + auto b_gpu = + blas::helper::allocate>(k * n, q); + auto c_gpu = + blas::helper::allocate>(m * n, q); + + auto copy_a = blas::helper::copy_to_device( + q, reinterpret_cast*>(a.data()), a_gpu, + m * k); + auto copy_b = blas::helper::copy_to_device( + q, reinterpret_cast*>(b.data()), b_gpu, + n * k); + auto copy_c = blas::helper::copy_to_device( + q, reinterpret_cast*>(c.data()), c_gpu, + m * n); + + sb_handle.wait({copy_a, copy_b, copy_c}); + + // Kernel expects sycl::complex and not std::complex data + blas::complex_sycl alpha_sycl(alpha); + blas::complex_sycl beta_sycl(beta); + +#ifdef BLAS_VERIFY_BENCHMARK + // Run a first time with a verification of the results + std::vector> c_ref = c; + reference_blas::cgemm(t_a, t_b, m, n, k, + reinterpret_cast(&alpha), + reinterpret_cast(a.data()), lda, + reinterpret_cast(b.data()), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data()), ldc); + + std::vector> c_temp = c; + + { + auto c_temp_gpu = + blas::helper::allocate>(m * n, + q); + auto copy_temp = blas::helper::copy_to_device( + q, reinterpret_cast*>(c_temp.data()), + c_temp_gpu, m * n); + sb_handle.wait(copy_temp); + auto gemm_event = _gemm(sb_handle, *t_a, *t_b, m, n, k, alpha_sycl, a_gpu, + lda, b_gpu, ldb, beta_sycl, c_temp_gpu, ldc); + sb_handle.wait(gemm_event); + auto copy_out = blas::helper::copy_to_host( + q, c_temp_gpu, + reinterpret_cast*>(c_temp.data()), m * n); + sb_handle.wait(copy_out); + + blas::helper::deallocate(c_temp_gpu, q); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors(c_temp, c_ref, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + + auto blas_method_def = [&]() -> std::vector { + auto event = _gemm(sb_handle, *t_a, *t_b, m, n, k, alpha_sycl, a_gpu, lda, + b_gpu, ldb, beta_sycl, c_gpu, ldc); + sb_handle.wait(event); + return event; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_method_def); + sb_handle.wait(); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + + blas_benchmark::utils::calc_avg_counters(state); + + blas::helper::deallocate(a_gpu, q); + blas::helper::deallocate(b_gpu, q); + blas::helper::deallocate(c_gpu, q); +}; + +/*! @brief Register & run benchmark of complex data types gemm. + * Function is similar to register_benchmark + * + * @tparam scalar_t element data type of underlying complex (float or double) + * @tparam mem_alloc USM or Buffer memory allocation + */ +template +void register_cplx_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success, + std::string mem_type, + std::vector> params) { + for (auto p : params) { + std::string t1s, t2s; + index_t m, n, k; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + + std::tie(t1s, t2s, m, k, n, alpha_r, alpha_i, beta_r, beta_i) = p; + int t1 = static_cast(blas_benchmark::utils::to_transpose_enum(t1s)); + int t2 = static_cast(blas_benchmark::utils::to_transpose_enum(t2s)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, blas::SB_Handle* sb_handle_ptr, + int t1, int t2, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, bool* success) { + run(st, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, + success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t1s, t2s, m, k, n, mem_type) + .c_str(), + BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, success) + ->UseRealTime(); + } +} + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + blas::SB_Handle* sb_handle_ptr, bool* success) { + auto gemm_params = + blas_benchmark::utils::get_blas3_cplx_params(args); + register_cplx_benchmark( + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, + gemm_params); +#ifdef SB_ENABLE_USM + register_cplx_benchmark( + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, gemm_params); +#endif +} + +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr, bool* success) { diff --git a/benchmark/portblas/blas3/gemm_batched.cpp b/benchmark/portblas/blas3/gemm_batched.cpp index 959f9eae7..3d98c3697 100644 --- a/benchmark/portblas/blas3/gemm_batched.cpp +++ b/benchmark/portblas/blas3/gemm_batched.cpp @@ -225,8 +225,8 @@ void register_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - t1s, t2s, m, k, n, batch_size, batch_type, - mem_type).c_str(), + t1s, t2s, m, k, n, batch_size, batch_type, mem_type) + .c_str(), BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size, batch_type, success) ->UseRealTime(); @@ -239,13 +239,227 @@ void register_benchmark(blas_benchmark::Args& args, auto gemm_batched_params = blas_benchmark::utils::get_gemm_batched_params(args); register_benchmark( - sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, gemm_batched_params); + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, + gemm_batched_params); #ifdef SB_ENABLE_USM register_benchmark( - sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, gemm_batched_params); + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, + gemm_batched_params); #endif } +#ifdef BLAS_ENABLE_COMPLEX +template +void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, int t1, + int t2, index_t m, index_t k, index_t n, std::complex alpha, + std::complex beta, index_t batch_size, int batch_type_i, + bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>( + state, sb_handle_ptr->get_queue()); + + // Standard test setup. + std::string t1s = blas_benchmark::utils::from_transpose_enum( + static_cast(t1)); + std::string t2s = blas_benchmark::utils::from_transpose_enum( + static_cast(t2)); + const char* t_a = t1s.c_str(); + const char* t_b = t2s.c_str(); + auto batch_type = static_cast(batch_type_i); + + index_t lda = t_a[0] == 'n' ? m : k; + index_t ldb = t_b[0] == 'n' ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm_batched, scalar_t>( + state, beta, m, n, k, batch_size); + + blas::SB_Handle& sb_handle = *sb_handle_ptr; + auto q = sb_handle.get_queue(); + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(m * k * batch_size); + std::vector> b = + blas_benchmark::utils::random_cplx_data(k * n * batch_size); + std::vector> c = + blas_benchmark::utils::const_cplx_data(m * n * batch_size, + scalar_t(0)); + +#ifdef BLAS_VERIFY_BENCHMARK + // Run a first time with a verification of the results + std::vector> c_ref = c; + auto _base = [=](index_t dim0, index_t dim1, index_t idx) { + return dim0 * dim1 * idx; + }; + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + reference_blas::cgemm( + t_a, t_b, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data() + _base(m, k, batch_idx)), lda, + reinterpret_cast(b.data() + _base(k, n, batch_idx)), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data() + _base(m, n, batch_idx)), ldc); + } + + if (batch_type == blas::gemm_batch_type_t::interleaved) { + state.SkipWithError("Interleaved unsupported with Complex data types."); + *success = false; + } + +#endif // BLAS_VERIFY_BENCHMARK + + auto a_gpu = blas::helper::allocate>( + m * k * batch_size, q); + auto b_gpu = blas::helper::allocate>( + k * n * batch_size, q); + auto c_gpu = blas::helper::allocate>( + m * n * batch_size, q); + + auto copy_a = blas::helper::copy_to_device( + q, reinterpret_cast*>(a.data()), a_gpu, + m * k * batch_size); + auto copy_b = blas::helper::copy_to_device( + q, reinterpret_cast*>(b.data()), b_gpu, + n * k * batch_size); + auto copy_c = blas::helper::copy_to_device( + q, reinterpret_cast*>(c.data()), c_gpu, + m * n * batch_size); + + sb_handle.wait({copy_a, copy_b, copy_c}); + + // Kernel expects sycl::complex and not std::complex data + blas::complex_sycl alpha_sycl(alpha); + blas::complex_sycl beta_sycl(beta); + +#ifdef BLAS_VERIFY_BENCHMARK + std::vector> c_temp = c; + { + auto c_temp_gpu = + blas::helper::allocate>( + m * n * batch_size, q); + auto copy_temp = blas::helper::copy_to_device( + q, reinterpret_cast*>(c_temp.data()), + c_temp_gpu, m * n * batch_size); + sb_handle.wait(copy_temp); + auto gemm_batched_event = _gemm_batched( + sb_handle, *t_a, *t_b, m, n, k, alpha_sycl, a_gpu, lda, b_gpu, ldb, + beta_sycl, c_temp_gpu, ldc, batch_size, batch_type); + sb_handle.wait(gemm_batched_event); + auto copy_out = blas::helper::copy_to_host( + q, c_temp_gpu, + reinterpret_cast*>(c_temp.data()), + m * n * batch_size); + sb_handle.wait(copy_out); + + blas::helper::deallocate(c_temp_gpu, q); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors(c_temp, c_ref, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif // BLAS_VERIFY_BENCHMARK + + auto blas_method_def = [&]() -> std::vector { + auto event = _gemm_batched(sb_handle, *t_a, *t_b, m, n, k, alpha_sycl, + a_gpu, lda, b_gpu, ldb, beta_sycl, c_gpu, ldc, + batch_size, batch_type); + sb_handle.wait(event); + return event; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_method_def); + sb_handle.wait(); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + + blas_benchmark::utils::calc_avg_counters(state); + + blas::helper::deallocate(a_gpu, q); + blas::helper::deallocate(b_gpu, q); + blas::helper::deallocate(c_gpu, q); +}; + +/*! @brief Register & run benchmark of complex data types gemm batched. + * Function is similar to register_benchmark + * + * @tparam scalar_t element data type of underlying complex (float or double) + * @tparam mem_alloc USM or Buffer memory allocation + */ +template +void register_cplx_benchmark( + blas::SB_Handle* sb_handle_ptr, bool* success, std::string mem_type, + std::vector> params) { + for (auto p : params) { + std::string t1s, t2s; + index_t m, n, k, batch_size; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + int batch_type; + std::tie(t1s, t2s, m, k, n, alpha_r, alpha_i, beta_r, beta_i, batch_size, + batch_type) = p; + // Only batch_type == strided is supported with Complex data + if (batch_type == 1) { + std::cerr << "interleaved memory for gemm_batched operator is not " + "supported whith complex data type\n"; + continue; + } + int t1 = static_cast(blas_benchmark::utils::to_transpose_enum(t1s)); + int t2 = static_cast(blas_benchmark::utils::to_transpose_enum(t2s)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, blas::SB_Handle* sb_handle_ptr, + int t1, int t2, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, index_t batch_size, + int batch_type, bool* success) { + run(st, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, + batch_size, batch_type, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t1s, t2s, m, k, n, batch_size, batch_type, mem_type) + .c_str(), + BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size, + batch_type, success) + ->UseRealTime(); + } +} + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + blas::SB_Handle* sb_handle_ptr, bool* success) { + auto gemm_batched_params = + blas_benchmark::utils::get_gemm_cplx_batched_params(args); + register_cplx_benchmark( + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, + gemm_batched_params); +#ifdef SB_ENABLE_USM + register_cplx_benchmark( + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, + gemm_batched_params); +#endif +} +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr, bool* success) { diff --git a/benchmark/portblas/blas3/gemm_batched_strided.cpp b/benchmark/portblas/blas3/gemm_batched_strided.cpp index 0fdb29db9..a24a2a188 100644 --- a/benchmark/portblas/blas3/gemm_batched_strided.cpp +++ b/benchmark/portblas/blas3/gemm_batched_strided.cpp @@ -195,7 +195,8 @@ void register_benchmark( benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( t1s, t2s, m, k, n, batch_size, stride_a_mul, stride_b_mul, - stride_c_mul, mem_type).c_str(), + stride_c_mul, mem_type) + .c_str(), BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size, stride_a_mul, stride_b_mul, stride_c_mul, success) ->UseRealTime(); @@ -208,13 +209,236 @@ void register_benchmark(blas_benchmark::Args& args, auto gemm_batched_strided_params = blas_benchmark::utils::get_gemm_batched_strided_params(args); register_benchmark( - sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, gemm_batched_strided_params); + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, + gemm_batched_strided_params); #ifdef SB_ENABLE_USM register_benchmark( - sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, gemm_batched_strided_params); + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, + gemm_batched_strided_params); #endif } +#ifdef BLAS_ENABLE_COMPLEX +template +void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, int t1, + int t2, index_t m, index_t k, index_t n, std::complex alpha, + std::complex beta, index_t batch_size, index_t stride_a_mul, + index_t stride_b_mul, index_t stride_c_mul, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>( + state, sb_handle_ptr->get_queue()); + + // Standard test setup. + std::string t1s = blas_benchmark::utils::from_transpose_enum( + static_cast(t1)); + std::string t2s = blas_benchmark::utils::from_transpose_enum( + static_cast(t2)); + const char* t_a = t1s.c_str(); + const char* t_b = t2s.c_str(); + + const bool trA = t_a[0] != 'n'; + const bool trB = t_b[0] != 'n'; + + index_t lda = trA ? k : m; + index_t ldb = trB ? n : k; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm_batched_strided, scalar_t>( + state, beta, m, n, k, batch_size, stride_a_mul, stride_b_mul, + stride_c_mul); + + blas::SB_Handle& sb_handle = *sb_handle_ptr; + auto q = sb_handle.get_queue(); + + // Data sizes + // Elementary matrices + const index_t a_size = m * k; + const index_t b_size = k * n; + const index_t c_size = m * n; + // Strides + const index_t stride_a = stride_a_mul * a_size; + const index_t stride_b = stride_b_mul * b_size; + const index_t stride_c = stride_c_mul * c_size; + // Batched matrices + const int size_a_batch = a_size + (batch_size - 1) * stride_a; + const int size_b_batch = b_size + (batch_size - 1) * stride_b; + const int size_c_batch = c_size + (batch_size - 1) * stride_c; + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(size_a_batch); + std::vector> b = + blas_benchmark::utils::random_cplx_data(size_b_batch); + std::vector> c = + blas_benchmark::utils::const_cplx_data(size_c_batch, + scalar_t(0)); + +#ifdef BLAS_VERIFY_BENCHMARK + // Run a first time with a verification of the results + std::vector> c_ref = c; + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + reference_blas::cgemm( + t_a, t_b, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data() + batch_idx * stride_a), lda, + reinterpret_cast(b.data() + batch_idx * stride_b), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data() + batch_idx * stride_c), ldc); + } + +#endif + + auto a_gpu = blas::helper::allocate>( + size_a_batch, q); + auto b_gpu = blas::helper::allocate>( + size_b_batch, q); + auto c_gpu = blas::helper::allocate>( + size_c_batch, q); + + auto copy_a = blas::helper::copy_to_device( + q, reinterpret_cast*>(a.data()), a_gpu, + size_a_batch); + auto copy_b = blas::helper::copy_to_device( + q, reinterpret_cast*>(b.data()), b_gpu, + size_b_batch); + auto copy_c = blas::helper::copy_to_device( + q, reinterpret_cast*>(c.data()), c_gpu, + size_c_batch); + + sb_handle.wait({copy_a, copy_b, copy_c}); + + // Kernel expects sycl::complex and not std::complex data + blas::complex_sycl alpha_sycl(alpha); + blas::complex_sycl beta_sycl(beta); + +#ifdef BLAS_VERIFY_BENCHMARK + std::vector> c_temp = c; + { + auto c_temp_gpu = + blas::helper::allocate>( + size_c_batch, q); + auto copy_temp = blas::helper::copy_to_device( + q, reinterpret_cast*>(c_temp.data()), + c_temp_gpu, size_c_batch); + sb_handle.wait(copy_temp); + auto gemm_batched_strided_event = _gemm_strided_batched( + sb_handle, *t_a, *t_b, m, n, k, alpha_sycl, a_gpu, lda, stride_a, b_gpu, + ldb, stride_b, beta_sycl, c_temp_gpu, ldc, stride_c, batch_size); + sb_handle.wait(gemm_batched_strided_event); + + auto copy_out = blas::helper::copy_to_host( + q, c_temp_gpu, + reinterpret_cast*>(c_temp.data()), + size_c_batch); + sb_handle.wait(copy_out); + + blas::helper::deallocate(c_temp_gpu, q); + } + + std::ostringstream err_stream; + if (!::utils::compare_vectors_strided(c_temp, c_ref, stride_c, + c_size, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + + auto blas_method_def = [&]() -> std::vector { + auto event = _gemm_strided_batched( + sb_handle, *t_a, *t_b, m, n, k, alpha_sycl, a_gpu, lda, stride_a, b_gpu, + ldb, stride_b, beta_sycl, c_gpu, ldc, stride_c, batch_size); + sb_handle.wait(event); + return event; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_method_def); + sb_handle.wait(); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + + blas_benchmark::utils::calc_avg_counters(state); + + blas::helper::deallocate(a_gpu, q); + blas::helper::deallocate(b_gpu, q); + blas::helper::deallocate(c_gpu, q); +}; + +/*! @brief Register & run benchmark of complex data types gemm batched strided. + * Function is similar to register_benchmark + * + * @tparam scalar_t element data type of underlying complex (float or double) + * @tparam mem_alloc USM or Buffer memory allocation + */ +template +void register_cplx_benchmark( + blas::SB_Handle* sb_handle_ptr, bool* success, std::string mem_type, + std::vector> params) { + for (auto p : params) { + std::string t1s, t2s; + index_t m, n, k, batch_size, stride_a_mul, stride_b_mul, stride_c_mul; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + + std::tie(t1s, t2s, m, k, n, alpha_r, alpha_i, beta_r, beta_i, batch_size, + stride_a_mul, stride_b_mul, stride_c_mul) = p; + int t1 = static_cast(blas_benchmark::utils::to_transpose_enum(t1s)); + int t2 = static_cast(blas_benchmark::utils::to_transpose_enum(t2s)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, blas::SB_Handle* sb_handle_ptr, + int t1, int t2, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, index_t batch_size, + index_t stride_a_mul, index_t stride_b_mul, + index_t stride_c_mul, bool* success) { + run(st, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, + batch_size, stride_a_mul, stride_b_mul, + stride_c_mul, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t1s, t2s, m, k, n, batch_size, stride_a_mul, stride_b_mul, + stride_c_mul, mem_type) + .c_str(), + BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size, + stride_a_mul, stride_b_mul, stride_c_mul, success) + ->UseRealTime(); + } +} + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + blas::SB_Handle* sb_handle_ptr, bool* success) { + auto gemm_batched_strided_params = + blas_benchmark::utils::get_gemm_batched_strided_cplx_params( + args); + register_cplx_benchmark( + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, + gemm_batched_strided_params); +#ifdef SB_ENABLE_USM + register_cplx_benchmark( + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, + gemm_batched_strided_params); +#endif +} +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr, bool* success) { diff --git a/benchmark/rocblas/CMakeLists.txt b/benchmark/rocblas/CMakeLists.txt index caa884725..64a559931 100644 --- a/benchmark/rocblas/CMakeLists.txt +++ b/benchmark/rocblas/CMakeLists.txt @@ -77,6 +77,9 @@ set(sources ) +# Operators supporting COMPLEX types benchmarking +set(CPLX_OPS "gemm" "gemm_batched" "gemm_batched_strided") + # Add individual benchmarks for each method foreach(rocblas_benchmark ${sources}) get_filename_component(rocblas_bench_exec ${rocblas_benchmark} NAME_WE) @@ -84,7 +87,11 @@ foreach(rocblas_benchmark ${sources}) target_link_libraries(bench_rocblas_${rocblas_bench_exec} PRIVATE benchmark Clara::Clara roc::rocblas bench_info) target_compile_definitions(bench_rocblas_${rocblas_bench_exec} PRIVATE -DBLAS_INDEX_T=${BLAS_BENCHMARK_INDEX_TYPE}) target_include_directories(bench_rocblas_${rocblas_bench_exec} PRIVATE ${PORTBLAS_INCLUDE} ${rocblas_INCLUDE_DIRS} ${CBLAS_INCLUDE} ${BLAS_BENCH} ${PORTBLAS_COMMON_INCLUDE_DIR}) - + if(${BLAS_ENABLE_COMPLEX}) + if("${rocblas_bench_exec}" IN_LIST CPLX_OPS) + target_compile_definitions(bench_rocblas_${rocblas_bench_exec} PRIVATE BLAS_ENABLE_COMPLEX=1) + endif() + endif() # Even though rocblas does not use sycl, the common tools indirectly include sycl headers. add_sycl_to_target( TARGET bench_rocblas_${rocblas_bench_exec} diff --git a/benchmark/rocblas/blas3/gemm.cpp b/benchmark/rocblas/blas3/gemm.cpp index b403bafec..ca07ba2ba 100644 --- a/benchmark/rocblas/blas3/gemm.cpp +++ b/benchmark/rocblas/blas3/gemm.cpp @@ -38,6 +38,18 @@ static inline void rocblas_gemm_f(args_t&&... args) { return; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline void rocblas_cplx_gemm_f(args_t&&... args) { + if constexpr (std::is_same_v) { + CHECK_ROCBLAS_STATUS(rocblas_cgemm(std::forward(args)...)); + } else if constexpr (std::is_same_v) { + CHECK_ROCBLAS_STATUS(rocblas_zgemm(std::forward(args)...)); + } + return; +} +#endif + template void run(benchmark::State& state, rocblas_handle& rb_handle, int t_a_i, int t_b_i, index_t m, index_t k, index_t n, scalar_t alpha, @@ -183,6 +195,177 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, } } +#ifdef BLAS_ENABLE_COMPLEX +template +using rocComplex = + typename std::conditional::type; + +template +void run(benchmark::State& state, rocblas_handle& rb_handle, int t_a_i, + int t_b_i, index_t m, index_t k, index_t n, + std::complex alpha, std::complex beta, + bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>(state); + + // Standard test setup. + std::string t_a = blas_benchmark::utils::from_transpose_enum( + static_cast(t_a_i)); + std::string t_b = blas_benchmark::utils::from_transpose_enum( + static_cast(t_b_i)); + const char* t_a_str = t_a.c_str(); + const char* t_b_str = t_b.c_str(); + + index_t lda = t_a_str[0] == 'n' ? m : k; + index_t ldb = t_b_str[0] == 'n' ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm, scalar_t>(state, beta, m, n, k); + + // Matrix options (rocBLAS) + const rocblas_operation trans_a_rb = + t_a_str[0] == 'n' ? rocblas_operation_none : rocblas_operation_transpose; + const rocblas_operation trans_b_rb = + t_b_str[0] == 'n' ? rocblas_operation_none : rocblas_operation_transpose; + + // rocBLAS complex alpha & beta + rocComplex rocBeta{beta.real(), beta.imag()}; + rocComplex rocAlpha{alpha.real(), alpha.imag()}; + + // Data sizes + const int a_size = m * k; + const int b_size = k * n; + const int c_size = m * n; + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(a_size); + std::vector> b = + blas_benchmark::utils::random_cplx_data(b_size); + std::vector> c = + blas_benchmark::utils::const_cplx_data(c_size, 0); + + { + // Device memory allocation & H2D copy + blas_benchmark::utils::HIPVector> a_gpu( + a_size, reinterpret_cast*>(a.data())); + blas_benchmark::utils::HIPVector> b_gpu( + b_size, reinterpret_cast*>(b.data())); + blas_benchmark::utils::HIPVector> c_gpu( + c_size, reinterpret_cast*>(c.data())); + +#ifdef BLAS_VERIFY_BENCHMARK + // Reference gemm + std::vector> c_ref = c; + reference_blas::cgemm( + t_a_str, t_b_str, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data()), lda, + reinterpret_cast(b.data()), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data()), ldc); + + // Rocblas verification gemm + std::vector> c_temp = c; + { + blas_benchmark::utils::HIPVector, true> c_temp_gpu( + c_size, reinterpret_cast*>(c_temp.data())); + rocblas_cplx_gemm_f(rb_handle, trans_a_rb, trans_b_rb, m, n, k, + &rocAlpha, a_gpu, lda, b_gpu, ldb, &rocBeta, + c_temp_gpu, ldc); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors(c_temp, c_ref, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + + auto blas_warmup = [&]() -> void { + rocblas_cplx_gemm_f(rb_handle, trans_a_rb, trans_b_rb, m, n, k, + &rocAlpha, a_gpu, lda, b_gpu, ldb, &rocBeta, + c_gpu, ldc); + return; + }; + + hipEvent_t start, stop; + CHECK_HIP_ERROR(hipEventCreate(&start)); + CHECK_HIP_ERROR(hipEventCreate(&stop)); + + auto blas_method_def = [&]() -> std::vector { + CHECK_HIP_ERROR(hipEventRecord(start, NULL)); + rocblas_cplx_gemm_f(rb_handle, trans_a_rb, trans_b_rb, m, n, k, + &rocAlpha, a_gpu, lda, b_gpu, ldb, &rocBeta, + c_gpu, ldc); + CHECK_HIP_ERROR(hipEventRecord(stop, NULL)); + CHECK_HIP_ERROR(hipEventSynchronize(stop)); + return std::vector{start, stop}; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_warmup); + CHECK_HIP_ERROR(hipStreamSynchronize(NULL)); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef_hip(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + + blas_benchmark::utils::calc_avg_counters(state); + + CHECK_HIP_ERROR(hipEventDestroy(start)); + CHECK_HIP_ERROR(hipEventDestroy(stop)); + } // release device memory via utils::DeviceVector destructors +}; + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + rocblas_handle& rb_handle, bool* success) { + auto gemm_params = + blas_benchmark::utils::get_blas3_cplx_params(args); + + for (auto p : gemm_params) { + std::string t_a, t_b; + index_t m, n, k; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + + std::tie(t_a, t_b, m, k, n, alpha_r, alpha_i, beta_r, beta_i) = p; + int t_a_i = static_cast(blas_benchmark::utils::to_transpose_enum(t_a)); + int t_b_i = static_cast(blas_benchmark::utils::to_transpose_enum(t_b)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, rocblas_handle rb_handle, + int t1i, int t2i, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, bool* success) { + run(st, rb_handle, t1i, t2i, m, k, n, alpha, beta, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t_a, t_b, m, k, n, blas_benchmark::utils::MEM_TYPE_USM) + .c_str(), + BM_lambda, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, success) + ->UseRealTime(); + } +} + +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, bool* success) { diff --git a/benchmark/rocblas/blas3/gemm_batched.cpp b/benchmark/rocblas/blas3/gemm_batched.cpp index 4cfb1418d..40147d5ff 100644 --- a/benchmark/rocblas/blas3/gemm_batched.cpp +++ b/benchmark/rocblas/blas3/gemm_batched.cpp @@ -38,6 +38,18 @@ static inline void rocblas_gemm_batched_f(args_t&&... args) { return; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline void rocblas_cplx_gemm_batched_f(args_t&&... args) { + if constexpr (std::is_same_v) { + CHECK_ROCBLAS_STATUS(rocblas_cgemm_batched(std::forward(args)...)); + } else if constexpr (std::is_same_v) { + CHECK_ROCBLAS_STATUS(rocblas_zgemm_batched(std::forward(args)...)); + } + return; +} +#endif + template void run(benchmark::State& state, rocblas_handle& rb_handle, index_t t_a_i, index_t t_b_i, index_t m, index_t k, index_t n, scalar_t alpha, @@ -209,6 +221,194 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, } } +#ifdef BLAS_ENABLE_COMPLEX +template +using rocComplex = + typename std::conditional::type; +template +void run(benchmark::State& state, rocblas_handle& rb_handle, index_t t_a_i, + index_t t_b_i, index_t m, index_t k, index_t n, + std::complex alpha, std::complex beta, + index_t batch_size, int batch_type_i, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>(state); + + // Standard setup + std::string t_a = blas_benchmark::utils::from_transpose_enum( + static_cast(t_a_i)); + std::string t_b = blas_benchmark::utils::from_transpose_enum( + static_cast(t_b_i)); + const char* t_a_str = t_a.c_str(); + const char* t_b_str = t_b.c_str(); + auto batch_type = static_cast(batch_type_i); + + const bool trA = (t_a_str[0] == 'n'); + const bool trB = (t_b_str[0] == 'n'); + + index_t lda = trA ? m : k; + index_t ldb = trB ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm_batched, scalar_t>( + state, beta, m, n, k, batch_size); + + // Matrix options (rocBLAS) + const rocblas_operation trans_a_rb = + trA ? rocblas_operation_none : rocblas_operation_transpose; + const rocblas_operation trans_b_rb = + trB ? rocblas_operation_none : rocblas_operation_transpose; + + // rocBLAS complex alpha & beta + rocComplex rocBeta{beta.real(), beta.imag()}; + rocComplex rocAlpha{alpha.real(), alpha.imag()}; + + // Data sizes + const int a_size = m * k; + const int b_size = k * n; + const int c_size = m * n; + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(a_size * batch_size); + std::vector> b = + blas_benchmark::utils::random_cplx_data(b_size * batch_size); + std::vector> c = + blas_benchmark::utils::const_cplx_data(c_size * batch_size, 0); + + { + // Device memory allocation & H2D copy + blas_benchmark::utils::HIPVectorBatched> a_batched_gpu( + a_size, batch_size, reinterpret_cast*>(a.data())); + blas_benchmark::utils::HIPVectorBatched> b_batched_gpu( + b_size, batch_size, reinterpret_cast*>(b.data())); + blas_benchmark::utils::HIPVectorBatched> c_batched_gpu( + c_size, batch_size); + +#ifdef BLAS_VERIFY_BENCHMARK + // Reference batched gemm + std::vector> c_ref = c; + for (int batch = 0; batch < batch_size; batch++) { + reference_blas::cgemm( + t_a_str, t_b_str, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data() + batch * a_size), lda, + reinterpret_cast(b.data() + batch * b_size), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data() + batch * c_size), ldc); + } + + // Rocblas verification gemm_batched + std::vector> c_temp = c; + { + blas_benchmark::utils::HIPVectorBatched, true> + c_temp_gpu(c_size, batch_size, + reinterpret_cast*>(c_temp.data())); + rocblas_cplx_gemm_batched_f( + rb_handle, trans_a_rb, trans_b_rb, m, n, k, &rocAlpha, a_batched_gpu, + lda, b_batched_gpu, ldb, &rocBeta, c_temp_gpu, ldc, batch_size); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors(c_temp, c_ref, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + + auto blas_warmup = [&]() -> void { + rocblas_cplx_gemm_batched_f( + rb_handle, trans_a_rb, trans_b_rb, m, n, k, &rocAlpha, a_batched_gpu, + lda, b_batched_gpu, ldb, &rocBeta, c_batched_gpu, ldc, batch_size); + return; + }; + + hipEvent_t start, stop; + CHECK_HIP_ERROR(hipEventCreate(&start)); + CHECK_HIP_ERROR(hipEventCreate(&stop)); + + auto blas_method_def = [&]() -> std::vector { + CHECK_HIP_ERROR(hipEventRecord(start, NULL)); + rocblas_cplx_gemm_batched_f( + rb_handle, trans_a_rb, trans_b_rb, m, n, k, &rocAlpha, a_batched_gpu, + lda, b_batched_gpu, ldb, &rocBeta, c_batched_gpu, ldc, batch_size); + CHECK_HIP_ERROR(hipEventRecord(stop, NULL)); + CHECK_HIP_ERROR(hipEventSynchronize(stop)); + return std::vector{start, stop}; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_warmup); + CHECK_HIP_ERROR(hipStreamSynchronize(NULL)); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef_hip(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + + blas_benchmark::utils::calc_avg_counters(state); + + CHECK_HIP_ERROR(hipEventDestroy(start)); + CHECK_HIP_ERROR(hipEventDestroy(stop)); + } // release device memory via utils::DeviceVector destructors +}; + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + rocblas_handle& rb_handle, bool* success) { + auto gemm_batched_params = + blas_benchmark::utils::get_gemm_cplx_batched_params(args); + + for (auto p : gemm_batched_params) { + std::string t_a, t_b; + index_t m, n, k, batch_size; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + int batch_type; + std::tie(t_a, t_b, m, k, n, alpha_r, alpha_i, beta_r, beta_i, batch_size, + batch_type) = p; + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + if (batch_type == 1) { + std::cerr << "interleaved memory for gemm_batched operator is not " + "supported by rocBLAS\n"; + continue; + } + + int t_a_i = static_cast(blas_benchmark::utils::to_transpose_enum(t_a)); + int t_b_i = static_cast(blas_benchmark::utils::to_transpose_enum(t_b)); + + auto BM_lambda = [&](benchmark::State& st, rocblas_handle rb_handle, + int t_a_i, int t_b_i, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, index_t batch_size, + int batch_type, bool* success) { + run(st, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, + batch_size, batch_type, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t_a, t_b, m, k, n, batch_size, batch_type, + blas_benchmark::utils::MEM_TYPE_USM) + .c_str(), + BM_lambda, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, batch_size, + batch_type, success) + ->UseRealTime(); + } +} +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, bool* success) { diff --git a/benchmark/rocblas/blas3/gemm_batched_strided.cpp b/benchmark/rocblas/blas3/gemm_batched_strided.cpp index 15dac9896..3ecbff82c 100644 --- a/benchmark/rocblas/blas3/gemm_batched_strided.cpp +++ b/benchmark/rocblas/blas3/gemm_batched_strided.cpp @@ -40,6 +40,20 @@ static inline void rocblas_gemm_strided_batched(args_t&&... args) { return; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline void rocblas_cplx_gemm_strided_batched(args_t&&... args) { + if constexpr (std::is_same_v) { + CHECK_ROCBLAS_STATUS( + rocblas_cgemm_strided_batched(std::forward(args)...)); + } else if constexpr (std::is_same_v) { + CHECK_ROCBLAS_STATUS( + rocblas_zgemm_strided_batched(std::forward(args)...)); + } + return; +} +#endif + template void run(benchmark::State& state, rocblas_handle& rb_handle, int t_a_i, int t_b_i, index_t m, index_t k, index_t n, scalar_t alpha, @@ -219,6 +233,209 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, } } +#ifdef BLAS_ENABLE_COMPLEX +template +using rocComplex = + typename std::conditional::type; + +template +void run(benchmark::State& state, rocblas_handle& rb_handle, int t_a_i, + int t_b_i, index_t m, index_t k, index_t n, + std::complex alpha, std::complex beta, + index_t batch_size, index_t stride_a_mul, index_t stride_b_mul, + index_t stride_c_mul, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>(state); + + // Standard test setup. + std::string t_a = blas_benchmark::utils::from_transpose_enum( + static_cast(t_a_i)); + std::string t_b = blas_benchmark::utils::from_transpose_enum( + static_cast(t_b_i)); + const char* t_a_str = t_a.c_str(); + const char* t_b_str = t_b.c_str(); + + const bool trA = (t_a_str[0] == 'n'); + const bool trB = (t_b_str[0] == 'n'); + + index_t lda = trA ? m : k; + index_t ldb = trB ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm_batched_strided, scalar_t>( + state, beta, m, n, k, batch_size, stride_a_mul, stride_b_mul, + stride_c_mul); + + // Matrix options (rocBLAS) + const rocblas_operation trans_a_rb = + trA ? rocblas_operation_none : rocblas_operation_transpose; + const rocblas_operation trans_b_rb = + trB ? rocblas_operation_none : rocblas_operation_transpose; + + // rocBLAS complex alpha & beta + rocComplex rocBeta{beta.real(), beta.imag()}; + rocComplex rocAlpha{alpha.real(), alpha.imag()}; + + // Data sizes + // Elementary matrices + const index_t a_size = m * k; + const index_t b_size = k * n; + const index_t c_size = m * n; + // Strides + const index_t stride_a = stride_a_mul * a_size; + const index_t stride_b = stride_b_mul * b_size; + const index_t stride_c = stride_c_mul * c_size; + // Batched matrices + const int size_a_batch = a_size + (batch_size - 1) * stride_a; + const int size_b_batch = b_size + (batch_size - 1) * stride_b; + const int size_c_batch = c_size + (batch_size - 1) * stride_c; + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(size_a_batch); + std::vector> b = + blas_benchmark::utils::random_cplx_data(size_b_batch); + std::vector> c = + blas_benchmark::utils::const_cplx_data(size_c_batch, 0); + + { + // Device memory allocation & H2D copy + blas_benchmark::utils::HIPVectorBatchedStrided> + a_batched_gpu(a_size, batch_size, stride_a, + reinterpret_cast*>(a.data())); + blas_benchmark::utils::HIPVectorBatchedStrided> + b_batched_gpu(b_size, batch_size, stride_b, + reinterpret_cast*>(b.data())); + blas_benchmark::utils::HIPVectorBatchedStrided> + c_batched_gpu(c_size, batch_size, stride_c, + reinterpret_cast*>(c.data())); + +#ifdef BLAS_VERIFY_BENCHMARK + // Reference gemm batched strided (strided loop of gemm) + std::vector> c_ref = c; + for (int batch = 0; batch < batch_size; batch++) { + reference_blas::cgemm( + t_a_str, t_b_str, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data() + batch * stride_a), lda, + reinterpret_cast(b.data() + batch * stride_b), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data() + batch * stride_c), ldc); + } + + // Rocblas verification gemm_batched_strided + std::vector> c_temp = c; + { + blas_benchmark::utils::HIPVectorBatchedStrided, true> + c_temp_gpu(c_size, batch_size, stride_c, + reinterpret_cast*>(c_temp.data())); + rocblas_cplx_gemm_strided_batched( + rb_handle, trans_a_rb, trans_b_rb, m, n, k, &rocAlpha, a_batched_gpu, + lda, stride_a, b_batched_gpu, ldb, stride_b, &rocBeta, c_temp_gpu, + ldc, stride_c, batch_size); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors_strided(c_temp, c_ref, stride_c, c_size, + err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + + auto blas_warmup = [&]() -> void { + rocblas_cplx_gemm_strided_batched( + rb_handle, trans_a_rb, trans_b_rb, m, n, k, &rocAlpha, a_batched_gpu, + lda, stride_a, b_batched_gpu, ldb, stride_b, &rocBeta, c_batched_gpu, + ldc, stride_c, batch_size); + return; + }; + + hipEvent_t start, stop; + CHECK_HIP_ERROR(hipEventCreate(&start)); + CHECK_HIP_ERROR(hipEventCreate(&stop)); + + auto blas_method_def = [&]() -> std::vector { + CHECK_HIP_ERROR(hipEventRecord(start, NULL)); + rocblas_cplx_gemm_strided_batched( + rb_handle, trans_a_rb, trans_b_rb, m, n, k, &rocAlpha, a_batched_gpu, + lda, stride_a, b_batched_gpu, ldb, stride_b, &rocBeta, c_batched_gpu, + ldc, stride_c, batch_size); + CHECK_HIP_ERROR(hipEventRecord(stop, NULL)); + CHECK_HIP_ERROR(hipEventSynchronize(stop)); + return std::vector{start, stop}; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_warmup); + CHECK_HIP_ERROR(hipStreamSynchronize(NULL)); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef_hip(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + + blas_benchmark::utils::calc_avg_counters(state); + + CHECK_HIP_ERROR(hipEventDestroy(start)); + CHECK_HIP_ERROR(hipEventDestroy(stop)); + } // release device memory via utils::DeviceVector destructors +}; + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + rocblas_handle& rb_handle, bool* success) { + auto gemm_batched_strided_params = + blas_benchmark::utils::get_gemm_batched_strided_cplx_params( + args); + + for (auto p : gemm_batched_strided_params) { + std::string t_a, t_b; + index_t m, n, k, batch_size, stride_a_mul, stride_b_mul, stride_c_mul; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + + std::tie(t_a, t_b, m, k, n, alpha_r, alpha_i, beta_r, beta_i, batch_size, + stride_a_mul, stride_b_mul, stride_c_mul) = p; + int t_a_i = static_cast(blas_benchmark::utils::to_transpose_enum(t_a)); + int t_b_i = static_cast(blas_benchmark::utils::to_transpose_enum(t_b)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, rocblas_handle rb_handle, + int t1i, int t2i, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, index_t batch_size, + index_t strd_a_mul, index_t strd_b_mul, + index_t strd_c_mul, bool* success) { + run(st, rb_handle, t1i, t2i, m, k, n, alpha, beta, batch_size, + strd_a_mul, strd_b_mul, strd_c_mul, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t_a, t_b, m, k, n, batch_size, stride_a_mul, stride_b_mul, + stride_c_mul, blas_benchmark::utils::MEM_TYPE_USM) + .c_str(), + BM_lambda, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, batch_size, + stride_a_mul, stride_b_mul, stride_c_mul, success) + ->UseRealTime(); + } +} + +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, bool* success) { diff --git a/common/include/common/blas3_state_counters.hpp b/common/include/common/blas3_state_counters.hpp index c7515eb07..68e332773 100644 --- a/common/include/common/blas3_state_counters.hpp +++ b/common/include/common/blas3_state_counters.hpp @@ -76,6 +76,66 @@ init_level_3_counters(benchmark::State& state, scalar_t beta = 0, index_t m = 0, return; } +#ifdef BLAS_ENABLE_COMPLEX +template +inline typename std::enable_if::type +init_level_3_cplx_counters( + benchmark::State& state, + std::complex beta = std::complex(0, 0), index_t m = 0, + index_t n = 0, index_t k = 0, index_t batch_size = 1, + index_t stride_a_mul = 1, index_t stride_b_mul = 1, + index_t stride_c_mul = 1) { + // Google-benchmark counters are double. + double beta_real_d = static_cast(beta.real()); + double beta_imag_d = static_cast(beta.imag()); + double m_d = static_cast(m); + double n_d = static_cast(n); + double k_d = static_cast(k); + double batch_size_d = static_cast(batch_size); + state.counters["beta_real"] = beta_real_d; + state.counters["beta_imag"] = beta_real_d; + state.counters["m"] = m_d; + state.counters["n"] = n_d; + state.counters["k"] = k_d; + state.counters["batch_size"] = batch_size_d; + if constexpr (op == Level3Op::gemm_batched_strided) { + double stride_a_mul_d = static_cast(stride_a_mul); + double stride_b_mul_d = static_cast(stride_b_mul); + double stride_c_mul_d = static_cast(stride_c_mul); + + state.counters["stride_a_mul"] = stride_a_mul_d; + state.counters["stride_b_mul"] = stride_b_mul_d; + state.counters["stride_c_mul"] = stride_c_mul_d; + } + + // Counters here should be reviewed as pure real/imaginary cases result in + // less flops + + bool beta_zero = (beta.real() == scalar_t{0}) && (beta.imag() == scalar_t{0}); + + const double nflops_AtimesB = + k_d * m_d * n_d * 6 + k_d * m_d * n_d * 2; // MulFlops + AddFlops + double nflops_timesAlpha = m_d * n_d * 6; + const double nflops_addBetaC = + beta_zero ? 0 : 6 * m_d * n_d + 2 * m_d * n_d; // MulFlops + AddFlops + const double nflops_tot = + (nflops_AtimesB + nflops_timesAlpha + nflops_addBetaC) * batch_size_d; + state.counters["n_fl_ops"] = nflops_tot; + + const double mem_readA = m_d * k_d; + const double mem_readB = k_d * n_d; + const double mem_writeC = m_d * n_d; + const double mem_readC = beta_zero ? 0 : m_d * n_d; + const double total_mem = (mem_readA + mem_readB + mem_readC + mem_writeC) * + batch_size_d * sizeof(scalar_t) * 2; + state.counters["bytes_processed"] = total_mem; + return; +} + +#endif + template inline typename std::enable_if::type init_level_3_counters(benchmark::State& state, scalar_t beta = 0, index_t m = 0, diff --git a/common/include/common/common_utils.hpp b/common/include/common/common_utils.hpp index a569ed2ff..251ee9b7f 100644 --- a/common/include/common/common_utils.hpp +++ b/common/include/common/common_utils.hpp @@ -53,6 +53,24 @@ using gemm_batched_strided_param_t = std::tuple; +#ifdef BLAS_ENABLE_COMPLEX +template +using blas3_cplx_param_t = + std::tuple; + +template +using gemm_batched_strided_cplx_param_t = + std::tuple; + +template +using gemm_batched_cplx_param_t = + std::tuple; +#endif + using reduction_param_t = std::tuple; template @@ -485,6 +503,157 @@ static inline std::vector> get_blas3_params( } } +#ifdef BLAS_ENABLE_COMPLEX +/** + * @fn get_blas3_cplx_params for complex data type + * @brief Returns a vector containing the blas 3 benchmark cplx parameters, + * either read from a file according to the command-line args, or the default + * ones. So far only used/supported for GEMM & its batched extensions. + */ +template +static inline std::vector> get_blas3_cplx_params( + Args& args) { + if (args.csv_param.empty()) { + warning_no_csv(); + std::vector> blas3_default; + constexpr index_t dmin = 32, dmax = 8192; + std::vector dtranspose = {"n", "t"}; + std::complex alpha{1, 1}; + std::complex beta{1, 1}; + for (std::string& t1 : dtranspose) { + for (std::string& t2 : dtranspose) { + for (index_t m = dmin; m <= dmax; m *= 8) { + for (index_t k = dmin; k <= dmax; k *= 8) { + for (index_t n = dmin; n <= dmax; n *= 8) { + blas3_default.push_back( + std::make_tuple(t1, t2, m, k, n, alpha.real(), alpha.imag(), + beta.real(), beta.imag())); + } + } + } + } + } + return blas3_default; + } else { + return parse_csv_file>( + args.csv_param, [&](std::vector& v) { + if (v.size() != 9) { + throw std::runtime_error( + "invalid number of parameters (9 expected)"); + } + try { + return std::make_tuple( + v[0].c_str(), v[1].c_str(), str_to_int(v[2]), + str_to_int(v[3]), str_to_int(v[4]), + str_to_scalar(v[5]), str_to_scalar(v[6]), + str_to_scalar(v[7]), str_to_scalar(v[8])); + } catch (...) { + throw std::runtime_error("invalid parameter"); + } + }); + } +} + +/** + * @fn get_gemm_batched_strided_cplx_params for complex data type + * @brief Returns a vector containing the gemm_batched_strided cplx benchmark + * parameters, either read from a file according to the command-line args, or + * the default ones. + */ +template +inline std::vector> +get_gemm_batched_strided_cplx_params(Args& args) { + if (args.csv_param.empty()) { + warning_no_csv(); + std::vector> + gemm_batched_strided_default; + constexpr index_t dmin = 128, dmax = 8192; + std::vector dtranspose = {"n", "t"}; + std::complex alpha{1, 1}; + std::complex beta{1, 1}; + index_t batch_size = 8; + for (std::string& t1 : dtranspose) { + for (std::string& t2 : dtranspose) { + for (index_t m = dmin; m <= dmax; m *= 8) { + gemm_batched_strided_default.push_back( + std::make_tuple(t1, t2, m, m, m, alpha.real(), alpha.imag(), + beta.real(), beta.imag(), batch_size, 2, 2, 2)); + } + } + } + return gemm_batched_strided_default; + } else { + return parse_csv_file>( + args.csv_param, [&](std::vector& v) { + if (v.size() != 13) { + throw std::runtime_error( + "invalid number of parameters (13 expected)"); + } + try { + return std::make_tuple( + v[0].c_str(), v[1].c_str(), str_to_int(v[2]), + str_to_int(v[3]), str_to_int(v[4]), + str_to_scalar(v[5]), str_to_scalar(v[6]), + str_to_scalar(v[7]), str_to_scalar(v[8]), + str_to_int(v[9]), str_to_int(v[10]), + str_to_int(v[11]), str_to_int(v[12])); + } catch (...) { + std::throw_with_nested(std::runtime_error("invalid parameter")); + } + }); + } +} + +/** + * @fn get_gemm_cplx_batched_params + * @brief Returns a vector containing the gemm_batched cplx benchmark + * parameters, either read from a file according to the command-line args, or + * the default ones. + */ +template +inline std::vector> +get_gemm_cplx_batched_params(Args& args) { + if (args.csv_param.empty()) { + warning_no_csv(); + std::vector> gemm_batched_default; + constexpr index_t dmin = 128, dmax = 8192; + std::vector dtranspose = {"n", "t"}; + std::complex alpha{1, 1}; + std::complex beta{1, 1}; + index_t batch_size = 8; + int batch_type = 0; + for (std::string& t1 : dtranspose) { + for (std::string& t2 : dtranspose) { + for (index_t n = dmin; n <= dmax; n *= 8) { + gemm_batched_default.push_back(std::make_tuple( + t1, t2, n, n, n, alpha.real(), alpha.imag(), beta.real(), + beta.imag(), batch_size, batch_type)); + } + } + } + return gemm_batched_default; + } else { + return parse_csv_file>( + args.csv_param, [&](std::vector& v) { + if (v.size() != 11) { + throw std::runtime_error( + "invalid number of parameters (11 expected)"); + } + try { + return std::make_tuple( + v[0].c_str(), v[1].c_str(), str_to_int(v[2]), + str_to_int(v[3]), str_to_int(v[4]), + str_to_scalar(v[5]), str_to_scalar(v[6]), + str_to_scalar(v[7]), str_to_scalar(v[8]), + str_to_int(v[9]), str_to_batch_type(v[10])); + } catch (...) { + std::throw_with_nested(std::runtime_error("invalid parameter")); + } + }); + } +} +#endif + /** * @fn get_gemm_batched_params * @brief Returns a vector containing the gemm_batched benchmark parameters, @@ -1334,6 +1503,17 @@ inline std::string get_type_name() { return "double"; } +#ifdef BLAS_ENABLE_COMPLEX +template <> +inline std::string get_type_name>() { + return "complex"; +} +template <> +inline std::string get_type_name>() { + return "complex"; +} +#endif + /** * @fn random_scalar * @brief Generates a random scalar value, using an arbitrary low quality @@ -1372,6 +1552,67 @@ static inline std::vector random_data(size_t size) { return v; } +#ifdef BLAS_ENABLE_COMPLEX +/** + * @fn random_cplx_scalar + * @brief Generates a random complex value, using an arbitrary low quality + * algorithm. + */ +template +static inline std::complex random_cplx_scalar() { + scalar_t rl = 1e-3 * ((rand() % 2000) - 1000); + scalar_t im = 1e-3 * ((rand() % 2000) - 1000); + return std::complex(rl, im); +} + +/** + * @brief Generates a random complex in the specified range of its underlying + * data elements (real & imag) + * @param rangeMin range minimum + * @param rangeMax range maximum + */ +template +static inline std::complex random_cplx_scalar(scalar_t rangeMin, + scalar_t rangeMax) { + static std::random_device rd; + static std::default_random_engine gen(rd()); + std::uniform_real_distribution disRl(rangeMin, rangeMax); + std::uniform_real_distribution disIm(rangeMin, rangeMax); + + return std::complex(disRl(gen), disIm(gen)); +} + +/** + * @fn random_cplx_data + * @brief Generates a random vector of complex values, using a uniform + * distribution of the underlying data elements (real & imag). + */ +template +static inline std::vector> random_cplx_data( + size_t size) { + std::vector> v(size); + + for (std::complex& e : v) { + e = random_cplx_scalar(scalar_t{-2}, scalar_t{5}); + } + return v; +} + +/** + * @fn const_cplx_data + * @brief Generates a vector of constant complex values, of a given length. + */ +template +static inline std::vector> const_cplx_data( + size_t size, scalar_t const_value = 0) { + std::vector> v(size); + std::complex const_cplx_value{const_value, const_value}; + std::fill(v.begin(), v.end(), const_cplx_value); + return v; +} + +#endif // BLAS_ENABLE_COMPLEX + /** * @breif Fills a lower or upper triangular matrix suitable for TRSM testing * @param A The matrix to fill. Size must be at least m * lda @@ -1575,17 +1816,39 @@ static inline void calc_avg_counters(benchmark::State& state) { #define BLAS_REGISTER_BENCHMARK_HALF(args, sb_handle_ptr, success) #endif // BLAS_DATA_TYPE_HALF +#ifdef BLAS_ENABLE_COMPLEX +/** Registers benchmark for the float complex data type + * @see BLAS_REGISTER_BENCHMARK + */ +#define BLAS_REGISTER_BENCHMARK_CPLX_FLOAT(args, sb_handle_ptr, success) \ + register_cplx_benchmark(args, sb_handle_ptr, success) +#else +#define BLAS_REGISTER_BENCHMARK_CPLX_FLOAT(args, sb_handle_ptr, success) +#endif + +#if defined(BLAS_ENABLE_COMPLEX) & defined(BLAS_DATA_TYPE_DOUBLE) +/** Registers benchmark for the double complex data type + * @see BLAS_REGISTER_BENCHMARK + */ +#define BLAS_REGISTER_BENCHMARK_CPLX_DOUBLE(args, sb_handle_ptr, success) \ + register_cplx_benchmark(args, sb_handle_ptr, success) +#else +#define BLAS_REGISTER_BENCHMARK_CPLX_DOUBLE(args, sb_handle_ptr, success) +#endif + /** Registers benchmark for all supported data types. * Expects register_benchmark to exist. * @param args Reference to blas_benchmark::Args * @param sb_handle_ptr Pointer to blas::SB_Handle * @param[out] success Pointer to boolean indicating success */ -#define BLAS_REGISTER_BENCHMARK(args, sb_handle_ptr, success) \ - do { \ - BLAS_REGISTER_BENCHMARK_FLOAT(args, sb_handle_ptr, success); \ - BLAS_REGISTER_BENCHMARK_DOUBLE(args, sb_handle_ptr, success); \ - BLAS_REGISTER_BENCHMARK_HALF(args, sb_handle_ptr, success); \ +#define BLAS_REGISTER_BENCHMARK(args, sb_handle_ptr, success) \ + do { \ + BLAS_REGISTER_BENCHMARK_FLOAT(args, sb_handle_ptr, success); \ + BLAS_REGISTER_BENCHMARK_DOUBLE(args, sb_handle_ptr, success); \ + BLAS_REGISTER_BENCHMARK_HALF(args, sb_handle_ptr, success); \ + BLAS_REGISTER_BENCHMARK_CPLX_FLOAT(args, sb_handle_ptr, success); \ + BLAS_REGISTER_BENCHMARK_CPLX_DOUBLE(args, sb_handle_ptr, success); \ } while (false) #endif diff --git a/common/include/common/set_benchmark_label.hpp b/common/include/common/set_benchmark_label.hpp index b1d4c3ca7..9495a3195 100644 --- a/common/include/common/set_benchmark_label.hpp +++ b/common/include/common/set_benchmark_label.hpp @@ -28,6 +28,10 @@ #include #include +#ifdef BLAS_ENABLE_COMPLEX +#define SYCL_EXT_ONEAPI_COMPLEX +#include +#endif #ifdef BUILD_CUBLAS_BENCHMARKS #include @@ -178,6 +182,20 @@ inline void add_datatype_info( } #endif // BLAS_DATA_TYPE_HALF +#ifdef BLAS_ENABLE_COMPLEX +template <> +inline void add_datatype_info>( + std::map& key_value_map) { + key_value_map["@datatype"] = "complex"; +} + +template <> +inline void add_datatype_info>( + std::map& key_value_map) { + key_value_map["@datatype"] = "complex"; +} +#endif // BLAS_ENABLE_COMPLEX + } // namespace datatype_info inline void set_label(benchmark::State& state, From ad82e1742c8b77d1f78acfab8f90314a287f3a4d Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Tue, 24 Oct 2023 15:17:27 +0100 Subject: [PATCH 18/18] Removed redundant gemm batch type check in benchmark --- benchmark/portblas/blas3/gemm_batched.cpp | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/benchmark/portblas/blas3/gemm_batched.cpp b/benchmark/portblas/blas3/gemm_batched.cpp index 3d98c3697..aabd9449a 100644 --- a/benchmark/portblas/blas3/gemm_batched.cpp +++ b/benchmark/portblas/blas3/gemm_batched.cpp @@ -302,11 +302,6 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, int t1, reinterpret_cast(c_ref.data() + _base(m, n, batch_idx)), ldc); } - if (batch_type == blas::gemm_batch_type_t::interleaved) { - state.SkipWithError("Interleaved unsupported with Complex data types."); - *success = false; - } - #endif // BLAS_VERIFY_BENCHMARK auto a_gpu = blas::helper::allocate>( @@ -417,7 +412,7 @@ void register_cplx_benchmark( batch_type) = p; // Only batch_type == strided is supported with Complex data if (batch_type == 1) { - std::cerr << "interleaved memory for gemm_batched operator is not " + std::cerr << "Interleaved memory for gemm_batched operator is not " "supported whith complex data type\n"; continue; }