Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Enabed Complex data type for Gemm (#462)
Browse files Browse the repository at this point in the history
Added preliminary support for sycl::complex<float/double> data types
for GEMM operator along with the relevant unit tests.
  • Loading branch information
OuadiElfarouki authored Oct 24, 2023
1 parent 445764f commit 19b0fed
Show file tree
Hide file tree
Showing 28 changed files with 1,514 additions and 117 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ if(IMGDNN_DIR)
endif()

option(BLAS_ENABLE_EXTENSIONS "Whether to enable portBLAS extensions" ON)
option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for supported operators" ON)

# CmakeFunctionHelper has to be included after any options that it depends on are declared.
# These include:
Expand All @@ -115,6 +116,7 @@ option(BLAS_ENABLE_EXTENSIONS "Whether to enable portBLAS extensions" ON)
# * BLAS_DATA_TYPES
# * BLAS_INDEX_TYPES
# * NAIVE_GEMM
# * BLAS_ENABLE_COMPLEX
include(CmakeFunctionHelper)

if (INSTALL_HEADER_ONLY)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ Some of the supported options are:
| `BLAS_ENABLE_EXTENSIONS` | `ON`/`OFF` | Determines whether to enable portBLAS extensions (`ON` by default) |
| `BLAS_DATA_TYPES` | `half;float;double` | Determines the floating-point types to instantiate BLAS operations for. Default is `float` |
| `BLAS_INDEX_TYPES` | `int32_t;int64_t` | Determines the type(s) to use for `index_t` and `increment_t`. Default is `int` |

| `BLAS_ENABLE_COMPLEX` | `ON`/`OFF` | Determines whether to enable Complex data type support *(GEMM Kernels only)* (`ON` by default) |

### Cross-Compile (ComputeCpp Only)

Expand Down
131 changes: 128 additions & 3 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,30 @@ function(cpp_type output data)
if (${data} STREQUAL "half")
set(${output} "cl::sycl::half" PARENT_SCOPE)
return()
elseif(${data} STREQUAL "complex<float>")
set(${output} "cl::sycl::ext::oneapi::experimental::complex<float>" PARENT_SCOPE)
return()
elseif(${data} STREQUAL "complex<double>")
set(${output} "cl::sycl::ext::oneapi::experimental::complex<double>" PARENT_SCOPE)
return()
endif()
set(${output} "${data}" PARENT_SCOPE)
endfunction()

function(set_complex_list output input append)
set(output_temp "")
if(${append} STREQUAL "true")
foreach(data ${input})
list(APPEND output_temp "${data};complex<${data}>")
endforeach(data)
else()
foreach(data ${input})
list(APPEND output_temp "complex<${data}>")
endforeach(data)
endif()
set(${output} ${output_temp} PARENT_SCOPE)
endfunction(set_complex_list)

## represent the list of bolean options
set(boolean_list "true" "false")

Expand All @@ -56,6 +76,9 @@ function(sanitize_file_name output file_name)
set(${output} "${file_name}" PARENT_SCOPE)
endfunction()

#List of operators supporting Complex Data types
set(COMPLEX_OPS "gemm" "gemm_launcher" "scal")

function(set_target_compile_def in_target)
#setting compiler flag for backend
if(${TUNING_TARGET} STREQUAL "INTEL_GPU")
Expand Down Expand Up @@ -84,16 +107,31 @@ function(set_target_compile_def in_target)
message(STATUS "Gemm vectorization support enabled for target ${in_target}")
target_compile_definitions(${in_target} PUBLIC GEMM_VECTORIZATION_SUPPORT=1)
endif()

#setting const data type support
if(BLAS_ENABLE_CONST_INPUT)
target_compile_definitions(${in_target} PUBLIC BLAS_ENABLE_CONST_INPUT=1)
endif()
#setting complex support
if(${BLAS_ENABLE_COMPLEX})
if("${in_target}" IN_LIST COMPLEX_OPS)
message(STATUS "Complex Data type support enabled for target ${in_target}")
target_compile_definitions(${in_target} PUBLIC BLAS_ENABLE_COMPLEX=1)
endif()
endif()
endfunction()

# blas unary function for generating source code
function(generate_blas_objects blas_level func)
set(LOCATION "${PORTBLAS_GENERATED_SRC}/${blas_level}/${func}/")
foreach(data ${data_list})
set(data_list_c ${data_list})
# Extend data_list to complex<data> for each data in list
# if target function is in COMPLEX_OPS
if(BLAS_ENABLE_COMPLEX)
if("${func}" IN_LIST COMPLEX_OPS)
set_complex_list(data_list_c "${data_list}" "true")
endif()
endif()
foreach(data ${data_list_c})
cpp_type(cpp_data ${data})
foreach(index ${index_list})
foreach(increment ${index_list})
Expand Down Expand Up @@ -234,7 +272,11 @@ function(add_gemm_configuration
batch_type
use_joint_matrix
)
if(NOT ("${data}" IN_LIST data_list))
set(data_list_c ${data_list})
if(BLAS_ENABLE_COMPLEX)
set_complex_list(data_list_c "${data_list}" "true")
endif()
if(NOT ("${data}" IN_LIST data_list_c))
# Data type not enabled, skip configuration
return()
endif()
Expand All @@ -249,6 +291,9 @@ function(add_gemm_configuration
cpp_type(cpp_data ${data})
foreach(symm_a ${boolean_list})
foreach(symm_b ${boolean_list})
if ((${data} MATCHES "complex") AND (symm_a OR symm_b))
continue()
endif()
foreach(trans_a ${boolean_list})
foreach(trans_b ${boolean_list})
foreach(is_beta_zero ${boolean_list})
Expand Down Expand Up @@ -380,6 +425,32 @@ if(${TUNING_TARGET} STREQUAL "INTEL_GPU")
"${data}" 64 "false" "false" "false"
64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false")
endforeach()
if(BLAS_ENABLE_COMPLEX)
# Extract list of complex<data> for each data in supported_types
# list for complex<data> specific gemm configurations
set(data_list_c)
set_complex_list(data_list_c "${supported_types}" "false")
foreach(data ${data_list_c})
add_gemm_configuration(
"${data}" 64 "true" "false" "false"
64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 4 8 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false")
if (${data} STREQUAL "complex<double>")
add_gemm_configuration(
"${data}" 64 "true" "true" "true"
64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false")
else()
add_gemm_configuration(
"${data}" 64 "true" "true" "true"
64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false")
endif()
endforeach()
endif() # BLAS_ENABLE_COMPLEX
elseif(${TUNING_TARGET} STREQUAL "POWER_VR" AND NOT IMGDNN_DIR)
set(supported_types
"float"
Expand Down Expand Up @@ -445,6 +516,35 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation
"${data}" 64 "false" "false" "false"
64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false")
endforeach()
if(BLAS_ENABLE_COMPLEX)
# Extract list of complex<data> for each data in supported_types
# list for complex<data> specific gemm configurations
set(data_list_c)
set_complex_list(data_list_c "${supported_types}" "false")
foreach(data ${data_list_c})
if (${data} STREQUAL "complex<double>")
add_gemm_configuration(
"${data}" 256 "true" "true" "true"
64 1 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false")
add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 1 1 4 4 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
else()
add_gemm_configuration(
"${data}" 256 "true" "true" "true"
64 1 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false")
add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 1 1 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
endif()
endforeach()
endif() # BLAS_ENABLE_COMPLEX
elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU")
set(supported_types
"float"
Expand Down Expand Up @@ -486,7 +586,18 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU")
add_gemm_configuration(
"${data}" 256 "false" "true" "true"
128 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
endforeach()
if(BLAS_ENABLE_COMPLEX)
# Extract list of complex<data> for each data in supported_types
# list for complex<data> specific gemm configurations
set(data_list_c)
set_complex_list(data_list_c "${supported_types}" "false")
foreach(data ${data_list_c})
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
64 2 2 16 16 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
endforeach()
endif() # BLAS_ENABLE_COMPLEX
else() # default cpu backend
set(supported_types
"float"
Expand All @@ -513,6 +624,20 @@ else() # default cpu backend
"${data}" 64 "false" "false" "false"
64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false" "false")
endforeach()
if(BLAS_ENABLE_COMPLEX)
# Extract list of complex<data> for each data in supported_types
# list for complex<data> specific gemm configurations
set(data_list_c)
set_complex_list(data_list_c "${supported_types}" "false")
foreach(data ${data_list_c})
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 2 2 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false")
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 8 8 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false")
endforeach()
endif() # BLAS_ENABLE_COMPLEX
endif()
add_library(${func} OBJECT ${gemm_sources})
set_target_compile_def(${func})
Expand Down
98 changes: 97 additions & 1 deletion common/include/common/float_comparison.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@

#include <cmath>
#include <iostream>
#ifdef BLAS_ENABLE_COMPLEX
#include <complex>
#endif

#ifdef BLAS_DATA_TYPE_HALF
#if SYCL_LANGUAGE_VERSION < 202000
Expand Down Expand Up @@ -65,6 +68,23 @@ scalar_t abs(scalar_t value) noexcept {
return std::abs(value);
}

#ifdef BLAS_ENABLE_COMPLEX
template <typename scalar_t>
bool isnan(std::complex<scalar_t> value) noexcept {
return (isnan<scalar_t>(value.real()) || isnan<scalar_t>(value.imag()));
}

template <typename scalar_t>
bool isinf(std::complex<scalar_t> value) noexcept {
return (isinf<scalar_t>(value.real()) || isinf<scalar_t>(value.imag()));
}

template <typename scalar_t>
scalar_t abs(std::complex<scalar_t> value) noexcept {
return std::abs(value);
}
#endif

#ifdef BLAS_DATA_TYPE_HALF
template <>
inline bool isnan<cl::sycl::half>(cl::sycl::half value) noexcept {
Expand Down Expand Up @@ -172,7 +192,7 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) {
return true;
}

const scalar_t absolute_diff = utils::abs(scalar1 - scalar2);
const auto absolute_diff = utils::abs(scalar1 - scalar2);

// Close to zero, the relative error doesn't work, use absolute error
if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} ||
Expand Down Expand Up @@ -212,6 +232,37 @@ inline bool compare_vectors(std::vector<scalar_t> const& vec,
return true;
}

#ifdef BLAS_ENABLE_COMPLEX
/**
* Compare two vectors of complex data and returns false if the difference is
* not acceptable. The second vector is considered the reference.
* @tparam scalar_t the type of complex underying data present in the input
* vectors
* @tparam epilon_t the type used as tolerance.
*/
template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool compare_vectors(std::vector<std::complex<scalar_t>> const& vec,
std::vector<std::complex<scalar_t>> const& ref,
std::ostream& err_stream = std::cerr,
std::string end_line = "\n") {
if (vec.size() != ref.size()) {
err_stream << "Error: tried to compare vectors of different sizes"
<< std::endl;
return false;
}

for (int i = 0; i < vec.size(); ++i) {
if (!almost_equal<std::complex<scalar_t>, epsilon_t>(vec[i], ref[i])) {
err_stream << "Value mismatch at index " << i << ": (" << vec[i].real()
<< "," << vec[i].imag() << "); expected (" << ref[i].real()
<< "," << ref[i].imag() << ")" << end_line;
return false;
}
}
return true;
}
#endif

/**
* Compare two vectors at a given stride and window (unit_vec_size) and returns
* false if the difference is not acceptable. The second vector is considered
Expand Down Expand Up @@ -253,6 +304,51 @@ inline bool compare_vectors_strided(std::vector<scalar_t> const& vec,
return true;
}

#ifdef BLAS_ENABLE_COMPLEX
/**
* Compare two vectors of complex data at a given stride and window and returns
* false if the difference is not acceptable. The second vector is considered
* the reference.
* @tparam scalar_t the type of the complex underying data present in the input
* vectors
* @tparam epsilon_t the type used as tolerance.
* @param stride is the stride between two consecutive 'windows'
* @param window is the size of a comparison window
*/
template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool compare_vectors_strided(
std::vector<std::complex<scalar_t>> const& vec,
std::vector<std::complex<scalar_t>> const& ref, int stride, int window,
std::ostream& err_stream = std::cerr, std::string end_line = "\n") {
if (vec.size() != ref.size()) {
err_stream << "Error: tried to compare vectors of different sizes"
<< std::endl;
return false;
}

int k = 0;

// Loop over windows
while (window + (k + 1) * stride < vec.size()) {
// Loop within a window
for (int i = 0; i < window; ++i) {
auto index = i + k * stride;
if (!almost_equal<std::complex<scalar_t>, epsilon_t>(vec[index],
ref[index])) {
err_stream << "Value mismatch at index " << index << ": ("
<< vec[index].real() << "," << vec[index].imag()
<< "); expected (" << ref[index].real() << ","
<< ref[index].imag() << ")" << end_line;
return false;
}
}
k += 1;
}

return true;
}
#endif

} // namespace utils

#endif // UTILS_FLOAT_COMPARISON_H_
15 changes: 15 additions & 0 deletions common/include/common/system_reference_blas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ auto blas_system_function(floatfn_t ffn, doublefn_t dfn)
return BlasSystemFunction<scalar_t>::get(ffn, dfn);
}

template <typename scalar_t, typename floatfn_t, typename doublefn_t>
auto blas_cplx_system_function(floatfn_t ffn, doublefn_t dfn)
-> decltype(BlasSystemFunction<scalar_t>::get(ffn, dfn)) {
return BlasSystemFunction<scalar_t>::get(ffn, dfn);
}

// =======
// Level 1
// =======
Expand Down Expand Up @@ -378,6 +384,15 @@ void gemm(const char *transA, const char *transB, int m, int n, int k,
lda, b, ldb, beta, c, ldc);
}

template <typename scalar_t>
void cgemm(const char *transA, const char *transB, int m, int n, int k,
const void *alpha, const void *a, int lda, const void *b, int ldb,
const void *beta, void *c, int ldc) {
auto func = blas_cplx_system_function<scalar_t>(&cblas_cgemm, &cblas_zgemm);
func(CblasColMajor, c_trans(*transA), c_trans(*transB), m, n, k, alpha, a,
lda, b, ldb, beta, c, ldc);
}

template <typename scalar_t>
void trsm(const char *side, const char *uplo, const char *trans,
const char *diag, int m, int n, scalar_t alpha, const scalar_t A[],
Expand Down
Loading

0 comments on commit 19b0fed

Please sign in to comment.