diff --git a/CMakeLists.txt b/CMakeLists.txt index a93e6463e..50496efa1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,6 +48,10 @@ option(BUILD_EXAMPLES "Build examples" TRUE) option(BUILD_FT "Build functional tests" TRUE) option(BUILD_UT "Build unit tests" FALSE) option(BUILD_CONFIG "Build cmake configs" TRUE) +option(ENABLE_MPI "Enable MPI for library" TRUE) +option(ENABLE_MPI_TESTS "Enable MPI for tests" TRUE) +option(ENABLE_SYCL_INTEROP_EVENT "Enable support for interop event functionality" TRUE) +option(ENABLE_OFI_HMEM "Enable support for OFI HMEM" FALSE) option(USE_CODECOV_FLAGS "Calculate code coverage" FALSE) option(WITH_ASAN "Use address sanitizer, can only be used in Debug build" FALSE) @@ -55,11 +59,11 @@ option(WITH_ASAN "Use address sanitizer, can only be used in Debug build" FALSE) #installation path variables include(GNUInstallDirs) -if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) +if (CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) set(CMAKE_INSTALL_PREFIX "${CMAKE_BINARY_DIR}/_install" CACHE PATH "Default install path" FORCE) endif() -#show build info +# show build info message(STATUS "Installation directory: ${CMAKE_INSTALL_PREFIX}") message(STATUS "Build type: ${CMAKE_BUILD_TYPE_CASE_INSENSITIVE}") message(STATUS "C compiler : ${CMAKE_C_COMPILER}") @@ -68,6 +72,10 @@ message(STATUS "Build examples: ${BUILD_EXAMPLES}") message(STATUS "Build functional tests: ${BUILD_FT}") message(STATUS "Build unit tests: ${BUILD_UT}") message(STATUS "Build cmake configs: ${BUILD_CONFIG}") +message(STATUS "Enable MPI for library: ${ENABLE_MPI}") +message(STATUS "Enable MPI for tests: ${ENABLE_MPI_TESTS}") +message(STATUS "Enable support for interop event functionality: ${ENABLE_SYCL_INTEROP_EVENT}") +message(STATUS "Enable support for OFI HMEM: ${ENABLE_OFI_HMEM}") add_definitions(-DCCL_C_COMPILER="${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") add_definitions(-DCCL_CXX_COMPILER="${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}") @@ -91,36 +99,29 @@ set(CCL_INSTALL_KERNELS "${CMAKE_INSTALL_PREFIX}/lib/kernels") set(CCL_UNIT_TESTS_BUILD "${CMAKE_BINARY_DIR}/tests/unit") - # setup dependency directories - set(DEPS_DIR "${PROJECT_SOURCE_DIR}/deps") set(MPI_INCLUDE_DIR "${DEPS_DIR}/mpi/include/") set(MPI_LIB_DIR "${DEPS_DIR}/mpi/lib/") -if ( "${LIBFABRIC_DIR}" STREQUAL "") +message(STATUS "MPI_INCLUDE_DIR: ${MPI_INCLUDE_DIR}") +message(STATUS "MPI_LIB_DIR: ${MPI_LIB_DIR}") + +if ("${LIBFABRIC_DIR}" STREQUAL "") set(LIBFABRIC_INCLUDE_DIR "${DEPS_DIR}/ofi/include") set(LIBFABRIC_LIB_DIR "${DEPS_DIR}/ofi/lib/") else() set(LIBFABRIC_INCLUDE_DIR "${LIBFABRIC_DIR}/include/") set(LIBFABRIC_LIB_DIR "${LIBFABRIC_DIR}/lib") endif() -set(HWLOC_INCLUDE_DIR "${DEPS_DIR}/hwloc/include/") -set(HWLOC_LIB_DIR "${DEPS_DIR}/hwloc/lib/") - -message(STATUS "MPI_INCLUDE_DIR: ${MPI_INCLUDE_DIR}") -message(STATUS "MPI_LIB_DIR: ${MPI_LIB_DIR}") message(STATUS "LIBFABRIC_LIB_DIR: ${LIBFABRIC_LIB_DIR}") message(STATUS "LIBFABRIC_INCLUDE_DIR: ${LIBFABRIC_INCLUDE_DIR}") + +set(HWLOC_INCLUDE_DIR "${DEPS_DIR}/hwloc/include/") +set(HWLOC_LIB_DIR "${DEPS_DIR}/hwloc/lib/") message(STATUS "HWLOC_INCLUDE_DIR: ${HWLOC_INCLUDE_DIR}") message(STATUS "HWLOC_LIB_DIR: ${HWLOC_LIB_DIR}") -include_directories(${MPI_INCLUDE_DIR}) -include_directories(${LIBFABRIC_INCLUDE_DIR}) - -link_directories(${MPI_LIB_DIR}) -link_directories(${LIBFABRIC_LIB_DIR}) - set(CMAKE_SKIP_INSTALL_RPATH TRUE) set(CMAKE_SKIP_RPATH TRUE) @@ -132,36 +133,9 @@ if (${CMAKE_VERSION} VERSION_LESS 3.1) set(C_COMPILER_FLAGS "-std=gnu99") endif() -# special flags for CCL library only -set(SRC_C_FLAGS "") -set(SRC_CXX_FLAGS "") -set(SRC_SHARED_LINKER_FLAGS "") - -#common settings of security options -if(USE_SECURITY_FLAGS) - set(SRC_C_FLAGS "${SRC_C_FLAGS} -Wformat -Wformat-security -D_FORTIFY_SOURCE=2 -fstack-protector") - set(SRC_CXX_FLAGS "${SRC_CXX_FLAGS} -Wformat -Wformat-security -D_FORTIFY_SOURCE=2 -fstack-protector") - set(SRC_SHARED_LINKER_FLAGS "${SRC_SHARED_LINKER_FLAGS} -fPIE -fPIC -z noexecstack -z relro -z now") - if(${CMAKE_C_COMPILER_ID} STREQUAL "GNU" AND ${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") - if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 4.9) - set(SRC_C_FLAGS "${SRC_C_FLAGS} -fstack-protector-strong") - set(SRC_CXX_FLAGS "${SRC_CXX_FLAGS} -fstack-protector-strong") - endif() - endif() -endif() - -set(SRC_SHARED_LINKER_FLAGS "${SRC_SHARED_LINKER_FLAGS} -Wl,--version-script=${PROJECT_SOURCE_DIR}/ccl.map") - -if(${CMAKE_C_COMPILER_ID} STREQUAL "Intel" OR ${CMAKE_CXX_COMPILER_ID} STREQUAL "Intel") - if (USE_CODECOV_FLAGS) - set(SRC_C_FLAGS "${SRC_C_FLAGS} -prof-gen=srcpos -prof-src-root-cwd") - set(SRC_CXX_FLAGS "${SRC_CXX_FLAGS} -prof-gen=srcpos -prof-src-root-cwd") - endif() -endif() +# TODO: add -Wextra to c/cxx flags -#TODO: add -Wextra to c/cxx flags - -#common release/debug compilation settings +# common release/debug compilation settings set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${C_COMPILER_FLAGS} -Wall -Werror -D_GNU_SOURCE -fvisibility=internal") set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} ${C_COMPILER_FLAGS} -O0 -g -DENABLE_DEBUG") set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} ${C_COMPILER_FLAGS} -O3") @@ -182,25 +156,35 @@ set(COMMON_CMAKE_DIR ${PROJECT_SOURCE_DIR}/cmake) if (COMPUTE_BACKEND) message(STATUS "COMPUTE_BACKEND: ${COMPUTE_BACKEND}") set_compute_backend(${COMMON_CMAKE_DIR}) + if (${COMPUTE_BACKEND} STREQUAL "dpcpp_level_zero" AND ENABLE_OFI_HMEM) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCCL_ENABLE_OFI_HMEM=1") + message(STATUS "Enable OFI HMEM support for compute backend ${COMPUTE_BACKEND}") + endif() endif() -if(${CMAKE_C_COMPILER_ID} STREQUAL "GNU" AND ${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") +if (${CMAKE_C_COMPILER_ID} STREQUAL "GNU" AND ${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) #c++17 introduces algined new operator, use it set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new") endif() endif() +# This is a temporal workaround until we fully switch to a new version of the compiler +# that supports the functionality +if (${ENABLE_SYCL_INTEROP_EVENT}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCCL_ENABLE_SYCL_INTEROP_EVENT=1") +endif() + # Clang doesn't automatically detects ninja processes as supporting colored output # due to the way they are spawned. In order to fix the issue we need to use the option # to force colored output -if(${CMAKE_GENERATOR} STREQUAL "Ninja") +if (${CMAKE_GENERATOR} STREQUAL "Ninja") if (${CMAKE_C_COMPILER_ID} STREQUAL "Clang" AND ${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") add_compile_options(-fcolor-diagnostics) endif() endif() -if(WITH_ASAN AND ${CMAKE_BUILD_TYPE_CASE_INSENSITIVE} STREQUAL "debug") +if (WITH_ASAN AND ${CMAKE_BUILD_TYPE_CASE_INSENSITIVE} STREQUAL "debug") message(STATUS "Compiling with address sanitizer") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address -fno-omit-frame-pointer") set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -fsanitize=address -fno-omit-frame-pointer") @@ -214,14 +198,16 @@ set(CCL_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/src) enable_testing() set(EXTERNAL_LIBS "") + set(EXAMPLES_INC_DIRS ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/examples/include ${MPI_INCLUDE_DIR}) +set(EXAMPLES_LIB_DIRS ${MPI_LIB_DIR} ${LIBFABRIC_LIB_DIR}) # allow `deprecated` set(CMAKE_CLANG_FLAGS "${CMAKE_CLANG_FLAGS}") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") -#generate & install vars.sh +# generate & install vars.sh configure_file(cmake/vars.sh.in ${CMAKE_CURRENT_BINARY_DIR}/vars.sh @ONLY) configure_file(cmake/setvars.sh.in ${CMAKE_CURRENT_BINARY_DIR}/setvars.sh @ONLY) configure_file(cmake/ccl ${CMAKE_CURRENT_BINARY_DIR}/ccl @ONLY) @@ -233,15 +219,15 @@ install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/third-party-programs.txt DESTINATIO install(PROGRAMS ${PROJECT_SOURCE_DIR}/LICENSE DESTINATION ${CCL_INSTALL_LICENSE}) # copy kernels -if(COMPUTE_BACKEND AND EXISTS "${PROJECT_SOURCE_DIR}/src/kernels") -file(GLOB spv_kernels "${PROJECT_SOURCE_DIR}/src/kernels/ring_*.spv") +if (COMPUTE_BACKEND AND EXISTS "${PROJECT_SOURCE_DIR}/src/kernels") +file(GLOB spv_kernels "${PROJECT_SOURCE_DIR}/src/kernels/kernels.spv") install(PROGRAMS ${spv_kernels} DESTINATION ${CCL_INSTALL_KERNELS} PERMISSIONS OWNER_WRITE OWNER_READ GROUP_READ WORLD_READ) endif() set(CCL_MAJOR_VERSION "2021") -set(CCL_MINOR_VERSION "3") +set(CCL_MINOR_VERSION "4") set(CCL_UPDATE_VERSION "0") set(CCL_PRODUCT_STATUS "Gold") string(TIMESTAMP CCL_PRODUCT_BUILD_DATE "%Y-%m-%dT %H:%M:%SZ") @@ -262,24 +248,24 @@ if (BUILD_CONFIG) @ONLY) endif() -#include other CMakeLists +# include other CMakeLists add_subdirectory(src) -if (BUILD_EXAMPLES) - add_subdirectory(examples/benchmark) - add_subdirectory(examples/common) - add_subdirectory(examples/cpu) - add_subdirectory(examples/external_launcher) - if (CCL_ENABLE_SYCL) - add_subdirectory(examples/sycl) +if (ENABLE_MPI_TESTS) + if (BUILD_EXAMPLES) + add_subdirectory(examples/benchmark) + add_subdirectory(examples/common) + add_subdirectory(examples/cpu) + add_subdirectory(examples/external_launcher) + if (CCL_ENABLE_SYCL) + add_subdirectory(examples/sycl) + endif() + endif() + if (BUILD_FT) + add_subdirectory(tests/functional) + endif() + if (BUILD_UT AND EXISTS "${PROJECT_SOURCE_DIR}/tests/unit") + add_subdirectory(tests/unit) endif() -endif() - -if (BUILD_FT) - add_subdirectory(tests/functional) -endif() - -if (BUILD_UT AND EXISTS "${PROJECT_SOURCE_DIR}/tests/unit") - add_subdirectory(tests/unit) endif() diff --git a/INSTALL.md b/INSTALL.md index 0dcf910c9..c0394a488 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -55,7 +55,7 @@ If your CXX compiler requires SYCL, it is possible to specify it (DPC++ is suppo Modify `cmake` command as follows: ``` -cmake .. -DCMAKE_C_COMPILER=your_c_compiler -DCMAKE_CXX_COMPILER=dpcpp -DCOMPUTE_BACKEND=dpcpp +cmake .. -DCMAKE_C_COMPILER=your_c_compiler -DCMAKE_CXX_COMPILER=dpcpp -DCOMPUTE_BACKEND=dpcpp_level_zero ``` ## Specify the build type diff --git a/README.md b/README.md index 17b5e4f92..22bde5ca2 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ cmake .. make -j install ``` -If you need a clean build, create a new build directory and invoke `cmake` within it. Refer to FAQ to learn [when you might need a clean build](#when-do-i-need-a-clean-build-when-should-i-remove-my-favorite-build-directory). +If you need a clean build, create a new build directory and invoke `cmake` within it. You can also do the following during installation: - [Specify installation directory](INSTALL.md#specify-installation-directory) @@ -57,7 +57,6 @@ You can also do the following during installation: - [Specify `SYCL` cross-platform abstraction level](INSTALL.md#specify-sycl-cross-platform-abstraction-level) - [Specify the build type](INSTALL.md#specify-the-build-type) - [Enable `make` verbose output](INSTALL.md#enable-make-verbose-output) -- [Build with address sanitizer](INSTALL.md#build-with-address-sanitizer) ## Usage @@ -139,16 +138,6 @@ cmake [-DOUTPUT_DIR=] -P cmake/script/config_generation.cmake - oneAPI, oneCCL and OFI: Path to Heterogeneous Architecure Programming with Scalable Collective Communications: [recording](https://www.youtube.com/watch?v=ksiZ90EtP98&feature=youtu.be) and [slides](https://www.openfabrics.org/wp-content/uploads/2020-workshop-presentations/502.-OFA-Virtual-Workshop-2020-oneCCL-v5.pdf) -## FAQ - -### When do I need a clean build? When should I remove my favorite build directory? - -In most cases, there is no need to remove the current build directory. You can just run `make` to -compile and link changed files. - -However, if you see some suspicious build errors after a significant -change in code (for example, after rebase or a change of branch), it is a hint for you to clean the build directory. - ## Contribute See [CONTRIBUTING](CONTRIBUTING.md) for more information. diff --git a/cmake/helpers.cmake b/cmake/helpers.cmake index b24acc18b..245de34e1 100644 --- a/cmake/helpers.cmake +++ b/cmake/helpers.cmake @@ -297,7 +297,7 @@ function(set_compute_backend COMMON_CMAKE_DIR) set(MULTI_GPU_SUPPORT ON) endif() if (MULTI_GPU_SUPPORT) - message(STATUS "Enable multi GPU support using L0") + message(STATUS "Enable GPU support using level-zero") endif() # need to pass these variables to overlying function diff --git a/deps/mpi/bin/hydra_bstrap_proxy b/deps/mpi/bin/hydra_bstrap_proxy index 218ff80c3..6a2e27a5b 100755 Binary files a/deps/mpi/bin/hydra_bstrap_proxy and b/deps/mpi/bin/hydra_bstrap_proxy differ diff --git a/deps/mpi/bin/hydra_nameserver b/deps/mpi/bin/hydra_nameserver index 028fa08df..3af2dc9bc 100755 Binary files a/deps/mpi/bin/hydra_nameserver and b/deps/mpi/bin/hydra_nameserver differ diff --git a/deps/mpi/bin/hydra_pmi_proxy b/deps/mpi/bin/hydra_pmi_proxy index 14efe5656..6e09d880f 100755 Binary files a/deps/mpi/bin/hydra_pmi_proxy and b/deps/mpi/bin/hydra_pmi_proxy differ diff --git a/deps/mpi/bin/mpicc b/deps/mpi/bin/mpicc index 089b40ad6..c0af92fc0 100755 --- a/deps/mpi/bin/mpicc +++ b/deps/mpi/bin/mpicc @@ -1,6 +1,6 @@ #!/bin/bash # -# Copyright 2003-2020 Intel Corporation. +# Copyright Intel Corporation. # # This software and the related documents are Intel copyrighted materials, and # your use of them is governed by the express license under which they were @@ -28,7 +28,7 @@ if [ -z "$1" ] ; then fi #------------------------------------------------------------------------------ -dir=`dirname $0` +dir=$(dirname "$0") compiler_name=${I_MPI_CC:-${MPICH_CC:-${default_compiler_name:?}}} for arg in "$@" ; do @@ -49,9 +49,9 @@ fi if [ x"$opt_args" == x"" ]; then case "${compiler_short_name}" in - icc|icx) $dir/mpiicc -cc=$compiler_name "$@" ;; - cc|*gcc*|clang*) $dir/mpigcc -cc=$compiler_name "$@" ;; - mpicc) $dir/mpigcc "$@" ;; + icc|icx) "$dir"/mpiicc -cc=$compiler_name "$@" ;; + cc|*gcc*|clang*) "$dir"/mpigcc -cc=$compiler_name "$@" ;; + mpicc) "$dir"/mpigcc "$@" ;; *) echo "Error: unsupported compiler name '$compiler_name'." echo "Check -cc= command line option and I_MPI_CC='$I_MPI_CC' and MPICH_CC='$MPICH_CC' variables."; @@ -59,9 +59,9 @@ if [ x"$opt_args" == x"" ]; then esac else case "${compiler_short_name}" in - icc|icx) $dir/mpiicc -cc=$compiler_name "$@" $opt_args ;; - cc|*gcc*|clang*) $dir/mpigcc -cc=$compiler_name "$@" $opt_args ;; - mpicc) $dir/mpigcc "$@" $opt_args ;; + icc|icx) "$dir"/mpiicc -cc=$compiler_name "$@" $opt_args ;; + cc|*gcc*|clang*) "$dir"/mpigcc -cc=$compiler_name "$@" $opt_args ;; + mpicc) "$dir"/mpigcc "$@" $opt_args ;; *) echo "Error: unsupported compiler name '$compiler_name'." echo "Check -cc= command line option and I_MPI_CC='$I_MPI_CC' and MPICH_CC='$MPICH_CC' variables."; diff --git a/deps/mpi/bin/mpicxx b/deps/mpi/bin/mpicxx index b48d16712..38185770d 100755 --- a/deps/mpi/bin/mpicxx +++ b/deps/mpi/bin/mpicxx @@ -1,6 +1,6 @@ #!/bin/bash # -# Copyright 2003-2020 Intel Corporation. +# Copyright Intel Corporation. # # This software and the related documents are Intel copyrighted materials, and # your use of them is governed by the express license under which they were @@ -28,7 +28,7 @@ if [ -z "$1" ] ; then fi #------------------------------------------------------------------------------ -dir=`dirname $0` +dir=$(dirname "$0") compiler_name=${I_MPI_CXX:-${MPICH_CXX:-${default_compiler_name:?}}} for arg in "$@" ; do @@ -49,9 +49,9 @@ fi if [ x"$opt_args" == x"" ]; then case "${compiler_short_name}" in - icc|icpc|dpcpp) $dir/mpiicpc -cxx=$compiler_name "$@" ;; - *g++*) $dir/mpigxx -cxx=$compiler_name "$@" ;; - mpicxx) $dir/mpigxx "$@" ;; + icc|icpc|dpcpp) "$dir"/mpiicpc -cxx=$compiler_name "$@" ;; + *g++*) "$dir"/mpigxx -cxx=$compiler_name "$@" ;; + mpicxx) "$dir"/mpigxx "$@" ;; *) echo "Error: unsupported compiler name '$compiler_name'." echo "Check -cxx= command line option and I_MPI_CXX='$I_MPI_CXX' and MPICH_CXX='$MPICH_CXX' variables."; @@ -59,9 +59,9 @@ if [ x"$opt_args" == x"" ]; then esac else case "${compiler_short_name}" in - icc|icpc|dpcpp) $dir/mpiicpc -cxx=$compiler_name "$@" $opt_args ;; - *g++*) $dir/mpigxx -cxx=$compiler_name "$@" $opt_args ;; - mpicxx) $dir/mpigxx "$@" $opt_args ;; + icc|icpc|dpcpp) "$dir"/mpiicpc -cxx=$compiler_name "$@" $opt_args ;; + *g++*) "$dir"/mpigxx -cxx=$compiler_name "$@" $opt_args ;; + mpicxx) "$dir"/mpigxx "$@" $opt_args ;; *) echo "Error: unsupported compiler name '$compiler_name'." echo "Check -cxx= command line option and I_MPI_CXX='$I_MPI_CXX' and MPICH_CXX='$MPICH_CXX' variables."; diff --git a/deps/mpi/bin/mpiexec b/deps/mpi/bin/mpiexec index 8826a76d3..61a4ff30a 100755 Binary files a/deps/mpi/bin/mpiexec and b/deps/mpi/bin/mpiexec differ diff --git a/deps/mpi/bin/mpiexec.hydra b/deps/mpi/bin/mpiexec.hydra index 8826a76d3..61a4ff30a 100755 Binary files a/deps/mpi/bin/mpiexec.hydra and b/deps/mpi/bin/mpiexec.hydra differ diff --git a/deps/mpi/bin/mpigcc b/deps/mpi/bin/mpigcc index 9a306a10a..c54304f11 100755 --- a/deps/mpi/bin/mpigcc +++ b/deps/mpi/bin/mpigcc @@ -1,6 +1,6 @@ #! /bin/sh # -# Copyright 2003-2020 Intel Corporation. +# Copyright Intel Corporation. # # This software and the related documents are Intel copyrighted materials, and # your use of them is governed by the express license under which they were @@ -79,7 +79,7 @@ # # Directory locations: Fixed for any MPI implementation. # Set from the directory arguments to configure (e.g., --prefix=/usr/local) -prefix=/usr/local +prefix=../_install # The environment variable I_MPI_ROOT may be used to override installation folder path if [ -n "$I_MPI_ROOT" ] ; then prefix=$I_MPI_ROOT; @@ -104,7 +104,7 @@ CFLAGS="" CPPFLAGS="" LDFLAGS=" -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl " LIBS="-lm -lpthread -lfabric -lrt " -MPIVERSION="2021.3" +MPIVERSION="2021.4" MPILIBNAME="mpi" @@ -319,7 +319,7 @@ for arg in "$@" ; do -v) # Pass this argument to the compiler as well. echo "$(basename $0) for the Intel(R) MPI Library $MPIVERSION for Linux*" - echo "Copyright 2003-2020, Intel Corporation." + echo "Copyright Intel Corporation." # if there is only 1 argument, it must be -v. if [ "$#" -eq "1" ] ; then linking=no @@ -399,7 +399,7 @@ for arg in "$@" ; do addarg=no ;; -g) - MPILIBDIR="/debug" + MPILIBDIR="/release" ;; -static_log) static_log=yes @@ -494,22 +494,22 @@ if [ $# -eq 0 ] ; then "$0" -help exit 1 fi -MPILIBDIR_MT="mt" + if [ -n "$mpilib_override" ] ; then case "$mpilib_override" in opt ) MPILIBDIR="/release" - MPILIBDIR_MT= ;; opt_mt ) MPILIBDIR="/release" + MPILIBDIR_MT="mt" ;; dbg ) MPILIBDIR="/debug" - MPILIBDIR_MT= ;; dbg_mt ) MPILIBDIR="/debug" + MPILIBDIR_MT="mt" ;; * ) echo "Warning: incorrect library version specified. Automatically selected library will be used." @@ -534,7 +534,7 @@ if [ "$static_mpi" = yes ] ; then else mpilibs="${libdir}${MPILIBDIR}/lib${MPILIBNAME}.a" fi - I_MPI_OTHERLIBS=" -lrt -lpthread -lfabric" + I_MPI_OTHERLIBS=" -lrt -lpthread " if [ "$ilp64" = yes ]; then mpilibs="$libdir/libmpi_ilp64.a $mpilibs" fi @@ -587,7 +587,7 @@ if [ -n "$profConf" ] ; then fi final_cflags=" " if [ "${static_mpi}" = "yes" ] ; then - final_cflags=" -Xlinker --export-dynamic -lfabric" + final_cflags=" -Xlinker --export-dynamic " else final_cflags=" " fi @@ -595,7 +595,7 @@ final_cppflags=" " final_ldflags=" -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl " final_libs="-lpthread -lrt " if test "no" = "no" -o "${interlib_deps}" = "no" ; then - final_ldflags="${final_ldflags} -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl -L/p/pdsd/scratch/jenkins/artefacts_impi_2019/hcoll/lib" + final_ldflags="${final_ldflags} -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl -L/p/pdsd/scratch/jenkins/artefacts_impi_2019/hcoll/lib -L/p/pdsd/scratch/Uploads/IMPI/other/software/libfabric/linux/v1.9.0/lib" final_libs="${final_libs} -lm -lpthread -lfabric -lrt " fi @@ -614,15 +614,15 @@ fi if [ "$no_rpath" = "yes" ]; then rpath_opt="-Xlinker --enable-new-dtags" else - rpath_opt="-Xlinker --enable-new-dtags -Xlinker -rpath -Xlinker ${libdir}${MPILIBDIR} -Xlinker -rpath -Xlinker $libdir" + rpath_opt="-Xlinker --enable-new-dtags -Xlinker -rpath -Xlinker \"${libdir}${MPILIBDIR}\" -Xlinker -rpath -Xlinker \"${libdir}\"" fi if [ "$linking" = yes ] ; then if [ "$nativelinking" = yes ] ; then - $Show $CC ${final_cppflags} $PROFILE_INCPATHS ${final_cflags} ${final_ldflags} $allargs -I$includedir + $Show $CC ${final_cppflags} $PROFILE_INCPATHS ${final_cflags} ${final_ldflags} $allargs -I\"${includedir}\" rc=$? else - $Show $CC $CPPFLAGS $CFLAGS $allargs -I$includedir ${path_list} -L${libdir}${MPILIBDIR} -L${libdir} $rpath_opt $mpilibs $I_MPI_OTHERLIBS $LDFLAGS + $Show $CC $CPPFLAGS $CFLAGS $allargs -I\"${includedir}\" -L\"${libdir}${MPILIBDIR}\" -L\"${libdir}\" $rpath_opt $mpilibs $I_MPI_OTHERLIBS $LDFLAGS rc=$? if [ $rc -eq 0 -a "x$strip_debug_info" = "xyes" ] ; then @@ -643,7 +643,7 @@ if [ "$linking" = yes ] ; then # fi fi else - cmd_line="$CC $CPPFLAGS $CFLAGS $allargs -I$includedir" + cmd_line="$CC $CPPFLAGS $CFLAGS $allargs -I\"${includedir}\"" if [ "$Show" = echo ] ; then echo $cmd_line else diff --git a/deps/mpi/bin/mpigxx b/deps/mpi/bin/mpigxx index ca11d0e20..b9382fd8c 100755 --- a/deps/mpi/bin/mpigxx +++ b/deps/mpi/bin/mpigxx @@ -1,6 +1,6 @@ #! /bin/sh # -# Copyright 2003-2020 Intel Corporation. +# Copyright Intel Corporation. # # This software and the related documents are Intel copyrighted materials, and # your use of them is governed by the express license under which they were @@ -78,7 +78,7 @@ # Set the default values of all variables. # # Directory locations: Fixed for any MPI implementation -prefix=/usr/local +prefix=../_install # The environment variable I_MPI_ROOT may be used to override installation folder path if [ -n "$I_MPI_ROOT" ] ; then prefix=$I_MPI_ROOT; @@ -101,7 +101,7 @@ MPICH_VERSION="3.3" CXXFLAGS="" LDFLAGS=" -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl " LIBS="-lm -lpthread -lfabric -lrt " -MPIVERSION="2021.3" +MPIVERSION="2021.4" MPILIBNAME="mpi" MPICXXLIBNAME="mpicxx" @@ -321,7 +321,7 @@ for arg in "$@" ; do -v) # Pass this argument to the compiler as well. echo "$(basename $0) for the Intel(R) MPI Library $MPIVERSION for Linux*" - echo "Copyright 2003-2020, Intel Corporation." + echo "Copyright Intel Corporation." # if there is only 1 argument, it must be -v. if [ "$#" -eq "1" ] ; then linking=no @@ -401,7 +401,7 @@ for arg in "$@" ; do addarg=no ;; -g) - MPILIBDIR="/debug" + MPILIBDIR="/release" ;; -static_log) static_log=yes @@ -496,22 +496,22 @@ if [ $# -eq 0 ] ; then "$0" -help exit 1 fi -MPILIBDIR_MT="mt" + if [ -n "$mpilib_override" ] ; then case "$mpilib_override" in opt ) MPILIBDIR="/release" - MPILIBDIR_MT= ;; opt_mt ) MPILIBDIR="/release" + MPILIBDIR_MT="mt" ;; dbg ) MPILIBDIR="/debug" - MPILIBDIR_MT= ;; dbg_mt ) MPILIBDIR="/debug" + MPILIBDIR_MT="mt" ;; * ) echo "Warning: incorrect library version specified. Automatically selected library will be used." @@ -537,7 +537,7 @@ if [ "$static_mpi" = yes ] ; then else mpilibs="${libdir}${MPILIBDIR}/lib${MPILIBNAME}.a" fi - I_MPI_OTHERLIBS=" -lrt -lpthread -lfabric " + I_MPI_OTHERLIBS=" -lrt -lpthread " if [ "$ilp64" = yes ]; then mpilibs="$libdir/libmpi_ilp64.a $mpilibs" fi @@ -598,7 +598,7 @@ if [ -n "$profConf" ] ; then fi if [ "${static_mpi}" = "yes" ] ; then - final_cxxflags=" -Xlinker --export-dynamic -lfabric" + final_cxxflags=" -Xlinker --export-dynamic " else final_cxxflags=" " fi @@ -607,7 +607,7 @@ final_cppflags=" " final_ldflags=" -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl " final_libs="-lpthread -lrt " if test "no" = "no" -o "${interlib_deps}" = "no" ; then - final_ldflags="${final_ldflags} -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl -L/p/pdsd/scratch/jenkins/artefacts_impi_2019/hcoll/lib" + final_ldflags="${final_ldflags} -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl -L/p/pdsd/scratch/jenkins/artefacts_impi_2019/hcoll/lib -L/p/pdsd/scratch/Uploads/IMPI/other/software/libfabric/linux/v1.9.0/lib" final_libs="${final_libs} -lm -lpthread -lfabric -lrt " fi @@ -618,14 +618,14 @@ fi if [ "$no_rpath" = "yes" ]; then rpath_opt="-Xlinker --enable-new-dtags" else - rpath_opt="-Xlinker --enable-new-dtags -Xlinker -rpath -Xlinker ${libdir}${MPILIBDIR} -Xlinker -rpath -Xlinker $libdir" + rpath_opt="-Xlinker --enable-new-dtags -Xlinker -rpath -Xlinker \"${libdir}${MPILIBDIR}\" -Xlinker -rpath -Xlinker \"${libdir}\"" fi if [ "$linking" = yes ] ; then if [ "$nativelinking" = yes ] ; then - $Show $CXX ${final_cppflags} $PROFILE_INCPATHS ${final_cxxflags} ${final_ldflags} $allargs -I$includedir + $Show $CXX ${final_cppflags} $PROFILE_INCPATHS ${final_cxxflags} ${final_ldflags} $allargs -I\"${includedir}\" rc=$? else - $Show $CXX $CXXFLAGS $allargs -I$includedir ${path_list} -L${libdir}${MPILIBDIR} -L${libdir} $rpath_opt $shllibpath $cxxlibs $mpilibs $I_MPI_OTHERLIBS $LDFLAGS + $Show $CXX $CXXFLAGS $allargs -I\"${includedir}\" -L\"${libdir}${MPILIBDIR}\" -L\"${libdir}\" $rpath_opt $shllibpath $cxxlibs $mpilibs $I_MPI_OTHERLIBS $LDFLAGS rc=$? if [ $rc -eq 0 -a "x$strip_debug_info" = "xyes" ] ; then $Show objcopy --only-keep-debug ${executable} ${executable}.dbg @@ -646,7 +646,7 @@ if [ "$linking" = yes ] ; then # fi # fi else - cmd_line="$CXX ${final_cppflags} $PROFILE_INCPATHS ${final_cxxflags} $allargs -I$includedir" + cmd_line="$CXX ${final_cppflags} $PROFILE_INCPATHS ${final_cxxflags} $allargs -I\"${includedir}\"" if [ "$Show" = echo ] ; then echo $cmd_line else diff --git a/deps/mpi/bin/mpiicc b/deps/mpi/bin/mpiicc index c623722ee..25c4dea5b 100755 --- a/deps/mpi/bin/mpiicc +++ b/deps/mpi/bin/mpiicc @@ -1,6 +1,6 @@ #!/bin/sh # -# Copyright 2003-2020 Intel Corporation. +# Copyright Intel Corporation. # # This software and the related documents are Intel copyrighted materials, and # your use of them is governed by the express license under which they were @@ -81,28 +81,13 @@ # # Directory locations: Fixed for any MPI implementation. # Set from the directory arguments to configure (e.g., --prefix=/usr/local) -prefix=/usr/local +prefix=../_install # The environment variable I_MPI_ROOT may be used to override installation folder path -if [ -n "$I_MPI_ROOT" ] ; then - prefix=$I_MPI_ROOT; +if [ -n "${I_MPI_ROOT}" ] ; then + prefix="${I_MPI_ROOT}"; fi -# Recognize the '-mmic' compiler switch. -for arg in "$@" ; do -{ - if [ "${arg}" = "-mmic" ] ; then - . ${prefix}/mic/bin/mpivars.sh - if [ -n "$VT_ROOT" ]; then - if [ -f "$VT_ROOT/mic/bin/itacvars.sh" ]; then - . "$VT_ROOT/mic/bin/itacvars.sh" "" - fi - fi - exec mpiicc "$@" - fi -} -done PLATFORM="" -exec_prefix=${prefix} sysconfdir=${prefix}/etc # The environment variable I_MPI_COMPILER_CONFIG_DIR may be used to override # folder where *.conf files are placed @@ -119,10 +104,9 @@ CC="icc" CFLAGS="" LDFLAGS="-ldl" MPILIBNAME="mpi" -PMPILIBNAME="pmpi" # MPIVERSION is the version of the MPICH2 library that mpicc is intended for -MPIVERSION="2021.3" +MPIVERSION="2021.4" # # Internal variables # Show is set to echo to cause the compilation command to be echoed instead @@ -132,9 +116,7 @@ static_mpi=no strip_debug_info= handle_executable= executable=a.out -static_log=yes ilp64=no -trace_opt=no no_rpath=no # # End of initialization of variables @@ -149,14 +131,14 @@ no_rpath=no # script (defined above) if [ -n "$I_MPI_CC" ] ; then CC="$I_MPI_CC" - CCname=`echo $CC | sed 's/ /-/g'` + CCname=$(echo "$CC" | sed 's/ /-/g') if [ -s $sysconfdir/mpicc-$(basename $CCname).conf ] ; then - . $sysconfdir/mpicc-$(basename $CCname).conf + . $sysconfdir/mpicc-$(basename $CCname).conf fi else if [ -n "$MPICH_CC" ] ; then CC="$MPICH_CC" - CCname=`echo $CC | sed 's/ /-/g'` + CCname=$(echo $CC | sed 's/ /-/g') if [ -s $sysconfdir/mpicc-$(basename $CCname).conf ] ; then . $sysconfdir/mpicc-$(basename $CCname).conf fi @@ -198,239 +180,227 @@ fi linking=yes allargs="" -argno=0 for arg in "$@" ; do # Set addarg to no if this arg should be ignored by the C compiler addarg=yes qarg=$arg if [ "x$handle_executable" = "xyes" ] ; then - executable=$arg - handle_executable= + executable=$arg + handle_executable= fi case "$arg" in # ---------------------------------------------------------------- # Compiler options that affect whether we are linking or no -c|-S|-E|-M|-MM) # The compiler links by default - linking=no - ;; + linking=no + ;; -o ) - handle_executable=yes - addarg=yes - ;; + handle_executable=yes + addarg=yes + ;; # ---------------------------------------------------------------- # Options that control how we use mpicc (e.g., -show, # -cc=* -config=* -echo) - addarg=no - set -x - ;; + addarg=no + set -x + ;; -cc=*) - CC=`echo A$arg | sed -e 's/A-cc=//g'` - addarg=no - ;; + CC=$(echo A$arg | sed -e 's/A-cc=//g') + addarg=no + ;; -show) - addarg=no - Show=echo - ;; + addarg=no + Show=echo + ;; -show_env) - show_env=yes - ;; + show_env=yes + ;; -config=*) - addarg=no - CCname=`echo A$arg | sed -e 's/A-config=//g'` - if [ -s "$sysconfdir/mpicc-$CCname.conf" ] ; then - . "$sysconfdir/mpicc-$CCname.conf" - else - echo "Configuration file mpicc-$CCname.conf not found" - fi - ;; + addarg=no + CCname=$(echo A$arg | sed -e 's/A-config=//g') + if [ -s "$sysconfdir/mpicc-$CCname.conf" ] ; then + . "$sysconfdir/mpicc-$CCname.conf" + else + echo "Configuration file mpicc-$CCname.conf not found" + fi + ;; -compile-info|-compile_info) - # -compile_info included for backward compatibility - Show=echo - addarg=no - ;; + # -compile_info included for backward compatibility + Show=echo + addarg=no + ;; -link-info|-link_info) - # -link_info included for backward compatibility - Show=echo - addarg=no - ;; + # -link_info included for backward compatibility + Show=echo + addarg=no + ;; -v) - # Pass this argument to the compiler as well. - echo "$(basename $0) for the Intel(R) MPI Library $MPIVERSION for Linux*" - echo "Copyright 2003-2020, Intel Corporation." - # if there is only 1 argument, it must be -v. - if [ "$#" -eq "1" ] ; then - linking=no - fi - ;; + # Pass this argument to the compiler as well. + echo "$(basename $0) for the Intel(R) MPI Library $MPIVERSION for Linux*" + echo "Copyright Intel Corporation." + # if there is only 1 argument, it must be -v. + if [ "$#" -eq "1" ] ; then + linking=no + fi + ;; -V) - # Pass this argument to the compiler to query the compiler version. - if [ "$#" -eq "1" ] ; then - linking=no - fi - ;; + # Pass this argument to the compiler to query the compiler version. + if [ "$#" -eq "1" ] ; then + linking=no + fi + ;; -profile=*) - # Pass the name of a profiling configuration. As - # a special case, lib.so or lib.la may be used - # if the library is in $libdir - profConf=`echo A$arg | sed -e 's/A-profile=//g'` - addarg=no - # Loading the profConf file is handled below - ;; + # Pass the name of a profiling configuration. As + # a special case, lib.so or lib.la may be used + # if the library is in $libdir + # Loading the profConf file is handled below + profConf=$(echo A$arg | sed -e 's/A-profile=//g') + addarg=no + ;; -help) - # Print mini-help if started without parameters - echo "Simple script to compile and/or link MPI programs." - echo "Usage: `basename $0` [options] " - echo "----------------------------------------------------------------------------" - echo "The following options are supported:" - echo " -cc= specify a C compiler name: i.e. -cc=icc" - echo " -echo print the scripts during their execution" - echo " -show show command lines without real calling" - echo " -show_env show environment variables" - echo " -config= specify a configuration file: i.e. -config=icc for mpicc-icc.conf file" -# echo " -compile-info show compiler command line" -# echo " -link-info show linker command line" - echo " -v print version info of $(basename $0) and its native compiler" - echo " -profile= specify a profile configuration file (an MPI profiling" - echo " library): i.e. -profile=myprofile for the myprofile.cfg file." - echo " As a special case, lib.so or lib.a may be used" - echo " if the library is found" - echo " -check_mpi link against the Intel(R) Trace Collector (-profile=vtmc)." - echo " -static_mpi link the Intel(R) MPI Library statically" - echo " -mt_mpi link the thread safe version of the Intel(R) MPI Library" - echo " -ilp64 link the ILP64 support of the Intel(R) MPI Library" - echo " -t or -trace" - echo " link against the Intel(R) Trace Collector" - echo " -trace-imbalance" - echo " link against the Intel(R) Trace Collector imbalance library" - echo " (-profile=vtim)" - echo " -dynamic_log link against the Intel(R) Trace Collector dynamically" - echo " -static use static linkage method" - echo " -nostrip turn off the debug information stripping during static linking" - echo " -fast the same as -static_mpi + pass -fast option to a compiler" - echo " -O enable optimization" - echo " -link_mpi=" - echo " link against the specified version of the Intel(R) MPI Library" - echo " i.e -link_mpi=opt|opt_mt|dbg|dbg_mt" - echo " -norpath disable rpath for compiler wrapper of the Intel(R) MPI Library" - echo "All other options will be passed to the compiler without changing." - echo "----------------------------------------------------------------------------" - echo "The following environment variables are used:" - echo " I_MPI_ROOT the Intel(R) MPI Library installation directory path" - echo " I_MPI_CC or MPICH_CC" - echo " the path/name of the underlying compiler to be used" - echo " I_MPI_CC_PROFILE or MPICC_PROFILE" - echo " the name of profile file (without extension)" - echo " I_MPI_COMPILER_CONFIG_DIR" - echo " the folder which contains configuration files *.conf" - echo " I_MPI_TRACE_PROFILE" - echo " specify a default profile for the -trace option" - echo " I_MPI_CHECK_PROFILE" - echo " specify a default profile for the -check_mpi option" - echo " I_MPI_LINK specify the version of the Intel(R) MPI Library" - echo " I_MPI_DEBUG_INFO_STRIP" - echo " turn on/off the debug information stripping during static linking" - echo "----------------------------------------------------------------------------" - exit 0 - ;; + # Print mini-help if started without parameters + echo "Simple script to compile and/or link MPI programs." + echo "Usage: $(basename $0) [options] " + echo "----------------------------------------------------------------------------" + echo "The following options are supported:" + echo " -cc= specify a C compiler name: i.e. -cc=icc" + echo " -echo print the scripts during their execution" + echo " -show show command lines without real calling" + echo " -show_env show environment variables" + echo " -config= specify a configuration file: i.e. -config=icc for mpicc-icc.conf file" + echo " -v print version info of $(basename $0) and its native compiler" + echo " -profile= specify a profile configuration file (an MPI profiling" + echo " library): i.e. -profile=myprofile for the myprofile.cfg file." + echo " As a special case, lib.so or lib.a may be used" + echo " if the library is found" + echo " -check_mpi link against the Intel(R) Trace Collector (-profile=vtmc)." + echo " -static_mpi link the Intel(R) MPI Library statically" + echo " -mt_mpi link the thread safe version of the Intel(R) MPI Library" + echo " -ilp64 link the ILP64 support of the Intel(R) MPI Library" + echo " -t or -trace" + echo " link against the Intel(R) Trace Collector" + echo " -trace-imbalance" + echo " link against the Intel(R) Trace Collector imbalance library" + echo " (-profile=vtim)" + echo " -dynamic_log link against the Intel(R) Trace Collector dynamically" + echo " -static use static linkage method" + echo " -nostrip turn off the debug information stripping during static linking" + echo " -fast the same as -static_mpi + pass -fast option to a compiler" + echo " -O enable optimization" + echo " -link_mpi=" + echo " link against the specified version of the Intel(R) MPI Library" + echo " i.e -link_mpi=opt|opt_mt|dbg|dbg_mt" + echo " -norpath disable rpath for compiler wrapper of the Intel(R) MPI Library" + echo "All other options will be passed to the compiler without changing." + echo "----------------------------------------------------------------------------" + echo "The following environment variables are used:" + echo " I_MPI_ROOT the Intel(R) MPI Library installation directory path" + echo " I_MPI_CC or MPICH_CC" + echo " the path/name of the underlying compiler to be used" + echo " I_MPI_CC_PROFILE or MPICC_PROFILE" + echo " the name of profile file (without extension)" + echo " I_MPI_COMPILER_CONFIG_DIR" + echo " the folder which contains configuration files *.conf" + echo " I_MPI_TRACE_PROFILE" + echo " specify a default profile for the -trace option" + echo " I_MPI_CHECK_PROFILE" + echo " specify a default profile for the -check_mpi option" + echo " I_MPI_LINK specify the version of the Intel(R) MPI Library" + echo " I_MPI_DEBUG_INFO_STRIP" + echo " turn on/off the debug information stripping during static linking" + echo "----------------------------------------------------------------------------" + exit 0 + ;; -nolinkage) - # This internal option is used by wrapper driver scripts mpicc, mpicxx, mpifc when -v option is used. - linking=no - addarg=no - ;; + # This internal option is used by wrapper driver scripts mpicc, mpicxx, mpifc when -v option is used. + linking=no + addarg=no + ;; -g) - MPILIBDIR="/release" - ;; + MPILIBDIR="/release" + ;; -static_mpi) - static_mpi=yes - CFLAGS="$CFLAGS -Xlinker --export-dynamic" - addarg=no - ;; - -static_log) - static_log=yes - addarg=no - ;; - -dynamic_log) - static_log=no - addarg=no - ;; + static_mpi=yes + CFLAGS="$CFLAGS -Xlinker --export-dynamic" + addarg=no + ;; -static) - static_mpi=yes - static_log=yes - CFLAGS="$CFLAGS -Xlinker --export-dynamic" - addarg=yes - ;; + static_mpi=yes + CFLAGS="$CFLAGS -Xlinker --export-dynamic" + addarg=yes + ;; -mt_mpi) - addarg=no - ;; + addarg=no + ;; -ilp64) - ilp64=yes - addarg=no - ;; + ilp64=yes + addarg=no + ;; -check_mpi) - if [ -z "$profConf" ]; then - if [ -z "$I_MPI_CHECK_PROFILE" ]; then - profConf="vtmc" + if [ -z "$profConf" ]; then + if [ -z "$I_MPI_CHECK_PROFILE" ]; then + profConf="vtmc" + else + profConf="$I_MPI_CHECK_PROFILE" + fi else - profConf="$I_MPI_CHECK_PROFILE" + echo "Warning: the -check_mpi option will be ignored because the profile was set." fi - else - echo "Warning: the -check_mpi option will be ignored because the profile was set." - fi - addarg=no - ;; + addarg=no + ;; -trace-imbalance) - if [ -z "$profConf" ]; then - profConf="vtim" - else - echo "Warning: the -trace-imbalance option will be ignored because the profile was set." - fi - addarg=no - ;; + if [ -z "$profConf" ]; then + profConf="vtim" + else + echo "Warning: the -trace-imbalance option will be ignored because the profile was set." + fi + addarg=no + ;; -t | -trace | -t=* | -trace=* ) - if [ -z "$profConf" ]; then - if [ -z "$I_MPI_TRACE_PROFILE" ]; then - profConf="vt" + if [ -z "$profConf" ]; then + if [ -z "$I_MPI_TRACE_PROFILE" ]; then + profConf="vt" + else + profConf="$I_MPI_TRACE_PROFILE" + fi else - profConf="$I_MPI_TRACE_PROFILE" + echo "Warning: the -trace option will be ignored because the profile was set." fi - else - echo "Warning: the -trace option will be ignored because the profile was set." - fi - # Disable strip to prevent debug symbols into separate dbg file in case of static linking IMPI-1493 - strip_debug_info=no - addarg=no - ;; + # Disable strip to prevent debug symbols into separate dbg file in case of static linking IMPI-1493 + strip_debug_info=no + addarg=no + ;; -fast) - echo "Warning: the -fast option forces static linkage method for the Intel(R) MPI Library." - static_mpi=yes - CFLAGS="$CFLAGS -Xlinker --export-dynamic" - ;; + echo "Warning: the -fast option forces static linkage method for the Intel(R) MPI Library." + static_mpi=yes + CFLAGS="$CFLAGS -Xlinker --export-dynamic" + ;; -link_mpi=* ) - mpilib_override=`echo A$arg | sed -e 's/A-link_mpi=//g'` - addarg=no - ;; + mpilib_override=`echo A$arg | sed -e 's/A-link_mpi=//g'` + addarg=no + ;; -nostrip ) - strip_debug_info=no - addarg=no - ;; + strip_debug_info=no + addarg=no + ;; -norpath ) - no_rpath=yes - addarg=no - ;; + no_rpath=yes + addarg=no + ;; # Other arguments. We are careful to handle arguments with # quotes (we try to quote all arguments in case they include # any spaces) *\"*) - qarg="'$arg'" - ;; + qarg="'$arg'" + ;; *\'*) - qarg=`echo \"$arg\"` - ;; + qarg=$(echo \"$arg\") + ;; *) - qarg="'$arg'" - ;; + qarg="'$arg'" + ;; esac if [ $addarg = yes ] ; then allargs="$allargs $qarg" @@ -468,30 +438,30 @@ fi # ----------------------------------------------------------------------- case "$MPILIBDIR" in release | /release | debug | /debug) - if [ ! -z "$MPILIBDIR_MT" ]; then - MPILIBDIR=${MPILIBDIR}_${MPILIBDIR_MT} - fi - ;; + if [ -n "$MPILIBDIR_MT" ]; then + MPILIBDIR=${MPILIBDIR}_${MPILIBDIR_MT} + fi + ;; "" ) - MPILIBDIR="/release" - ;; + MPILIBDIR="/release" + ;; esac if [ "$static_mpi" = yes ] ; then mpilibs="${libdir}/libmpifort.a ${libdir}${MPILIBDIR}/lib${MPILIBNAME}.a" I_MPI_OTHERLIBS="" - MPI_OTHERLIBS=" -lrt -lpthread -lfabric " + MPI_OTHERLIBS=" -lrt -lpthread " if [ "$ilp64" = yes ]; then - mpilibs="$libdir/libmpi_ilp64.a $mpilibs" + mpilibs="$libdir/libmpi_ilp64.a $mpilibs" fi if [ "x$strip_debug_info" = "x" ] ; then - strip_debug_info=yes + strip_debug_info=yes fi else mpilibs="-lmpifort -l$MPILIBNAME" I_MPI_OTHERLIBS="" MPI_OTHERLIBS=" -lrt -lpthread " if [ "$ilp64" = yes ]; then - mpilibs="-lmpi_ilp64 $mpilibs" + mpilibs="-lmpi_ilp64 $mpilibs" fi fi # Derived variables. These are assembled from variables set from the @@ -502,27 +472,27 @@ fi # Handle the case of a profile switch if [ -n "$profConf" ] ; then profConffile= - if [ -s "$libdir/lib$profConf.a" -o -s "$libdir/lib$profConf.so" ] ; then - mpilibs="-l$profConf $mpilibs" + if [ -s "$libdir/lib$profConf.a" ] || [ -s "$libdir/lib$profConf.so" ] ; then + mpilibs="-l$profConf $mpilibs" elif [ -s "$sysconfdir/$profConf.conf" ] ; then - profConffile="$sysconfdir/$profConf.conf" + profConffile="$sysconfdir/$profConf.conf" elif [ -s "$profConf.conf" ] ; then profConffile="$profConf.conf" else echo "Profiling configuration file $profConf.conf not found in $sysconfdir" fi - if [ -n "$profConffile" -a -s "$profConffile" ] ; then - . $profConffile - if [ -n "$PROFILE_INCPATHS" ] ; then - CFLAGS="$PROFILE_INCPATHS $CFLAGS" + if [ -n "$profConffile" ] && [ -s "$profConffile" ] ; then + . $profConffile + if [ -n "$PROFILE_INCPATHS" ] ; then + CFLAGS="$PROFILE_INCPATHS $CFLAGS" + fi + if [ -n "$PROFILE_PRELIB" ] ; then + mpilibs="$PROFILE_PRELIB $mpilibs" + fi + if [ -n "$PROFILE_POSTLIB" ] ; then + mpilibs="$mpilibs $PROFILE_POSTLIB" + fi fi - if [ -n "$PROFILE_PRELIB" ] ; then - mpilibs="$PROFILE_PRELIB $mpilibs" - fi - if [ -n "$PROFILE_POSTLIB" ] ; then - mpilibs="$mpilibs $PROFILE_POSTLIB" - fi - fi fi # ----------------------------------------------------------------------- @@ -546,27 +516,27 @@ fi if [ "$no_rpath" = "yes" ]; then rpath_opt="-Xlinker --enable-new-dtags" else - rpath_opt="-Xlinker --enable-new-dtags -Xlinker -rpath -Xlinker ${libdir}${MPILIBDIR} -Xlinker -rpath -Xlinker $libdir" + rpath_opt="-Xlinker --enable-new-dtags -Xlinker -rpath -Xlinker \"${libdir}${MPILIBDIR}\" -Xlinker -rpath -Xlinker \"${libdir}\"" fi if [ "$linking" = yes ] ; then - cmd_line="$CC $CFLAGS $allargs -I$includedir -L${libdir}${MPILIBDIR} -L$libdir $rpath_opt $mpilibs $I_MPI_OTHERLIBS $LDFLAGS $MPI_OTHERLIBS" + cmd_line="$CC $CFLAGS $allargs -I\"${includedir}\" -L\"${libdir}${MPILIBDIR}\" -L\"${libdir}\" $rpath_opt $mpilibs $I_MPI_OTHERLIBS $LDFLAGS $MPI_OTHERLIBS" if [ "$Show" = echo ] ; then echo $cmd_line else - eval `echo $cmd_line` + eval $(echo $cmd_line) fi rc=$? - if [ $rc -eq 0 -a "x$strip_debug_info" = "xyes" ] ; then - $Show objcopy --only-keep-debug ${executable} ${executable}.dbg - $Show objcopy --strip-debug ${executable} - $Show objcopy --add-gnu-debuglink=${executable}.dbg ${executable} + if [ $rc -eq 0 ] && [ "x$strip_debug_info" = "xyes" ] ; then + $Show objcopy --only-keep-debug ${executable} ${executable}.dbg + $Show objcopy --strip-debug ${executable} + $Show objcopy --add-gnu-debuglink=${executable}.dbg ${executable} fi else - cmd_line="$CC $CFLAGS $allargs -I$includedir" + cmd_line="$CC $CFLAGS $allargs -I\"${includedir}\"" if [ "$Show" = echo ] ; then - echo $cmd_line + echo "$cmd_line" else - eval `echo $cmd_line` + eval $(echo $cmd_line) fi rc=$? fi diff --git a/deps/mpi/bin/mpiicpc b/deps/mpi/bin/mpiicpc index 13695ab64..1e221dbbd 100755 --- a/deps/mpi/bin/mpiicpc +++ b/deps/mpi/bin/mpiicpc @@ -1,6 +1,6 @@ #!/bin/sh # -# Copyright 2003-2020 Intel Corporation. +# Copyright Intel Corporation. # # This software and the related documents are Intel copyrighted materials, and # your use of them is governed by the express license under which they were @@ -81,48 +81,34 @@ # Set the default values of all variables. # # Directory locations: Fixed for any MPI implementation -prefix=/usr/local +prefix=../_install # The environment variable I_MPI_ROOT may be used to override installation folder path -if [ -n "$I_MPI_ROOT" ] ; then - prefix=$I_MPI_ROOT; +if [ -n "${I_MPI_ROOT}" ] ; then + prefix="${I_MPI_ROOT}"; fi -# Recognize the '-mmic' compiler switch. -for arg in "$@" ; do -{ - if [ "${arg}" = "-mmic" ] ; then - . ${prefix}/mic/bin/mpivars.sh - if [ -n "$VT_ROOT" ]; then - if [ -f "$VT_ROOT/mic/bin/itacvars.sh" ]; then - . "$VT_ROOT/mic/bin/itacvars.sh" "" - fi - fi - exec mpiicpc "$@" - fi -} -done PLATFORM="" -exec_prefix=${prefix} sysconfdir=${prefix}/etc + # The environment variable I_MPI_COMPILER_CONFIG_DIR may be used to override # folder where *.conf files are placed if [ -n "$I_MPI_COMPILER_CONFIG_DIR" ] ; then sysconfdir=$I_MPI_COMPILER_CONFIG_DIR; fi + includedir=${prefix}/include libdir=${prefix}/lib -# + # Default settings for compiler, flags, and libraries CXX="icpc" CXXFLAGS="" LDFLAGS="-ldl" MPILIBNAME="mpi" -PMPILIBNAME="pmpi" MPICXXLIBNAME="mpicxx" # MPIVERSION is the version of the Intel(R) MPI Library that mpiicpc is intended for -MPIVERSION="2021.3" -# +MPIVERSION="2021.4" + # Internal variables # Show is set to echo to cause the compilation command to be echoed instead # of executed. @@ -131,12 +117,10 @@ static_mpi=no strip_debug_info= handle_executable= executable=a.out -static_log=yes ilp64=no -trace_opt=no no_rpath=no -# # End of initialization of variables +# #--------------------------------------------------------------------- # Environment Variables. # The environment variables I_MPI_CXX, MPICH_CXX may be used to override the @@ -146,6 +130,7 @@ no_rpath=no # (e.g., "CC -64" becomes "CC--64", that file is sources, allowing other # changes to the compilation environment. See the variables used by the # script (defined above) + if [ -n "$I_MPI_CXX" ] ; then CXX="$I_MPI_CXX" CXXname=`echo $CXX | sed 's/ /-/g'` @@ -197,244 +182,232 @@ fi linking=yes allargs="" -argno=0 for arg in "$@" ; do # Set addarg to no if this arg should be ignored by the C compiler addarg=yes qarg=$arg if [ "x$handle_executable" = "xyes" ] ; then - executable=$arg - handle_executable= + executable=$arg + handle_executable= fi case "$arg" in # ---------------------------------------------------------------- # Compiler options that affect whether we are linking or no -c|-S|-E|-M|-MM) - # The compiler links by default - linking=no - ;; + # The compiler links by default + linking=no + ;; -o ) - handle_executable=yes - addarg=yes - ;; + handle_executable=yes + addarg=yes + ;; # ---------------------------------------------------------------- # Options that control how we use mpicxx (e.g., -show, # -cxx=* -config=* -echo) - addarg=no - set -x - ;; + addarg=no + set -x + ;; -cxx=*) - CXX=`echo A$arg | sed -e 's/A-cxx=//g'` - addarg=no - ;; + CXX=$(echo A$arg | sed -e 's/A-cxx=//g') + addarg=no + ;; # Backwards compatibility for MPICH1 - scripts -CC=*) - CXX=`echo A$arg | sed -e 's/A-CC=//g'` - addarg=no - ;; + CXX=$(echo A$arg | sed -e 's/A-CC=//g') + addarg=no + ;; -show) - addarg=no - Show=echo - ;; + addarg=no + Show=echo + ;; -show_env) - show_env=yes - ;; + show_env=yes + ;; -config=*) - addarg=no - CXXname=`echo A$arg | sed -e 's/A-config=//g'` - if [ -s "$sysconfdir/mpicxx-$CXXname.conf" ] ; then - . "$sysconfdir/mpicxx-$CXXname.conf" - else - echo "Configuration file mpicxx-$CXXname.conf not found" - fi - ;; + addarg=no + CXXname=$(echo A$arg | sed -e 's/A-config=//g') + if [ -s "$sysconfdir/mpicxx-$CXXname.conf" ] ; then + . "$sysconfdir/mpicxx-$CXXname.conf" + else + echo "Configuration file mpicxx-$CXXname.conf not found" + fi + ;; -compile-info|-compile_info) - # -compile_info included for backward compatibility - Show=echo - addarg=no - ;; + # -compile_info included for backward compatibility + Show=echo + addarg=no + ;; -link-info|-link_info) - # -link_info included for backward compatibility - Show=echo - addarg=no - ;; + # -link_info included for backward compatibility + Show=echo + addarg=no + ;; -v) - # Pass this argument to the compiler as well. - echo "$(basename $0) for the Intel(R) MPI Library $MPIVERSION for Linux*" - echo "Copyright 2003-2020, Intel Corporation." - # if there is only 1 argument, it must be -v. - if [ "$#" -eq "1" ] ; then - linking=no - fi - ;; + # Pass this argument to the compiler as well. + echo "$(basename $0) for the Intel(R) MPI Library $MPIVERSION for Linux*" + echo "Copyright Intel Corporation." + # if there is only 1 argument, it must be -v. + if [ "$#" -eq "1" ] ; then + linking=no + fi + ;; -V) - # Pass this argument to the compiler to query the compiler version. - if [ "$#" -eq "1" ] ; then - linking=no - fi - ;; + # Pass this argument to the compiler to query the compiler version. + if [ "$#" -eq "1" ] ; then + linking=no + fi + ;; -profile=*) - # Pass the name of a profiling configuration. As - # a special case, lib.so or lib.la may be used - # if the library is in $libdir - profConf=`echo A$arg | sed -e 's/A-profile=//g'` - addarg=no - # Loading the profConf file is handled below - ;; + # Pass the name of a profiling configuration. As + # a special case, lib.so or lib.la may be used + # if the library is in $libdir + profConf=$(echo A$arg | sed -e 's/A-profile=//g') + addarg=no + # Loading the profConf file is handled below + ;; -help) - # Print mini-help if started without parameters - echo "Simple script to compile and/or link MPI programs." - echo "Usage: `basename $0` [options] " - echo "----------------------------------------------------------------------------" - echo "The following options are supported:" - echo " -cxx= specify a C++ compiler name: i.e. -cxx=icpc" - echo " -echo print the scripts during their execution" - echo " -show show command lines without real calling" - echo " -show_env show environment variables" - echo " -config= specify a configuration file: i.e. -config=icpc for mpicc-icpc.conf file" -# echo " -compile-info show compiler command line" -# echo " -link-info show linker command line" - echo " -v print version info of $(basename $0) and its native compiler" - echo " -profile= specify a profile configuration file (an MPI profiling" - echo " library): i.e. -profile=myprofile for the myprofile.cfg file." - echo " As a special case, lib.so or lib.a may be used" - echo " if the library is found" - echo " -check_mpi link against the Intel(R) Trace Collector (-profile=vtmc)." - echo " -static_mpi link the Intel(R) MPI Library statically" - echo " -mt_mpi link the thread safe version of the Intel(R) MPI Library" - echo " -ilp64 link the ILP64 support of the Intel(R) MPI Library" - echo " -fast the same as -static_mpi + pass -fast option to a compiler" - echo " -t or -trace" - echo " link against the Intel(R) Trace Collector" - echo " -trace-imbalance" - echo " link against the Intel(R) Trace Collector imbalance library" - echo " (-profile=vtim)" - echo " -static use static linkage method" - echo " -nostrip turn off the debug information stripping during static linking" - echo " -dynamic_log link against the Intel(R) Trace Collector dynamically" - echo " -O enable optimization" - echo " -link_mpi=" - echo " link against the specified version of the Intel(R) MPI Library" - echo " i.e -link_mpi=opt|opt_mt|dbg|dbg_mt" - echo " -norpath disable rpath for compiler wrapper of the Intel(R) MPI Library" - echo "All other options will be passed to the compiler without changing." - echo "----------------------------------------------------------------------------" - echo "The following environment variables are used:" - echo " I_MPI_ROOT the Intel(R) MPI Library installation directory path" - echo " I_MPI_CXX or MPICH_CXX" - echo " the path/name of the underlying compiler to be used" - echo " I_MPI_CXX_PROFILE or MPICXX_PROFILE" - echo " the name of profile file (without extension)" - echo " I_MPI_COMPILER_CONFIG_DIR" - echo " the folder which contains configuration files *.conf" - echo " I_MPI_TRACE_PROFILE" - echo " specify a default profile for the -trace option" - echo " I_MPI_CHECK_PROFILE" - echo " specify a default profile for the -check_mpi option" - echo " I_MPI_LINK specify the version of the Intel(R) MPI Library" - echo " I_MPI_DEBUG_INFO_STRIP" - echo " turn on/off the debug information stripping during static linking" - echo "----------------------------------------------------------------------------" - exit 0 - ;; + # Print mini-help if started without parameters + echo "Simple script to compile and/or link MPI programs." + echo "Usage: $(basename $0) [options] " + echo "----------------------------------------------------------------------------" + echo "The following options are supported:" + echo " -cxx= specify a C++ compiler name: i.e. -cxx=icpc" + echo " -echo print the scripts during their execution" + echo " -show show command lines without real calling" + echo " -show_env show environment variables" + echo " -config= specify a configuration file: i.e. -config=icpc for mpicc-icpc.conf file" + echo " -v print version info of $(basename $0) and its native compiler" + echo " -profile= specify a profile configuration file (an MPI profiling" + echo " library): i.e. -profile=myprofile for the myprofile.cfg file." + echo " As a special case, lib.so or lib.a may be used" + echo " if the library is found" + echo " -check_mpi link against the Intel(R) Trace Collector (-profile=vtmc)." + echo " -static_mpi link the Intel(R) MPI Library statically" + echo " -mt_mpi link the thread safe version of the Intel(R) MPI Library" + echo " -ilp64 link the ILP64 support of the Intel(R) MPI Library" + echo " -fast the same as -static_mpi + pass -fast option to a compiler" + echo " -t or -trace" + echo " link against the Intel(R) Trace Collector" + echo " -trace-imbalance" + echo " link against the Intel(R) Trace Collector imbalance library" + echo " (-profile=vtim)" + echo " -static use static linkage method" + echo " -nostrip turn off the debug information stripping during static linking" + echo " -dynamic_log link against the Intel(R) Trace Collector dynamically" + echo " -O enable optimization" + echo " -link_mpi=" + echo " link against the specified version of the Intel(R) MPI Library" + echo " i.e -link_mpi=opt|opt_mt|dbg|dbg_mt" + echo " -norpath disable rpath for compiler wrapper of the Intel(R) MPI Library" + echo "All other options will be passed to the compiler without changing." + echo "----------------------------------------------------------------------------" + echo "The following environment variables are used:" + echo " I_MPI_ROOT the Intel(R) MPI Library installation directory path" + echo " I_MPI_CXX or MPICH_CXX" + echo " the path/name of the underlying compiler to be used" + echo " I_MPI_CXX_PROFILE or MPICXX_PROFILE" + echo " the name of profile file (without extension)" + echo " I_MPI_COMPILER_CONFIG_DIR" + echo " the folder which contains configuration files *.conf" + echo " I_MPI_TRACE_PROFILE" + echo " specify a default profile for the -trace option" + echo " I_MPI_CHECK_PROFILE" + echo " specify a default profile for the -check_mpi option" + echo " I_MPI_LINK specify the version of the Intel(R) MPI Library" + echo " I_MPI_DEBUG_INFO_STRIP" + echo " turn on/off the debug information stripping during static linking" + echo "----------------------------------------------------------------------------" + exit 0 + ;; -nolinkage) - # This internal option is used by wrapper driver scripts mpicc, mpicxx, mpifc when -v option is used. - linking=no - addarg=no - ;; + # This internal option is used by wrapper driver scripts mpicc, mpicxx, mpifc when -v option is used. + linking=no + addarg=no + ;; -g) - MPILIBDIR="/release" - ;; + MPILIBDIR="/release" + ;; -static_mpi) - static_mpi=yes - CXXFLAGS="$CXXFLAGS -Xlinker --export-dynamic" - addarg=no - ;; - -static_log) - static_log=yes - addarg=no - ;; - -dynamic_log) - static_log=no - addarg=no - ;; + static_mpi=yes + CXXFLAGS="$CXXFLAGS -Xlinker --export-dynamic" + addarg=no + ;; -static) - static_mpi=yes - static_log=yes - CXXFLAGS="$CXXFLAGS -Xlinker --export-dynamic" - addarg=yes - ;; + static_mpi=yes + CXXFLAGS="$CXXFLAGS -Xlinker --export-dynamic" + addarg=yes + ;; -mt_mpi) - addarg=no - ;; + addarg=no + ;; -ilp64) - ilp64=yes - addarg=no - ;; + ilp64=yes + addarg=no + ;; -check_mpi) - if [ -z "$profConf" ]; then - if [ -z "$I_MPI_CHECK_PROFILE" ]; then - profConf="vtmc" + if [ -z "$profConf" ]; then + if [ -z "$I_MPI_CHECK_PROFILE" ]; then + profConf="vtmc" + else + profConf="$I_MPI_CHECK_PROFILE" + fi else - profConf="$I_MPI_CHECK_PROFILE" + echo "Warning: the -check_mpi option will be ignored because the profile was set." fi - else - echo "Warning: the -check_mpi option will be ignored because the profile was set." - fi - addarg=no - ;; + addarg=no + ;; -trace-imbalance) - if [ -z "$profConf" ]; then - profConf="vtim" - else - echo "Warning: the -trace-imbalance option will be ignored because the profile was set." - fi - addarg=no - ;; + if [ -z "$profConf" ]; then + profConf="vtim" + else + echo "Warning: the -trace-imbalance option will be ignored because the profile was set." + fi + addarg=no + ;; -t | -trace | -t=* | -trace=* ) - if [ -z "$profConf" ]; then - if [ -z "$I_MPI_TRACE_PROFILE" ]; then - profConf="vt" + if [ -z "$profConf" ]; then + if [ -z "$I_MPI_TRACE_PROFILE" ]; then + profConf="vt" + else + profConf="$I_MPI_TRACE_PROFILE" + fi else - profConf="$I_MPI_TRACE_PROFILE" + echo "Warning: the -trace option will be ignored because the profile was set." fi - else - echo "Warning: the -trace option will be ignored because the profile was set." - fi - # Disable strip to prevent debug symbols into separate dbg file in case of static linking IMPI-1493 - strip_debug_info=no - addarg=no - ;; + # Disable strip to prevent debug symbols into separate dbg file in case of static linking IMPI-1493 + strip_debug_info=no + addarg=no + ;; -fast) - echo "Warning: the -fast option forces static linkage method for the Intel(R) MPI Library." - static_mpi=yes - CXXFLAGS="$CXXFLAGS -Xlinker --export-dynamic" - ;; + echo "Warning: the -fast option forces static linkage method for the Intel(R) MPI Library." + static_mpi=yes + CXXFLAGS="$CXXFLAGS -Xlinker --export-dynamic" + ;; -link_mpi=* ) - mpilib_override=`echo A$arg | sed -e 's/A-link_mpi=//g'` - addarg=no - ;; + mpilib_override=$(echo A$arg | sed -e 's/A-link_mpi=//g') + addarg=no + ;; -nostrip ) - strip_debug_info=no - addarg=no - ;; + strip_debug_info=no + addarg=no + ;; -norpath ) - no_rpath=yes - addarg=no - ;; + no_rpath=yes + addarg=no + ;; # Other arguments. We are careful to handle arguments with # quotes (we try to quote all arguments in case they include # any spaces) *\"*) - qarg="'$arg'" - ;; + qarg="'$arg'" + ;; *\'*) - qarg=`echo \"$arg\"` - ;; + qarg=$(echo \"$arg\") + ;; *) - qarg="'$arg'" - ;; + qarg="'$arg'" + ;; esac if [ $addarg = yes ] ; then allargs="$allargs $qarg" @@ -472,32 +445,32 @@ fi # ----------------------------------------------------------------------- case "$MPILIBDIR" in release | /release | debug | /debug) - if [ ! -z "$MPILIBDIR_MT" ]; then - MPILIBDIR=${MPILIBDIR}_${MPILIBDIR_MT} - fi - ;; - "" ) - MPILIBDIR="/release" - ;; + if [ -n "$MPILIBDIR_MT" ]; then + MPILIBDIR=${MPILIBDIR}_${MPILIBDIR_MT} + fi + ;; + "" ) MPILIBDIR="/release" ;; esac if [ "$static_mpi" = yes ] ; then mpilibs="${libdir}/libmpifort.a ${libdir}${MPILIBDIR}/lib${MPILIBNAME}.a" I_MPI_OTHERLIBS="" - MPI_OTHERLIBS=" -lrt -lpthread -lfabric " + MPI_OTHERLIBS=" -lrt -lpthread " if [ "$ilp64" = yes ]; then - mpilibs="$libdir/libmpi_ilp64.a $mpilibs" + mpilibs="$libdir/libmpi_ilp64.a $mpilibs" fi + CXX_BIND_LIB="$libdir/libmpicxx.a" if [ "x$strip_debug_info" = "x" ] ; then - strip_debug_info=yes + strip_debug_info=yes fi else mpilibs="-lmpifort -l$MPILIBNAME" I_MPI_OTHERLIBS="" MPI_OTHERLIBS=" -lrt -lpthread " if [ "$ilp64" = yes ]; then - mpilibs="-lmpi_ilp64 $mpilibs" + mpilibs="-lmpi_ilp64 $mpilibs" fi + CXX_BIND_LIB="-lmpicxx" fi # Derived variables. These are assembled from variables set from the @@ -512,25 +485,25 @@ fi # Handle the case of a profile switch if [ -n "$profConf" ] ; then profConffile= - if [ -s "$libdir/lib$profConf.a" -o -s "$libdir/lib$profConf.so" ] ; then - mpilibs="-l$profConf $mpilibs" + if [ -s "$libdir/lib$profConf.a" ] || [ -s "$libdir/lib$profConf.so" ] ; then + mpilibs="-l$profConf $mpilibs" elif [ -s "$sysconfdir/$profConf.conf" ] ; then - profConffile="$sysconfdir/$profConf.conf" + profConffile="$sysconfdir/$profConf.conf" elif [ -s "$profConf.conf" ] ; then profConffile="$profConf.conf" else echo "Profiling configuration file $profConf.conf not found in $sysconfdir" fi - if [ -n "$profConffile" -a -s "$profConffile" ] ; then - . $profConffile - if [ -n "$PROFILE_INCPATHS" ] ; then - CXXFLAGS="$PROFILE_INCPATHS $CXXFLAGS" + if [ -n "$profConffile" ] && [ -s "$profConffile" ] ; then + . $profConffile + if [ -n "$PROFILE_INCPATHS" ] ; then + CXXFLAGS="$PROFILE_INCPATHS $CXXFLAGS" fi if [ -n "$PROFILE_PRELIB" ] ; then - mpilibs="$PROFILE_PRELIB $mpilibs" + mpilibs="$PROFILE_PRELIB $mpilibs" fi if [ -n "$PROFILE_POSTLIB" ] ; then - mpilibs="$mpilibs $PROFILE_POSTLIB" + mpilibs="$mpilibs $PROFILE_POSTLIB" fi fi fi @@ -547,23 +520,23 @@ fi if [ "$no_rpath" = "yes" ]; then rpath_opt="-Xlinker --enable-new-dtags" else - rpath_opt="-Xlinker --enable-new-dtags -Xlinker -rpath -Xlinker ${libdir}${MPILIBDIR} -Xlinker -rpath -Xlinker $libdir" + rpath_opt="-Xlinker --enable-new-dtags -Xlinker -rpath -Xlinker \"${libdir}${MPILIBDIR}\" -Xlinker -rpath -Xlinker \"${libdir}\"" fi if [ "$linking" = yes ] ; then - cmd_line="$CXX $CXXFLAGS $allargs -I$includedir -L${libdir}${MPILIBDIR} -L$libdir $rpath_opt $cxxlibs $mpilibs $I_MPI_OTHERLIBS $LDFLAGS $MPI_OTHERLIBS" + cmd_line="$CXX $CXXFLAGS $allargs -I\"${includedir}\" -L\"${libdir}${MPILIBDIR}\" -L\"${libdir}\" $rpath_opt $cxxlibs $mpilibs $I_MPI_OTHERLIBS $LDFLAGS $MPI_OTHERLIBS" if [ "$Show" = echo ] ; then echo $cmd_line else eval `echo $cmd_line` fi rc=$? - if [ $rc -eq 0 -a "x$strip_debug_info" = "xyes" ] ; then - $Show objcopy --only-keep-debug ${executable} ${executable}.dbg - $Show objcopy --strip-debug ${executable} - $Show objcopy --add-gnu-debuglink=${executable}.dbg ${executable} + if [ $rc -eq 0 ] && [ "x$strip_debug_info" = "xyes" ] ; then + $Show objcopy --only-keep-debug ${executable} ${executable}.dbg + $Show objcopy --strip-debug ${executable} + $Show objcopy --add-gnu-debuglink=${executable}.dbg ${executable} fi else - cmd_line="$CXX $CXXFLAGS $allargs -I$includedir" + cmd_line="$CXX $CXXFLAGS $allargs -I\"${includedir}\"" if [ "$Show" = echo ] ; then $Show $cmd_line else diff --git a/deps/mpi/bin/mpirun b/deps/mpi/bin/mpirun index 30107bf0d..86628a81d 100755 --- a/deps/mpi/bin/mpirun +++ b/deps/mpi/bin/mpirun @@ -1,6 +1,6 @@ #!/bin/sh # -# Copyright 2003-2020 Intel Corporation. +# Copyright Intel Corporation. # # This software and the related documents are Intel copyrighted materials, and # your use of them is governed by the express license under which they were @@ -22,15 +22,15 @@ if [ -n "$I_MPI_TMPDIR" ]; then fi np_boot= -username=`whoami` +username=$(whoami) rc=0 -if [ -z "$I_MPI_ROOT" -a -z "`uname -m | grep 1om`" ] ; then +if [ -z "$I_MPI_ROOT" -a -z "$(uname -m | grep 1om)" ] ; then if [ -f ${0%/*}/mpivars.sh ]; then . ${0%/*}/mpivars.sh "" - else + elif [ -f ${0%/*}/../env/vars.sh ]; then . ${0%/*}/../env/vars.sh "" - fi + fi # else it can be a runtime package without any scripts to source fi ##### mpirun detection ##### @@ -41,12 +41,12 @@ export I_MPI_MPIRUN="mpirun" ####### Job Scheduler autodetection ####### ############################################# # PBS -if [ -n "$PBS_ENVIRONMENT" -a -z "$I_MPI_HYDRA_RMK" ] ; then +if [ -n "$PBS_ENVIRONMENT" ] && [ -z "$I_MPI_HYDRA_RMK" ] ; then export I_MPI_HYDRA_RMK=pbs -elif [ -n "$LSB_JOBID" -a -z "$I_MPI_HYDRA_RMK" ]; then +elif [ -n "$LSB_JOBID" ] && [ -z "$I_MPI_HYDRA_RMK" ]; then export I_MPI_HYDRA_RMK=lsf fi -if [ -z "$I_MPI_HYDRA_BOOTSTRAP" -a -z "$I_MPI_HYDRA_BOOTSTRAP_EXEC" ]; then +if [ -z "$I_MPI_HYDRA_BOOTSTRAP" ] && [ -z "$I_MPI_HYDRA_BOOTSTRAP_EXEC" ]; then # SLURM if [ -n "$SLURM_JOBID" ]; then export I_MPI_HYDRA_BOOTSTRAP=slurm @@ -57,7 +57,7 @@ if [ -z "$I_MPI_HYDRA_BOOTSTRAP" -a -z "$I_MPI_HYDRA_BOOTSTRAP_EXEC" ]; then elif [ -n "$PE_HOSTFILE" ]; then export I_MPI_HYDRA_BOOTSTRAP=sge # Fujitsu NQS (Network Queuing System) - elif [ -n "$ENVIRONMENT" -a -n "$QSUB_REQID" -a -n "$QSUB_NODEINF" ] ; then + elif [ -n "$ENVIRONMENT" ] && [ -n "$QSUB_REQID" ] && [ -n "$QSUB_NODEINF" ] ; then if [ -z "$I_MPI_HYDRA_HOST_FILE" ]; then export I_MPI_HYDRA_HOST_FILE=$QSUB_NODEINF fi @@ -65,12 +65,14 @@ if [ -z "$I_MPI_HYDRA_BOOTSTRAP" -a -z "$I_MPI_HYDRA_BOOTSTRAP_EXEC" ]; then export I_MPI_HYDRA_BOOTSTRAP=rsh export I_MPI_HYDRA_BOOTSTRAP_EXEC=/usr/bin/plesh fi + # TORQUE + elif [ -n "${PBS_NODEFILE}" ]; then + export I_MPI_HYDRA_BOOTSTRAP=pbs fi fi # Slurm, LoadLeveler, LSF, SGE -if [ -n "$LOADL_HOSTFILE" -o \ - -n "$PE_HOSTFILE" ]; then +if [ -n "$LOADL_HOSTFILE" ] || [ -n "$PE_HOSTFILE" ]; then # Create a host file if [ -n "$LOADL_HOSTFILE" ]; then mpiexec.hydra "$@" <&0 @@ -80,8 +82,8 @@ if [ -n "$LOADL_HOSTFILE" -o \ > $machinefile while read line; do if [ -n "$line" ]; then - host_name=`echo $line | sed -e "s/ .*//"` - num_of_processes=`expr match "$line" '.* \([0-9]\+\) .*'` + host_name=$(echo $line | sed -e "s/ .*//") + num_of_processes=$(expr match "$line" '.* \([0-9]\+\) .*') echo "$host_name:$num_of_processes" >> $machinefile np_boot=$(( np_boot + 1 )) fi @@ -96,8 +98,8 @@ elif [ "x$I_MPI_YARN" = "xyes" -o "x$I_MPI_YARN" = "xenable" -o "x$I_MPI_YARN" = rc=$? # Netbatch elif [ -n "$NB_PARALLEL_JOB_HOSTS" ]; then - hosts_opt=`echo $NB_PARALLEL_JOB_HOSTS | tr ' ' ','` - hosts_opt="-hosts `hostname -s`,$hosts_opt" + hosts_opt=$(echo $NB_PARALLEL_JOB_HOSTS | tr ' ' ',') + hosts_opt="-hosts $(hostname -s),$hosts_opt" mpiexec.hydra $hosts_opt "$@" <&0 rc=$? # PBS or ordinary job diff --git a/deps/mpi/include/mpi.h b/deps/mpi/include/mpi.h old mode 100644 new mode 100755 index 658e5a3a5..3dc48685b --- a/deps/mpi/include/mpi.h +++ b/deps/mpi/include/mpi.h @@ -580,8 +580,8 @@ typedef int (MPI_Delete_function) ( MPI_Comm, int, void *, void * ); * digits for REV, 1 digit for EXT and 2 digits for EXT_NUMBER. So, * 2019.0.0b0 will have the numeric version 20190000100. */ -#define I_MPI_VERSION "2021.3.0" -#define I_MPI_NUMVERSION 20210300300 +#define I_MPI_VERSION "2021.4.0" +#define I_MPI_NUMVERSION 20210400300 /* for the datatype decoders */ enum MPIR_Combiner_enum { diff --git a/deps/mpi/include/mpicxx.h b/deps/mpi/include/mpicxx.h old mode 100644 new mode 100755 diff --git a/deps/mpi/include/mpio.h b/deps/mpi/include/mpio.h old mode 100644 new mode 100755 diff --git a/deps/mpi/lib/libmpi.so b/deps/mpi/lib/libmpi.so index d7243ada7..84631e5a7 100755 Binary files a/deps/mpi/lib/libmpi.so and b/deps/mpi/lib/libmpi.so differ diff --git a/deps/mpi/lib/libmpi.so.12 b/deps/mpi/lib/libmpi.so.12 index d7243ada7..84631e5a7 100755 Binary files a/deps/mpi/lib/libmpi.so.12 and b/deps/mpi/lib/libmpi.so.12 differ diff --git a/deps/mpi/lib/libmpi.so.12.0 b/deps/mpi/lib/libmpi.so.12.0 index d7243ada7..84631e5a7 100755 Binary files a/deps/mpi/lib/libmpi.so.12.0 and b/deps/mpi/lib/libmpi.so.12.0 differ diff --git a/deps/mpi/lib/libmpi.so.12.0.0 b/deps/mpi/lib/libmpi.so.12.0.0 index d7243ada7..84631e5a7 100755 Binary files a/deps/mpi/lib/libmpi.so.12.0.0 and b/deps/mpi/lib/libmpi.so.12.0.0 differ diff --git a/deps/mpi/lib/libmpifort.so b/deps/mpi/lib/libmpifort.so index f67aaad45..00d80af4b 100755 Binary files a/deps/mpi/lib/libmpifort.so and b/deps/mpi/lib/libmpifort.so differ diff --git a/deps/mpi/lib/libmpifort.so.12 b/deps/mpi/lib/libmpifort.so.12 index f67aaad45..00d80af4b 100755 Binary files a/deps/mpi/lib/libmpifort.so.12 and b/deps/mpi/lib/libmpifort.so.12 differ diff --git a/deps/mpi/lib/libmpifort.so.12.0 b/deps/mpi/lib/libmpifort.so.12.0 index f67aaad45..00d80af4b 100755 Binary files a/deps/mpi/lib/libmpifort.so.12.0 and b/deps/mpi/lib/libmpifort.so.12.0 differ diff --git a/deps/mpi/lib/libmpifort.so.12.0.0 b/deps/mpi/lib/libmpifort.so.12.0.0 index f67aaad45..00d80af4b 100755 Binary files a/deps/mpi/lib/libmpifort.so.12.0.0 and b/deps/mpi/lib/libmpifort.so.12.0.0 differ diff --git a/deps/mpi/licensing/third-party-programs.txt b/deps/mpi/licensing/third-party-programs.txt index 307780de4..f85123769 100644 --- a/deps/mpi/licensing/third-party-programs.txt +++ b/deps/mpi/licensing/third-party-programs.txt @@ -1,4 +1,4 @@ -Intel(R) MPI Library 2021.3 Third Party Programs File +Intel(R) MPI Library 2021.4 Third Party Programs File This file is the "third-party-programs.txt" file specified in the associated Intel end user license agreement for the Intel software you are licensing. @@ -48,100 +48,83 @@ terms are listed below. 2. Open MPI - Copyright (c) 2004-2010 The Trustees of Indiana University and Indiana - University Research and Technology - Corporation. All rights reserved. - Copyright (c) 2004-2017 The University of Tennessee and The University - of Tennessee Research Foundation. All rights - reserved. - Copyright (c) 2004-2010 High Performance Computing Center Stuttgart, - University of Stuttgart. All rights reserved. - Copyright (c) 2004-2008 The Regents of the University of California. - All rights reserved. - Copyright (c) 2006-2017 Los Alamos National Security, LLC. All rights - reserved. - Copyright (c) 2006-2017 Cisco Systems, Inc. All rights reserved. - Copyright (c) 2006-2010 Voltaire, Inc. All rights reserved. - Copyright (c) 2006-2017 Sandia National Laboratories. All rights reserved. - Copyright (c) 2006-2010 Sun Microsystems, Inc. All rights reserved. - Use is subject to license terms. - Copyright (c) 2006-2017 The University of Houston. All rights reserved. - Copyright (c) 2006-2009 Myricom, Inc. All rights reserved. - Copyright (c) 2007-2017 UT-Battelle, LLC. All rights reserved. - Copyright (c) 2007-2017 IBM Corporation. All rights reserved. - Copyright (c) 1998-2005 Forschungszentrum Juelich, Juelich Supercomputing - Centre, Federal Republic of Germany - Copyright (c) 2005-2008 ZIH, TU Dresden, Federal Republic of Germany - Copyright (c) 2007 Evergrid, Inc. All rights reserved. - Copyright (c) 2008 Chelsio, Inc. All rights reserved. - Copyright (c) 2008-2009 Institut National de Recherche en - Informatique. All rights reserved. - Copyright (c) 2007 Lawrence Livermore National Security, LLC. - All rights reserved. - Copyright (c) 2007-2017 Mellanox Technologies. All rights reserved. - Copyright (c) 2006-2010 QLogic Corporation. All rights reserved. - Copyright (c) 2008-2017 Oak Ridge National Labs. All rights reserved. - Copyright (c) 2006-2012 Oracle and/or its affiliates. All rights reserved. - Copyright (c) 2009-2015 Bull SAS. All rights reserved. - Copyright (c) 2010 ARM ltd. All rights reserved. - Copyright (c) 2016 ARM, Inc. All rights reserved. - Copyright (c) 2010-2011 Alex Brick . All rights reserved. - Copyright (c) 2012 The University of Wisconsin-La Crosse. All rights - reserved. - Copyright (c) 2013-2016 Intel, Inc. All rights reserved. - Copyright (c) 2011-2017 NVIDIA Corporation. All rights reserved. - Copyright (c) 2016 Broadcom Limited. All rights reserved. - Copyright (c) 2011-2017 Fujitsu Limited. All rights reserved. - Copyright (c) 2014-2015 Hewlett-Packard Development Company, LP. All - rights reserved. - Copyright (c) 2013-2017 Research Organization for Information Science (RIST). - All rights reserved. - Copyright (c) 2017-2018 Amazon.com, Inc. or its affiliates. All Rights - reserved. - Copyright (c) 2018 DataDirect Networks. All rights reserved. - Copyright (c) 2018-2019 Triad National Security, LLC. All rights reserved. - - -3. hwloc - - Copyright (c) 2004-2005 The University of Tennessee and The University of - Tennessee Research Foundation. All rights reserved. - Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, - University of Stuttgart. All rights reserved. - Copyright (c) 2004-2005 The Regents of the University of California. - All rights reserved. - Copyright (c) 2009 CNRS - Copyright (c) 2009-2016 Inria. All rights reserved. - Copyright (c) 2009-2015 Université Bordeaux - Copyright (c) 2009-2015 Cisco Systems, Inc. All rights reserved. - Copyright (c) 2009-2012 Oracle and/or its affiliates. All rights reserved. - Copyright (c) 2010 IBM - Copyright (c) 2010 Jirka Hladky - Copyright (c) 2012 Aleksej Saushev, The NetBSD Foundation - Copyright (c) 2012 Blue Brain Project, EPFL. All rights reserved. - Copyright (c) 2013-2014 University of Wisconsin-La Crosse. All rights reserved. - Copyright (c) 2015 Research Organization for Information Science and - Technology (RIST). All rights reserved. - Copyright (c) 2015-2016 Intel, Inc. All rights reserved. - - - The -3-Clause BSD license + Copyright (c) 2004-2010 The Trustees of Indiana University and Indiana + University Research and Technology + Corporation. All rights reserved. + Copyright (c) 2004-2017 The University of Tennessee and The University + of Tennessee Research Foundation. All rights + reserved. + Copyright (c) 2004-2010 High Performance Computing Center Stuttgart, + University of Stuttgart. All rights reserved. + Copyright (c) 2004-2008 The Regents of the University of California. + All rights reserved. + Copyright (c) 2006-2018 Los Alamos National Security, LLC. All rights + reserved. + Copyright (c) 2006-2020 Cisco Systems, Inc. All rights reserved. + Copyright (c) 2006-2010 Voltaire, Inc. All rights reserved. + Copyright (c) 2006-2017 Sandia National Laboratories. All rights reserved. + Copyright (c) 2006-2010 Sun Microsystems, Inc. All rights reserved. + Use is subject to license terms. + Copyright (c) 2006-2017 The University of Houston. All rights reserved. + Copyright (c) 2006-2009 Myricom, Inc. All rights reserved. + Copyright (c) 2007-2017 UT-Battelle, LLC. All rights reserved. + Copyright (c) 2007-2020 IBM Corporation. All rights reserved. + Copyright (c) 1998-2005 Forschungszentrum Juelich, Juelich Supercomputing + Centre, Federal Republic of Germany + Copyright (c) 2005-2008 ZIH, TU Dresden, Federal Republic of Germany + Copyright (c) 2007 Evergrid, Inc. All rights reserved. + Copyright (c) 2008 Chelsio, Inc. All rights reserved. + Copyright (c) 2008-2009 Institut National de Recherche en + Informatique. All rights reserved. + Copyright (c) 2007 Lawrence Livermore National Security, LLC. + All rights reserved. + Copyright (c) 2007-2017 Mellanox Technologies. All rights reserved. + Copyright (c) 2006-2010 QLogic Corporation. All rights reserved. + Copyright (c) 2008-2017 Oak Ridge National Labs. All rights reserved. + Copyright (c) 2006-2012 Oracle and/or its affiliates. All rights reserved. + Copyright (c) 2009-2015 Bull SAS. All rights reserved. + Copyright (c) 2010 ARM ltd. All rights reserved. + Copyright (c) 2016 ARM, Inc. All rights reserved. + Copyright (c) 2010-2011 Alex Brick . All rights reserved. + Copyright (c) 2012 The University of Wisconsin-La Crosse. All rights + reserved. + Copyright (c) 2013-2020 Intel, Inc. All rights reserved. + Copyright (c) 2011-2017 NVIDIA Corporation. All rights reserved. + Copyright (c) 2016 Broadcom Limited. All rights reserved. + Copyright (c) 2011-2017 Fujitsu Limited. All rights reserved. + Copyright (c) 2014-2015 Hewlett-Packard Development Company, LP. All + rights reserved. + Copyright (c) 2013-2017 Research Organization for Information Science (RIST). + All rights reserved. + Copyright (c) 2017-2020 Amazon.com, Inc. or its affiliates. All Rights + reserved. + Copyright (c) 2018 DataDirect Networks. All rights reserved. + Copyright (c) 2018-2020 Triad National Security, LLC. All rights reserved. + Copyright (c) 2020 Google, LLC. All rights reserved. + Copyright (c) 2002 University of Chicago + Copyright (c) 2001 Argonne National Laboratory + + $COPYRIGHT$ + + Additional copyrights may follow + + $HEADER$ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. + notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer listed - in this license in the documentation and/or other materials - provided with the distribution. + notice, this list of conditions and the following disclaimer listed + in this license in the documentation and/or other materials + provided with the distribution. - Neither the name of the copyright holders nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. The copyright holders provide no reassurances that the source code provided does not infringe any patent, copyright, or any other @@ -161,6 +144,94 @@ terms are listed below. THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +----------------[Copyright from inclusion of MPICH code]---------------- + + The following is a notice of limited availability of the code, and disclaimer + which must be included in the prologue of the code and in all source listings + of the code. + + Copyright Notice + + 2002 University of Chicago + + Permission is hereby granted to use, reproduce, prepare derivative works, and + to redistribute to others. This software was authored by: + + Mathematics and Computer Science Division + Argonne National Laboratory, Argonne IL 60439 + + (and) + + Department of Computer Science + University of Illinois at Urbana-Champaign + + + GOVERNMENT LICENSE + + Portions of this material resulted from work developed under a U.S. + Government Contract and are subject to the following license: the Government + is granted for itself and others acting on its behalf a paid-up, nonexclusive, + irrevocable worldwide license in this computer software to reproduce, prepare + derivative works, and perform publicly and display publicly. + + DISCLAIMER + + This computer code material was prepared, in part, as an account of work + sponsored by an agency of the United States Government. Neither the United + States, nor the University of Chicago, nor any of their employees, makes any + warranty express or implied, or assumes any legal liability or responsibility + for the accuracy, completeness, or usefulness of any information, apparatus, + product, or process disclosed, or represents that its use would not infringe + privately owned rights. + +------------------------------------------------------------------------------- + +3. hwloc + + Copyright © 2004-2006 The Trustees of Indiana University and Indiana University Research and Technology Corporation. All rights reserved. + Copyright © 2004-2005 The University of Tennessee and The University of Tennessee Research Foundation. All rights reserved. + Copyright © 2004-2005 High Performance Computing Center Stuttgart, University of Stuttgart. All rights reserved. + Copyright © 2004-2005 The Regents of the University of California. All rights reserved. + Copyright © 2009 CNRS + Copyright © 2009-2016 Inria. All rights reserved. + Copyright © 2009-2015 Université Bordeaux + Copyright © 2009-2015 Cisco Systems, Inc. All rights reserved. + Copyright © 2009-2012 Oracle and/or its affiliates. All rights reserved. + Copyright © 2010 IBM + Copyright © 2010 Jirka Hladky + Copyright © 2012 Aleksej Saushev, The NetBSD Foundation + Copyright © 2012 Blue Brain Project, EPFL. All rights reserved. + Copyright © 2013-2014 University of Wisconsin-La Crosse. All rights reserved. + Copyright © 2015 Research Organization for Information Science and Technology (RIST). All rights reserved. + Copyright © 2015-2016 Intel, Inc. All rights reserved. + See COPYING in top-level directory. + + The 3-Clause BSD License + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + - The name of the author may not be used to endorse or promote products + derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. + IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, + INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT + NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ------------------------------------------------------------------------------- @@ -199,58 +270,86 @@ terms are listed below. ------------------------------------------------------------------------------- -5. Python - - Copyright © 2001-2020 Python Software Foundation. All rights reserved. - - Copyright © 2000 BeOpen.com. All rights reserved. - - Copyright © 1995-2000 Corporation for National Research Initiatives. All rights reserved. +5. Intel® Distribution for Python + + Intel Simplified Software License (Version February 2020) + + Use and Redistribution. You may use and redistribute the software (the + "Software"), without modification, provided the following conditions are met: + + * Redistributions must reproduce the above copyright notice and the following + terms of use in the Software and in the documentation and/or other materials + provided with the distribution. + * Neither the name of Intel nor the names of its suppliers may be used to + endorse or promote products derived from this Software without specific prior + written permission. + * No reverse engineering, decompilation, or disassembly of this Software is + permitted. + + Limited patent license. Intel grants you a world-wide, royalty-free, + non-exclusive license under patents it now or hereafter owns or controls to + make, have made, use, import, offer to sell and sell ("Utilize") this Software, + but solely to the extent that any such patent is necessary to Utilize the + Software alone. The patent license shall not apply to any combinations which + include this software. No hardware per se is licensed hereunder. + + Third party programs. The Software may contain Third Party Programs. "Third + Party Programs" are third party software, open source software or other Intel + software listed in the "third-party-programs.txt" or other similarly named text + file that is included with the Software. Third Party Programs, even if included + with the distribution of the Software, may be governed by separate license + terms, including without limitation, third party license terms, open source + software notices and terms, and/or other Intel software license terms. These + separate license terms may govern your use of the Third Party Programs. + + DISCLAIMER. THIS SOFTWARE IS PROVIDED "AS IS" AND ANY EXPRESS OR IMPLIED + WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE + DISCLAIMED. THIS SOFTWARE IS NOT INTENDED FOR USE IN SYSTEMS OR APPLICATIONS + WHERE FAILURE OF THE SOFTWARE MAY CAUSE PERSONAL INJURY OR DEATH AND YOU AGREE + THAT YOU ARE FULLY RESPONSIBLE FOR ANY CLAIMS, COSTS, DAMAGES, EXPENSES, AND + ATTORNEYS' FEES ARISING OUT OF ANY SUCH USE, EVEN IF ANY CLAIM ALLEGES THAT + INTEL WAS NEGLIGENT REGARDING THE DESIGN OR MANUFACTURE OF THE MATERIALS. + + LIMITATION OF LIABILITY. IN NO EVENT WILL INTEL BE LIABLE FOR ANY DIRECT, + INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE + OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. YOU AGREE TO INDEMNIFY AND HOLD + INTEL HARMLESS AGAINST ANY CLAIMS AND EXPENSES RESULTING FROM YOUR USE OR + UNAUTHORIZED USE OF THE SOFTWARE. + + No support. Intel may make changes to the Software, at any time without notice, + and is not obligated to support, update or provide training for the Software. + + Termination. Intel may terminate your right to use the Software in the event of + your breach of this Agreement and you fail to cure the breach within a + reasonable period of time. + + Feedback. Should you provide Intel with comments, modifications, corrections, + enhancements or other input ("Feedback") related to the Software Intel will be + free to use, disclose, reproduce, license or otherwise distribute or exploit the + Feedback in its sole discretion without any obligations or restrictions of any + kind, including without limitation, intellectual property rights or licensing + obligations. + + Compliance with laws. You agree to comply with all relevant laws and regulations + governing your use, transfer, import or export (or prohibition thereof) of the + Software. + + Governing law. All disputes will be governed by the laws of the United States of + America and the State of Delaware without reference to conflict of law + principles and subject to the exclusive jurisdiction of the state or federal + courts sitting in the State of Delaware, and each party agrees that it submits + to the personal jurisdiction and venue of those courts and waives any + objections. The United Nations Convention on Contracts for the International + Sale of Goods (1980) is specifically excluded and will not apply to the + Software. + + *Other names and brands may be claimed as the property of others. - Copyright © 1991-1995 Stichting Mathematisch Centrum. All rights reserved. - - PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 - - PSF LICENSE AGREEMENT FOR PYTHON - - 1. This LICENSE AGREEMENT is between the Python Software Foundation - ("PSF"), and the Individual or Organization ("Licensee") accessing and - otherwise using Python software in source or binary form and its - associated documentation. - 2. Subject to the terms and conditions of this License Agreement, PSF - hereby grants Licensee a nonexclusive, royalty-free, world-wide - license to reproduce, analyze, test, perform and/or display publicly, - prepare derivative works, distribute, and otherwise use Python - alone or in any derivative version, provided, however, that PSF's - License Agreement and PSF's notice of copyright, for example, "Copyright (c) - 2001, 2002, 2003, 2004 Python Software Foundation; All Rights Reserved" are - retained in Python alone or in any derivative version prepared by - Licensee. - 3. In the event Licensee prepares a derivative work that is based on - or incorporates Python or any part thereof, and wants to make - the derivative work available to others as provided herein, then - Licensee hereby agrees to include in any such work a brief summary of - the changes made to Python. - 4. PSF is making Python available to Licensee on an "AS IS" - basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR - IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND - DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS - FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT - INFRINGE ANY THIRD PARTY RIGHTS. - 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON - FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS - A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, - OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. - 6. This License Agreement will automatically terminate upon a material - breach of its terms and conditions. - 7. Nothing in this License Agreement shall be deemed to create any - relationship of agency, partnership, or joint venture between PSF and - Licensee. This License Agreement does not grant permission to use PSF - trademarks or trade name in a trademark sense to endorse or promote - products or services of Licensee, or any third party. - 8. By copying, installing or otherwise using Python, Licensee - agrees to be bound by the terms and conditions of this License - Agreement. ------------------------------------------------------------------------------- @@ -309,8 +408,6 @@ terms are listed below. ------------------------------------------------------------------------------- 8. zlib - - Copyright (c) <''year''> <''copyright holders''> Permissive software license This software is provided 'as-is', without any express or implied @@ -331,245 +428,94 @@ terms are listed below. ------------------------------------------------------------------------------- -9. Intel(R) MPI Benchmarks - - Common Public License Version 1.0 - - THE ACCOMPANYING PROGRAM IS PROVIDED UNDER THE TERMS OF THIS COMMON - PUBLIC LICENSE ("AGREEMENT"). ANY USE, REPRODUCTION OR DISTRIBUTION OF - THE PROGRAM CONSTITUTES RECIPIENT'S ACCEPTANCE OF THIS AGREEMENT. +9. Intel(R) MPI Benchmarks + Copyright (c) Intel Corporation. + + OpenUCX/UCX + Copyright (c) 2014-2015 UT-Battelle, LLC. All rights reserved. + Copyright (C) 2014-2020 Mellanox Technologies Ltd. All rights reserved. + Copyright (C) 2014-2015 The University of Houston System. All rights reserved. + Copyright (C) 2015 The University of Tennessee and The University + of Tennessee Research Foundation. All rights reserved. + Copyright (C) 2016-2020 ARM Ltd. All rights reserved. + Copyright (c) 2016 Los Alamos National Security, LLC. All rights reserved. + Copyright (C) 2016-2020 Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 UChicago Argonne, LLC. All rights reserved. + Copyright (c) 2018-2020 NVIDIA CORPORATION. All rights reserved. + Copyright (C) 2020 Huawei Technologies Co., Ltd. All rights reserved. + Copyright (C) 2016-2020 Stony Brook University. All rights reserved. + + Mellanox Hardware MultiCast library + Copyright (C) 2014-2020 Mellanox Technologies Ltd. All rights reserved. + Copyright (c) 2018-2020 NVIDIA CORPORATION. All rights reserved. + + BSD 3-Clause "New" or "Revised" License + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + + - Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - 1. DEFINITIONS +------------------------------------------------------------------------------- - "Contribution" means: +10. PMIx + Copyright (c) 2019, PMIx + All rights reserved. - a) in the case of the initial Contributor, the initial code and - documentation distributed under this Agreement, and + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: - b) in the case of each subsequent Contributor: - i) changes to the Program, and - ii) additions to the Program; where such changes and/or additions - to the Program originate from and are distributed by that - particular Contributor. A Contribution 'originates' from a - Contributor if it was added to the Program by such - Contributor itself or anyone acting on such Contributor's - behalf. Contributions do not include additions to the Program - which: (i) are separate modules of software distributed in - conjunction with the Program under their own license - agreement, and (ii) are not derivative works of the Program. + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. - "Contributor" means any person or entity that distributes the Program. + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. - "Licensed Patents " mean patent claims licensable by a Contributor - which are necessarily infringed by the use or sale of its Contribution - alone or when combined with the Program. + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. - "Program" means the Contributions distributed in accordance with this - Agreement. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - "Recipient" means anyone who receives the Program under this - Agreement, including all Contributors. - - 2. GRANT OF RIGHTS - - a) Subject to the terms of this Agreement, each Contributor hereby - grants Recipient a non-exclusive, worldwide, royalty-free - copyright license to reproduce, prepare derivative works of, - publicly display, publicly perform, distribute and sublicense the - Contribution of such Contributor, if any, and such derivative - works, in source code and object code form. - - b) Subject to the terms of this Agreement, each Contributor hereby - grants Recipient a non-exclusive, worldwide, royalty-free patent - license under Licensed Patents to make, use, sell, offer to sell, - import and otherwise transfer the Contribution of such - Contributor, if any, in source code and object code form. This - patent license shall apply to the combination of the Contribution - and the Program if, at the time the Contribution is added by the - Contributor, such addition of the Contribution causes such - combination to be covered by the Licensed Patents. The patent - license shall not apply to any other combinations which include - the Contribution. No hardware per se is licensed hereunder. - - c) Recipient understands that although each Contributor grants the - licenses to its Contributions set forth herein, no assurances are - provided by any Contributor that the Program does not infringe - the patent or other intellectual property rights of any other - entity. Each Contributor disclaims any liability to Recipient for - claims brought by any other entity based on infringement of - intellectual property rights or otherwise. As a condition to - exercising the rights and licenses granted hereunder, each - Recipient hereby assumes sole responsibility to secure any other - intellectual property rights needed, if any. For example, if a - third party patent license is required to allow Recipient to - distribute the Program, it is Recipient's responsibility to - acquire that license before distributing the Program. - - d) Each Contributor represents that to its knowledge it has - sufficient copyright rights in its Contribution, if any, to grant - the copyright license set forth in this Agreement. - - 3. REQUIREMENTS - - A Contributor may choose to distribute the Program in object code form - under its own license agreement, provided that: - - a) it complies with the terms and conditions of this Agreement; and - - b) its license agreement: - i) effectively disclaims on behalf of all Contributors all - warranties and conditions, express and implied, including - warranties or conditions of title and non-infringement, and - implied warranties or conditions of merchantability and - fitness for a particular purpose; - ii) effectively excludes on behalf of all Contributors all - liability for damages, including direct, indirect, special, - incidental and consequential damages, such as lost profits; - iii) states that any provisions which differ from this Agreement - are offered by that Contributor alone and not by any other - party; and - iv) states that source code for the Program is available from - such Contributor, and informs licensees how to obtain it in a - reasonable manner on or through a medium customarily used for - software exchange. - - When the Program is made available in source code form: - - a) it must be made available under this Agreement; and - - b) a copy of this Agreement must be included with each copy of the - Program. Contributors may not remove or alter any copyright - notices contained within the Program. - - Each Contributor must identify itself as the originator of its - Contribution, if any, in a manner that reasonably allows subsequent - Recipients to identify the originator of the Contribution. - - 4. COMMERCIAL DISTRIBUTION - - Commercial distributors of software may accept certain - responsibilities with respect to end users, business partners and the - like. While this license is intended to facilitate the commercial use - of the Program, the Contributor who includes the Program in a - commercial product offering should do so in a manner which does not - create potential liability for other Contributors. Therefore, if a - Contributor includes the Program in a commercial product offering, - such Contributor ("Commercial Contributor") hereby agrees to defend - and indemnify every other Contributor ("Indemnified Contributor") - against any losses, damages and costs (collectively "Losses") arising - from claims, lawsuits and other legal actions brought by a third party - against the Indemnified Contributor to the extent caused by the acts - or omissions of such Commercial Contributor in connection with its - distribution of the Program in a commercial product offering. The - obligations in this section do not apply to any claims or Losses - relating to any actual or alleged intellectual property - infringement. In order to qualify, an Indemnified Contributor must: - - a) promptly notify the Commercial Contributor in writing of such - claim, and - - b) allow the Commercial Contributor to control, and cooperate with - the Commercial Contributor in, the defense and any related - settlement negotiations. The Indemnified Contributor may - participate in any such claim at its own expense. - - For example, a Contributor might include the Program in a commercial - product offering, Product X. That Contributor is then a Commercial - Contributor. If that Commercial Contributor then makes performance - claims, or offers warranties related to Product X, those performance - claims and warranties are such Commercial Contributor's responsibility - alone. Under this section, the Commercial Contributor would have to - defend claims against the other Contributors related to those - performance claims and warranties, and if a court requires any other - Contributor to pay any damages as a result, the Commercial Contributor - must pay those damages. - - 5. NO WARRANTY - - EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, THE PROGRAM IS - PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, EITHER EXPRESS OR IMPLIED INCLUDING, WITHOUT LIMITATION, ANY - WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY - OR FITNESS FOR A PARTICULAR PURPOSE. Each Recipient is solely - responsible for determining the appropriateness of using and - distributing the Program and assumes all risks associated with its - exercise of rights under this Agreement, including but not limited to - the risks and costs of program errors, compliance with applicable - laws, damage to or loss of data, programs or equipment, and - unavailability or interruption of operations. - - 6. DISCLAIMER OF LIABILITY - - EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, NEITHER RECIPIENT NOR - ANY CONTRIBUTORS SHALL HAVE ANY LIABILITY FOR ANY DIRECT, INDIRECT, - INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING - WITHOUT LIMITATION LOST PROFITS), HOWEVER CAUSED AND ON ANY THEORY OF - LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OR - DISTRIBUTION OF THE PROGRAM OR THE EXERCISE OF ANY RIGHTS GRANTED - HEREUNDER, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. - - 7. GENERAL - - If any provision of this Agreement is invalid or unenforceable under - applicable law, it shall not affect the validity or enforceability of - the remainder of the terms of this Agreement, and without further - action by the parties hereto, such provision shall be reformed to the - minimum extent necessary to make such provision valid and - enforceable. - - If Recipient institutes patent litigation against a Contributor with - respect to a patent applicable to software (including a cross-claim or - counterclaim in a lawsuit), then any patent licenses granted by that - Contributor to such Recipient under this Agreement shall terminate as - of the date such litigation is filed. In addition, if Recipient - institutes patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Program - itself (excluding combinations of the Program with other software or - hardware) infringes such Recipient's patent(s), then such Recipient's - rights granted under Section 2(b) shall terminate as of the date such - litigation is filed. - - All Recipient's rights under this Agreement shall terminate if it - fails to comply with any of the material terms or conditions of this - Agreement and does not cure such failure in a reasonable period of - time after becoming aware of such noncompliance. If all Recipient's - rights under this Agreement terminate, Recipient agrees to cease use - and distribution of the Program as soon as reasonably - practicable. However, Recipient's obligations under this Agreement and - any licenses granted by Recipient relating to the Program shall - continue and survive. - - Everyone is permitted to copy and distribute copies of this Agreement, - but in order to avoid inconsistency the Agreement is copyrighted and - may only be modified in the following manner. The Agreement Steward - reserves the right to publish new versions (including revisions) of - this Agreement from time to time. No one other than the Agreement - Steward has the right to modify this Agreement. IBM is the initial - Agreement Steward. IBM may assign the responsibility to serve as the - Agreement Steward to a suitable separate entity. Each new version of - the Agreement will be given a distinguishing version number. The - Program (including Contributions) may always be distributed subject to - the version of the Agreement under which it was received. In addition, - after a new version of the Agreement is published, Contributor may - elect to distribute the Program (including its Contributions) under - the new version. Except as expressly stated in Sections 2(a) and 2(b) - above, Recipient receives no rights or licenses to the intellectual - property of any Contributor under this Agreement, whether expressly, - by implication, estoppel or otherwise. All rights in the Program not - expressly granted under this Agreement are reserved. - - This Agreement is governed by the laws of the State of New York and - the intellectual property laws of the United States of America. No - party to this Agreement will bring a legal action under this Agreement - more than one year after the cause of action arose. Each party waives - its rights to a jury trial in any resulting litigation. - ------------------------------------------------------------------------------- - + The following third party programs have their own third party programs. These additional third party program files are as follows: - 1. Intel(R) MPI Benchmarks /imb/license/third-party-programs.txt + 1. Intel(R) MPI Benchmarks https://raw.githubusercontent.com/intel/mpi-benchmarks/master/license/third-party-programs.txt + 2. Intel(R) Distribution for Python: third-party-programs-python.txt file ------------------------------------------------------------------------------- diff --git a/deps/ofi/bin/fi_info b/deps/ofi/bin/fi_info index 347648e8a..b4df1a8e6 100755 Binary files a/deps/ofi/bin/fi_info and b/deps/ofi/bin/fi_info differ diff --git a/deps/ofi/include/rdma/fabric.h b/deps/ofi/include/rdma/fabric.h index 71628035e..cdfa11e8d 100644 --- a/deps/ofi/include/rdma/fabric.h +++ b/deps/ofi/include/rdma/fabric.h @@ -79,8 +79,8 @@ extern "C" { #endif #define FI_MAJOR_VERSION 1 -#define FI_MINOR_VERSION 12 -#define FI_REVISION_VERSION 1 +#define FI_MINOR_VERSION 13 +#define FI_REVISION_VERSION 0 enum { FI_PATH_MAX = 256, @@ -166,6 +166,7 @@ typedef struct fid *fid_t; #define FI_COMMIT_COMPLETE (1ULL << 30) #define FI_MATCH_COMPLETE (1ULL << 31) +#define FI_HMEM_DEVICE_ONLY (1ULL << 46) #define FI_HMEM (1ULL << 47) #define FI_VARIABLE_MSG (1ULL << 48) #define FI_RMA_PMEM (1ULL << 49) @@ -519,6 +520,8 @@ enum { FI_CLASS_MC, FI_CLASS_NIC, FI_CLASS_AV_SET, + FI_CLASS_MR_CACHE, + FI_CLASS_MEM_MONITOR, }; struct fi_eq_attr; @@ -577,7 +580,10 @@ struct fid_fabric { uint32_t api_version; }; -int fi_fabric(struct fi_fabric_attr *attr, struct fid_fabric **fabric, void *context); +int fi_fabric(struct fi_fabric_attr *attr, struct fid_fabric **fabric, + void *context); +int fi_open(uint32_t version, const char *name, void *attr, size_t attr_len, + uint64_t flags, struct fid **fid, void *context); struct fid_nic { struct fid fid; @@ -641,6 +647,7 @@ enum { FI_GETWAITOBJ, /*enum fi_wait_obj * */ FI_GET_VAL, /* struct fi_fid_var */ FI_SET_VAL, /* struct fi_fid_var */ + FI_EXPORT_FID, /* struct fi_fid_export */ }; static inline int fi_control(struct fid *fid, int command, void *arg) diff --git a/deps/ofi/include/rdma/fi_domain.h b/deps/ofi/include/rdma/fi_domain.h index 27d6dd398..b5399ba00 100644 --- a/deps/ofi/include/rdma/fi_domain.h +++ b/deps/ofi/include/rdma/fi_domain.h @@ -185,6 +185,12 @@ enum fi_datatype { FI_LONG_DOUBLE_COMPLEX, /* End of point to point atomic datatypes */ FI_DATATYPE_LAST, + /* + * enums for 128-bit integer atomics, existing ordering and + * FI_DATATYPE_LAST preserved for compatabilty. + */ + FI_INT128 = FI_DATATYPE_LAST, + FI_UINT128, /* Collective datatypes */ FI_VOID = FI_COLLECTIVE_OFFSET, diff --git a/deps/ofi/lib/libfabric.so b/deps/ofi/lib/libfabric.so index 35c21dfc3..da151da5d 100755 Binary files a/deps/ofi/lib/libfabric.so and b/deps/ofi/lib/libfabric.so differ diff --git a/deps/ofi/lib/libfabric.so.1 b/deps/ofi/lib/libfabric.so.1 index 35c21dfc3..da151da5d 100755 Binary files a/deps/ofi/lib/libfabric.so.1 and b/deps/ofi/lib/libfabric.so.1 differ diff --git a/deps/ofi/lib/prov/libpsm3-fi.so b/deps/ofi/lib/prov/libpsm3-fi.so index 47830166d..fab9b8d4e 100755 Binary files a/deps/ofi/lib/prov/libpsm3-fi.so and b/deps/ofi/lib/prov/libpsm3-fi.so differ diff --git a/deps/ofi/lib/prov/libpsmx2-fi.so b/deps/ofi/lib/prov/libpsmx2-fi.so index 375463c58..28235ef3a 100755 Binary files a/deps/ofi/lib/prov/libpsmx2-fi.so and b/deps/ofi/lib/prov/libpsmx2-fi.so differ diff --git a/deps/ofi/lib/prov/librxm-fi.so b/deps/ofi/lib/prov/librxm-fi.so index 83af28e2e..99a542183 100755 Binary files a/deps/ofi/lib/prov/librxm-fi.so and b/deps/ofi/lib/prov/librxm-fi.so differ diff --git a/deps/ofi/lib/prov/libshm-fi.so b/deps/ofi/lib/prov/libshm-fi.so index dfce33131..73ec980df 100755 Binary files a/deps/ofi/lib/prov/libshm-fi.so and b/deps/ofi/lib/prov/libshm-fi.so differ diff --git a/deps/ofi/lib/prov/libsockets-fi.so b/deps/ofi/lib/prov/libsockets-fi.so index b164233e5..83d743b77 100755 Binary files a/deps/ofi/lib/prov/libsockets-fi.so and b/deps/ofi/lib/prov/libsockets-fi.so differ diff --git a/deps/ofi/lib/prov/libtcp-fi.so b/deps/ofi/lib/prov/libtcp-fi.so index 10f430bc1..89b2c7a01 100755 Binary files a/deps/ofi/lib/prov/libtcp-fi.so and b/deps/ofi/lib/prov/libtcp-fi.so differ diff --git a/deps/ofi/lib/prov/libverbs-fi.so b/deps/ofi/lib/prov/libverbs-fi.so index 2a895fbd5..91c41bce2 100755 Binary files a/deps/ofi/lib/prov/libverbs-fi.so and b/deps/ofi/lib/prov/libverbs-fi.so differ diff --git a/doc/rst/source/_templates/layout.html b/doc/rst/source/_templates/layout.html index 7488ec4aa..a3774b953 100644 --- a/doc/rst/source/_templates/layout.html +++ b/doc/rst/source/_templates/layout.html @@ -1,4 +1,16 @@ {% extends "!layout.html" %} {% block extrahead %} + + {% endblock %} \ No newline at end of file diff --git a/doc/rst/source/advanced-configuration/dmabuf.rst b/doc/rst/source/advanced-configuration/dmabuf.rst new file mode 100644 index 000000000..4201d2704 --- /dev/null +++ b/doc/rst/source/advanced-configuration/dmabuf.rst @@ -0,0 +1,64 @@ +.. _`here`: https://github.com/ofiwg/libfabric/releases/tag/v1.13.1 +.. _`documentation`: https://one-api.gitlab-pages.devtools.intel.com/level_zero/core/PROG.html#affinity-mask + +===================================== +Enabling OFI/verbs dmabuf support +===================================== + +|product_short| provides experimental support for device memory transfers using Linux dmabuf, +which is exposed through OFI API for verbs provider. + + +Requirements +############ + +- Linux kernel version >= 5.12 +- RDMA core version >= 34.0 +- level-zero-devel package + + +Limitations +########### + +- Only first tile should be used from each GPU card. + For example, if GPU with 2 tiles is used then set ZE_AFFINITY_MASK=0.0. + More information about GPU selection can be found in level-zero `documentation`_. + + +Build instructions +################## + +OFI +*** + +:: + + git clone --single-branch --branch v1.13.1 https://github.com/ofiwg/libfabric.git + cd libfabric + ./autogen.sh + ./configure --prefix= --enable-verbs= --enable-ze-dlopen=yes + make -j install + +.. note:: + You may also get OFI release package directly from `here`_. + No need to run autogen.sh if using the release package. + +|product_short| +*************** + +:: + + cmake -DCMAKE_INSTALL_PREFIX= -DLIBFABRIC_DIR= -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=dpcpp -DCOMPUTE_BACKEND=dpcpp_level_zero -DENABLE_OFI_HMEM=1 .. + make -j install + + +Run instructions +################ + +Run allreduce test with ring algorithm and SYCL USM device buffers. + +:: + + source /env/setvars.sh + export LD_LIBRARY_PATH=/lib:${LD_LIBRARY_PATH} + CCL_ATL_TRANSPORT=ofi CCL_ATL_HMEM=1 CCL_ALLREDUCE=ring FI_PROVIDER=verbs mpiexec -n 2 /examples/sycl/sycl_allreduce_usm_test gpu device diff --git a/doc/rst/source/conf.py b/doc/rst/source/conf.py index 4c0fbd89f..74157a343 100755 --- a/doc/rst/source/conf.py +++ b/doc/rst/source/conf.py @@ -115,5 +115,6 @@ 'path_to_docs': 'doc/source', 'use_issues_button': True, 'use_edit_page_button': True, - 'repository_branch': 'master' + 'repository_branch': 'master', + 'extra_footer': '

Cookies

' } diff --git a/doc/rst/source/env-variables.rst b/doc/rst/source/env-variables.rst index 988554c9f..393839863 100644 --- a/doc/rst/source/env-variables.rst +++ b/doc/rst/source/env-variables.rst @@ -369,8 +369,11 @@ CCL_FUSION_CYCLE_MS Set this environment variable to specify the frequency of checking for collectives operations to be fused. +ATL +### + CCL_ATL_TRANSPORT -################# +***************** **Syntax** :: @@ -393,7 +396,36 @@ CCL_ATL_TRANSPORT **Description** -Set this environment variable to select the transport for inter-node communications. +Set this environment variable to select the transport for inter-process communications. + + +CCL_ATL_HMEM +************ +**Syntax** + +:: + + CCL_ATL_HMEM= + +**Arguments** + +.. list-table:: + :widths: 25 50 + :header-rows: 1 + :align: left + + * - + - Description + * - ``1`` + - Enable heterogeneous memory support on the transport layer. + * - ``0`` + - Disable heterogeneous memory support on the transport layer (**default**). + +**Description** + +Set this environment variable to enable handling of HMEM/GPU buffers by the transport layer. +The actual HMEM support depends on the limitations on the transport level and system configuration. + CCL_UNORDERED_COLL ################## @@ -483,7 +515,7 @@ CCL_WORKER_AFFINITY :: - CCL_WORKER_AFFINITY= + CCL_WORKER_AFFINITY= **Arguments** @@ -492,21 +524,54 @@ CCL_WORKER_AFFINITY :header-rows: 1 :align: left - * - + * - - Description * - ``auto`` - Workers are automatically pinned to last cores of pin domain. Pin domain depends from process launcher. If ``mpirun`` from |product_short| package is used then pin domain is MPI process pin domain. Otherwise, pin domain is all cores on the node. - * - ``n1,n2,..`` - - Affinity is explicitly specified for all local workers. + * - ```` + - A comma-separated list of core numbers and/or ranges of core numbers for all local workers, one number per worker. + The i-th local worker is pinned to the i-th core in the list. + For example ,- defines list of cores contaning core with number + and range of cores with numbers from to . + The number should not exceed the number of cores available on the system. **Description** Set this environment variable to specify cpu affinity for |product_short| worker threads. +CCL_WORKER_MEM_AFFINITY +####################### +**Syntax** + +:: + + CCL_WORKER_MEM_AFFINITY= + +**Arguments** + +.. list-table:: + :widths: 25 50 + :header-rows: 1 + :align: left + + * - + - Description + * - ``auto`` + - Workers are automatically pinned to NUMA nodes that correspond to CPU affinity of workers. + * - ```` + - A comma-separated list of NUMA node numbers for all local workers, one number per worker. + The i-th local worker is pinned to the i-th NUMA node in the list. + The number should not exceed the number of NUMA nodes available on the system. + +**Description** + +Set this environment variable to specify memory affinity for |product_short| worker threads. + + CCL_LOG_LEVEL ############# **Syntax** @@ -558,8 +623,16 @@ CCL_MAX_SHORT_SIZE Set this environment variable to specify the threshold of the number of bytes for a collective operation to be split. +Multi-NIC +######### + + +CCL_MNIC, CCL_MNIC_NAME and CCL_MNIC_COUNT define filters to select multiple NICs. +|product_short| workers will be pinned on selected NICs in a round-robin way. + + CCL_MNIC -######## +******** **Syntax** :: @@ -584,12 +657,39 @@ CCL_MNIC **Description** -Set this environment variable to control multi-NIC selection policy. -|product_short| workers will be pinned on selected NICs in a round-robin way. +Set this environment variable to control multi-NIC selection by NIC locality. + + +CCL_MNIC_NAME +************* +**Syntax** + +:: + + CCL_MNIC_NAME= + +**Arguments** + +.. list-table:: + :widths: 25 50 + :header-rows: 1 + :align: left + + * - + - Description + * - ```` + - A comma-separated list of NIC full names or prefixes to filter NICs. + Use the ``^`` symbol to exclude NICs starting with the specified prefixes. For example, + if you provide a list ``mlx5_0,mlx5_1,^mlx5_2``, NICs with the names ``mlx5_0`` and ``mlx5_1`` + will be selected, while ``mlx5_2`` will be excluded from the selection. + +**Description** + +Set this environment variable to control multi-NIC selection by NIC names. CCL_MNIC_COUNT -############## +************** **Syntax** :: diff --git a/doc/rst/source/introduction/installation.rst b/doc/rst/source/introduction/installation.rst index b74aa5266..a3d905dd7 100644 --- a/doc/rst/source/introduction/installation.rst +++ b/doc/rst/source/introduction/installation.rst @@ -78,7 +78,7 @@ You can customize CLI-based installation (for example, specify directory, compil :: - cmake .. -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=dpcpp -DCOMPUTE_BACKEND=dpcpp + cmake .. -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=dpcpp -DCOMPUTE_BACKEND=dpcpp_level_zero * To specify the **build type**, modify the ``cmake`` command: diff --git a/doc/rst/source/programming-model/limitations.rst b/doc/rst/source/programming-model/limitations.rst index fdcc67132..53ebf3d94 100644 --- a/doc/rst/source/programming-model/limitations.rst +++ b/doc/rst/source/programming-model/limitations.rst @@ -5,4 +5,3 @@ Limitations The list of scenarios not yet supported by oneCCL: - Creation of multiple ranks within single process -- Handling of dependencies as operation parameter (for example, ``deps`` vector in ``ccl::allreduce(..., deps)``) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 499d34283..a5186efaf 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -98,7 +98,7 @@ endif() include_directories(include) add_subdirectory(cpu) -if (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") +if (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang" OR ${CMAKE_CXX_COMPILER_ID} STREQUAL "IntelLLVM") add_subdirectory(sycl) endif() add_subdirectory(common) diff --git a/examples/benchmark/CMakeLists.txt b/examples/benchmark/CMakeLists.txt index 41879eb9e..2a6c4199d 100644 --- a/examples/benchmark/CMakeLists.txt +++ b/examples/benchmark/CMakeLists.txt @@ -26,6 +26,8 @@ include_directories(src) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") find_package(NUMA) +link_directories(${EXAMPLES_LIB_DIRS}) + foreach(src ${sources}) get_filename_component(executable ${src} NAME_WE) add_executable(${executable} ${src}) diff --git a/examples/benchmark/include/benchmark.hpp b/examples/benchmark/include/benchmark.hpp index d7c624d5c..aafa9684a 100644 --- a/examples/benchmark/include/benchmark.hpp +++ b/examples/benchmark/include/benchmark.hpp @@ -37,7 +37,7 @@ #include using namespace cl::sycl; using namespace cl::sycl::access; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL #include "base.hpp" #include "base_utils.hpp" @@ -59,18 +59,18 @@ void print_help_usage(const char* app) { "\t[-f,--min_elem_count ]: %d\n" "\t[-t,--max_elem_count ]: %d\n" "\t[-y,--elem_counts ]: [%d-%d]\n" - "\t[-c,--check ]: %d\n" + "\t[-c,--check ]: %s\n" "\t[-p,--cache ]: %d\n" "\t[-q,--inplace ]: %d\n" "\t[-k,--ranks_per_proc ]: %d\n" #ifdef CCL_ENABLE_NUMA "\t[-s,--numa_node ]: %s\n" -#endif /* CCL_ENABLE_NUMA */ +#endif // CCL_ENABLE_NUMA #ifdef CCL_ENABLE_SYCL "\t[-a,--sycl_dev_type ]: %s\n" "\t[-m,--sycl_mem_type ]: %s\n" "\t[-u,--sycl_usm_type ]: %s\n" -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL "\t[-l,--coll ]: %s\n" "\t[-d,--dtype ]: %s\n" "\t[-r,--reduction ]: %s\n" @@ -89,18 +89,18 @@ void print_help_usage(const char* app) { DEFAULT_MAX_ELEM_COUNT, DEFAULT_MIN_ELEM_COUNT, DEFAULT_MAX_ELEM_COUNT, - DEFAULT_CHECK_VALUES, + check_values_names[DEFAULT_CHECK_VALUES].c_str(), DEFAULT_CACHE_OPS, DEFAULT_INPLACE, DEFAULT_RANKS_PER_PROC, #ifdef CCL_ENABLE_NUMA DEFAULT_NUMA_NODE_STR, -#endif /* CCL_ENABLE_NUMA */ +#endif // CCL_ENABLE_NUMA #ifdef CCL_ENABLE_SYCL sycl_dev_names[DEFAULT_SYCL_DEV_TYPE].c_str(), sycl_mem_names[DEFAULT_SYCL_MEM_TYPE].c_str(), sycl_usm_names[DEFAULT_SYCL_USM_TYPE].c_str(), -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL DEFAULT_COLL_LIST, DEFAULT_DTYPES_LIST, DEFAULT_REDUCTIONS_LIST, @@ -127,6 +127,13 @@ bool find_key_val(ccl::reduction& key, Container& mp, const Dtype& val) { return false; } +bool is_check_values_enabled(check_values_t check_values) { + bool ret = false; + if (check_values == CHECK_LAST_ITER || check_values == CHECK_ALL_ITERS) + return true; + return ret; +} + int check_supported_options(const std::string& option_name, const std::string& option_value, const std::set& supported_option_values) { @@ -192,6 +199,29 @@ int set_iter_policy(const std::string& option_value, iter_policy_t& policy) { return 0; } +int set_check_values(const std::string& option_value, check_values_t& check) { + std::string option_name = "check"; + + std::set supported_option_values{ check_values_names[CHECK_OFF], + check_values_names[CHECK_LAST_ITER], + check_values_names[CHECK_ALL_ITERS] }; + + if (check_supported_options(option_name, option_value, supported_option_values)) + return -1; + + if (option_value == check_values_names[CHECK_OFF]) { + check = CHECK_OFF; + } + else if (option_value == check_values_names[CHECK_LAST_ITER]) { + check = CHECK_LAST_ITER; + } + else if (option_value == check_values_names[CHECK_ALL_ITERS]) { + check = CHECK_ALL_ITERS; + } + + return 0; +} + #ifdef CCL_ENABLE_SYCL int set_sycl_dev_type(const std::string& option_value, sycl_dev_type_t& dev) { std::string option_name = "sycl_dev_type"; @@ -237,12 +267,14 @@ int set_sycl_usm_type(const std::string& option_value, sycl_usm_type_t& usm) { return 0; } -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL -int set_datatypes(std::string option_value, int check_values, std::list& datatypes) { +int set_datatypes(std::string option_value, + check_values_t check_values, + std::list& datatypes) { datatypes.clear(); if (option_value == "all") { - if (check_values) { + if (is_check_values_enabled(check_values)) { datatypes = tokenize(ALL_DTYPES_LIST_WITH_CHECK, ','); } else { @@ -257,7 +289,7 @@ int set_datatypes(std::string option_value, int check_values, std::list& reductions) { +int set_reductions(std::string option_value, + check_values_t check_values, + std::list& reductions) { reductions.clear(); if (option_value == "all") { - if (check_values) { + if (is_check_values_enabled(check_values)) { reductions = tokenize(ALL_REDUCTIONS_LIST_WITH_CHECK, ','); } else { @@ -292,14 +326,15 @@ int set_reductions(std::string option_value, int check_values, std::list supported_option_values; for (auto p : reduction_names) { - if ((p.first != ccl::reduction::sum) && check_values) + if ((p.first != ccl::reduction::sum) && is_check_values_enabled(check_values)) continue; supported_option_values.insert(p.second); } for (auto r : reductions) { if (check_supported_options(option_name, r, supported_option_values)) { - if ((r != reduction_names[ccl::reduction::sum]) && check_values) { + if ((r != reduction_names[ccl::reduction::sum]) && + is_check_values_enabled(check_values)) { PRINT("WARN: correctness checking is not implemented for '%s'", r.c_str()); } } @@ -533,12 +568,12 @@ int parse_user_options(int& argc, char**(&argv), user_options_t& options) { #ifdef CCL_ENABLE_NUMA const char* numa_options = "s:"; memcpy(short_options + strlen(short_options), numa_options, strlen(numa_options)); -#endif /* CCL_ENABLE_NUMA */ +#endif // CCL_ENABLE_NUMA #ifdef CCL_ENABLE_SYCL const char* sycl_options = "a:m:u:"; memcpy(short_options + strlen(short_options), sycl_options, strlen(sycl_options)); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL struct option getopt_options[] = { { "backend", required_argument, nullptr, 'b' }, @@ -556,12 +591,12 @@ int parse_user_options(int& argc, char**(&argv), user_options_t& options) { { "ranks_per_proc", required_argument, nullptr, 'k' }, #ifdef CCL_ENABLE_NUMA { "numa_node", required_argument, nullptr, 's' }, -#endif /* CCL_ENABLE_NUMA */ +#endif // CCL_ENABLE_NUMA #ifdef CCL_ENABLE_SYCL { "sycl_dev_type", required_argument, nullptr, 'a' }, { "sycl_mem_type", required_argument, nullptr, 'm' }, { "sycl_usm_type", required_argument, nullptr, 'u' }, -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL { "coll", required_argument, nullptr, 'l' }, { "dtype", required_argument, nullptr, 'd' }, { "reduction", required_argument, nullptr, 'r' }, @@ -639,7 +674,12 @@ int parse_user_options(int& argc, char**(&argv), user_options_t& options) { else errors++; break; - case 'c': options.check_values = atoi(optarg); break; + case 'c': + if (set_check_values(optarg, options.check_values)) { + PRINT("failed to parse 'check' option"); + errors++; + } + break; case 'p': options.cache_ops = atoi(optarg); break; case 'q': options.inplace = atoi(optarg); break; case 'k': @@ -675,7 +715,7 @@ int parse_user_options(int& argc, char**(&argv), user_options_t& options) { errors++; } break; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL case 'l': if (strcmp("all", optarg) == 0) { options.coll_names = tokenize(ALL_COLLS_LIST, ','); @@ -786,6 +826,7 @@ void print_user_options(const user_options_t& options, const ccl::communicator& std::string backend_str = find_str_val(backend_names, options.backend); std::string loop_str = find_str_val(loop_names, options.loop); std::string iter_policy_str = find_str_val(iter_policy_names, options.iter_policy); + std::string check_values_str = find_str_val(check_values_names, options.check_values); #ifdef CCL_ENABLE_SYCL std::string sycl_dev_type_str = find_str_val(sycl_dev_names, options.sycl_dev_type); @@ -805,18 +846,18 @@ void print_user_options(const user_options_t& options, const ccl::communicator& "\n min_elem_count: %zu" "\n max_elem_count: %zu" "\n elem_counts: %s" - "\n check: %d" + "\n check: %s" "\n cache: %d" "\n inplace: %d" "\n ranks_per_proc: %zu" #ifdef CCL_ENABLE_NUMA "\n numa_node: %s" -#endif /* CCL_ENABLE_NUMA */ +#endif // CCL_ENABLE_NUMA #ifdef CCL_ENABLE_SYCL "\n sycl_dev_type: %s" "\n sycl_mem_type: %s" "\n sycl_usm_type: %s" -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL "\n collectives: %s" "\n datatypes: %s" "\n reductions: %s" @@ -831,7 +872,7 @@ void print_user_options(const user_options_t& options, const ccl::communicator& options.min_elem_count, options.max_elem_count, elem_counts_str.c_str(), - options.check_values, + check_values_str.c_str(), options.cache_ops, options.inplace, options.ranks_per_proc, @@ -839,12 +880,12 @@ void print_user_options(const user_options_t& options, const ccl::communicator& (options.numa_node == DEFAULT_NUMA_NODE) ? DEFAULT_NUMA_NODE_STR : std::to_string(options.numa_node).c_str(), -#endif /* CCL_ENABLE_NUMA */ +#endif // CCL_ENABLE_NUMA #ifdef CCL_ENABLE_SYCL sycl_dev_type_str.c_str(), sycl_mem_type_str.c_str(), sycl_usm_type_str.c_str(), -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL collectives_str.c_str(), datatypes_str.c_str(), reductions_str.c_str(), diff --git a/examples/benchmark/include/config.hpp b/examples/benchmark/include/config.hpp index 78794fa8a..fbd981fa7 100644 --- a/examples/benchmark/include/config.hpp +++ b/examples/benchmark/include/config.hpp @@ -32,9 +32,9 @@ #ifdef CCL_ENABLE_SYCL #define DEFAULT_BACKEND BACKEND_SYCL -#else /* CCL_ENABLE_SYCL */ +#else // CCL_ENABLE_SYCL #define DEFAULT_BACKEND BACKEND_HOST -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL #define DEFAULT_LOOP LOOP_REGULAR #define DEFAULT_ITERS (16) #define DEFAULT_WARMUP_ITERS (16) @@ -42,7 +42,7 @@ #define DEFAULT_BUF_COUNT (1) #define DEFAULT_MIN_ELEM_COUNT (1) #define DEFAULT_MAX_ELEM_COUNT (128) -#define DEFAULT_CHECK_VALUES (0) +#define DEFAULT_CHECK_VALUES CHECK_LAST_ITER #define DEFAULT_CACHE_OPS (1) #define DEFAULT_INPLACE (0) #define DEFAULT_RANKS_PER_PROC (1) diff --git a/examples/benchmark/include/cpu_coll.hpp b/examples/benchmark/include/cpu_coll.hpp index 176ee958e..4287bab01 100644 --- a/examples/benchmark/include/cpu_coll.hpp +++ b/examples/benchmark/include/cpu_coll.hpp @@ -17,7 +17,7 @@ #ifdef CCL_ENABLE_NUMA #include -#endif /* CCL_ENABLE_NUMA */ +#endif // CCL_ENABLE_NUMA #include "coll.hpp" @@ -139,7 +139,7 @@ struct cpu_base_coll : base_coll, protected strategy { ptr, "failed to allocate buffer with size %zu on NUMA node %d", bytes, numa_node); } else -#endif /* CCL_ENABLE_NUMA */ +#endif // CCL_ENABLE_NUMA { size_t alignment = REG_MSG_ALIGNMENT; if (bytes >= LARGE_MSG_THRESHOLD) @@ -160,7 +160,7 @@ struct cpu_base_coll : base_coll, protected strategy { numa_free(ptr, bytes); } else -#endif /* CCL_ENABLE_NUMA */ +#endif // CCL_ENABLE_NUMA { free(ptr); } diff --git a/examples/benchmark/include/sycl_coll.hpp b/examples/benchmark/include/sycl_coll.hpp index 064a333eb..a605af700 100644 --- a/examples/benchmark/include/sycl_coll.hpp +++ b/examples/benchmark/include/sycl_coll.hpp @@ -207,4 +207,4 @@ struct sycl_base_coll : base_coll, private strategy { std::vector> allocators; }; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL diff --git a/examples/benchmark/include/types.hpp b/examples/benchmark/include/types.hpp index 9466d84af..2c12cdc67 100644 --- a/examples/benchmark/include/types.hpp +++ b/examples/benchmark/include/types.hpp @@ -24,7 +24,7 @@ if (comm.rank() == 0) { \ printf(fmt "\n", ##__VA_ARGS__); \ } -#endif /* PRINT_BY_ROOT */ +#endif // PRINT_BY_ROOT constexpr std::initializer_list all_dtypes = { ccl::datatype::int8, ccl::datatype::int32, ccl::datatype::int64, ccl::datatype::uint64, @@ -34,6 +34,7 @@ constexpr std::initializer_list all_dtypes = { typedef enum { BACKEND_HOST, BACKEND_SYCL } backend_type_t; typedef enum { LOOP_REGULAR, LOOP_UNORDERED } loop_type_t; typedef enum { ITER_POLICY_OFF, ITER_POLICY_AUTO } iter_policy_t; +typedef enum { CHECK_OFF, CHECK_LAST_ITER, CHECK_ALL_ITERS } check_values_t; typedef enum { SYCL_DEV_HOST, SYCL_DEV_CPU, SYCL_DEV_GPU } sycl_dev_type_t; typedef enum { SYCL_MEM_USM, SYCL_MEM_BUF } sycl_mem_type_t; @@ -49,6 +50,12 @@ std::map iter_policy_names = { std::make_pair(ITER_P std::make_pair(ITER_POLICY_AUTO, "auto") }; +std::map check_values_names = { + std::make_pair(CHECK_OFF, "off"), + std::make_pair(CHECK_LAST_ITER, "last"), + std::make_pair(CHECK_ALL_ITERS, "all") +}; + #ifdef CCL_ENABLE_SYCL std::map sycl_dev_names = { std::make_pair(SYCL_DEV_HOST, "host"), std::make_pair(SYCL_DEV_CPU, "cpu"), @@ -115,7 +122,7 @@ typedef struct user_options_t { size_t min_elem_count; size_t max_elem_count; std::list elem_counts; - int check_values; + check_values_t check_values; int cache_ops; int inplace; size_t ranks_per_proc; diff --git a/examples/benchmark/src/allgatherv/sycl_allgatherv_coll.hpp b/examples/benchmark/src/allgatherv/sycl_allgatherv_coll.hpp index c8d465bf6..1d99ac63e 100644 --- a/examples/benchmark/src/allgatherv/sycl_allgatherv_coll.hpp +++ b/examples/benchmark/src/allgatherv/sycl_allgatherv_coll.hpp @@ -94,4 +94,4 @@ struct sycl_allgatherv_coll : sycl_base_coll { } }; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL diff --git a/examples/benchmark/src/allreduce/sycl_allreduce_coll.hpp b/examples/benchmark/src/allreduce/sycl_allreduce_coll.hpp index 400b3e53c..cd79face6 100644 --- a/examples/benchmark/src/allreduce/sycl_allreduce_coll.hpp +++ b/examples/benchmark/src/allreduce/sycl_allreduce_coll.hpp @@ -87,4 +87,4 @@ struct sycl_allreduce_coll : sycl_base_coll { } } }; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL diff --git a/examples/benchmark/src/alltoall/sycl_alltoall_coll.hpp b/examples/benchmark/src/alltoall/sycl_alltoall_coll.hpp index 9400551f8..5d51be30e 100644 --- a/examples/benchmark/src/alltoall/sycl_alltoall_coll.hpp +++ b/examples/benchmark/src/alltoall/sycl_alltoall_coll.hpp @@ -88,4 +88,4 @@ struct sycl_alltoall_coll : sycl_base_coll { } } }; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL diff --git a/examples/benchmark/src/alltoallv/sycl_alltoallv_coll.hpp b/examples/benchmark/src/alltoallv/sycl_alltoallv_coll.hpp index 3b018de31..4e1a31af2 100644 --- a/examples/benchmark/src/alltoallv/sycl_alltoallv_coll.hpp +++ b/examples/benchmark/src/alltoallv/sycl_alltoallv_coll.hpp @@ -88,4 +88,4 @@ struct sycl_alltoallv_coll : sycl_base_coll { } } }; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL diff --git a/examples/benchmark/src/bcast/sycl_bcast_coll.hpp b/examples/benchmark/src/bcast/sycl_bcast_coll.hpp index 616bd6567..f0a06af50 100644 --- a/examples/benchmark/src/bcast/sycl_bcast_coll.hpp +++ b/examples/benchmark/src/bcast/sycl_bcast_coll.hpp @@ -91,7 +91,7 @@ struct sycl_bcast_coll : sycl_base_coll { for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { value = host_recv_buf[e_idx]; - if (value != b_idx) { + if (value != static_cast(b_idx)) { // comparison float16 with size_t ?? std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " << rank_idx << ", elem_idx " << e_idx << ", expected " << (Dtype)b_idx << ", got " << value << std::endl; @@ -101,4 +101,4 @@ struct sycl_bcast_coll : sycl_base_coll { } } }; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL diff --git a/examples/benchmark/src/benchmark.cpp b/examples/benchmark/src/benchmark.cpp index 558c87092..1c169b091 100644 --- a/examples/benchmark/src/benchmark.cpp +++ b/examples/benchmark/src/benchmark.cpp @@ -29,6 +29,20 @@ #include "declarations.hpp" #include "transport_impl.hpp" +inline void prepare_coll(const user_options_t& options, + ccl::communicator& service_comm, + std::shared_ptr coll, + const size_t elem_count) { + coll->prepare(elem_count); + ccl::barrier(service_comm); +} + +inline void finalize_coll(const user_options_t& options, + std::shared_ptr coll, + const size_t elem_count) { + coll->finalize(elem_count); +} + void do_regular(ccl::communicator& service_comm, bench_exec_attr& bench_attr, coll_list_t& all_colls, @@ -118,9 +132,8 @@ void do_regular(ccl::communicator& service_comm, for (size_t iter_idx = 0; iter_idx < (iter_count + warmup_iter_count); iter_idx++) { - if (options.check_values) { - coll->prepare(count); - ccl::barrier(service_comm); + if (options.check_values == CHECK_ALL_ITERS) { + prepare_coll(options, service_comm, coll, count); } double coll_start_time = when(); @@ -147,13 +160,34 @@ void do_regular(ccl::communicator& service_comm, wait_time += wait_end_time - wait_start_time; } - if (options.check_values) { - coll->finalize(count); + if (options.check_values == CHECK_ALL_ITERS) { + finalize_coll(options, coll, count); } } total_timers[coll_idx] += coll_time + wait_time; wait_timers[coll_idx] += wait_time; + + if (options.check_values == CHECK_LAST_ITER) { + prepare_coll(options, service_comm, coll, count); + + for (size_t buf_idx = 0; buf_idx < options.buf_count; buf_idx++) { + match_id_stream << "coll_" << coll->name() << "_" << coll_idx + << "_count_" << count << "_buf_" << buf_idx + << "_dt_" << dtype_name << "_rt_" << reduction; + bench_attr.set( + ccl::string_class(match_id_stream.str())); + match_id_stream.str(""); + coll->start(count, buf_idx, bench_attr, reqs); + } + + for (auto& req : reqs) { + req.wait(); + } + reqs.clear(); + + finalize_coll(options, coll, count); + } } print_timings(service_comm, @@ -172,6 +206,8 @@ void do_regular(ccl::communicator& service_comm, } } + ccl::barrier(service_comm); + PRINT_BY_ROOT(service_comm, "\n# All done\n"); } @@ -446,7 +482,7 @@ void create_sycl_colls(bench_init_attr& init_attr, user_options_t& options, coll " - empty colls, reason: " + coll_processing_log); } } -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL template void create_colls(bench_init_attr& init_attr, user_options_t& options, coll_list_t& colls) { @@ -513,7 +549,7 @@ int main(int argc, char* argv[]) { #ifdef CCL_ENABLE_SYCL init_attr.sycl_mem_type = options.sycl_mem_type; init_attr.sycl_usm_type = options.sycl_usm_type; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL try { create_all_colls(init_attr, options, colls); diff --git a/examples/benchmark/src/reduce/sycl_reduce_coll.hpp b/examples/benchmark/src/reduce/sycl_reduce_coll.hpp index 4b87ce244..b9ac0ce95 100644 --- a/examples/benchmark/src/reduce/sycl_reduce_coll.hpp +++ b/examples/benchmark/src/reduce/sycl_reduce_coll.hpp @@ -92,4 +92,4 @@ struct sycl_reduce_coll : sycl_base_coll { } } }; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL diff --git a/examples/benchmark/src/reduce_scatter/sycl_reduce_scatter_coll.hpp b/examples/benchmark/src/reduce_scatter/sycl_reduce_scatter_coll.hpp index 186b02a15..0013ec1cb 100644 --- a/examples/benchmark/src/reduce_scatter/sycl_reduce_scatter_coll.hpp +++ b/examples/benchmark/src/reduce_scatter/sycl_reduce_scatter_coll.hpp @@ -91,4 +91,4 @@ struct sycl_reduce_scatter_coll : sycl_base_coll& expected_recv_counts, } } } -} /* namespace sparse_detail */ +} // namespace sparse_detail #endif diff --git a/examples/benchmark/src/sparse_allreduce/sycl_sparse_allreduce_coll.hpp b/examples/benchmark/src/sparse_allreduce/sycl_sparse_allreduce_coll.hpp index ea7507a89..95540400d 100644 --- a/examples/benchmark/src/sparse_allreduce/sycl_sparse_allreduce_coll.hpp +++ b/examples/benchmark/src/sparse_allreduce/sycl_sparse_allreduce_coll.hpp @@ -134,6 +134,6 @@ struct sycl_sparse_allreduce_coll : base_sparse_allreduce_coll #include "sycl_coll.hpp" -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL #include "transport.hpp" @@ -113,7 +113,8 @@ void transport_data::init_comms(user_options_t& options) { else if (options.backend == BACKEND_SYCL) { auto sycl_queues = create_sycl_queues(sycl_dev_names[options.sycl_dev_type], local_ranks); ASSERT(!sycl_queues.empty(), "queues should contain at least one queue"); - ASSERT(ranks_per_proc == sycl_queues.size(), "ranks and queues sizes should match"); + ASSERT(static_cast(ranks_per_proc) == sycl_queues.size(), + "ranks and queues sizes should match"); auto sycl_context = sycl_queues[0].get_context(); context = ccl::create_context(sycl_context); @@ -128,7 +129,7 @@ void transport_data::init_comms(user_options_t& options) { // "all sycl queues should be from the same sycl context"); } } -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL else { ASSERT(0, "unknown backend %d", (int)options.backend); } @@ -150,6 +151,7 @@ std::vector& transport_data::get_comms() { } void transport_data::reset_comms() { + ccl::barrier(get_service_comm()); comms.clear(); service_comms.clear(); } diff --git a/examples/common/CMakeLists.txt b/examples/common/CMakeLists.txt index c95d44bc6..296a83adb 100644 --- a/examples/common/CMakeLists.txt +++ b/examples/common/CMakeLists.txt @@ -15,6 +15,8 @@ # file(GLOB sources "*.c" "*.cpp") +link_directories(${EXAMPLES_LIB_DIRS}) + foreach(src ${sources}) get_filename_component(executable ${src} NAME_WE) add_executable(${executable} ${src}) diff --git a/examples/cpu/CMakeLists.txt b/examples/cpu/CMakeLists.txt index 403a409d4..77a2c0342 100644 --- a/examples/cpu/CMakeLists.txt +++ b/examples/cpu/CMakeLists.txt @@ -15,6 +15,8 @@ # file(GLOB sources "*.c" "*.cpp") +link_directories(${EXAMPLES_LIB_DIRS}) + foreach(src ${sources}) get_filename_component(executable ${src} NAME_WE) add_executable(${executable} ${src}) diff --git a/examples/cpu/allreduce.cpp b/examples/cpu/allreduce.cpp index 5ebafb312..04f94665a 100644 --- a/examples/cpu/allreduce.cpp +++ b/examples/cpu/allreduce.cpp @@ -21,7 +21,7 @@ void run_collective(const char* cmd_name, const ccl::communicator& comm, const ccl::allreduce_attr& attr) { std::chrono::system_clock::duration exec_time{ 0 }; - float expected = (comm.size() - 1) * (static_cast(comm.size()) / 2); + float expected = (static_cast(comm.size()) + 1) / 2 * static_cast(comm.size()); ccl::barrier(comm); @@ -80,7 +80,7 @@ int main() { std::terminate(); } - MSG_LOOP(comm, std::vector send_buf(msg_count, static_cast(comm.rank())); + MSG_LOOP(comm, std::vector send_buf(msg_count, static_cast(comm.rank() + 1)); std::vector recv_buf(msg_count); attr.set(false); run_collective("warmup allreduce", send_buf, recv_buf, comm, attr); diff --git a/examples/external_launcher/CMakeLists.txt b/examples/external_launcher/CMakeLists.txt index fc2e50f69..31fadf72f 100644 --- a/examples/external_launcher/CMakeLists.txt +++ b/examples/external_launcher/CMakeLists.txt @@ -15,6 +15,8 @@ # file(GLOB sources "*.c" "*.cpp") +link_directories(${EXAMPLES_LIB_DIRS}) + foreach(src ${sources}) get_filename_component(executable ${src} NAME_WE) add_executable(${executable} ${src}) diff --git a/examples/external_launcher/external_launcher.cpp b/examples/external_launcher/external_launcher.cpp index f304fe644..bafcb5136 100644 --- a/examples/external_launcher/external_launcher.cpp +++ b/examples/external_launcher/external_launcher.cpp @@ -22,7 +22,7 @@ #define ELEM_COUNT (256 * 1024) #define ITER_COUNT 10 -#define REINIT_COUNT 20 +#define REINIT_COUNT 10 #define STORE_TIMEOUT_SEC 120 #define MAX_SLEEP_MSEC 500 diff --git a/examples/include/base.hpp b/examples/include/base.hpp index 2421fb309..6506acd44 100644 --- a/examples/include/base.hpp +++ b/examples/include/base.hpp @@ -35,7 +35,7 @@ #include using namespace cl::sycl; using namespace cl::sycl::access; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL #define GETTID() syscall(SYS_gettid) diff --git a/examples/include/bf16.hpp b/examples/include/bf16.hpp index 72ab78955..3b8039d24 100644 --- a/examples/include/bf16.hpp +++ b/examples/include/bf16.hpp @@ -23,13 +23,13 @@ /* - https://www.johndcook.com/blog/2018/11/15/bfloat16/ + https://www.johndcook.com/blog/2018/11/15/bfloat16/ In this example we use the accuracy 0.00781250 of calculations performed in the bfloat16, but don't take into account the error that may occur during conversion - from float32 datatype to bfloat16. - + from float32 datatype to bfloat16. + */ #define BF16_PRECISION 0.00781250 /* 2^-7 */ @@ -154,7 +154,7 @@ void convert_bf16_to_fp32_arrays(void* recv_buf_bf16, float* recv_buf, int count memcpy((recv_buf + i), &int_val_shifted, 4); } } -#else /* CCL_BF16_COMPILER */ +#else // CCL_BF16_COMPILER void convert_fp32_to_bf16_arrays(void* send_buf, void* send_buf_bf16, int count) { printf("unsupported\n"); @@ -166,4 +166,37 @@ void convert_bf16_to_fp32_arrays(void* recv_buf_bf16, float* recv_buf, int count assert(0); } -#endif /* CCL_BF16_COMPILER */ +#endif // CCL_BF16_COMPILER + +// Routines to convert between fp32 and bf16 without relying on AVX instructions. +// These are useful when bf16 is only natively supported on a device. +void convert_fp32_to_bf16_arrays_generic(float* send_buf_float, void* send_buf_bf16, int count) { + int int_val = 0, int_val_shifted = 0; + + for (int i = 0; i < count; ++i) { + /* iterate over send_buf_bf16 */ + int* send_bfp_tail = (int*)(((char*)send_buf_bf16) + (2 * i)); + /* copy float (4 bytes) data as is to int variable, */ + memcpy(&int_val, &send_buf_float[i], 4); + /* then perform shift and */ + int_val_shifted = int_val >> BF16_SHIFT; + /* save pointer to result */ + *send_bfp_tail = int_val_shifted; + } +} + +void convert_bf16_to_fp32_arrays_generic(void* recv_buf_bf16, float* recv_buf_float, int count) { + int int_val = 0, int_val_shifted = 0; + + /* proceed remaining bf16's in buffer */ + for (int i = 0; i < count; i++) { + /* iterate over recv_buf_bf16 */ + int* recv_bfp_tail = (int*)((char*)recv_buf_bf16 + (2 * i)); + /* copy bf16 data as is to int variable, */ + memcpy(&int_val, recv_bfp_tail, 4); + /* then perform shift and */ + int_val_shifted = int_val << BF16_SHIFT; + /* copy result to output */ + memcpy((recv_buf_float + i), &int_val_shifted, 4); + } +} diff --git a/examples/include/sycl_base.hpp b/examples/include/sycl_base.hpp index c7861d2a7..f944018f2 100644 --- a/examples/include/sycl_base.hpp +++ b/examples/include/sycl_base.hpp @@ -206,7 +206,7 @@ inline std::vector create_sycl_gpu_devices() { << " sub-devices\n"; result.insert(result.end(), sub_devices.begin(), sub_devices.end()); - for (auto idx = 0; idx < sub_devices.size(); idx++) { + for (size_t idx = 0; idx < sub_devices.size(); idx++) { ss << sub_dev_prefix << "sub-device " << idx << ": [" << sub_devices[idx].get_info() << "]\n"; } @@ -285,7 +285,7 @@ inline std::vector create_sycl_queues(const std::string& device_typ try { ctx = sycl::context(rank_devices); } - catch (sycl::runtime_error&) { + catch (sycl::exception&) { size_t preferred_idx = (ranks.back() / ranks.size()) % devices.size(); cout << "Can not create context from all rank devices of type: " << device_type << ", create context from single device, idx " << preferred_idx << "\n"; diff --git a/examples/sycl/CMakeLists.txt b/examples/sycl/CMakeLists.txt index 4bff71065..ff4fc8b5b 100644 --- a/examples/sycl/CMakeLists.txt +++ b/examples/sycl/CMakeLists.txt @@ -18,6 +18,8 @@ file(GLOB sources "*.c" "*.cpp") set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) +link_directories(${EXAMPLES_LIB_DIRS}) + foreach(src ${sources}) get_filename_component(executable ${src} NAME_WE) add_executable(${executable} ${src}) diff --git a/examples/sycl/sycl_allgatherv_custom_usm_test.cpp b/examples/sycl/sycl_allgatherv_custom_usm_test.cpp index 935c6f1c0..1bf2b0d49 100644 --- a/examples/sycl/sycl_allgatherv_custom_usm_test.cpp +++ b/examples/sycl/sycl_allgatherv_custom_usm_test.cpp @@ -26,7 +26,6 @@ struct custom_data_type { int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; - int i = 0; int size = 0; int rank = 0; @@ -125,6 +124,7 @@ int main(int argc, char *argv[]) { /* print out the result of the test on the host side */ { host_accessor check_buf_acc(check_buf, read_only); + size_t i; for (i = 0; i < size * send_count; i++) { if (check_buf_acc[i] == -1) { cout << "FAILED\n"; diff --git a/examples/sycl/sycl_allgatherv_inplace_test.cpp b/examples/sycl/sycl_allgatherv_inplace_test.cpp index 6fd7d7bc2..237a10b3e 100644 --- a/examples/sycl/sycl_allgatherv_inplace_test.cpp +++ b/examples/sycl/sycl_allgatherv_inplace_test.cpp @@ -21,8 +21,6 @@ using namespace sycl; int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; - int i = 0; - int j = 0; int size = 0; int rank = 0; size_t send_buf_count = 0; @@ -79,15 +77,15 @@ int main(int argc, char *argv[]) { host_accessor recv_buf_acc(recv_buf, write_only); host_accessor expected_acc_buf(expected_buf, write_only); - for (i = 0; i < send_buf_count; i++) { + for (size_t i = 0; i < send_buf_count; i++) { send_buf_acc[i] = rank; } - for (i = 0; i < recv_buf_count; i++) { + for (size_t i = 0; i < recv_buf_count; i++) { recv_buf_acc[i] = -1; } size_t idx = 0; - for (i = 0; i < size; i++) { - for (j = 0; j < count + i; j++) { + for (int i = 0; i < size; i++) { + for (size_t j = 0; j < count + i; j++) { expected_acc_buf[idx + j] = i + 1; } idx += count + i; @@ -97,7 +95,7 @@ int main(int argc, char *argv[]) { /* open send_buf and modify it on the device side */ /* make in-place updates in the appropriate place */ size_t rbuf_idx = 0; - for (i = 0; i < rank; i++) + for (int i = 0; i < rank; i++) rbuf_idx += recv_counts[i]; q.submit([&](auto &h) { @@ -132,6 +130,7 @@ int main(int argc, char *argv[]) { /* print out the result of the test on the host side */ { host_accessor recv_buf_acc(recv_buf, read_only); + size_t i; for (i = 0; i < recv_buf_count; i++) { if (recv_buf_acc[i] == -1) { cout << "FAILED\n"; diff --git a/examples/sycl/sycl_allgatherv_test.cpp b/examples/sycl/sycl_allgatherv_test.cpp index 3b2c9f764..e176b240b 100644 --- a/examples/sycl/sycl_allgatherv_test.cpp +++ b/examples/sycl/sycl_allgatherv_test.cpp @@ -21,8 +21,6 @@ using namespace sycl; int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; - int i = 0; - int j = 0; int size = 0; int rank = 0; @@ -73,14 +71,14 @@ int main(int argc, char *argv[]) { host_accessor recv_buf_acc(recv_buf, write_only); host_accessor expected_acc_buf(expected_buf, write_only); - for (i = 0; i < count; i++) { + for (size_t i = 0; i < count; i++) { send_buf_acc[i] = rank; } - for (i = 0; i < count * size; i++) { + for (size_t i = 0; i < count * size; i++) { recv_buf_acc[i] = -1; } - for (i = 0; i < size; i++) { - for (j = 0; j < count; j++) { + for (int i = 0; i < size; i++) { + for (size_t j = 0; j < count; j++) { expected_acc_buf[i * count + j] = i + 1; } } @@ -117,6 +115,7 @@ int main(int argc, char *argv[]) { /* print out the result of the test on the host side */ { host_accessor recv_buf_acc(recv_buf, read_only); + size_t i; for (i = 0; i < size * count; i++) { if (recv_buf_acc[i] == -1) { cout << "FAILED\n"; diff --git a/examples/sycl/sycl_allgatherv_usm_test.cpp b/examples/sycl/sycl_allgatherv_usm_test.cpp index 3e95c3d1c..895bbd31b 100644 --- a/examples/sycl/sycl_allgatherv_usm_test.cpp +++ b/examples/sycl/sycl_allgatherv_usm_test.cpp @@ -21,7 +21,6 @@ using namespace sycl; int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; - int i = 0; int size = 0; int rank = 0; @@ -115,6 +114,7 @@ int main(int argc, char *argv[]) { /* print out the result of the test on the host side */ { host_accessor check_buf_acc(check_buf, read_only); + size_t i; for (i = 0; i < size * count; i++) { if (check_buf_acc[i] == -1) { cout << "FAILED\n"; diff --git a/examples/sycl/sycl_allreduce_inplace_usm_test.cpp b/examples/sycl/sycl_allreduce_inplace_usm_test.cpp index 55bfd3fd2..ab2de50d6 100644 --- a/examples/sycl/sycl_allreduce_inplace_usm_test.cpp +++ b/examples/sycl/sycl_allreduce_inplace_usm_test.cpp @@ -21,7 +21,6 @@ using namespace sycl; int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; - int i = 0; int size = 0; int rank = 0; @@ -105,6 +104,7 @@ int main(int argc, char *argv[]) { /* print out the result of the test on the host side */ { host_accessor check_buf_acc(check_buf, read_only); + size_t i; for (i = 0; i < count; i++) { if (check_buf_acc[i] == -1) { cout << "FAILED\n"; diff --git a/examples/sycl/sycl_allreduce_test.cpp b/examples/sycl/sycl_allreduce_test.cpp index 1f94bfbdb..d9b2651d6 100644 --- a/examples/sycl/sycl_allreduce_test.cpp +++ b/examples/sycl/sycl_allreduce_test.cpp @@ -21,7 +21,6 @@ using namespace sycl; int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; - int i = 0; int size = 0; int rank = 0; @@ -67,7 +66,7 @@ int main(int argc, char *argv[]) { /* open buffers and initialize them on the host side */ host_accessor send_buf_acc(send_buf, write_only); host_accessor recv_buf_acc(recv_buf, write_only); - for (i = 0; i < count; i++) { + for (size_t i = 0; i < count; i++) { send_buf_acc[i] = rank; recv_buf_acc[i] = -1; } @@ -102,6 +101,7 @@ int main(int argc, char *argv[]) { /* print out the result of the test on the host side */ { host_accessor recv_buf_acc(recv_buf, read_only); + size_t i; for (i = 0; i < count; i++) { if (recv_buf_acc[i] == -1) { cout << "FAILED\n"; diff --git a/examples/sycl/sycl_allreduce_usm_test.cpp b/examples/sycl/sycl_allreduce_usm_test.cpp index 5065a3d39..26f2ce4f4 100644 --- a/examples/sycl/sycl_allreduce_usm_test.cpp +++ b/examples/sycl/sycl_allreduce_usm_test.cpp @@ -21,7 +21,6 @@ using namespace sycl; int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; - int i = 0; int size = 0; int rank = 0; @@ -107,6 +106,7 @@ int main(int argc, char *argv[]) { /* print out the result of the test on the host side */ { host_accessor check_buf_acc(check_buf, read_only); + size_t i; for (i = 0; i < count; i++) { if (check_buf_acc[i] == -1) { cout << "FAILED\n"; diff --git a/examples/sycl/sycl_alltoall_test.cpp b/examples/sycl/sycl_alltoall_test.cpp index 23e20629a..b9c1d4c9b 100644 --- a/examples/sycl/sycl_alltoall_test.cpp +++ b/examples/sycl/sycl_alltoall_test.cpp @@ -21,8 +21,6 @@ using namespace sycl; int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; - int i = 0; - int j = 0; int size = 0; int rank = 0; @@ -69,8 +67,8 @@ int main(int argc, char *argv[]) { host_accessor send_buf_acc(send_buf, write_only); host_accessor recv_buf_acc(recv_buf, write_only); - for (i = 0; i < size; i++) { - for (j = 0; j < count; j++) { + for (int i = 0; i < size; i++) { + for (size_t j = 0; j < count; j++) { send_buf_acc[(i * count) + j] = i; recv_buf_acc[(i * count) + j] = -1; } @@ -107,6 +105,7 @@ int main(int argc, char *argv[]) { /* print out the result of the test on the host side */ { host_accessor recv_buf_acc(recv_buf, read_only); + size_t i; for (i = 0; i < count * size; i++) { if (recv_buf_acc[i] == -1) { cout << "FAILED\n"; diff --git a/examples/sycl/sycl_alltoall_usm_test.cpp b/examples/sycl/sycl_alltoall_usm_test.cpp index ecb75538f..8d035c0d4 100644 --- a/examples/sycl/sycl_alltoall_usm_test.cpp +++ b/examples/sycl/sycl_alltoall_usm_test.cpp @@ -21,7 +21,6 @@ using namespace sycl; int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; - int i = 0; int size = 0; int rank = 0; @@ -107,6 +106,7 @@ int main(int argc, char *argv[]) { /* print out the result of the test on the host side */ { host_accessor check_buf_acc(check_buf, read_only); + size_t i; for (i = 0; i < count * size; i++) { if (check_buf_acc[i] == -1) { cout << "FAILED\n"; diff --git a/examples/sycl/sycl_alltoallv_test.cpp b/examples/sycl/sycl_alltoallv_test.cpp index fd9bd7810..2905235bc 100644 --- a/examples/sycl/sycl_alltoallv_test.cpp +++ b/examples/sycl/sycl_alltoallv_test.cpp @@ -21,8 +21,6 @@ using namespace sycl; int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; - int i = 0; - int j = 0; int size = 0; int rank = 0; @@ -72,8 +70,8 @@ int main(int argc, char *argv[]) { host_accessor send_buf_acc(send_buf, write_only); host_accessor recv_buf_acc(recv_buf, write_only); - for (i = 0; i < size; i++) { - for (j = 0; j < count; j++) { + for (int i = 0; i < size; i++) { + for (size_t j = 0; j < count; j++) { send_buf_acc[(i * count) + j] = i; recv_buf_acc[(i * count) + j] = -1; } @@ -110,6 +108,7 @@ int main(int argc, char *argv[]) { /* print out the result of the test on the host side */ { host_accessor recv_buf_acc(recv_buf, read_only); + size_t i; for (i = 0; i < count * size; i++) { if (recv_buf_acc[i] == -1) { cout << "FAILED\n"; diff --git a/examples/sycl/sycl_alltoallv_usm_test.cpp b/examples/sycl/sycl_alltoallv_usm_test.cpp index 36b2b2d1d..211b548d1 100644 --- a/examples/sycl/sycl_alltoallv_usm_test.cpp +++ b/examples/sycl/sycl_alltoallv_usm_test.cpp @@ -21,7 +21,6 @@ using namespace sycl; int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; - int i = 0; int size = 0; int rank = 0; @@ -110,6 +109,7 @@ int main(int argc, char *argv[]) { /* print out the result of the test on the host side */ { host_accessor check_buf_acc(check_buf, read_only); + size_t i; for (i = 0; i < count * size; i++) { if (check_buf_acc[i] == -1) { cout << "FAILED\n"; diff --git a/examples/sycl/sycl_broadcast_test.cpp b/examples/sycl/sycl_broadcast_test.cpp index d03d6e33d..5c960d3be 100644 --- a/examples/sycl/sycl_broadcast_test.cpp +++ b/examples/sycl/sycl_broadcast_test.cpp @@ -22,7 +22,6 @@ int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; const size_t root_rank = 0; - int i = 0; int size = 0; int rank = 0; @@ -66,7 +65,7 @@ int main(int argc, char *argv[]) { if (rank == root_rank) { /* open buf and initialize it on the host side */ host_accessor send_buf_acc(buf, write_only); - for (i = 0; i < count; i++) { + for (size_t i = 0; i < count; i++) { send_buf_acc[i] = 10; } @@ -100,6 +99,7 @@ int main(int argc, char *argv[]) { /* print out the result of the test on the host side */ host_accessor recv_buf_acc(buf, read_only); + size_t i; for (i = 0; i < count; i++) { if (recv_buf_acc[i] == -1) { cout << "FAILED\n"; diff --git a/examples/sycl/sycl_broadcast_usm_test.cpp b/examples/sycl/sycl_broadcast_usm_test.cpp index 1f47abfc8..662455560 100644 --- a/examples/sycl/sycl_broadcast_usm_test.cpp +++ b/examples/sycl/sycl_broadcast_usm_test.cpp @@ -22,7 +22,6 @@ int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; const size_t root_rank = 0; - int i = 0; int size = 0; int rank = 0; @@ -107,6 +106,7 @@ int main(int argc, char *argv[]) { /* print out the result of the test on the host side */ host_accessor check_buf_acc(check_buf, read_only); + size_t i; for (i = 0; i < count; i++) { if (check_buf_acc[i] == -1) { cout << "FAILED\n"; diff --git a/examples/sycl/sycl_reduce_test.cpp b/examples/sycl/sycl_reduce_test.cpp index 8d3230a2a..120ceb6a4 100644 --- a/examples/sycl/sycl_reduce_test.cpp +++ b/examples/sycl/sycl_reduce_test.cpp @@ -22,7 +22,6 @@ int main(int argc, char *argv[]) { const size_t count = 10 * 1024 * 1024; const size_t root_rank = 0; - int i = 0; int size = 0; int rank = 0; @@ -69,7 +68,7 @@ int main(int argc, char *argv[]) { host_accessor send_buf_acc(send_buf, write_only); host_accessor recv_buf_acc(recv_buf, write_only); - for (i = 0; i < count; i++) { + for (size_t i = 0; i < count; i++) { send_buf_acc[i] = rank; recv_buf_acc[i] = 0; } @@ -113,6 +112,7 @@ int main(int argc, char *argv[]) { { if (rank == root_rank) { host_accessor recv_buf_acc(recv_buf, read_only); + size_t i; for (i = 0; i < count; i++) { if (recv_buf_acc[i] == -1) { cout << "FAILED for rank: " << rank << "\n"; diff --git a/include/oneapi/ccl/communicator.hpp b/include/oneapi/ccl/communicator.hpp index b2046f235..c66a84fdb 100644 --- a/include/oneapi/ccl/communicator.hpp +++ b/include/oneapi/ccl/communicator.hpp @@ -99,12 +99,6 @@ class communicator final : public ccl_api_base_movable - stream create_stream(attr_val_type&&... avs) { - // return stream::create_stream_from_attr(get_device(), get_context(), std::forward(avs)...); - throw ccl::unsupported("API", "create_stream"); - } - communicator split(const comm_split_attr& attr); private: diff --git a/include/oneapi/ccl/environment.hpp b/include/oneapi/ccl/environment.hpp index 3956dff89..02f322e71 100644 --- a/include/oneapi/ccl/environment.hpp +++ b/include/oneapi/ccl/environment.hpp @@ -185,42 +185,7 @@ class environment { class = typename std::enable_if()>::type> stream create_stream(native_stream_type& native_stream); - template ()>::type> - stream create_stream(native_stream_type& native_stream, native_context_type& native_ctx); - - template - stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, - attr_val_type&&... avs) { - stream str = create_stream(device); - int expander[]{ (str.template set(avs.val()), 0)... }; - (void)expander; - str.build_from_params(); - return str; - } - - template - stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, - typename unified_context_type::ccl_native_t context, - attr_val_type&&... avs) { - stream str = create_stream(device, context); - int expander[]{ (str.template set(avs.val()), 0)... }; - (void)expander; - str.build_from_params(); - return str; - } - /******************** COMMUNICATOR ********************/ - -#ifdef CCL_ENABLE_SYCL - communicator create_single_device_communicator(int comm_size, - int rank, - const cl::sycl::device& device, - const cl::sycl::context& context, - shared_ptr_class kvs) const; -#endif - template static comm_split_attr create_comm_split_attr(attr_val_type&&... avs) { auto split_attr = create_postponed_api_type(); @@ -280,11 +245,6 @@ class environment { auto version = get_library_version(); return ccl_api_type(std::forward(args)..., version); } - - stream create_stream(typename unified_device_type::ccl_native_t device); - - stream create_stream(typename unified_device_type::ccl_native_t device, - typename unified_context_type::ccl_native_t context); }; } // namespace detail diff --git a/include/oneapi/ccl/native_device_api/interop_utils.hpp b/include/oneapi/ccl/native_device_api/interop_utils.hpp index 02bd77083..7756a6f05 100644 --- a/include/oneapi/ccl/native_device_api/interop_utils.hpp +++ b/include/oneapi/ccl/native_device_api/interop_utils.hpp @@ -26,49 +26,7 @@ namespace detail { #ifdef CCL_ENABLE_SYCL size_t get_sycl_device_id(const cl::sycl::device& dev); size_t get_sycl_subdevice_id(const cl::sycl::device& dev); -std::string usm_to_string(cl::sycl::usm::alloc val); #endif -enum usm_support_mode { prohibited = 0, direct, shared, need_conversion, last_value }; -std::string to_string(usm_support_mode val); - -using assoc_result = std::tuple; -enum assoc_result_index { SUPPORT_MODE = 0, POINTER_VALUE, ERROR_CAUSE }; - -#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) -// TODO: move to src -assoc_result check_assoc_device_memory(const void* mem, - const ccl::unified_device_type::ccl_native_t& device, - const ccl::unified_context_type::ccl_native_t& ctx); - -usm_support_mode check_assoc_device_memory(const std::vector& mems, - const ccl::unified_device_type::ccl_native_t& device, - const ccl::unified_context_type::ccl_native_t& ctx); - -#endif //defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) -std::string to_string(const assoc_result& res); - -#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) -template -using multiple_assoc_result = std::array; - -template -auto check_multiple_assoc_device_memory(const ccl::unified_device_type::ccl_native_t& device, - const ccl::unified_context_type::ccl_native_t& ctx, - const mem_type*... mem) - -> multiple_assoc_result { - multiple_assoc_result ret{ check_assoc_device_memory(mem, device, ctx)... }; - return ret; -} - -template -std::string to_string(const multiple_assoc_result& res) { - std::stringstream ss; - for (size_t i = 0; i < N; i++) { - ss << "Arg: " << std::to_string(i) << to_string(res[i]) << std::endl; - } - return ss.str(); -} -#endif //defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) } // namespace detail } // namespace native diff --git a/include/oneapi/ccl/native_device_api/l0/base_impl.hpp b/include/oneapi/ccl/native_device_api/l0/base_impl.hpp index 6becbfe1c..93a314e2a 100644 --- a/include/oneapi/ccl/native_device_api/l0/base_impl.hpp +++ b/include/oneapi/ccl/native_device_api/l0/base_impl.hpp @@ -115,7 +115,7 @@ const typename cl_base::context_ptr_t cl_base constexpr size_t cl_base::get_size_for_serialize() { - return resource_owner::get_size_for_serialize() + sizeof(handle_t); + return resource_owner::get_size_for_serialize() + sizeof(handle_t) + sizeof(size_t); } template @@ -129,14 +129,21 @@ size_t cl_base::serialize(std::vector& out, throw std::runtime_error("cannot serialize without owner"); } - constexpr size_t expected_bytes = sizeof(handle_t); + constexpr size_t expected_bytes = sizeof(handle_t) + sizeof(size_t); // serialize from position size_t serialized_bytes = lock->serialize(out, from_pos, expected_bytes, args...); //resize vector inside uint8_t* data_start = out.data() + from_pos + serialized_bytes; + + // Looks like this method is only used with ipc handles, so we can safely downcast to the corresponding + // child class + auto* handle_ptr = static_cast*>(this); + assert(handle_ptr != nullptr); + *(reinterpret_cast(data_start)) = handle; + *(reinterpret_cast(data_start + sizeof(handle_t))) = handle_ptr->get_offset(); serialized_bytes += expected_bytes; return serialized_bytes; } @@ -147,7 +154,7 @@ std::shared_ptr cl_base::deserialize(const uint8_t** dat size_t& size, std::shared_ptr ctx, helpers&... args) { - constexpr size_t expected_bytes = sizeof(handle); + constexpr size_t expected_bytes = sizeof(handle) + sizeof(size_t); size_t initial_size = size; // recover parent handle at first @@ -163,9 +170,11 @@ std::shared_ptr cl_base::deserialize(const uint8_t** dat } handle_t h = *(reinterpret_cast(*data)); + size_t off = *(reinterpret_cast(*data + sizeof(handle_t))); + *data += expected_bytes; size -= expected_bytes; - return std::shared_ptr{ new type(h, owner, ctx) }; + return std::shared_ptr{ new type(h, owner, ctx, off) }; } #undef TEMPLATE_DEF_ARG diff --git a/include/oneapi/ccl/native_device_api/l0/primitives.hpp b/include/oneapi/ccl/native_device_api/l0/primitives.hpp index 303e15750..778624977 100644 --- a/include/oneapi/ccl/native_device_api/l0/primitives.hpp +++ b/include/oneapi/ccl/native_device_api/l0/primitives.hpp @@ -57,7 +57,24 @@ template using module = cl_base; template -using ipc_memory_handle = cl_base; +class ipc_memory_handle : public cl_base { + using base = cl_base; + +public: + ipc_memory_handle(ze_ipc_mem_handle_t handle, + std::weak_ptr owner, + std::weak_ptr ctx, + size_t offset = 0) + : base(handle, owner, ctx), + offset(offset) {} + + size_t get_offset() const { + return offset; + } + +private: + size_t offset; +}; template using queue_fence = cl_base; @@ -140,6 +157,7 @@ struct memory /**/ : private cl_base diff --git a/include/oneapi/ccl/stream.hpp b/include/oneapi/ccl/stream.hpp index 7f97487ae..a75320296 100644 --- a/include/oneapi/ccl/stream.hpp +++ b/include/oneapi/ccl/stream.hpp @@ -85,25 +85,16 @@ class stream : public ccl_api_base_copyable - friend stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, - typename unified_context_type::ccl_native_t context, - attr_val_type&&... avs); - template - friend stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, - attr_val_type&&... avs); - stream(impl_value_t&& impl); /** - *Parametrized stream creation helper + * Parameterized stream creation helper */ template ()>::type*/> typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); - void build_from_params(); stream(const typename detail::ccl_api_type_attr_traits::type& version); @@ -113,20 +104,6 @@ class stream : public ccl_api_base_copyable()>::type> static stream create_stream(native_stream_type& native_stream); - - template ()>::type> - static stream create_stream(native_stream_type& native_stream, native_context_type& native_ctx); - - template - static stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, - attr_val_type&&... avs); - - template - static stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, - typename unified_context_type::ccl_native_t context, - attr_val_type&&... avs); }; /** diff --git a/include/oneapi/ccl/stream_attr_ids.hpp b/include/oneapi/ccl/stream_attr_ids.hpp index 7a63721c2..ba5b941fa 100644 --- a/include/oneapi/ccl/stream_attr_ids.hpp +++ b/include/oneapi/ccl/stream_attr_ids.hpp @@ -31,13 +31,6 @@ enum class stream_attr_id : int { version, native_handle, - device, - context, - ordinal, - index, - flags, - mode, - priority, }; } // namespace v1 diff --git a/include/oneapi/ccl/stream_attr_ids_traits.hpp b/include/oneapi/ccl/stream_attr_ids_traits.hpp index 9d44e8f01..2f3546762 100644 --- a/include/oneapi/ccl/stream_attr_ids_traits.hpp +++ b/include/oneapi/ccl/stream_attr_ids_traits.hpp @@ -39,50 +39,6 @@ struct ccl_api_type_attr_traits { using return_type = type; }; -template <> -struct ccl_api_type_attr_traits { - using type = typename unified_device_type::ccl_native_t; - using handle_t = typename unified_device_type::handle_t; - using return_type = type; -}; - -template <> -struct ccl_api_type_attr_traits { - using type = typename unified_context_type::ccl_native_t; - using handle_t = typename unified_context_type::handle_t; - using return_type = type; -}; - -template <> -struct ccl_api_type_attr_traits { - using type = uint32_t; - using return_type = type; -}; - -template <> -struct ccl_api_type_attr_traits { - using type = uint32_t; - using return_type = type; -}; - -template <> -struct ccl_api_type_attr_traits { - using type = size_t; - using return_type = type; -}; - -template <> -struct ccl_api_type_attr_traits { - using type = size_t; - using return_type = type; -}; - -template <> -struct ccl_api_type_attr_traits { - using type = size_t; - using return_type = type; -}; - } // namespace detail } // namespace ccl diff --git a/include/oneapi/ccl/type_traits.hpp b/include/oneapi/ccl/type_traits.hpp index 0f2f33300..6f26e04f5 100644 --- a/include/oneapi/ccl/type_traits.hpp +++ b/include/oneapi/ccl/type_traits.hpp @@ -126,7 +126,7 @@ CCL_CLASS_TYPE_TRAITS(ccl::datatype::bfloat16, cl::sycl::buffer, sizeof(bfloat16), bfloat16) -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL /** * Checks for supporting @c type in ccl API diff --git a/pkgconfig/template.pc b/pkgconfig/template.pc new file mode 100755 index 000000000..85b74168f --- /dev/null +++ b/pkgconfig/template.pc @@ -0,0 +1,13 @@ +# +prefix=${pcfiledir}/../../ +exec_prefix=${prefix} +libdir=${exec_prefix}/lib/@BUILD_TYPE@ +includedir=${prefix}/include/@BUILD_TYPE@ + +Name: oneAPI Collective Communications Library (oneCCL) +Description: oneCCL provides an efficient implementation of communication patterns used in deep learning. +URL: https://github.com/oneapi-src/oneCCL +Version: CCL_SUBSTITUTE_OFFICIAL_VERSION +Requires: impi +Libs: -L${libdir} -lccl @OTHER_FLAGS@ +Cflags: -I${includedir} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 356e0b49a..b262ba7bc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,15 +22,11 @@ if (CCL_ENABLE_SYCL) native_device_api/l0/utils.cpp native_device_api/sycl/export.cpp native_device_api/interop_utils.cpp - #common/comm/l0/comm_context.cpp - #common/comm/l0/comm_context_storage.cpp - common/comm/single_device_communicator/single_device_communicator.cpp ) endif(CCL_ENABLE_SYCL) if (MULTI_GPU_SUPPORT) list (APPEND EXTENSIONS_SRC - ccl_gpu_modules.cpp ccl_cpp_utils.cpp native_device_api/l0/base.cpp @@ -44,9 +40,6 @@ list (APPEND EXTENSIONS_SRC native_device_api/l0/utils.cpp native_device_api/l0/primitives.cpp native_device_api/interop_utils.cpp - sched/gpu_sched.cpp - sched/gpu_concurrent_sched.cpp - common/event/impls/gpu_event.cpp common/comm/l0/comm_context.cpp common/comm/l0/comm_context_storage.cpp @@ -62,14 +55,6 @@ list (APPEND EXTENSIONS_SRC common/comm/l0/devices/communication_structs/ipc_server.cpp common/comm/l0/devices/communication_structs/ipc_client.cpp - common/comm/single_device_communicator/single_device_communicator.cpp - common/comm/l0/communicator/device_group/device_ring_communicator.cpp - common/comm/l0/communicator/device_group/device_a2a_communicator.cpp - common/comm/l0/communicator/thread_group/thread_ring_communicator.cpp - common/comm/l0/communicator/thread_group/thread_a2a_communicator.cpp - common/comm/l0/communicator/process_group/process_ring_communicator.cpp - common/comm/l0/communicator/process_group/process_a2a_communicator.cpp - common/comm/l0/context/process_group_ctx.cpp common/comm/l0/context/thread_group_ctx.cpp common/comm/l0/context/device_group_ctx.cpp @@ -91,9 +76,29 @@ list (APPEND EXTENSIONS_SRC common/comm/l0/gpu_comm_attr.cpp common/comm/l0/modules/base_entry_module.cpp common/comm/l0/modules/modules_source_data.cpp - common/comm/l0/modules/kernel_utils.cpp) + common/comm/l0/modules/kernel_utils.cpp + + sched/gpu_sched.cpp + sched/gpu_concurrent_sched.cpp + sched/entry/gpu/ze_cache.cpp + sched/entry/gpu/ze_call.cpp + sched/entry/gpu/ze_primitives.cpp) endif(MULTI_GPU_SUPPORT) +if (CCL_ENABLE_SYCL AND MULTI_GPU_SUPPORT) + list (APPEND EXTENSIONS_SRC + sched/entry/gpu/ze_base_entry.cpp + sched/entry/gpu/ze_allreduce_entry.cpp + sched/entry/gpu/ze_copy_entry.cpp + sched/entry/gpu/ze_handle_exchange_entry.cpp + sched/entry/gpu/ze_event_signal_entry.cpp + sched/entry/gpu/ze_event_wait_entry.cpp + sched/entry/gpu/ze_reduce_entry.cpp + sched/entry/reduce_local_entry.cpp + sched/ze_handle_manager.cpp + ) +endif(CCL_ENABLE_SYCL AND MULTI_GPU_SUPPORT) + set(CCL_SRC ccl_cpp_communicator.cpp ccl_cpp_environment.cpp @@ -122,6 +127,7 @@ set(CCL_SRC atl/atl_wrapper.cpp atl/mpi/atl_mpi.cpp atl/ofi/atl_ofi.cpp + atl/ofi/atl_ofi_helper.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable_simple_internal.cpp @@ -162,6 +168,8 @@ set(CCL_SRC coll/algorithms/reduce.cpp coll/algorithms/reduce_scatter.cpp coll/coll.cpp + coll/coll_check.cpp + coll/selection/selection.cpp coll/selection/selector_allgatherv.cpp coll/selection/selector_allreduce.cpp coll/selection/selector_alltoall.cpp @@ -176,11 +184,13 @@ set(CCL_SRC comp/comp.cpp comp/fp16/fp16.cpp comp/fp16/fp16_intrisics.cpp - hwloc/hwloc_wrapper.c + hwloc/hwloc_wrapper.cpp sched/sched.cpp + sched/buffer_cache.cpp sched/extra_sched.cpp sched/master_sched.cpp sched/sched_base.cpp + sched/sched_timer.cpp sched/cache/cache.cpp sched/cache/key.cpp sched/queue/flow_control.cpp @@ -188,6 +198,7 @@ set(CCL_SRC sched/queue/queue.cpp sched/entry/coll/coll_entry.cpp sched/entry/coll/coll_entry_helper.cpp + sched/entry/copy/copy_entry.cpp sched/entry/copy/copy_helper.cpp sched/entry/entry.cpp sched/entry/factory/chunked_entry_factory.cpp @@ -224,60 +235,99 @@ set(CCL_SRC ${EXTENSIONS_SRC}) -list(APPEND CCL_INC_DIRS +set(SRC_C_FLAGS) +set(SRC_CXX_FLAGS) +set(SRC_SHARED_LINKER_FLAGS) +set(SRC_INCLUDE_DIRS) +set(SRC_LINK_DIRS) +set(SRC_LINK_LIBS) + +# common settings of security options +if (USE_SECURITY_FLAGS) + set(SRC_C_FLAGS "${SRC_C_FLAGS} -Wformat -Wformat-security -D_FORTIFY_SOURCE=2 -fstack-protector") + set(SRC_CXX_FLAGS "${SRC_CXX_FLAGS} -Wformat -Wformat-security -D_FORTIFY_SOURCE=2 -fstack-protector") + set(SRC_SHARED_LINKER_FLAGS "${SRC_SHARED_LINKER_FLAGS} -fPIE -fPIC -z noexecstack -z relro -z now") + if (${CMAKE_C_COMPILER_ID} STREQUAL "GNU" AND ${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") + if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 4.9) + set(SRC_C_FLAGS "${SRC_C_FLAGS} -fstack-protector-strong") + set(SRC_CXX_FLAGS "${SRC_CXX_FLAGS} -fstack-protector-strong") + endif() + endif() +endif() + +set(SRC_SHARED_LINKER_FLAGS "${SRC_SHARED_LINKER_FLAGS} -Wl,--version-script=${PROJECT_SOURCE_DIR}/ccl.map") + +if (${CMAKE_C_COMPILER_ID} STREQUAL "Intel" OR ${CMAKE_CXX_COMPILER_ID} STREQUAL "Intel") + if (USE_CODECOV_FLAGS) + set(SRC_C_FLAGS "${SRC_C_FLAGS} -prof-gen=srcpos -prof-src-root-cwd") + set(SRC_CXX_FLAGS "${SRC_CXX_FLAGS} -prof-gen=srcpos -prof-src-root-cwd") + endif() +endif() + +list(APPEND SRC_INCLUDE_DIRS ${PROJECT_SOURCE_DIR}/include - ${MPI_INCLUDE_DIR} - ${LIBFABRIC_INCLUDE_DIR} - ${HWLOC_INCLUDE_DIR} ${PROJECT_SOURCE_DIR}/src - ${PROJECT_SOURCE_DIR}/src/atl) + ${PROJECT_SOURCE_DIR}/src/atl + ${LIBFABRIC_INCLUDE_DIR} + ${HWLOC_INCLUDE_DIR}) + +list(APPEND SRC_LINK_DIRS ${LIBFABRIC_LIB_DIR}) + +list(APPEND SRC_LINK_LIBS + dl + pthread + ${EXTERNAL_LIBS} + ${COMPUTE_BACKEND_TARGET_NAME} + fabric + ${HWLOC_LIB_DIR}/libhwloc.a) + +if (ENABLE_MPI) + set(SRC_C_FLAGS "${SRC_C_FLAGS} -DCCL_ENABLE_MPI") + set(SRC_CXX_FLAGS "${SRC_CXX_FLAGS} -DCCL_ENABLE_MPI") + list(APPEND SRC_INCLUDE_DIRS ${MPI_INCLUDE_DIR}) + list(APPEND SRC_LINK_DIRS ${MPI_LIB_DIR}) + list(APPEND SRC_LINK_LIBS mpi) +endif() + +link_directories(${SRC_LINK_DIRS}) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SRC_C_FLAGS} -pthread") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SRC_CXX_FLAGS} -pthread") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${SRC_SHARED_LINKER_FLAGS}") -message(STATUS "SRC C_FLAGS: ${CMAKE_C_FLAGS}") -message(STATUS "SRC CXX_FLAGS: ${CMAKE_CXX_FLAGS}") -message(STATUS "SRC SHARED_LINKER_FLAGS: ${CMAKE_SHARED_LINKER_FLAGS}") -message(STATUS "SRC INC_DIRS: ${CCL_INC_DIRS}") - -#special library that holds objects only +# special library that holds objects only add_library(ccl-objects OBJECT ${CCL_SRC}) set_target_properties(ccl-objects PROPERTIES POSITION_INDEPENDENT_CODE 1) -target_include_directories(ccl-objects PRIVATE ${CCL_INC_DIRS}) +target_include_directories(ccl-objects PRIVATE ${SRC_INCLUDE_DIRS}) -if(COMPUTE_BACKEND_TARGET_NAME) +if (COMPUTE_BACKEND_TARGET_NAME) target_include_directories(ccl-objects PRIVATE $) endif() -# add library search directory -link_directories(${MPI_LIB_DIR}) -link_directories(${LIBFABRIC_LIB_DIR}) - # shared library add_library(ccl SHARED $) -target_include_directories(ccl PUBLIC ${CCL_INC_DIRS}) +target_include_directories(ccl PUBLIC ${SRC_INCLUDE_DIRS}) # link with release_mt libmpi.so for oneAPI Base toolkit # libccl.so -> cpu_icc/cpu_gpu_dpcpp -> lib -> latest -> ccl -> mpi -> ... set(ONEAPI_IMPI_RPATH "'$ORIGIN'/../../../../mpi/latest/lib/release_mt/") set_target_properties(ccl PROPERTIES LINK_FLAGS "-Wl,-rpath,${ONEAPI_IMPI_RPATH}") -target_link_libraries(ccl PUBLIC - dl - pthread - fabric - mpi - ${HWLOC_LIB_DIR}/libhwloc.a - ${EXTERNAL_LIBS} - ${COMPUTE_BACKEND_TARGET_NAME}) +target_link_libraries(ccl PUBLIC ${SRC_LINK_LIBS}) if (NOT LIB_SO_VERSION AND NOT LIB_MAJOR_VERSION) - set_target_properties(ccl PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) + set_target_properties(ccl PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) else() - set_target_properties(ccl PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CCL_BUILD_DIR} VERSION ${LIB_SO_VERSION} SOVERSION ${LIB_MAJOR_VERSION}) + set_target_properties(ccl PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CCL_BUILD_DIR} VERSION ${LIB_SO_VERSION} SOVERSION ${LIB_MAJOR_VERSION}) endif() +message(STATUS "SRC C_FLAGS: ${CMAKE_C_FLAGS}") +message(STATUS "SRC CXX_FLAGS: ${CMAKE_CXX_FLAGS}") +message(STATUS "SRC SHARED_LINKER_FLAGS: ${CMAKE_SHARED_LINKER_FLAGS}") +message(STATUS "SRC INCLUDE_DIRS: ${SRC_INCLUDE_DIRS}") +message(STATUS "SRC LINK_DIRS: ${SRC_LINK_DIRS}") +message(STATUS "SRC LINK_LIBS: ${SRC_LINK_LIBS}") + install(TARGETS ccl LIBRARY DESTINATION ${CCL_INSTALL_LIB}) install(FILES "${PROJECT_SOURCE_DIR}/cmake/FindComputeCpp.cmake" @@ -299,21 +349,21 @@ install(TARGETS ccl-static ARCHIVE DESTINATION ${CCL_INSTALL_LIB} OPTIONAL) install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/ DESTINATION ${CCL_INSTALL_INCLUDE} FILES_MATCHING REGEX ".*\\.(h|hpp)$") -# MPI and OFI runtimes -file(GLOB mpi_bins "${DEPS_DIR}/mpi/bin/*") -install(PROGRAMS ${mpi_bins} DESTINATION ${CCL_INSTALL_BIN}) - -install(DIRECTORY ${DEPS_DIR}/ofi/lib/ - DESTINATION ${CCL_INSTALL_LIB}) - -install(DIRECTORY ${DEPS_DIR}/mpi/include/ - DESTINATION ${CCL_INSTALL_INCLUDE}) - -install(DIRECTORY ${DEPS_DIR}/mpi/lib/ - DESTINATION ${CCL_INSTALL_LIB}) - -install(DIRECTORY ${DEPS_DIR}/mpi/etc/ - DESTINATION ${CCL_INSTALL_ETC}) +if ("${LIBFABRIC_DIR}" STREQUAL "") + # internal libfabric is used, install it into package + install(DIRECTORY ${DEPS_DIR}/ofi/lib/ + DESTINATION ${CCL_INSTALL_LIB}) +endif() -install(DIRECTORY ${DEPS_DIR}/mpi/licensing/ - DESTINATION ${CCL_INSTALL_LICENSE}/mpi/) +if (ENABLE_MPI) + file(GLOB mpi_bins "${DEPS_DIR}/mpi/bin/*") + install(PROGRAMS ${mpi_bins} DESTINATION ${CCL_INSTALL_BIN}) + install(DIRECTORY ${DEPS_DIR}/mpi/include/ + DESTINATION ${CCL_INSTALL_INCLUDE}) + install(DIRECTORY ${DEPS_DIR}/mpi/lib/ + DESTINATION ${CCL_INSTALL_LIB}) + install(DIRECTORY ${DEPS_DIR}/mpi/etc/ + DESTINATION ${CCL_INSTALL_ETC}) + install(DIRECTORY ${DEPS_DIR}/mpi/licensing/ + DESTINATION ${CCL_INSTALL_LICENSE}/mpi/) +endif() diff --git a/src/atl/atl_def.h b/src/atl/atl_def.h index 5fd6ab20c..3e02a0aa8 100644 --- a/src/atl/atl_def.h +++ b/src/atl/atl_def.h @@ -17,6 +17,7 @@ #include #include +#include #ifndef container_of #define container_of(ptr, type, field) ((type*)((char*)ptr - offsetof(type, field))) @@ -32,8 +33,9 @@ #define SIZEOFARR(arr) (sizeof(arr) / sizeof(arr[0])) #define ATL_CACHELINE_LEN 64 -#define ATL_REQ_SIZE 8 +#define ATL_REQ_SIZE 16 #define ATL_PROGRESS_MODE_ENV "ATL_PROGRESS_MODE" +#define ATL_MAX_HOSTNAME_LEN 64 #define DIR_SEP '/' #define FILENAME (strrchr(__FILE__, DIR_SEP) ? strrchr(__FILE__, DIR_SEP) + 1 : __FILE__) @@ -43,16 +45,11 @@ * This is invoked by the ATL framework when the transport library is loaded. */ -#define ATL_EXT_INI atl_status_t atl_ini(atl_transport_t* atl_transport) - -#define ATL_OFI_INI ATL_EXT_INI -#define ATL_MPI_INI ATL_EXT_INI - #define ATL_CALL(func, err_action) \ do { \ atl_status_t status = func; \ if (status != FI_SUCCESS) { \ - LOG_ERROR(#func "\n fails with status: ", status); \ + CCL_THROW(#func "\n fails with status: ", status); \ err_action; \ } \ } while (0) @@ -115,17 +112,18 @@ typedef struct { struct { int enable_shm; int enable_rma; - int enable_device_buf; + int enable_hmem; int enable_sync_coll; int enable_extra_ep; size_t ep_count; atl_mnic_t mnic_type; + std::string mnic_name; size_t mnic_count; } in; struct { int enable_shm; int enable_rma; - int enable_device_buf; + int enable_hmem; atl_mnic_t mnic_type; size_t mnic_count; size_t tag_bits; @@ -146,6 +144,7 @@ typedef struct { int global_count; int local_idx; int local_count; + size_t hostname_hash; } atl_proc_coord_t; typedef struct { @@ -154,29 +153,7 @@ typedef struct { void* internal[ATL_REQ_SIZE]; } atl_req_t __attribute__((aligned(ATL_CACHELINE_LEN))); -typedef struct { - const char* name; - atl_status_t (*init)(int* argc, - char*** argv, - atl_attr_t* attr, - atl_ctx_t** ctx, - const char* main_addr, - ipmi* pmi); - atl_status_t (*reserve_addr)(char* main_addr); -} atl_transport_t; - -typedef struct { - atl_status_t (*finalize)(atl_ctx_t* ctx); -} atl_ops_t; - -typedef struct { - atl_status_t (*mr_reg)(atl_ctx_t* ctx, const void* buf, size_t len, atl_mr_t** mr); - atl_status_t (*mr_dereg)(atl_ctx_t* ctx, atl_mr_t* mr); -} atl_mr_ops_t; - struct atl_ctx { - atl_ops_t* ops; - atl_mr_ops_t* mr_ops; atl_proc_coord_t coord; size_t ep_count; @@ -189,96 +166,7 @@ struct atl_ctx { count - for iov and for dtype-arrays like in reduce/allreduce */ -typedef struct { - atl_status_t (*send)(atl_ep_t* ep, - const void* buf, - size_t len, - int dst_proc_idx, - uint64_t tag, - atl_req_t* req); - atl_status_t ( - *recv)(atl_ep_t* ep, void* buf, size_t len, int src_proc_idx, uint64_t tag, atl_req_t* req); - atl_status_t ( - *probe)(atl_ep_t* ep, int src_proc_idx, uint64_t tag, int* found, size_t* recv_len); -} atl_p2p_ops_t; - -typedef struct { - /* order convention - keep alphabetical order */ - atl_status_t (*allgatherv)(atl_ep_t* ep, - const void* send_buf, - size_t send_len, - void* recv_buf, - const int* recv_lens, - const int* offsets, - atl_req_t* req); - atl_status_t (*allreduce)(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t count, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req); - atl_status_t ( - *alltoall)(atl_ep_t* ep, const void* send_buf, void* recv_buf, size_t len, atl_req_t* req); - atl_status_t (*alltoallv)(atl_ep_t* ep, - const void* send_buf, - const int* send_lens, - const int* send_offsets, - void* recv_buf, - const int* recv_lens, - const int* recv_offsets, - atl_req_t* req); - atl_status_t (*barrier)(atl_ep_t* ep, atl_req_t* req); - atl_status_t (*bcast)(atl_ep_t* ep, void* buf, size_t len, int root, atl_req_t* req); - atl_status_t (*reduce)(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t count, - int root, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req); - atl_status_t (*reduce_scatter)(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t recv_count, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req); -} atl_coll_ops_t; - -typedef struct { - atl_status_t (*read)(atl_ep_t* ep, - void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req); - atl_status_t (*write)(atl_ep_t* ep, - const void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req); -} atl_rma_ops_t; - -typedef struct { - atl_status_t (*wait)(atl_ep_t* ep, atl_req_t* req); - atl_status_t (*wait_all)(atl_ep_t* ep, atl_req_t* reqs, size_t count); - atl_status_t (*cancel)(atl_ep_t* ep, atl_req_t* req); - atl_status_t (*poll)(atl_ep_t* ep); - atl_status_t (*check)(atl_ep_t* ep, int* is_completed, atl_req_t* req); -} atl_comp_ops_t; - struct atl_ep { size_t idx; atl_ctx_t* ctx; - atl_p2p_ops_t* p2p_ops; - atl_coll_ops_t* coll_ops; - atl_rma_ops_t* rma_ops; - atl_comp_ops_t* comp_ops; }; diff --git a/src/atl/atl_wrapper.cpp b/src/atl/atl_wrapper.cpp index fc1bed457..00eba6f9d 100644 --- a/src/atl/atl_wrapper.cpp +++ b/src/atl/atl_wrapper.cpp @@ -18,7 +18,9 @@ #include "atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h" #include "atl/util/pm/pmi_resizable_rt/pmi_resizable.h" #include "atl/ofi/atl_ofi.hpp" +#ifdef CCL_ENABLE_MPI #include "atl/mpi/atl_mpi.hpp" +#endif // CCL_ENABLE_MPI #include "atl_wrapper.h" #include "common/global/global.hpp" #include "exec/exec.hpp" @@ -32,11 +34,12 @@ atl_attr_t atl_wrapper::attr = { { 0, /* enable_shm */ 0, /* enable_rma */ - 0, /* enable_device_buf */ + 0, /* enable_hmem */ 0, /* enable_sync_coll */ 0, /* enable_extra_ep */ 1, /* ep_count */ ATL_MNIC_NONE, /* mnic_type */ + "", /* mnic_name */ 1 /* mnic_count */ }, @@ -47,10 +50,12 @@ atl_attr_t atl_wrapper::attr = { void atl_wrapper::set_internal_env(const atl_attr_t& attr) { auto transport_type = ccl::global_data::env().atl_transport; - if (transport_type == ccl_atl_mpi) - atl_mpi::atl_set_env(attr); - else if (transport_type == ccl_atl_ofi) + if (transport_type == ccl_atl_ofi) atl_ofi::atl_set_env(attr); +#ifdef CCL_ENABLE_MPI + else if (transport_type == ccl_atl_mpi) + atl_mpi::atl_set_env(attr); +#endif // CCL_ENABLE_MPI } void atl_wrapper::set_exec(ccl_executor* exec) { @@ -81,7 +86,9 @@ atl_wrapper::atl_wrapper() { } transport = std::shared_ptr(new atl_ofi()); break; +#ifdef CCL_ENABLE_MPI case ccl_atl_mpi: transport = std::shared_ptr(new atl_mpi()); break; +#endif // CCL_ENABLE_MPI default: LOG_ERROR("Unsupported yet"); break; } @@ -111,7 +118,9 @@ atl_wrapper::atl_wrapper(std::shared_ptr k) { } transport = std::shared_ptr(new atl_ofi()); break; +#ifdef CCL_ENABLE_MPI case ccl_atl_mpi: transport = std::shared_ptr(new atl_mpi()); break; +#endif // CCL_ENABLE_MPI default: LOG_ERROR("Unsupported yet"); break; } @@ -148,7 +157,9 @@ atl_wrapper::atl_wrapper(int total_rank_count, transport = transports.back(); } } break; +#ifdef CCL_ENABLE_MPI case ccl_atl_mpi: transport = std::shared_ptr(new atl_mpi()); break; +#endif // CCL_ENABLE_MPI default: LOG_ERROR("Unsupported yet"); break; } @@ -174,19 +185,21 @@ void atl_wrapper::init_transport() { rank = pmi->get_rank(); size = pmi->get_size(); } +#ifdef CCL_ENABLE_MPI else { threads_per_process = 1; ranks_per_process = 1; rank = static_cast(transport.get())->get_rank(); size = static_cast(transport.get())->get_size(); } +#endif // CCL_ENABLE_MPI if (rank == 0) { tag->print(); LOG_INFO("atl-in-attrs:"); LOG_INFO(" enable_shm: ", attr.in.enable_shm); LOG_INFO(" enable_rma: ", attr.in.enable_rma); - LOG_INFO(" enable_device_buf: ", attr.in.enable_device_buf); + LOG_INFO(" enable_hmem: ", attr.in.enable_hmem); LOG_INFO(" enable_sync_coll: ", attr.in.enable_sync_coll); LOG_INFO(" enable_extra_ep: ", attr.in.enable_extra_ep); LOG_INFO(" ep_count: ", attr.in.ep_count); @@ -196,7 +209,7 @@ void atl_wrapper::init_transport() { LOG_INFO("atl-out-attrs:"); LOG_INFO(" enable_shm: ", attr.out.enable_shm); LOG_INFO(" enable_rma: ", attr.out.enable_rma); - LOG_INFO(" enable_device_buf: ", attr.out.enable_device_buf); + LOG_INFO(" enable_hmem: ", attr.out.enable_hmem); LOG_INFO(" mnic_type: ", attr.out.mnic_type); LOG_INFO(" mnic_count: ", attr.out.mnic_count); LOG_INFO(" tag_bits: ", attr.out.tag_bits); diff --git a/src/atl/atl_wrapper.h b/src/atl/atl_wrapper.h index 5edad0aa3..4c5ec38e1 100644 --- a/src/atl/atl_wrapper.h +++ b/src/atl/atl_wrapper.h @@ -242,6 +242,14 @@ class atl_wrapper { return size; } + int get_r2r_color() { + return transport->atl_get_proc_coord()->local_idx; + } + + int get_host_color() { + return transport->atl_get_proc_coord()->hostname_hash; + } + /* * TODO: Temporary change. * Need to define correct to unique id diff --git a/src/atl/mpi/atl_mpi.cpp b/src/atl/mpi/atl_mpi.cpp index f76df1a4b..2e742d80e 100644 --- a/src/atl/mpi/atl_mpi.cpp +++ b/src/atl/mpi/atl_mpi.cpp @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +#ifdef CCL_ENABLE_MPI + #include "atl_mpi.hpp" #include "atl_mpi_impl.cpp" @@ -195,3 +197,5 @@ atl_mpi::~atl_mpi() { if (!is_finalized) atl_finalize(); } + +#endif // CCL_ENABLE_MPI diff --git a/src/atl/mpi/atl_mpi.hpp b/src/atl/mpi/atl_mpi.hpp index 1a98e3419..03760a0ae 100644 --- a/src/atl/mpi/atl_mpi.hpp +++ b/src/atl/mpi/atl_mpi.hpp @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +#ifdef CCL_ENABLE_MPI + #include "atl.h" class atl_mpi final : public iatl { @@ -159,3 +161,5 @@ class atl_mpi final : public iatl { bool is_finalized{ false }; bool inited{ false }; }; + +#endif // CCL_ENABLE_MPI diff --git a/src/atl/mpi/atl_mpi_impl.cpp b/src/atl/mpi/atl_mpi_impl.cpp index 923d538bb..95636ac82 100644 --- a/src/atl/mpi/atl_mpi_impl.cpp +++ b/src/atl/mpi/atl_mpi_impl.cpp @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +#ifdef CCL_ENABLE_MPI + #include #include #include @@ -46,7 +48,7 @@ typedef enum { ATL_MPI_LIB_IMPI, ATL_MPI_LIB_MPICH, ATL_MPI_LIB_NONE } atl_mpi_l typedef struct { atl_mpi_lib_type_t type; - int device_buf; + int hmem; } atl_mpi_lib_attr_t; typedef struct { @@ -62,8 +64,8 @@ typedef struct { /* minimal expected version of library, mandatory */ int min_version_value; - /* minimal expected version of library with device_buf support, mandatory */ - int min_device_buf_version_value; + /* minimal expected version of library with hmem support, mandatory */ + int min_hmem_version_value; /* string prefix before library kind, optional */ const char* kind_prefix; @@ -89,11 +91,11 @@ static atl_mpi_lib_info_t mpi_lib_infos[MPI_LIB_INFO_MAX_COUNT] = { #ifdef CCL_BF16_COMPILER #define ATL_MPI_BF16 -#endif /* CCL_BF16_COMPILER */ +#endif // CCL_BF16_COMPILER #ifdef CCL_FP16_COMPILER #define ATL_MPI_FP16 -#endif /* CCL_FP16_COMPILER */ +#endif // CCL_FP16_COMPILER typedef struct { // custom MPI operations for BF16 @@ -132,7 +134,7 @@ typedef struct atl_mpi_global_data { mnic_type(ATL_MNIC_NONE), mnic_count(1) { mpi_lib_attr.type = ATL_MPI_LIB_NONE; - mpi_lib_attr.device_buf = 0; + mpi_lib_attr.hmem = 0; bf16.dtype = MPI_DATATYPE_NULL; bf16.sum_op = MPI_OP_NULL; @@ -277,7 +279,7 @@ static void BF16_TARGET_ATTRIBUTE_ALL atl_mpi_bf16_max_op(void* in, atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); atl_mpi_bf16_base_op(in, inout, length, ccl::reduction::max); } -#endif /* ATL_MPI_BF16 */ +#endif // ATL_MPI_BF16 #ifdef ATL_MPI_FP16 @@ -323,7 +325,7 @@ static void FP16_TARGET_ATTRIBUTE_ALL atl_mpi_fp16_max_op(void* in, atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); atl_mpi_fp16_base_op(in, inout, length, ccl::reduction::max); } -#endif /* ATL_MPI_FP16 */ +#endif // ATL_MPI_FP16 static int atl_mpi_bf16_init() { int ret = MPI_SUCCESS; @@ -381,7 +383,7 @@ static int atl_mpi_bf16_init() { return RET2ATL(ret); } -#endif /* ATL_MPI_BF16 */ +#endif // ATL_MPI_BF16 return RET2ATL(ret); } @@ -464,7 +466,7 @@ static int atl_mpi_fp16_init() { return RET2ATL(ret); } -#endif /* ATL_MPI_FP16 */ +#endif // ATL_MPI_FP16 return RET2ATL(ret); } @@ -519,7 +521,7 @@ static MPI_Op atl2mpi_op_bf16(atl_reduction_t rtype) { default: printf("unknown reduction type: %d\n", rtype); exit(1); } } -#endif /* ATL_MPI_BF16 */ +#endif // ATL_MPI_BF16 #ifdef ATL_MPI_FP16 static MPI_Op atl2mpi_op_fp16(atl_reduction_t rtype) { @@ -531,18 +533,18 @@ static MPI_Op atl2mpi_op_fp16(atl_reduction_t rtype) { default: printf("unknown reduction type: %d\n", rtype); exit(1); } } -#endif /* ATL_MPI_FP16 */ +#endif // ATL_MPI_FP16 static MPI_Op atl2mpi_op(atl_reduction_t rtype, MPI_Datatype dtype) { #ifdef ATL_MPI_BF16 if (dtype == global_data.bf16.dtype) return atl2mpi_op_bf16(rtype); -#endif /* ATL_MPI_BF16 */ +#endif // ATL_MPI_BF16 #ifdef ATL_MPI_FP16 if (dtype == global_data.fp16.dtype) return atl2mpi_op_fp16(rtype); -#endif /* ATL_MPI_FP16 */ +#endif // ATL_MPI_FP16 (void)dtype; switch (rtype) { @@ -683,8 +685,7 @@ atl_mpi_lib_attr_t atl_mpi_get_lib_attr() { ")"); lib_attr.type = final_info->type; - lib_attr.device_buf = - (final_info->min_device_buf_version_value >= version_value) ? 1 : 0; + lib_attr.hmem = (final_info->min_hmem_version_value >= version_value) ? 1 : 0; break; } @@ -721,7 +722,7 @@ atl_status_t atl_mpi_set_base_env(const atl_attr_t& attr) { #ifdef CCL_ENABLE_SYCL setenv("FI_SHM_DISABLE_CMA", "1", 0); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL setenv("MPIR_CVAR_DEFAULT_THREAD_LEVEL", "MPI_THREAD_MULTIPLE", 0); @@ -740,18 +741,19 @@ atl_status_t atl_mpi_set_impi_env(const atl_attr_t& attr, const atl_mpi_lib_attr #ifdef CCL_ENABLE_SYCL setenv("I_MPI_SHM_CMA", "0", 0); - if (attr.in.enable_device_buf && lib_attr.device_buf) { + if (attr.in.enable_hmem && lib_attr.hmem) { setenv("I_MPI_OFFLOAD", "2", 0); setenv("I_MPI_OFFLOAD_TOPOLIB", "l0", 0); setenv("I_MPI_OFFLOAD_QUEUE_CACHE", "1", 0); setenv("I_MPI_OFFLOAD_LIST_CACHE", "1", 0); + setenv("I_MPI_OFFLOAD_MEMCPY_KIND", "blocked", 0); if (attr.in.ep_count > 1) { /* try to set global lock level before vci level because setenv is invoked with overwrite=0 */ setenv("I_MPI_THREAD_LOCK_LEVEL", "global", 0); } } -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL setenv("I_MPI_THREAD_SPLIT", "1", 0); setenv("I_MPI_THREAD_RUNTIME", "generic", 0); @@ -976,7 +978,7 @@ static atl_status_t atl_mpi_finalize(atl_ctx_t* ctx) { free(mpi_ep); } if ((global_data.ctx_count == 0) && (ctx->coord.global_idx == 0)) { - LOG_WARN("MPI_Finalize has been called"); + LOG_WARN("MPI_Finalize has been called before CCL finalization"); } } @@ -1440,41 +1442,6 @@ static atl_status_t atl_mpi_ep_check(atl_ep_t* ep, int* is_completed, atl_req_t* return status; } -static atl_ops_t atl_mpi_ops = { - .finalize = atl_mpi_finalize, -}; - -static atl_mr_ops_t atl_mpi_mr_ops = { - .mr_reg = atl_mpi_mr_reg, - .mr_dereg = atl_mpi_mr_dereg, -}; - -static atl_p2p_ops_t atl_mpi_ep_p2p_ops = { - .send = atl_mpi_ep_send, - .recv = atl_mpi_ep_recv, - .probe = atl_mpi_ep_probe, -}; - -static atl_coll_ops_t atl_mpi_ep_coll_ops = { .allgatherv = atl_mpi_ep_allgatherv, - .allreduce = atl_mpi_ep_allreduce, - .alltoall = atl_mpi_ep_alltoall, - .alltoallv = atl_mpi_ep_alltoallv, - .barrier = atl_mpi_ep_barrier, - .bcast = atl_mpi_ep_bcast, - .reduce = atl_mpi_ep_reduce, - .reduce_scatter = atl_mpi_ep_reduce_scatter }; - -static atl_rma_ops_t atl_mpi_ep_rma_ops = { - .read = atl_mpi_ep_read, - .write = atl_mpi_ep_write, -}; - -static atl_comp_ops_t atl_mpi_ep_comp_ops = { .wait = atl_mpi_ep_wait, - .wait_all = atl_mpi_ep_wait_all, - .cancel = NULL, - .poll = atl_mpi_ep_poll, - .check = atl_mpi_ep_check }; - static atl_status_t atl_mpi_ep_init(atl_mpi_ctx_t* mpi_ctx, size_t idx, atl_ep_t** ep) { int ret; @@ -1535,10 +1502,6 @@ static atl_status_t atl_mpi_ep_init(atl_mpi_ctx_t* mpi_ctx, size_t idx, atl_ep_t *ep = &mpi_ep->ep; (*ep)->idx = idx; (*ep)->ctx = &mpi_ctx->ctx; - (*ep)->p2p_ops = &atl_mpi_ep_p2p_ops; - (*ep)->coll_ops = &atl_mpi_ep_coll_ops; - (*ep)->rma_ops = &atl_mpi_ep_rma_ops; - (*ep)->comp_ops = &atl_mpi_ep_comp_ops; return ATL_STATUS_SUCCESS; @@ -1567,6 +1530,8 @@ static atl_status_t atl_mpi_init(int* argc, void* tag_ub_ptr = NULL; int required_thread_level = MPI_THREAD_MULTIPLE, provided_thread_level; + char my_hostname[ATL_MAX_HOSTNAME_LEN] = { 0 }; + atl_mpi_ctx_t* mpi_ctx = (atl_mpi_ctx_t*)calloc(1, sizeof(atl_mpi_ctx_t)); if (!mpi_ctx) return ATL_STATUS_FAILURE; @@ -1654,8 +1619,9 @@ static atl_status_t atl_mpi_init(int* argc, MPI_Comm_size(local_comm, (int*)&(coord->local_count)); MPI_Comm_free(&local_comm); - ctx->ops = &atl_mpi_ops; - ctx->mr_ops = &atl_mpi_mr_ops; + gethostname(my_hostname, ATL_MAX_HOSTNAME_LEN - 1); + coord->hostname_hash = std::hash{}(my_hostname); + ctx->ep_count = attr->in.ep_count; ctx->eps = (atl_ep_t**)calloc(1, sizeof(void*) * attr->in.ep_count); if (!ctx->eps) @@ -1676,7 +1642,7 @@ static atl_status_t atl_mpi_init(int* argc, LOG_INFO("atl-mpi-global:") LOG_INFO(" is_external_init: ", global_data.is_external_init); LOG_INFO(" mpi_lib_attr.type: ", mpi_lib_infos[global_data.mpi_lib_attr.type].name); - LOG_INFO(" mpi_lib_attr.device_buf: ", global_data.mpi_lib_attr.device_buf); + LOG_INFO(" mpi_lib_attr.hmem: ", global_data.mpi_lib_attr.hmem); LOG_INFO(" extra_ep: ", global_data.extra_ep); LOG_INFO(" mnic_type: ", global_data.mnic_type); if (global_data.mnic_type != ATL_MNIC_NONE) @@ -1700,7 +1666,7 @@ static atl_status_t atl_mpi_init(int* argc, /* report actual attributes back to upper level */ attr->out.enable_shm = 0; attr->out.enable_rma = 0; - attr->out.enable_device_buf = attr->in.enable_device_buf & global_data.mpi_lib_attr.device_buf; + attr->out.enable_hmem = attr->in.enable_hmem & global_data.mpi_lib_attr.hmem; attr->out.mnic_type = global_data.mnic_type; attr->out.mnic_count = global_data.mnic_count; attr->out.tag_bits = 32; @@ -1741,3 +1707,5 @@ static atl_status_t atl_mpi_init(int* argc, atl_status_t atl_mpi_main_addr_reserve(char* main_addr) { return ATL_STATUS_UNSUPPORTED; } + +#endif // CCL_ENABLE_MPI diff --git a/src/atl/ofi/atl_ofi.cpp b/src/atl/ofi/atl_ofi.cpp index d753ea040..21bc2cb6f 100644 --- a/src/atl/ofi/atl_ofi.cpp +++ b/src/atl/ofi/atl_ofi.cpp @@ -14,7 +14,150 @@ limitations under the License. */ #include "atl_ofi.hpp" -#include "atl_ofi_impl.cpp" + +#ifdef CCL_ENABLE_SYCL +#include "common/utils/sycl_utils.hpp" +#endif // CCL_ENABLE_SYCL + +// cache + +void atl_ofi::fi_cache::clear() { + for (auto& instance : memory_regions) { + instance.clear(); + } +} + +void atl_ofi::fi_cache::init(size_t instance_count, int enable_hmem) { + this->enable_hmem = enable_hmem; + memory_regions.resize(instance_count); +} + +atl_ofi::fi_cache::~fi_cache() { + clear(); +} + +void atl_ofi::fi_cache::get(size_t idx, fid_domain* domain, void* buf, size_t bytes, fid_mr** mr) { + CCL_THROW_IF_NOT(mr); + *mr = nullptr; +#ifdef CCL_ENABLE_OFI_HMEM + if (enable_hmem) { + memory_regions.at(idx % memory_regions.size()).get(domain, buf, bytes, mr); + } +#endif // CCL_ENABLE_OFI_HMEM +} + +void atl_ofi::fi_cache::push(size_t idx, fid_mr* mr) { +#ifdef CCL_ENABLE_OFI_HMEM + if (mr) + memory_regions.at(idx % memory_regions.size()).push(mr); +#endif // CCL_ENABLE_OFI_HMEM +} + +atl_ofi::mr_cache::~mr_cache() { + if (!cache.empty()) { + LOG_WARN("mr cache is not empty, size: ", cache.size()); + clear(); + } +} + +void atl_ofi::mr_cache::clear() { + LOG_DEBUG("mr cache size: ", cache.size()); + for (auto& key_value : cache) { + fi_close(&key_value.second->fid); + } + cache.clear(); +} + +void atl_ofi::mr_cache::get(fid_domain* domain, void* buf, size_t bytes, fid_mr** mr) { + CCL_THROW_IF_NOT(domain); + CCL_THROW_IF_NOT(mr); + + if (ccl::global_data::env().enable_atl_cache) { + key_t key(domain, buf, bytes); + auto key_value = cache.find(key); + if (key_value != cache.end()) { + *mr = key_value->second; + LOG_DEBUG("loaded from mr cache: buf: ", buf, ", bytes: ", bytes); + return; + } + } + + struct fi_mr_attr mr_attr = {}; + struct iovec iov = {}; + + iov.iov_base = buf; + iov.iov_len = bytes; + mr_attr.mr_iov = &iov; + mr_attr.iov_count = 1; + mr_attr.access = FI_SEND | FI_RECV | FI_REMOTE_READ | FI_REMOTE_WRITE; + mr_attr.requested_key = mr_key++; + +#ifdef CCL_ENABLE_OFI_HMEM + + mr_attr.iface = FI_HMEM_SYSTEM; + mr_attr.device.ze = 0; + + atl_ofi_ze_data& ze_data = global_data.ze_data; + ze_memory_allocation_properties_t alloc_props = ccl::ze::default_alloc_props; + ze_device_handle_t alloc_dev = nullptr; + ZE_CALL(zeMemGetAllocProperties, (ze_data.context, buf, &alloc_props, &alloc_dev)); + + LOG_DEBUG("alloc_props: dev ", alloc_dev, ", type ", alloc_props.type); + + if (alloc_props.type == ZE_MEMORY_TYPE_HOST || alloc_props.type == ZE_MEMORY_TYPE_DEVICE || + alloc_props.type == ZE_MEMORY_TYPE_SHARED) { + mr_attr.iface = FI_HMEM_ZE; + } + + if (alloc_dev) { + ze_device_properties_t alloc_dev_props = ccl::ze::default_device_props; + ZE_CALL(zeDeviceGetProperties, (alloc_dev, &alloc_dev_props)); + + int dev_idx = -1; + for (int idx = 0; idx < ze_data.device_count; idx++) { + ze_device_properties_t dev_props = ccl::ze::default_device_props; + ZE_CALL(zeDeviceGetProperties, (ze_data.devices[idx], &dev_props)); + + if (!std::memcmp(&dev_props.uuid, &alloc_dev_props.uuid, sizeof(ze_device_uuid_t))) { + dev_idx = idx; + LOG_DEBUG("buf ", buf, " corresponds to ze device idx ", dev_idx); + break; + } + } + CCL_THROW_IF_NOT(dev_idx != -1); + mr_attr.device.ze = dev_idx; + } +#endif // CCL_ENABLE_OFI_HMEM + + int ofi_ret; + ATL_OFI_CALL(fi_mr_regattr(domain, &mr_attr, 0, mr), + ofi_ret, + CCL_THROW("failed to register mr, ret: ", + ofi_ret, + ", buf: ", + buf, + ", bytes: ", + bytes, + ", iface: ", + mr_attr.iface)); + + if (ccl::global_data::env().enable_atl_cache) { + key_t key(domain, buf, bytes); + LOG_DEBUG("inserted to mr cache: buf: ", buf, ", bytes: ", bytes); + cache.insert({ std::move(key), *mr }); + } +} + +void atl_ofi::mr_cache::push(fid_mr* mr) { + CCL_THROW_IF_NOT(mr); + if (ccl::global_data::env().enable_atl_cache) { + /* do nothing, all mem regions will be closed in clear() */ + return; + } + fi_close(&mr->fid); +} + +// atl_ofi atl_status_t atl_ofi::atl_set_env(const atl_attr_t& attr) { return atl_ofi_set_env(attr); @@ -26,12 +169,348 @@ atl_status_t atl_ofi::atl_init(int* argc, const char* main_addr, std::unique_ptr& pmi) { inited = true; - return atl_ofi_init(argc, argv, attr, &ctx, main_addr, pmi.get()); + struct fi_info *prov_list = nullptr, *base_hints = nullptr, *prov_hints = nullptr; + int fi_version; + ssize_t ret = 0; + size_t idx = 0, ep_idx = 0, prov_idx = 0; + char* prov_name = nullptr; + char* prov_env = nullptr; + char* fi_version_env = nullptr; + atl_ofi_prov_t* prov = nullptr; + char *max_retry_count_env = nullptr, *progress_mode_env = nullptr; + int open_nw_provs = 1; + int enable_shm = 0; + + CCL_THROW_IF_NOT((sizeof(atl_ofi_req_t) <= sizeof(atl_req_t) - offsetof(atl_req_t, internal)), + "unexpected offset: atl_ofi_request size ", + sizeof(atl_ofi_req_t), + ", atl_request size ", + sizeof(atl_req_t), + ", expected offset ", + offsetof(atl_req_t, internal)); + + if (global_data.ctx_count == 0) { + ret = atl_ofi_set_env(*attr); + if (ret != ATL_STATUS_SUCCESS) { + LOG_ERROR("atl_ofi_set_env error"); + return ATL_STATUS_FAILURE; + } + + fi_version_env = getenv(ATL_OFI_MAJOR_VERSION); + if (fi_version_env) { + global_data.fi_major_version = safe_c_strtol(fi_version_env, nullptr, 10); + } + + fi_version_env = getenv(ATL_OFI_MINOR_VERSION); + if (fi_version_env) { + global_data.fi_minor_version = safe_c_strtol(fi_version_env, nullptr, 10); + } + + LOG_INFO("fi_version: ", global_data.fi_major_version, ".", global_data.fi_minor_version); + +#ifdef CCL_ENABLE_OFI_HMEM + atl_ofi_init_ze_data(); +#endif // CCL_ENABLE_OFI_HMEM + } + global_data.ctx_count++; + + atl_ofi_ctx_t* ofi_ctx; + ofi_ctx = (atl_ofi_ctx_t*)calloc(1, sizeof(atl_ofi_ctx_t)); + if (!ofi_ctx) + return ATL_STATUS_FAILURE; + + ctx = &(ofi_ctx->ctx); + + ctx->ep_count = attr->in.ep_count; + ctx->eps = (atl_ep**)calloc(1, sizeof(void*) * attr->in.ep_count); + if (!ctx->eps) + goto err; + + ctx->coord.global_count = pmi->get_size(); + ctx->coord.global_idx = pmi->get_rank(); + + ret = atl_ofi_get_local_proc_coord(ofi_ctx, pmi); + if (ret) { + LOG_ERROR("atl_ofi_get_local_proc_coord error"); + goto err; + } + + atl_proc_coord_t* coord; + coord = &(ctx->coord); + + base_hints = fi_allocinfo(); + if (!base_hints) { + LOG_ERROR("can't alloc base_hints"); + goto err; + } + + base_hints->mode = FI_CONTEXT; + base_hints->ep_attr->type = FI_EP_RDM; + base_hints->domain_attr->resource_mgmt = FI_RM_ENABLED; + base_hints->domain_attr->control_progress = FI_PROGRESS_MANUAL; + base_hints->domain_attr->data_progress = FI_PROGRESS_MANUAL; + base_hints->caps = FI_TAGGED; + + prov_env = getenv("FI_PROVIDER"); + + ofi_ctx->enable_hmem = 0; + +#ifdef CCL_ENABLE_OFI_HMEM + if (prov_env && strstr(prov_env, "verbs") && attr->in.enable_hmem) { + ofi_ctx->enable_hmem = 1; + } + + if (ofi_ctx->enable_hmem) { + base_hints->caps |= FI_HMEM; + base_hints->domain_attr->mr_mode = + (FI_MR_ALLOCATED | FI_MR_PROV_KEY | FI_MR_VIRT_ADDR | FI_MR_LOCAL | FI_MR_HMEM); + + /* TODO: enable shm with HMEM */ + attr->in.enable_shm = 0; + + /* TODO: implement fallback logic if HMEM can't be enabled */ + } +#endif // CCL_ENABLE_OFI_HMEM + + cache.init(attr->in.ep_count, ofi_ctx->enable_hmem); + + fi_version = FI_VERSION(global_data.fi_major_version, global_data.fi_minor_version); + + if (coord->global_idx == 0) + LOG_INFO("libfabric version: ", fi_tostr("1" /* ignored */, FI_TYPE_VERSION)); + + if (prov_env && !strcmp(prov_env, ATL_OFI_SHM_PROV_NAME)) { + if (coord->global_count != coord->local_count) { + LOG_ERROR("shm provider is requested as primary provider but global_count (", + coord->global_count, + ") != local_count (", + coord->local_count, + ")"); + goto err; + } + + if (!attr->in.enable_shm) { + LOG_ERROR( + "shm provider is requested through FI_PROVIDER but not requested from CCL level"); + goto err; + } + } + + atl_ofi_print_coord(coord); + + enable_shm = attr->in.enable_shm; + if (enable_shm) { + prov_hints = fi_dupinfo(base_hints); + prov_hints->fabric_attr->prov_name = strdup(ATL_OFI_SHM_PROV_NAME); + ret = fi_getinfo(fi_version, nullptr, nullptr, 0ULL, prov_hints, &prov_list); + if (ret || !prov_list) { + enable_shm = 0; + LOG_INFO("shm provider is requested but not available"); + } + else { + LOG_INFO("shm provider is requested and available"); + } + + fi_freeinfo(prov_list); + prov_list = nullptr; + + fi_freeinfo(prov_hints); + prov_hints = nullptr; + } + + ofi_ctx->prov_count = 0; + ofi_ctx->nw_prov_count = 0; + ofi_ctx->shm_prov_idx = 0; + ofi_ctx->nw_prov_first_idx = (enable_shm) ? 1 : 0; + ofi_ctx->mnic_type = attr->in.mnic_type; + ATL_CALL(atl_ofi_parse_mnic_name(ctx, attr->in.mnic_name), goto err); + ofi_ctx->mnic_count = std::min(attr->in.mnic_count, (size_t)(ATL_OFI_MAX_NW_PROV_COUNT)); + ofi_ctx->mnic_count = std::min(ofi_ctx->mnic_count, attr->in.ep_count); + ofi_ctx->mnic_count = std::max(ofi_ctx->mnic_count, (size_t)(1)); + + if ((ofi_ctx->mnic_type != ATL_MNIC_NONE) && + !ccl::global_data::get().hwloc_wrapper->is_initialized()) { + ofi_ctx->mnic_type = ATL_MNIC_NONE; + LOG_WARN("hwloc is not initialized, disable multi-nic") + } + + if (ofi_ctx->mnic_type == ATL_MNIC_NONE) + ofi_ctx->mnic_count = 1; + + attr->out.tag_bits = 64; + attr->out.max_tag = 0xFFFFFFFFFFFFFFFF; + + /* open SHM provider */ + if (enable_shm) { + prov_idx = ofi_ctx->shm_prov_idx; + prov_name = strdup(ATL_OFI_SHM_PROV_NAME); + prov = &ofi_ctx->provs[prov_idx]; + prov->idx = prov_idx; + prov->is_shm = 1; + ATL_CALL(atl_ofi_get_prov_list(ctx, prov_name, base_hints, &prov_list), goto err); + ATL_CALL(atl_ofi_prov_init(ctx, prov_list, prov, attr, pmi), goto err); + free(prov_name); + fi_freeinfo(prov_list); + ofi_ctx->prov_count++; + } + + /* open NW provider(s) */ + if (prov_env && !strcmp(prov_env, ATL_OFI_SHM_PROV_NAME) && enable_shm) { + open_nw_provs = 0; + } + + if (open_nw_provs) { + ATL_CALL(atl_ofi_open_nw_provs(ctx, base_hints, attr, pmi), goto err); + ofi_ctx->mnic_count = ofi_ctx->nw_prov_count; + } + + for (ep_idx = 0; ep_idx < ctx->ep_count; ep_idx++) { + atl_ofi_ep_t* ofi_ep; + ofi_ep = (atl_ofi_ep_t*)calloc(1, sizeof(atl_ofi_ep_t)); + if (!ofi_ep) { + LOG_ERROR("can't alloc ofi_ep, idx ", ep_idx); + goto err; + } + + atl_ep_t* ep; + ep = &(ofi_ep->ep); + ep->idx = ep_idx; + ep->ctx = ctx; + + ofi_ep->active_prov_count = 0; + if (enable_shm) { + ofi_ep->active_prov_idxs[ofi_ep->active_prov_count] = ofi_ctx->shm_prov_idx; + ofi_ep->active_prov_count++; + } + if (open_nw_provs) { + ofi_ep->active_prov_idxs[ofi_ep->active_prov_count] = + ofi_ctx->nw_prov_first_idx + ep_idx % ofi_ctx->nw_prov_count; + ofi_ep->active_prov_count++; + } + CCL_THROW_IF_NOT(ofi_ep->active_prov_count, "no active providers for ep_idx ", ep_idx); + + if (coord->global_idx == 0) { + std::stringstream ss; + for (idx = 0; idx < ofi_ep->active_prov_count; idx++) { + ss << ofi_ep->active_prov_idxs[idx] << " "; + } + LOG_INFO("ep_idx: ", ep_idx, ", active_prov_idxs: ", ss.str()); + } + + ctx->eps[ep_idx] = ep; + } + + pmi->pmrt_barrier(); + + max_retry_count_env = getenv(ATL_OFI_MAX_RETRY_COUNT_ENV); + if (max_retry_count_env) { + ofi_ctx->max_retry_count = safe_c_strtol(max_retry_count_env, nullptr, 10); + } + else { + ofi_ctx->max_retry_count = ATL_OFI_MAX_RETRY_COUNT; + } + + if ((coord->global_count == coord->local_count) && (coord->global_count <= 4)) { + ofi_ctx->progress_mode = ATL_PROGRESS_CHECK; + } + else { + ofi_ctx->progress_mode = ATL_PROGRESS_POLL; + } + + progress_mode_env = getenv(ATL_PROGRESS_MODE_ENV); + if (progress_mode_env) { + ofi_ctx->progress_mode = static_cast(atoi(progress_mode_env)); + } + + if (coord->global_idx == 0) { + LOG_INFO("atl-ofi-ctx:"); + LOG_INFO(" new ctx_count: ", global_data.ctx_count); + LOG_INFO(" prov_count: ", ofi_ctx->prov_count); + LOG_INFO(" nw_prov_count: ", ofi_ctx->nw_prov_count); + LOG_INFO(" nw_prov_first_idx: ", ofi_ctx->nw_prov_first_idx); + LOG_INFO(" mnic_type: ", ofi_ctx->mnic_type); + LOG_INFO(" mnic_include_names: ", vec_to_string(ofi_ctx->mnic_include_names)); + LOG_INFO(" mnic_exclude_names: ", vec_to_string(ofi_ctx->mnic_exclude_names)); + LOG_INFO(" mnic_count: ", ofi_ctx->mnic_count); + LOG_INFO(" max_retry_count: ", ofi_ctx->max_retry_count); + LOG_INFO(" progress_mode: ", ofi_ctx->progress_mode); +#ifdef CCL_ENABLE_OFI_HMEM + LOG_INFO(" hmem: ", ofi_ctx->enable_hmem); +#endif // CCL_ENABLE_OFI_HMEM + } + + fi_freeinfo(base_hints); + base_hints = nullptr; + + /* report actual attributes back to upper level */ + attr->out.enable_shm = enable_shm; + attr->out.enable_rma = 0; + attr->out.enable_hmem = ofi_ctx->enable_hmem; + attr->out.mnic_type = ofi_ctx->mnic_type; + attr->out.mnic_count = ofi_ctx->mnic_count; + attr->out.max_order_waw_size = 0; + + return ATL_STATUS_SUCCESS; + +err: + LOG_ERROR("can't find suitable provider"); + + if (prov_list) { + fi_freeinfo(prov_list); + } + + if (base_hints) { + fi_freeinfo(base_hints); + } + + if (prov_hints) { + fi_freeinfo(prov_hints); + } + + if (ctx != nullptr) + atl_finalize(); + + return ATL_STATUS_FAILURE; } atl_status_t atl_ofi::atl_finalize() { is_finalized = true; - return atl_ofi_finalize(ctx); + int ret = 0; + size_t idx; + + atl_ofi_ctx_t* ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); + + global_data.ctx_count--; + if (ctx->coord.global_idx == 0) { + LOG_INFO("finalize atl-ofi ctx, remaining ctx_count ", global_data.ctx_count); + } + + cache.clear(); + + for (idx = 0; idx < ofi_ctx->prov_count; idx++) { + atl_ofi_prov_t* prov = &ofi_ctx->provs[idx]; + atl_ofi_prov_destroy(ctx, prov); + } + + for (idx = 0; idx < ctx->ep_count; idx++) { + atl_ofi_ep_t* ofi_ep = container_of(ctx->eps[idx], atl_ofi_ep_t, ep); + free(ofi_ep); + } + + if (global_data.ctx_count == 0) { + if (global_data.dlhandle) { + dlclose(global_data.dlhandle); + } + + if (ctx->coord.global_idx == 0) { + LOG_INFO("finalized last atl-ofi ctx"); + } + } + + free(ctx->eps); + free(ofi_ctx); + + return RET2ATL(ret); } atl_status_t atl_ofi::atl_update(std::unique_ptr& pmi) { @@ -53,7 +532,7 @@ atl_status_t atl_ofi::atl_update(std::unique_ptr& pmi) { ctx->coord.global_count = pmi->get_size(); ctx->coord.global_idx = pmi->get_rank(); - ret = atl_ofi_get_local_proc_coord(ofi_ctx, pmi.get()); + ret = atl_ofi_get_local_proc_coord(ofi_ctx, pmi); if (ret) return RET2ATL(ret); @@ -71,7 +550,7 @@ atl_status_t atl_ofi::atl_update(std::unique_ptr& pmi) { atl_ofi_print_coord(coord); for (prov_idx = 0; prov_idx < ofi_ctx->prov_count; prov_idx++) { - ret = atl_ofi_prov_eps_connect(ofi_ctx, prov_idx, pmi.get()); + ret = atl_ofi_prov_eps_connect(ofi_ctx, prov_idx, pmi); if (ret) return RET2ATL(ret); } @@ -91,11 +570,47 @@ atl_proc_coord_t* atl_ofi::atl_get_proc_coord() { } atl_status_t atl_ofi::atl_mr_reg(const void* buf, size_t len, atl_mr_t** mr) { - return atl_ofi_mr_reg(ctx, buf, len, mr); + int ret; + atl_ofi_ctx_t* ofi_ctx; + ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); + atl_ofi_prov_t* prov = &(ofi_ctx->provs[0]); + + atl_ofi_mr_t* ofi_mr; + ofi_mr = (atl_ofi_mr_t*)calloc(1, sizeof(atl_ofi_mr_t)); + if (!ofi_mr) + return ATL_STATUS_FAILURE; + + ret = fi_mr_reg(prov->domain, + buf, + len, + FI_SEND | FI_RECV | FI_READ | FI_WRITE | FI_REMOTE_READ | FI_REMOTE_WRITE, + 0, + 0, + 0, + &ofi_mr->fi_mr, + nullptr); + if (ret) + goto mr_reg_err; + + ofi_mr->mr.buf = (void*)buf; + ofi_mr->mr.len = len; + ofi_mr->mr.remote_key = (uintptr_t)fi_mr_key(ofi_mr->fi_mr); + ofi_mr->mr.local_key = (uintptr_t)fi_mr_desc(ofi_mr->fi_mr); + + *mr = &ofi_mr->mr; + return ATL_STATUS_SUCCESS; + +mr_reg_err: + free(ofi_mr); + return ATL_STATUS_FAILURE; } atl_status_t atl_ofi::atl_mr_dereg(atl_mr_t* mr) { - return atl_ofi_mr_dereg(ctx, mr); + atl_ofi_mr_t* ofi_mr; + ofi_mr = container_of(mr, atl_ofi_mr_t, mr); + int ret = fi_close(&ofi_mr->fi_mr->fid); + free(ofi_mr); + return RET2ATL(ret); } atl_status_t atl_ofi::atl_ep_send(atl_ep_t* ep, @@ -104,7 +619,43 @@ atl_status_t atl_ofi::atl_ep_send(atl_ep_t* ep, int dst_proc_idx, uint64_t tag, atl_req_t* req) { - return atl_ofi_ep_send(ep, buf, len, dst_proc_idx, tag, req); + ssize_t ret; + + atl_ofi_prov_t* prov; + atl_ofi_prov_ep_t* prov_ep; + atl_ofi_req_t* ofi_req; + + prov = atl_ofi_get_prov(ep, dst_proc_idx, len); + prov_ep = &(prov->eps[ep->idx]); + ofi_req = ((atl_ofi_req_t*)req->internal); + + req->tag = tag; + req->remote_proc_idx = dst_proc_idx; + ofi_req->comp_state = ATL_OFI_COMP_POSTED; + + ofi_req->prov_ep = prov_ep; + ofi_req->fi_ep = prov_ep->tx; + + cache.get(ep->idx, prov->domain, const_cast(buf), len, &ofi_req->mr); + void* desc = (ofi_req->mr) ? fi_mr_desc(ofi_req->mr) : nullptr; + + struct iovec iov; + iov.iov_base = const_cast(buf); + iov.iov_len = len; + + struct fi_msg_tagged msg; + msg.desc = &desc; + msg.msg_iov = &iov; + msg.iov_count = 1; + msg.tag = tag; + msg.ignore = 0; + msg.addr = atl_ofi_get_addr(ep->ctx, prov, dst_proc_idx, ep->idx); + msg.context = &ofi_req->fi_ctx; + msg.data = 0; + + ATL_OFI_RETRY(fi_tsendmsg(prov_ep->tx, &msg, 0), ep, ret); + + return RET2ATL(ret); } atl_status_t atl_ofi::atl_ep_recv(atl_ep_t* ep, @@ -113,7 +664,43 @@ atl_status_t atl_ofi::atl_ep_recv(atl_ep_t* ep, int src_proc_idx, uint64_t tag, atl_req_t* req) { - return atl_ofi_ep_recv(ep, buf, len, src_proc_idx, tag, req); + ssize_t ret; + + atl_ofi_prov_t* prov; + atl_ofi_prov_ep_t* prov_ep; + atl_ofi_req_t* ofi_req; + + prov = atl_ofi_get_prov(ep, src_proc_idx, len); + prov_ep = &(prov->eps[ep->idx]); + ofi_req = ((atl_ofi_req_t*)req->internal); + + req->tag = tag; + req->remote_proc_idx = src_proc_idx; + ofi_req->comp_state = ATL_OFI_COMP_POSTED; + + ofi_req->prov_ep = prov_ep; + ofi_req->fi_ep = prov_ep->rx; + + cache.get(ep->idx, prov->domain, const_cast(buf), len, &ofi_req->mr); + void* desc = (ofi_req->mr) ? fi_mr_desc(ofi_req->mr) : nullptr; + + struct iovec iov; + iov.iov_base = buf; + iov.iov_len = len; + + struct fi_msg_tagged msg; + msg.desc = &desc; + msg.msg_iov = &iov; + msg.iov_count = 1; + msg.tag = tag; + msg.ignore = 0; + msg.addr = atl_ofi_get_addr(ep->ctx, prov, src_proc_idx, ep->idx); + msg.context = &ofi_req->fi_ctx; + msg.data = 0; + + ATL_OFI_RETRY(fi_trecvmsg(prov_ep->rx, &msg, 0), ep, ret); + + return RET2ATL(ret); } atl_status_t atl_ofi::atl_ep_probe(atl_ep_t* ep, @@ -121,7 +708,110 @@ atl_status_t atl_ofi::atl_ep_probe(atl_ep_t* ep, uint64_t tag, int* found, size_t* recv_len) { - return atl_ofi_ep_probe(ep, src_proc_idx, tag, found, recv_len); + CCL_THROW("unexpected path"); + + atl_status_t ret; + atl_ofi_req_t reqs[ATL_OFI_MAX_PROV_COUNT]; + struct fi_msg_tagged msgs[ATL_OFI_MAX_PROV_COUNT]; + int flag, len; + ssize_t ofi_ret; + size_t idx; + int do_poll; + + atl_ofi_ctx_t* ofi_ctx; + + ret = ATL_STATUS_SUCCESS; + flag = 0; + len = 0; + ofi_ret = FI_SUCCESS; + do_poll = 1; + + ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); + + for (idx = 0; idx < ofi_ctx->prov_count; idx++) { + atl_ofi_prov_t* prov; + atl_ofi_prov_ep_t* prov_ep; + atl_ofi_req_t* req; + struct fi_msg_tagged* msg; + + prov = &(ofi_ctx->provs[idx]); + prov_ep = &(prov->eps[ep->idx]); + req = &(reqs[idx]); + msg = &(msgs[idx]); + + if (prov->is_shm && + ((src_proc_idx < prov->first_proc_idx) || + (src_proc_idx >= (prov->first_proc_idx + ep->ctx->coord.local_count)))) { + req->prov_ep = nullptr; + continue; + } + + req->comp_state = ATL_OFI_COMP_PEEK_STARTED; + req->prov_ep = prov_ep; + req->fi_ep = prov_ep->rx; + + msg->msg_iov = nullptr; + msg->desc = nullptr; + msg->iov_count = 0; + msg->addr = atl_ofi_get_addr(ep->ctx, prov, src_proc_idx, ep->idx); + msg->tag = tag; + msg->ignore = 0; + msg->context = &(req->fi_ctx); + msg->data = 0; + + ATL_OFI_RETRY(fi_trecvmsg(prov_ep->rx, msg, FI_PEEK | FI_COMPLETION), ep, ofi_ret); + } + + do { + ret = atl_ep_poll(ep); + if (ret != ATL_STATUS_SUCCESS) + return ret; + + for (idx = 0; idx < ofi_ctx->prov_count; idx++) { + atl_ofi_req_t* req; + req = &(reqs[idx]); + + if (!req->prov_ep) + continue; + + if (req->comp_state != ATL_OFI_COMP_PEEK_STARTED) { + do_poll = 0; + + if (req->comp_state == ATL_OFI_COMP_PEEK_FOUND) { + flag = 1; + len = req->recv_len; + req->prov_ep = nullptr; + } + else if (req->comp_state == ATL_OFI_COMP_PEEK_NOT_FOUND) { + req->prov_ep = nullptr; + } + else { + CCL_THROW("unexpected completion state ", req->comp_state); + } + + break; + } + } + } while (do_poll); + + for (idx = 0; idx < ofi_ctx->prov_count; idx++) { + atl_ofi_req_t* req; + req = &(reqs[idx]); + + if (!req->prov_ep) + continue; + + if (fi_cancel(&req->fi_ep->fid, &req->fi_ctx) == 0) { + atl_ofi_wait_cancel_cq(req->prov_ep->cq); + } + } + + if (found) + *found = flag; + if (recv_len) + *recv_len = len; + + return RET2ATL(ofi_ret); } atl_status_t atl_ofi::atl_ep_allgatherv(atl_ep_t* ep, @@ -131,7 +821,7 @@ atl_status_t atl_ofi::atl_ep_allgatherv(atl_ep_t* ep, const int* recv_lens, const int* offsets, atl_req_t* req) { - return atl_ofi_ep_allgatherv(ep, send_buf, send_len, recv_buf, recv_lens, offsets, req); + return ATL_STATUS_UNSUPPORTED; } atl_status_t atl_ofi::atl_ep_allreduce(atl_ep_t* ep, @@ -141,7 +831,7 @@ atl_status_t atl_ofi::atl_ep_allreduce(atl_ep_t* ep, atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req) { - return atl_ofi_ep_allreduce(ep, send_buf, recv_buf, len, dtype, op, req); + return ATL_STATUS_UNSUPPORTED; } atl_status_t atl_ofi::atl_ep_alltoall(atl_ep_t* ep, @@ -149,7 +839,7 @@ atl_status_t atl_ofi::atl_ep_alltoall(atl_ep_t* ep, void* recv_buf, int len, atl_req_t* req) { - return atl_ofi_ep_alltoall(ep, send_buf, recv_buf, len, req); + return ATL_STATUS_UNSUPPORTED; } atl_status_t atl_ofi::atl_ep_alltoallv(atl_ep_t* ep, @@ -160,16 +850,15 @@ atl_status_t atl_ofi::atl_ep_alltoallv(atl_ep_t* ep, const int* recv_lens, const int* recv_offsets, atl_req_t* req) { - return atl_ofi_ep_alltoallv( - ep, send_buf, send_lens, send_offsets, recv_buf, recv_lens, recv_offsets, req); + return ATL_STATUS_UNSUPPORTED; } atl_status_t atl_ofi::atl_ep_barrier(atl_ep_t* ep, atl_req_t* req) { - return atl_ofi_ep_barrier(ep, req); + return ATL_STATUS_UNSUPPORTED; } atl_status_t atl_ofi::atl_ep_bcast(atl_ep_t* ep, void* buf, size_t len, int root, atl_req_t* req) { - return atl_ofi_ep_bcast(ep, buf, len, root, req); + return ATL_STATUS_UNSUPPORTED; } atl_status_t atl_ofi::atl_ep_reduce(atl_ep_t* ep, @@ -180,7 +869,7 @@ atl_status_t atl_ofi::atl_ep_reduce(atl_ep_t* ep, atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req) { - return atl_ofi_ep_reduce(ep, send_buf, recv_buf, len, root, dtype, op, req); + return ATL_STATUS_UNSUPPORTED; } atl_status_t atl_ofi::atl_ep_reduce_scatter(atl_ep_t* ep, @@ -190,7 +879,7 @@ atl_status_t atl_ofi::atl_ep_reduce_scatter(atl_ep_t* ep, atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req) { - return atl_ofi_ep_reduce_scatter(ep, send_buf, recv_buf, recv_len, dtype, op, req); + return ATL_STATUS_UNSUPPORTED; } atl_status_t atl_ofi::atl_ep_read(atl_ep_t* ep, @@ -201,7 +890,34 @@ atl_status_t atl_ofi::atl_ep_read(atl_ep_t* ep, uintptr_t remote_key, int dst_proc_idx, atl_req_t* req) { - return atl_ofi_ep_read(ep, buf, len, mr, addr, remote_key, dst_proc_idx, req); + ssize_t ret; + + atl_ofi_prov_t* prov; + atl_ofi_prov_ep_t* prov_ep; + atl_ofi_req_t* ofi_req; + + prov = atl_ofi_get_prov(ep, dst_proc_idx, len); + prov_ep = &(prov->eps[ep->idx]); + ofi_req = ((atl_ofi_req_t*)req->internal); + + req->tag = 0; + req->remote_proc_idx = dst_proc_idx; + ofi_req->comp_state = ATL_OFI_COMP_POSTED; + + ofi_req->prov_ep = prov_ep; + ofi_req->fi_ep = prov_ep->tx; + + ATL_OFI_RETRY(fi_read(prov_ep->tx, + buf, + len, + (void*)mr->local_key, + atl_ofi_get_addr(ep->ctx, prov, dst_proc_idx, ep->idx), + addr, + remote_key, + &ofi_req->fi_ctx), + ep, + ret); + return RET2ATL(ret); } atl_status_t atl_ofi::atl_ep_write(atl_ep_t* ep, @@ -212,29 +928,193 @@ atl_status_t atl_ofi::atl_ep_write(atl_ep_t* ep, uintptr_t remote_key, int dst_proc_idx, atl_req_t* req) { - return atl_ofi_ep_write(ep, buf, len, mr, addr, remote_key, dst_proc_idx, req); + ssize_t ret; + + atl_ofi_prov_t* prov; + atl_ofi_prov_ep_t* prov_ep; + atl_ofi_req_t* ofi_req; + + prov = atl_ofi_get_prov(ep, dst_proc_idx, len); + prov_ep = &(prov->eps[ep->idx]); + ofi_req = ((atl_ofi_req_t*)req->internal); + + req->tag = 0; + req->remote_proc_idx = dst_proc_idx; + ofi_req->comp_state = ATL_OFI_COMP_POSTED; + + ofi_req->prov_ep = prov_ep; + ofi_req->fi_ep = prov_ep->tx; + + ATL_OFI_RETRY(fi_write(prov_ep->tx, + buf, + len, + (void*)mr->local_key, + atl_ofi_get_addr(ep->ctx, prov, dst_proc_idx, ep->idx), + addr, + remote_key, + &ofi_req->fi_ctx), + ep, + ret); + return RET2ATL(ret); } atl_status_t atl_ofi::atl_ep_wait(atl_ep_t* ep, atl_req_t* req) { - return atl_ofi_ep_wait(ep, req); + atl_status_t ret; + atl_ofi_req_t* ofi_req; + + ret = ATL_STATUS_SUCCESS; + ofi_req = ((atl_ofi_req_t*)req->internal); + + while ((ofi_req->comp_state != ATL_OFI_COMP_COMPLETED) && + ((ret = atl_ep_poll(ep)) == ATL_STATUS_SUCCESS)) + ; + + return ret; } -atl_status_t atl_ofi::atl_ep_wait_all(atl_ep_t* ep, atl_req_t* req, size_t count) { - return atl_ofi_ep_wait_all(ep, req, count); +atl_status_t atl_ofi::atl_ep_wait_all(atl_ep_t* ep, atl_req_t* reqs, size_t count) { + size_t i; + atl_status_t ret; + + for (i = 0; i < count; i++) { + ret = atl_ep_wait(ep, &reqs[i]); + if (ret != ATL_STATUS_SUCCESS) + return ret; + } + + return ATL_STATUS_SUCCESS; } atl_status_t atl_ofi::atl_ep_cancel(atl_ep_t* ep, atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; + int ret; + atl_ofi_req_t* ofi_req; + + ret = ATL_STATUS_SUCCESS; + ofi_req = ((atl_ofi_req_t*)req->internal); + + ret = fi_cancel(&ofi_req->fi_ep->fid, &ofi_req->fi_ctx); + if (ret == 0) { + return RET2ATL(atl_ofi_wait_cancel_cq(ofi_req->prov_ep->cq)); + } + + return ATL_STATUS_SUCCESS; } atl_status_t atl_ofi::atl_ep_poll(atl_ep_t* ep) { - return atl_ofi_ep_poll(ep); + atl_ofi_ctx_t* ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); + if (ofi_ctx->progress_mode == ATL_PROGRESS_POLL) { + atl_ep_progress(ep); + } + return ATL_STATUS_SUCCESS; } atl_status_t atl_ofi::atl_ep_check(atl_ep_t* ep, int* is_completed, atl_req_t* req) { - return atl_ofi_ep_check(ep, is_completed, req); + CCL_THROW_IF_NOT(is_completed); + + atl_status_t status; + atl_ofi_req_t* ofi_req; + atl_ofi_ctx_t* ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); + + status = ATL_STATUS_SUCCESS; + ofi_req = ((atl_ofi_req_t*)req->internal); + + *is_completed = (ofi_req->comp_state == ATL_OFI_COMP_COMPLETED); + if (*is_completed) { + return ATL_STATUS_SUCCESS; + } + + if (ofi_ctx->progress_mode == ATL_PROGRESS_CHECK) { + status = atl_ep_progress(ep); + *is_completed = (ofi_req->comp_state == ATL_OFI_COMP_COMPLETED); + } + + return status; } + atl_ofi::~atl_ofi() { - if (!is_finalized) + if (!is_finalized) { atl_finalize(); + } +} + +atl_status_t atl_ofi::atl_ep_progress(atl_ep_t* ep) { + ssize_t ret; + size_t idx; + struct fi_cq_tagged_entry entries[ATL_OFI_CQ_BUNCH_SIZE]; + atl_ofi_ep_t* ofi_ep = container_of(ep, atl_ofi_ep_t, ep); + atl_ofi_ctx_t* ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); + size_t ep_idx = ep->idx; + + /* ensure progress for all active providers */ + for (idx = 0; idx < ofi_ep->active_prov_count; idx++) { + atl_ofi_prov_ep_t* prov_ep; + prov_ep = &(ofi_ctx->provs[ofi_ep->active_prov_idxs[idx]].eps[ep_idx]); + do { + ret = fi_cq_read(prov_ep->cq, entries, ATL_OFI_CQ_BUNCH_SIZE); + if (ret > 0) + atl_process_comps(ep, entries, ret); + else if (ret == -FI_EAGAIN) + break; + else + return atl_prov_ep_handle_cq_err(prov_ep); + } while (ret > 0); + } + + return ATL_STATUS_SUCCESS; +} + +atl_status_t atl_ofi::atl_prov_ep_handle_cq_err(atl_ofi_prov_ep_t* ep) { + struct fi_cq_err_entry err_entry; + atl_ofi_req_t* ofi_req; + + int ret = fi_cq_readerr(ep->cq, &err_entry, 0); + if (ret != 1) { + CCL_THROW("unable to read error from cq"); + return ATL_STATUS_FAILURE; + } + else { + ofi_req = container_of(err_entry.op_context, atl_ofi_req_t, fi_ctx); + + if (err_entry.err == FI_ECANCELED) { + return ATL_STATUS_SUCCESS; + } + + if (err_entry.err == FI_ENOMSG && ofi_req->comp_state == ATL_OFI_COMP_PEEK_STARTED) { + ofi_req->comp_state = ATL_OFI_COMP_PEEK_NOT_FOUND; + } + else { + LOG_ERROR("fi_cq_readerr: err: ", + err_entry.err, + ", prov_err: ", + fi_cq_strerror(ep->cq, err_entry.prov_errno, err_entry.err_data, nullptr, 0), + "(", + err_entry.prov_errno, + ")"); + return ATL_STATUS_FAILURE; + } + return ATL_STATUS_SUCCESS; + } +} + +void atl_ofi::atl_process_comps(atl_ep_t* ep, struct fi_cq_tagged_entry* entries, ssize_t ret) { + ssize_t idx; + atl_ofi_req_t* comp_ofi_req; + for (idx = 0; idx < ret; idx++) { + comp_ofi_req = container_of(entries[idx].op_context, atl_ofi_req_t, fi_ctx); + switch (comp_ofi_req->comp_state) { + case ATL_OFI_COMP_POSTED: + comp_ofi_req->comp_state = ATL_OFI_COMP_COMPLETED; + cache.push(ep->idx, comp_ofi_req->mr); + break; + case ATL_OFI_COMP_COMPLETED: break; + case ATL_OFI_COMP_PEEK_STARTED: + comp_ofi_req->comp_state = ATL_OFI_COMP_PEEK_FOUND; + break; + default: CCL_THROW("unexpected completion state ", comp_ofi_req->comp_state); break; + } + + if (entries[idx].flags & FI_RECV) { + comp_ofi_req->recv_len = entries[idx].len; + } + } } diff --git a/src/atl/ofi/atl_ofi.hpp b/src/atl/ofi/atl_ofi.hpp index 7bcac3c5e..a06a45648 100644 --- a/src/atl/ofi/atl_ofi.hpp +++ b/src/atl/ofi/atl_ofi.hpp @@ -15,8 +15,12 @@ */ #include #include +#include +#include #include "atl.h" +#include "atl_ofi_helper.hpp" +#include "common/utils/hash.hpp" class atl_ofi final : public iatl { public: @@ -153,7 +157,49 @@ class atl_ofi final : public iatl { } private: + atl_status_t atl_ep_progress(atl_ep_t* ep); + void atl_process_comps(atl_ep_t* ep, struct fi_cq_tagged_entry* entries, ssize_t ret); + atl_status_t atl_prov_ep_handle_cq_err(atl_ofi_prov_ep_t* ep); + atl_ctx_t* ctx = nullptr; + + class mr_cache { + public: + mr_cache() = default; + ~mr_cache(); + + void clear(); + void get(fid_domain* domain, void* buf, size_t bytes, fid_mr** mr); + void push(fid_mr* mr); + + private: + size_t mr_key = 0; + + using key_t = typename std::tuple; + using value_t = fid_mr*; + std::unordered_multimap cache{}; + }; + + class fi_cache { + public: + fi_cache() = default; + fi_cache(const fi_cache&) = delete; + fi_cache& operator=(const fi_cache&) = delete; + ~fi_cache(); + + void clear(); + + void init(size_t instance_count, int enable_hmem); + void get(size_t idx, fid_domain* domain, void* buf, size_t bytes, fid_mr** mr); + void push(size_t idx, fid_mr* mr); + + private: + int enable_hmem; + std::vector memory_regions; + }; + + fi_cache cache{}; + bool is_finalized{ false }; bool inited{ false }; }; diff --git a/src/atl/ofi/atl_ofi_helper.cpp b/src/atl/ofi/atl_ofi_helper.cpp new file mode 100644 index 000000000..7218da8d1 --- /dev/null +++ b/src/atl/ofi/atl_ofi_helper.cpp @@ -0,0 +1,1320 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "atl_ofi_helper.hpp" + +atl_ofi_global_data_t global_data; + +template +std::string vec_to_string(Container& elems) { + if (elems.empty()) { + return ""; + } + + size_t idx = 0; + std::ostringstream ss; + for (auto elem : elems) { + ss << elem; + idx++; + if (idx < elems.size()) { + ss << " "; + } + } + return ss.str(); +} + +#ifdef CCL_ENABLE_OFI_HMEM +void atl_ofi_init_ze_data() { + atl_ofi_ze_data& ze_data = global_data.ze_data; + + ZE_CALL(zeInit, (ZE_INIT_FLAG_GPU_ONLY)); + + uint32_t count = 1; + ZE_CALL(zeDriverGet, (&count, &ze_data.driver)); + + ze_data.device_count = 0; + ZE_CALL(zeDeviceGet, (ze_data.driver, &ze_data.device_count, nullptr)); + ZE_CALL(zeDeviceGet, (ze_data.driver, &ze_data.device_count, ze_data.devices)); + ZE_CALL(zeContextCreate, (ze_data.driver, &ccl::ze::default_context_desc, &ze_data.context)); + + CCL_THROW_IF_NOT(ze_data.driver, "null ze driver"); + CCL_THROW_IF_NOT(ze_data.context, "null ze context"); + CCL_THROW_IF_NOT( + ze_data.device_count > 0, "unexpected ze device count: ", ze_data.device_count); + + LOG_DEBUG("ze device count: ", ze_data.device_count); +} +#endif // CCL_ENABLE_OFI_HMEM + +void atl_ofi_print_coord(atl_proc_coord_t* coord) { + LOG_DEBUG("coord: global [idx ", + coord->global_idx, + ", cnt ", + coord->global_count, + "], local [idx ", + coord->local_idx, + ", cnt ", + coord->local_count, + "]"); +} + +std::string atl_ofi_get_short_nic_name(const struct fi_info* prov) { + std::stringstream ss; + ss << prov->domain_attr->name; + return ss.str(); +} + +std::string atl_ofi_get_nic_name(const struct fi_info* prov) { + std::stringstream ss; + ss << prov->fabric_attr->prov_name << ":"; + // ss << prov->fabric_attr->name << ":"; + ss << atl_ofi_get_short_nic_name(prov); + return ss.str(); +} + +atl_ofi_prov_t* atl_ofi_get_prov(atl_ep_t* ep, int peer_proc_idx, size_t msg_size) { + size_t prov_idx; + atl_ofi_ctx_t* ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); + + CCL_THROW_IF_NOT(ofi_ctx->prov_count <= ATL_OFI_MAX_PROV_COUNT, + "unexpected prov_count ", + ofi_ctx->prov_count); + + atl_proc_coord_t* coord = &(ep->ctx->coord); + int my_node_idx = coord->global_idx / coord->local_count; + int peer_node_idx = peer_proc_idx / coord->local_count; + + int has_shm = (ofi_ctx->prov_count == ofi_ctx->nw_prov_count + 1) ? 1 : 0; + + if (has_shm && (my_node_idx == peer_node_idx) && + (msg_size <= ofi_ctx->provs[ofi_ctx->shm_prov_idx].max_msg_size)) { + prov_idx = ofi_ctx->shm_prov_idx; + } + else { + size_t nw_prov_offset = ep->idx % ofi_ctx->nw_prov_count; + prov_idx = ofi_ctx->nw_prov_first_idx + nw_prov_offset; + } + + LOG_DEBUG("get_prov: ep_idx ", + ep->idx, + ", prov_idx ", + prov_idx, + ", my_node_idx ", + my_node_idx, + ", peer_node_idx ", + peer_node_idx, + ", msg_size ", + msg_size, + ", has_shm ", + has_shm); + + /* TODO: add segmentation logic */ + CCL_THROW_IF_NOT(msg_size <= ofi_ctx->provs[prov_idx].max_msg_size, + "msg_size (", + msg_size, + ") is greater than max_msg_size (", + ofi_ctx->provs[prov_idx].max_msg_size, + "), prov_idx ", + prov_idx); + + return &(ofi_ctx->provs[prov_idx]); +} + +fi_addr_t atl_ofi_get_addr(atl_ctx_t* ctx, atl_ofi_prov_t* prov, int proc_idx, size_t ep_idx) { + return *(prov->addr_table + ((ctx->ep_count * (proc_idx - prov->first_proc_idx)) + ep_idx)); +} + +atl_status_t atl_ofi_get_local_proc_coord(atl_ofi_ctx_t* ofi_ctx, std::unique_ptr& pmi) { + CCL_THROW_IF_NOT(ofi_ctx, "ofi_ctx is null"); + + atl_proc_coord_t* coord = &(ofi_ctx->ctx.coord); + + atl_status_t ret = ATL_STATUS_SUCCESS; + int i; + int local_idx = 0, local_count = 0; + char* all_hostnames = nullptr; + char my_hostname[ATL_MAX_HOSTNAME_LEN] = { 0 }; + size_t my_hostname_len = 0; + int my_global_proc_idx = coord->global_idx; + + gethostname(my_hostname, ATL_MAX_HOSTNAME_LEN - 1); + my_hostname_len = strlen(my_hostname); + coord->hostname_hash = std::hash{}(my_hostname); + + CCL_THROW_IF_NOT(my_hostname_len < ATL_MAX_HOSTNAME_LEN, + "unexpected my_hostname_len ", + my_hostname_len, + ", expected max ", + (size_t)(ATL_MAX_HOSTNAME_LEN)); + + if (ATL_MAX_HOSTNAME_LEN - my_hostname_len <= 10) { + LOG_WARN("hostname is quite long, len: ", my_hostname_len, ", name: ", my_hostname); + } + + snprintf(my_hostname + my_hostname_len, + ATL_MAX_HOSTNAME_LEN - my_hostname_len, + "-%d", + my_global_proc_idx); + + ret = pmi->pmrt_kvs_put((char*)ATL_OFI_HOSTNAME_PM_KEY, + my_global_proc_idx * ATL_OFI_PMI_PROC_MULTIPLIER, + my_hostname, + ATL_MAX_HOSTNAME_LEN); + + if (ret) { + LOG_ERROR("pmrt_kvs_put: ret: ", ret); + goto fn_err; + } + + pmi->pmrt_barrier(); + + all_hostnames = (char*)calloc(1, coord->global_count * ATL_MAX_HOSTNAME_LEN); + if (!all_hostnames) { + LOG_ERROR("can't allocate all_hostnames"); + goto fn_err; + } + + for (i = 0; i < coord->global_count; i++) { + ret = pmi->pmrt_kvs_get((char*)ATL_OFI_HOSTNAME_PM_KEY, + i * ATL_OFI_PMI_PROC_MULTIPLIER, + all_hostnames + i * ATL_MAX_HOSTNAME_LEN, + ATL_MAX_HOSTNAME_LEN); + if (ret) { + LOG_ERROR("pmrt_kvs_get: ret: ", ret); + goto fn_err; + } + } + + for (i = 0; i < coord->global_count; i++) { + if (!strncmp(my_hostname, + all_hostnames + i * ATL_MAX_HOSTNAME_LEN, + my_hostname_len + 1 /* including "-" at the end */)) { + local_count++; + int peer_global_proc_idx; + sscanf(all_hostnames + i * ATL_MAX_HOSTNAME_LEN + my_hostname_len + 1, + "%d", + &peer_global_proc_idx); + if (my_global_proc_idx > peer_global_proc_idx) + local_idx++; + } + } + + coord->local_idx = local_idx; + coord->local_count = local_count; + +fn_exit: + free(all_hostnames); + return ret; + +fn_err: + ret = ATL_STATUS_FAILURE; + goto fn_exit; +} + +atl_status_t atl_ofi_prov_update_addr_table(atl_ofi_ctx_t* ofi_ctx, + size_t prov_idx, + std::unique_ptr& pmi) { + CCL_THROW_IF_NOT(ofi_ctx, "ofi_ctx is null"); + + atl_ctx_t* ctx = &(ofi_ctx->ctx); + atl_ofi_prov_t* prov = &(ofi_ctx->provs[prov_idx]); + + atl_status_t ret = ATL_STATUS_SUCCESS; + int i; + size_t j; + int insert_count; + + size_t addr_idx = 0; + char* ep_names_table; + size_t ep_names_table_len; + + size_t named_ep_count = (prov->sep ? 1 : ctx->ep_count); + + int local_count = ctx->coord.local_count; + int node_idx = ctx->coord.global_idx / local_count; + int shm_start_idx = node_idx * local_count; + int shm_end_idx = (node_idx + 1) * local_count; + + LOG_DEBUG("shm_start_idx ", shm_start_idx, ", shm_end_idx ", shm_end_idx); + + int proc_count = prov->is_shm ? ctx->coord.local_count : ctx->coord.global_count; + + if (proc_count == 0) + return ATL_STATUS_SUCCESS; + + LOG_DEBUG("name ", + atl_ofi_get_nic_name(prov->info), + ", is_shm ", + prov->is_shm, + ", addr_len ", + prov->addr_len, + ", local_count ", + ctx->coord.local_count, + ", global_count ", + ctx->coord.global_count, + ", proc_count ", + proc_count); + + /* allocate OFI EP names table that will contain all published names */ + ep_names_table_len = prov->addr_len * named_ep_count * proc_count; + + if (ep_names_table_len == 0) { + LOG_ERROR("ep_names_table_len == 0, addr_len ", + prov->addr_len, + ", named_ep_count ", + named_ep_count, + ", proc_count ", + proc_count); + return ATL_STATUS_FAILURE; + } + + ep_names_table = (char*)calloc(1, ep_names_table_len); + if (!ep_names_table) { + LOG_ERROR("can't allocate epnames_table"); + return ATL_STATUS_FAILURE; + } + + pmi->pmrt_barrier(); + + /* retrieve all OFI EP names in order */ + for (i = 0; i < ctx->coord.global_count; i++) { + if (prov->is_shm) { + if (!(i >= shm_start_idx && i < shm_end_idx)) { + continue; + } + } + + for (j = 0; j < named_ep_count; j++) { + ret = pmi->pmrt_kvs_get( + (char*)ATL_OFI_FI_ADDR_PM_KEY, + i * ATL_OFI_PMI_PROC_MULTIPLIER + prov_idx * ATL_OFI_PMI_PROV_MULTIPLIER + j, + ep_names_table + addr_idx * prov->addr_len, + prov->addr_len); + + if (ret) { + LOG_ERROR("kvs_get error: ret ", + ret, + ", proc_idx ", + i, + ", ep_idx ", + j, + ", addr_idx ", + addr_idx); + goto err_ep_names; + } + + addr_idx++; + } + } + + LOG_DEBUG( + "kvs_get: ep_count ", named_ep_count, ", proc_count ", proc_count, ", got ", addr_idx); + + if (addr_idx != named_ep_count * proc_count) { + LOG_ERROR("unexpected kvs_get results: expected ", + named_ep_count * proc_count, + ", got ", + addr_idx); + + ret = ATL_STATUS_FAILURE; + goto err_addr_table; + } + + if (prov->addr_table != nullptr) + free(prov->addr_table); + + prov->addr_table = (fi_addr_t*)calloc(1, ctx->ep_count * proc_count * sizeof(fi_addr_t)); + + if (!prov->addr_table) + goto err_ep_names; + + /* insert all the EP names into the AV */ + insert_count = fi_av_insert( + prov->av, ep_names_table, named_ep_count * proc_count, prov->addr_table, 0, nullptr); + + LOG_DEBUG("av_insert: ep_count ", + named_ep_count, + ", proc_count ", + proc_count, + ", inserted ", + insert_count); + + if (insert_count != (int)(named_ep_count * proc_count)) { + LOG_ERROR("unexpected av_insert results: expected ", + named_ep_count * proc_count, + " got ", + insert_count); + ret = ATL_STATUS_FAILURE; + goto err_addr_table; + } + else { + ret = ATL_STATUS_SUCCESS; + } + + if (prov->sep) { + if (named_ep_count != 1) { + LOG_ERROR("unexpected named_ep_count ", named_ep_count); + goto err_addr_table; + } + + fi_addr_t* table; + table = (fi_addr_t*)calloc(1, proc_count * sizeof(fi_addr_t)); + if (table == nullptr) { + LOG_ERROR("memory allocaion failed"); + ret = ATL_STATUS_FAILURE; + goto err_addr_table; + } + memcpy(table, prov->addr_table, proc_count * sizeof(fi_addr_t)); + + for (i = 0; i < proc_count; i++) { + for (j = 0; j < ctx->ep_count; j++) { + prov->addr_table[i * ctx->ep_count + j] = + fi_rx_addr(table[i], j, prov->rx_ctx_bits); + } + } + free(table); + } + + /* normal end of execution */ + free(ep_names_table); + return ret; + + /* abnormal end of execution */ +err_addr_table: + free(prov->addr_table); + +err_ep_names: + free(ep_names_table); + return ret; +} + +atl_status_t atl_ofi_prov_ep_get_name(atl_ofi_prov_t* prov, size_t ep_idx) { + int ret; + + atl_ofi_prov_ep_t* ep = &(prov->eps[ep_idx]); + struct fid_ep* fi_ep = (prov->sep) ? prov->sep : ep->tx; + + ep->name.addr = nullptr; + ep->name.len = 0; + + ret = fi_getname(&fi_ep->fid, ep->name.addr, &(ep->name.len)); + if ((ret != -FI_ETOOSMALL) || ep->name.len <= 0) + ep->name.len = FI_NAME_MAX; + + if (ep->name.addr) + free(ep->name.addr); + + ep->name.addr = calloc(1, ep->name.len); + + if (!(ep->name.addr)) { + LOG_ERROR("can't allocate addr"); + ret = ATL_STATUS_FAILURE; + goto err_addr; + } + + ret = fi_getname(&fi_ep->fid, ep->name.addr, &(ep->name.len)); + if (ret) { + LOG_ERROR("fi_getname error"); + goto err_getname; + } + + prov->addr_len = MAX(prov->addr_len, ep->name.len); + + return ATL_STATUS_SUCCESS; + +err_getname: + free(ep->name.addr); + ep->name.addr = nullptr; + ep->name.len = 0; + +err_addr: + return RET2ATL(ret); +} + +atl_status_t atl_ofi_prov_eps_connect(atl_ofi_ctx_t* ofi_ctx, + size_t prov_idx, + std::unique_ptr& pmi) { + int ret; + size_t ep_idx; + + atl_ctx_t* ctx = &(ofi_ctx->ctx); + atl_ofi_prov_t* prov = &(ofi_ctx->provs[prov_idx]); + size_t named_ep_count = (prov->sep ? 1 : ctx->ep_count); + atl_proc_coord_t* coord = &(ctx->coord); + + prov->addr_len = 0; + prov->first_proc_idx = + (prov->is_shm) ? ((coord->global_idx / coord->local_count) * coord->local_count) : 0; + + for (ep_idx = 0; ep_idx < ctx->ep_count; ep_idx++) { + ret = atl_ofi_prov_ep_get_name(prov, ep_idx); + if (ret) { + LOG_ERROR("atl_ofi_prov_ep_get_name error"); + return ATL_STATUS_FAILURE; + } + } + + for (ep_idx = 0; ep_idx < named_ep_count; ep_idx++) { + atl_ofi_prov_ep_t* ep = &(prov->eps[ep_idx]); + ret = pmi->pmrt_kvs_put((char*)ATL_OFI_FI_ADDR_PM_KEY, + coord->global_idx * ATL_OFI_PMI_PROC_MULTIPLIER + + prov_idx * ATL_OFI_PMI_PROV_MULTIPLIER + ep_idx, + ep->name.addr, + ep->name.len); + if (ret) { + LOG_ERROR("pmrt_kvs_put: ret: ", ret); + return ATL_STATUS_FAILURE; + } + } + + ret = atl_ofi_prov_update_addr_table(ofi_ctx, prov_idx, pmi); + + return RET2ATL(ret); +} + +void atl_ofi_prov_ep_destroy(atl_ofi_prov_t* prov, atl_ofi_prov_ep_t* ep) { + if (ep->rx) + fi_close(&ep->rx->fid); + + if (prov->sep && ep->tx) + fi_close(&ep->tx->fid); + + if (ep->cq) + fi_close(&ep->cq->fid); + + if (ep->name.addr) + free(ep->name.addr); + + ep->rx = ep->tx = nullptr; + ep->cq = nullptr; + ep->name.addr = nullptr; + ep->name.len = 0; +} + +void atl_ofi_prov_destroy(atl_ctx_t* ctx, atl_ofi_prov_t* prov) { + size_t i; + + for (i = 0; i < ctx->ep_count; i++) { + atl_ofi_prov_ep_destroy(prov, &(prov->eps[i])); + } + + free(prov->eps); + free(prov->addr_table); + + if (prov->sep) + fi_close(&prov->sep->fid); + + if (prov->av) + fi_close(&prov->av->fid); + + if (prov->domain) + fi_close(&prov->domain->fid); + + if (prov->fabric) + fi_close(&prov->fabric->fid); + + if (prov->info) { + fi_freeinfo(prov->info); + } +} + +int atl_ofi_wait_cancel_cq(struct fid_cq* cq) { + struct fi_cq_err_entry err_entry; + int ret, i; + struct fi_cq_tagged_entry entries[ATL_OFI_CQ_BUNCH_SIZE]; + + double time = 0; + clock_t start, end; + + while (time < ATL_OFI_WAIT_SEC) { + for (i = 0; i < ATL_OFI_CQ_READ_ITERS; i++) { + start = clock(); + ret = fi_cq_read(cq, entries, ATL_OFI_CQ_BUNCH_SIZE); + + if (ret < 0 && ret != -FI_EAGAIN) { + ret = fi_cq_readerr(cq, &err_entry, 0); + + if (err_entry.err != FI_ECANCELED) { + LOG_ERROR( + "fi_cq_readerr: err: ", + err_entry.err, + ", prov_err: ", + fi_cq_strerror(cq, err_entry.prov_errno, err_entry.err_data, nullptr, 0), + "(", + err_entry.prov_errno, + ")"); + return ATL_STATUS_FAILURE; + } + return ATL_STATUS_SUCCESS; + } + } + end = clock(); + time += (double)(end - start) / CLOCKS_PER_SEC; + } + + LOG_ERROR("too long for cancel"); + + return ATL_STATUS_FAILURE; +} + +atl_status_t atl_ofi_prov_ep_init(atl_ofi_prov_t* prov, size_t ep_idx) { + ssize_t ret = 0; + + struct fi_cq_attr cq_attr; + struct fi_tx_attr tx_attr; + struct fi_rx_attr rx_attr; + + atl_ofi_prov_ep_t* ep = &(prov->eps[ep_idx]); + + memset(&cq_attr, 0, sizeof(cq_attr)); + cq_attr.format = FI_CQ_FORMAT_TAGGED; + + ATL_OFI_CALL( + fi_cq_open(prov->domain, &cq_attr, &ep->cq, nullptr), ret, return ATL_STATUS_FAILURE); + + if (prov->sep) { + rx_attr = *prov->info->rx_attr; + rx_attr.caps |= FI_TAGGED; + + ATL_OFI_CALL(fi_rx_context(prov->sep, ep_idx, &rx_attr, &ep->rx, nullptr), ret, goto err); + + ATL_OFI_CALL(fi_ep_bind(ep->rx, &ep->cq->fid, FI_RECV), ret, goto err); + + tx_attr = *prov->info->tx_attr; + tx_attr.caps |= FI_TAGGED; + + ATL_OFI_CALL(fi_tx_context(prov->sep, ep_idx, &tx_attr, &ep->tx, nullptr), ret, goto err); + + ATL_OFI_CALL(fi_ep_bind(ep->tx, &ep->cq->fid, FI_SEND), ret, goto err); + + fi_enable(ep->rx); + fi_enable(ep->tx); + } + else { + struct fid_ep* endpoint; + + ATL_OFI_CALL(fi_endpoint(prov->domain, prov->info, &endpoint, nullptr), ret, goto err); + + ep->tx = ep->rx = endpoint; + + ATL_OFI_CALL(fi_ep_bind(endpoint, &ep->cq->fid, FI_SEND | FI_RECV), ret, goto err); + + ATL_OFI_CALL(fi_ep_bind(endpoint, &prov->av->fid, 0), ret, goto err); + + fi_enable(endpoint); + } + + return ATL_STATUS_SUCCESS; + +err: + atl_ofi_prov_ep_destroy(prov, ep); + return ATL_STATUS_FAILURE; +} + +atl_status_t atl_ofi_try_to_drain_cq_err(struct fid_cq* cq) { + struct fi_cq_err_entry err_entry; + int ret = fi_cq_readerr(cq, &err_entry, 0); + if (ret != 1) { + LOG_DEBUG("unable to fi_cq_readerr"); + return ATL_STATUS_FAILURE; + } + else { + if (err_entry.err != FI_ENOMSG && err_entry.err != FI_ECANCELED && + err_entry.err != FI_ETRUNC) { + LOG_ERROR("fi_cq_readerr: err: ", + err_entry.err, + ", prov_err: ", + fi_cq_strerror(cq, err_entry.prov_errno, err_entry.err_data, nullptr, 0), + "(", + err_entry.prov_errno, + ")"); + return ATL_STATUS_FAILURE; + } + return ATL_STATUS_SUCCESS; + } +} + +int atl_ofi_try_to_drain_cq(struct fid_cq* cq) { + int ret = -FI_EAGAIN, i; + double time = 0; + clock_t start, end; + struct fi_cq_tagged_entry entries[ATL_OFI_CQ_BUNCH_SIZE]; + + while (time < ATL_OFI_WAIT_SEC) { + start = clock(); + for (i = 0; i < ATL_OFI_CQ_READ_ITERS; i++) { + ret = fi_cq_read(cq, entries, ATL_OFI_CQ_BUNCH_SIZE); + + if (ret < 0 && ret != -FI_EAGAIN) { + atl_ofi_try_to_drain_cq_err(cq); + return ret; + } + + if (ret > 0) + return ret; + } + end = clock(); + time += (double)(end - start) / CLOCKS_PER_SEC; + } + + return ret; +} + +void atl_ofi_reset(atl_ctx_t* ctx) { + atl_ofi_ctx_t* ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); + + int again = 1; + size_t prov_idx, ep_idx; + int recv_buf_len = sizeof(char); + char* recv_buf; + struct fi_context fi_ctx; + recv_buf = (char*)malloc(recv_buf_len); + for (prov_idx = 0; prov_idx < ofi_ctx->prov_count; prov_idx++) { + atl_ofi_prov_t* prov = &(ofi_ctx->provs[prov_idx]); + + for (ep_idx = 0; ep_idx < ctx->ep_count; ep_idx++) { + atl_ofi_prov_ep_t* ep = &(prov->eps[ep_idx]); + + /* complete active sends and receives */ + while (atl_ofi_try_to_drain_cq(ep->cq) != -FI_EAGAIN) { + } + + /* try to complete active incoming sends */ + while (again) { + again = 0; + /* post recv to complete incoming send */ + while (fi_trecv(ep->rx, + recv_buf, + recv_buf_len, + nullptr, + FI_ADDR_UNSPEC, + 0, + UINTMAX_MAX, + &fi_ctx) == -FI_EAGAIN) { + } + + /* wait until recv will be completed or finished by timeout */ + while (atl_ofi_try_to_drain_cq(ep->cq) != -FI_EAGAIN) { + /* something is completed -> send queue not empty */ + again = 1; + } + } + + /* nothing to recv -> cancel last recv */ + fi_cancel(&ep->rx->fid, &fi_ctx); + + atl_ofi_wait_cancel_cq(ep->cq); + } + } + + free(recv_buf); +} + +atl_status_t atl_ofi_adjust_env(const atl_attr_t& attr) { + char* prov_env = getenv("FI_PROVIDER"); + + if (prov_env && strlen(prov_env)) { + CCL_THROW_IF_NOT(strlen(prov_env) < sizeof(global_data.prov_env_copy), + "too long FI_PROVIDER value, max expected length ", + sizeof(global_data.prov_env_copy)); + memcpy(global_data.prov_env_copy, prov_env, strlen(prov_env)); + } + + if (attr.in.enable_shm) { + /* add shm provider in the list of allowed providers */ + if (prov_env && !strstr(prov_env, ATL_OFI_SHM_PROV_NAME)) { + /* whether single provider will be in the final env variable */ + int single_prov = (strlen(prov_env) == 0) ? 1 : 0; + + size_t prov_env_new_size = strlen(prov_env) + strlen(ATL_OFI_SHM_PROV_NAME) + + (single_prov ? 0 : 1) + /* for delimeter */ + 1; /* for terminating null symbol */ + + char* prov_env_new = (char*)calloc(prov_env_new_size, sizeof(char)); + if (prov_env_new == nullptr) { + LOG_ERROR("memory allocaion failed"); + return ATL_STATUS_FAILURE; + } + + if (single_prov) + snprintf(prov_env_new, prov_env_new_size, "%s", ATL_OFI_SHM_PROV_NAME); + else { + snprintf(prov_env_new, prov_env_new_size, "%s,%s", prov_env, ATL_OFI_SHM_PROV_NAME); + } + + LOG_INFO("atl-ofi-shm is requested, modify FI_PROVIDER: old value: ", + prov_env, + ", new value: ", + prov_env_new); + + setenv("FI_PROVIDER", prov_env_new, 1); + + free(prov_env_new); + } + } + + return ATL_STATUS_SUCCESS; +} + +atl_status_t atl_ofi_set_env(const atl_attr_t& attr) { + if (global_data.is_env_inited) { + return ATL_STATUS_SUCCESS; + } + + setenv("FI_PSM2_DELAY", "0", 0); + setenv("FI_PSM2_TIMEOUT", "0", 0); + setenv("FI_PSM2_LOCK_LEVEL", "1", 0); + setenv("FI_PSM2_NAME_SERVER", "0", 0); + setenv("HFI_NO_CPUAFFINITY", "1", 0); + setenv("PSM2_MULTI_EP", "1", 0); + + setenv("FI_PSM3_DELAY", "0", 0); + setenv("FI_PSM3_TIMEOUT", "0", 0); + setenv("FI_PSM3_LOCK_LEVEL", "1", 0); + setenv("FI_PSM3_NAME_SERVER", "0", 0); + setenv("PSM3_NO_CPUAFFINITY", "1", 0); + setenv("PSM3_RDMA", "2", 0); + setenv("PSM3_MR_CACHE_MODE", "0", 0); //TODO temporary + setenv("PSM3_MULTI_EP", "1", 0); + if (attr.in.mnic_type == ATL_MNIC_NONE) + setenv("PSM3_NIC", "any", 0); + + char* hydra_uuid_env = getenv("I_MPI_HYDRA_UUID"); + if (hydra_uuid_env) { + setenv("FI_PSM2_UUID", hydra_uuid_env, 0); + setenv("FI_PSM3_UUID", hydra_uuid_env, 0); + } + + setenv("FI_OFI_RXM_USE_HASH", "0", 0); + setenv("FI_OFI_RXM_USE_SRX", "0", 0); + setenv("FI_OFI_RXM_RX_SIZE", "8192", 0); + setenv("FI_OFI_RXM_TX_SIZE", "8192", 0); + setenv("FI_OFI_RXM_MSG_RX_SIZE", "128", 0); + setenv("FI_OFI_RXM_MSG_TX_SIZE", "128", 0); + + setenv("FI_SHM_TX_SIZE", "8192", 0); + setenv("FI_SHM_RX_SIZE", "8192", 0); + +#ifdef CCL_ENABLE_SYCL + setenv("FI_SHM_DISABLE_CMA", "1", 0); +#endif // CCL_ENABLE_SYCL + + atl_ofi_adjust_env(attr); + + /* + load libfabric symbols into global namespace + to workaround issue with undefined symbols + in case of out-of-tree providers, like OFI/PSM3 + */ + global_data.dlhandle = dlopen("libfabric.so", RTLD_GLOBAL | RTLD_NOW); + if (global_data.dlhandle == nullptr) { + LOG_WARN("dlopen (libfabric.so): ", dlerror()); + } + + global_data.is_env_inited = 1; + + return ATL_STATUS_SUCCESS; +} + +atl_status_t atl_ofi_get_prov_list(atl_ctx_t* ctx, + const char* prov_name, + struct fi_info* base_hints, + struct fi_info** out_prov_list) { + struct fi_info* hints = nullptr; + struct fi_info* prov_list = nullptr; + ssize_t ret = 0; + int fi_version = FI_VERSION(global_data.fi_major_version, global_data.fi_minor_version); + const char* prov_name_str = (prov_name) ? prov_name : ""; + + hints = fi_dupinfo(base_hints); + if (!hints) { + LOG_ERROR("fi_dupinfo error"); + goto err; + } + + *out_prov_list = nullptr; + + LOG_DEBUG("request providers with name: ", prov_name_str); + + hints->fabric_attr->prov_name = (prov_name) ? strdup(prov_name) : nullptr; + + ret = fi_getinfo(fi_version, nullptr, nullptr, 0ULL, hints, &prov_list); + if (ret || !prov_list) { + LOG_ERROR("fi_getinfo error: ret ", ret, ", providers ", (void*)prov_list); + goto err; + } + + if (prov_list->domain_attr->max_ep_tx_ctx > 1) { + hints->ep_attr->tx_ctx_cnt = ctx->ep_count; + hints->ep_attr->rx_ctx_cnt = ctx->ep_count; + } + else { + hints->ep_attr->tx_ctx_cnt = 1; + hints->ep_attr->rx_ctx_cnt = 1; + } + + fi_freeinfo(prov_list); + prov_list = nullptr; + + ret = fi_getinfo(fi_version, nullptr, nullptr, 0ULL, hints, &prov_list); + if (ret || !prov_list) { + LOG_ERROR("fi_getinfo error, prov_name ", prov_name_str); + goto err; + } + + fi_freeinfo(hints); + hints = nullptr; + + *out_prov_list = prov_list; + return ATL_STATUS_SUCCESS; + +err: + LOG_ERROR("can't create providers for name ", prov_name_str); + return ATL_STATUS_FAILURE; +} + +atl_status_t atl_ofi_prov_init(atl_ctx_t* ctx, + struct fi_info* info, + atl_ofi_prov_t* prov, + atl_attr_t* attr, + std::unique_ptr& pmi) { + struct fi_av_attr av_attr; + size_t ep_idx = 0; + ssize_t ret = 0; + + memset(&av_attr, 0, sizeof(av_attr)); + + atl_ofi_ctx_t* ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); + + if (ctx->coord.global_idx == 0) { + LOG_INFO("provider: ", info->fabric_attr->prov_name); + LOG_INFO(" nic: ", atl_ofi_get_nic_name(info)); + LOG_INFO(" mr_mode: ", info->domain_attr->mr_mode); + LOG_INFO(" threading: ", info->domain_attr->threading); + LOG_INFO(" tx_ctx_cnt: ", info->domain_attr->tx_ctx_cnt); + LOG_INFO(" max_ep_tx_ctx: ", info->domain_attr->max_ep_tx_ctx); + LOG_INFO(" max_msg_size: ", info->ep_attr->max_msg_size); + } + + prov->info = fi_dupinfo(info); + + if (!prov->info) { + LOG_ERROR("fi_dupinfo error"); + goto err; + } + + prov->max_msg_size = info->ep_attr->max_msg_size; + + ATL_OFI_CALL(fi_fabric(info->fabric_attr, &prov->fabric, nullptr), ret, goto err); + + ATL_OFI_CALL(fi_domain(prov->fabric, info, &prov->domain, nullptr), ret, goto err); + + av_attr.type = FI_AV_TABLE; + av_attr.rx_ctx_bits = prov->rx_ctx_bits = (int)ceil(log2(prov->info->ep_attr->rx_ctx_cnt)); + + ATL_OFI_CALL(fi_av_open(prov->domain, &av_attr, &prov->av, nullptr), ret, goto err); + + if (info->domain_attr->max_ep_tx_ctx > 1) { + ATL_OFI_CALL(fi_scalable_ep(prov->domain, info, &prov->sep, nullptr), ret, goto err); + ATL_OFI_CALL(fi_scalable_ep_bind(prov->sep, &prov->av->fid, 0), ret, goto err); + } + + prov->eps = (atl_ofi_prov_ep_t*)calloc(1, sizeof(atl_ofi_prov_ep_t) * ctx->ep_count); + if (!prov->eps) { + LOG_ERROR("can't allocate prov->eps"); + goto err; + } + + for (ep_idx = 0; ep_idx < ctx->ep_count; ep_idx++) { + ret = atl_ofi_prov_ep_init(prov, ep_idx); + if (ret) { + LOG_ERROR("atl_ofi_prov_ep_init error"); + goto err; + } + } + + if (prov->sep) { + fi_enable(prov->sep); + } + + /* TODO: make separate function to be called on CCL comm creation */ + ret = atl_ofi_prov_eps_connect(ofi_ctx, prov->idx, pmi); + if (ret) { + LOG_ERROR("atl_ofi_prov_eps_connect error, prov_idx ", prov->idx); + goto err; + } + + ATL_CALL(atl_ofi_adjust_out_tag(prov, attr), goto err); + + return ATL_STATUS_SUCCESS; + +err: + LOG_ERROR("can't init provider ", atl_ofi_get_nic_name(info)); + return ATL_STATUS_FAILURE; +} + +atl_status_t atl_ofi_adjust_out_tag(atl_ofi_prov_t* prov, atl_attr_t* attr) { + size_t tag_bits = 64; + uint64_t mem_tag_format = prov->info->ep_attr->mem_tag_format; + while (tag_bits && !(mem_tag_format & ((uint64_t)1 << (tag_bits - 1)))) { + tag_bits--; + } + + attr->out.tag_bits = std::min(attr->out.tag_bits, tag_bits); + + if (attr->out.tag_bits == 64) { + attr->out.max_tag = 0xFFFFFFFFFFFFFFFF; + } + else { + attr->out.max_tag = (((uint64_t)1 << attr->out.tag_bits) - 1); + } + + const char* prov_name = prov->info->fabric_attr->prov_name; + + CCL_THROW_IF_NOT(attr->out.tag_bits > 0, + "unexpected tag_bits ", + attr->out.tag_bits, + " for prov ", + prov_name); + + CCL_THROW_IF_NOT( + attr->out.max_tag > 0, "unexpected max_tag ", attr->out.max_tag, " for prov ", prov_name); + + LOG_INFO(prov_name, + " tag_bits: ", + attr->out.tag_bits, + ", max_tag: ", + attr->out.max_tag, + ", mem_tag_format: ", + mem_tag_format); + + return ATL_STATUS_SUCCESS; +} + +/* determine if NIC has already been included in others */ +int atl_ofi_nic_already_used(const struct fi_info* prov, + struct fi_info** others, + size_t nic_count) { + for (size_t i = 0; i < nic_count; i++) { + if (prov->nic && others[i]->nic && prov->nic->bus_attr->bus_type == FI_BUS_PCI && + others[i]->nic->bus_attr->bus_type == FI_BUS_PCI) { + struct fi_pci_attr pci = prov->nic->bus_attr->attr.pci; + struct fi_pci_attr other_pci = others[i]->nic->bus_attr->attr.pci; + LOG_DEBUG("compare nic ", + prov->fabric_attr->prov_name, + " pci ", + (int)pci.domain_id, + ":", + (int)pci.bus_id, + ":", + (int)pci.device_id, + ":", + (int)pci.function_id, + " with nic ", + others[i]->fabric_attr->prov_name, + " pci ", + (int)other_pci.domain_id, + ":", + (int)other_pci.bus_id, + ":", + (int)other_pci.device_id, + ":", + (int)other_pci.function_id); + if (pci.domain_id == other_pci.domain_id && pci.bus_id == other_pci.bus_id && + pci.device_id == other_pci.device_id && pci.function_id == other_pci.function_id) + return 1; + } + else { + LOG_DEBUG("compare nic ", + atl_ofi_get_nic_name(prov), + " with nic ", + atl_ofi_get_nic_name(others[i])); + if (atl_ofi_get_short_nic_name(prov) == atl_ofi_get_short_nic_name(others[i])) + return 1; + } + } + return 0; +} + +/* return true if the NIC is bound to the same socket as calling process */ +int atl_ofi_is_nic_local(struct fi_info* info) { + if (info->nic && info->nic->bus_attr->bus_type == FI_BUS_PCI) { + struct fi_pci_attr pci = info->nic->bus_attr->attr.pci; + return ccl::global_data::get().hwloc_wrapper->is_dev_close_by_pci( + pci.domain_id, pci.bus_id, pci.device_id, pci.function_id); + } + return 0; +} + +atl_status_t atl_ofi_parse_mnic_name(atl_ctx_t* ctx, std::string str_to_parse) { + atl_status_t ret = ATL_STATUS_SUCCESS; + atl_ofi_ctx_t* ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); + + std::string include_str; + std::string exclude_str; + + auto pos = str_to_parse.find('^'); + + if (pos == 0) { + exclude_str = str_to_parse.substr(1); + } + else { + if (pos != std::string::npos) { + include_str = str_to_parse.substr(0, pos - 1); + exclude_str = str_to_parse.substr(pos + 1); + } + else { + include_str = str_to_parse.substr(0, pos); + } + } + + if (!include_str.empty()) { + LOG_DEBUG("include names str: ", include_str); + } + + if (!exclude_str.empty()) { + LOG_DEBUG("exclude names str: ", exclude_str); + } + + auto include_names = tokenize>(include_str, ','); + auto exclude_names = tokenize>(exclude_str, ','); + + if (!include_names.empty() && !exclude_names.empty()) { + auto include_set = std::set(include_names.begin(), include_names.end()); + auto exclude_set = std::set(exclude_names.begin(), exclude_names.end()); + + std::set intersect; + std::set_intersection(include_set.begin(), + include_set.end(), + exclude_set.begin(), + exclude_set.end(), + std::inserter(intersect, intersect.begin())); + if (!intersect.empty()) { + LOG_ERROR("include and exclude sets can not intersect"); + ret = ATL_STATUS_FAILURE; + } + + for (auto include_name : include_names) { + for (auto exclude_name : exclude_names) { + std::string& larger_name = + (include_name.size() > exclude_name.size()) ? include_name : exclude_name; + std::string& smaller_name = + (include_name.size() > exclude_name.size()) ? exclude_name : include_name; + if (larger_name.substr(0, smaller_name.size()) == smaller_name) { + LOG_ERROR("include name ", + include_name, + " and exclude name ", + exclude_name, + " have commom prefix"); + ret = ATL_STATUS_FAILURE; + break; + } + } + } + } + + if (ret == ATL_STATUS_SUCCESS) { + LOG_DEBUG("include names: ", vec_to_string(include_names)); + LOG_DEBUG("exclude names: ", vec_to_string(exclude_names)); + ofi_ctx->mnic_include_names = include_names; + ofi_ctx->mnic_exclude_names = exclude_names; + } + + return ret; +} + +int atl_ofi_is_allowed_nic_name(atl_ofi_ctx_t* ofi_ctx, struct fi_info* info) { + auto& include_names = ofi_ctx->mnic_include_names; + auto& exclude_names = ofi_ctx->mnic_exclude_names; + std::string nic_name = atl_ofi_get_short_nic_name(info); + + int should_include = 0; + int should_exclude = 0; + + if (include_names.empty()) { + should_include = 1; + } + + for (auto name : include_names) { + if (nic_name.substr(0, name.size()) == name) { + should_include = 1; + break; + } + } + + for (auto name : exclude_names) { + if (nic_name.substr(0, name.size()) == name) { + should_exclude = 1; + break; + } + } + + return (should_include && !should_exclude); +} + +atl_status_t atl_ofi_open_nw_provs(atl_ctx_t* ctx, + struct fi_info* base_hints, + atl_attr_t* attr, + std::unique_ptr& pmi) { + atl_status_t ret = ATL_STATUS_SUCCESS; + struct fi_info* prov_list = nullptr; + struct fi_info* prov_iter = nullptr; + size_t idx = 0, prov_idx = 0; + char* prov_name = nullptr; + atl_ofi_prov_t* prov = nullptr; + size_t name_prov_count = 0; + size_t topo_prov_count = 0; + size_t final_prov_count = 0; + struct fi_info* name_prov_list[ATL_OFI_MAX_NW_PROV_COUNT] = { 0 }; + struct fi_info* topo_prov_list[ATL_OFI_MAX_NW_PROV_COUNT] = { 0 }; + struct fi_info* final_prov_list[ATL_OFI_MAX_NW_PROV_COUNT] = { 0 }; + std::set all_nic_names; + + atl_ofi_ctx_t* ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); + + ofi_ctx->nw_prov_count = 0; + + /* 1. get full list of providers */ + if (strlen(global_data.prov_env_copy) && !strstr(global_data.prov_env_copy, ",")) + prov_name = global_data.prov_env_copy; + else + prov_name = nullptr; + ATL_CALL(atl_ofi_get_prov_list(ctx, prov_name, base_hints, &prov_list), goto err); + + /* 2. filter out by names */ + prov_iter = prov_list; + while (prov_iter) { + LOG_DEBUG("name filter: check nic ", atl_ofi_get_nic_name(prov_iter)); + if (!atl_ofi_nic_already_used(prov_iter, name_prov_list, name_prov_count)) { + all_nic_names.insert(atl_ofi_get_short_nic_name(prov_iter)); + if (atl_ofi_is_allowed_nic_name(ofi_ctx, prov_iter)) { + LOG_DEBUG("name filter: found suitable nic ", atl_ofi_get_nic_name(prov_iter)); + name_prov_list[name_prov_count] = fi_dupinfo(prov_iter); + name_prov_count++; + } + } + prov_iter = prov_iter->next; + } + + if (!name_prov_count) { + LOG_ERROR("name filter: can not find network providers", + ", include names: ", + vec_to_string(ofi_ctx->mnic_include_names), + ", exclude names: ", + vec_to_string(ofi_ctx->mnic_exclude_names), + ", all names: ", + vec_to_string(all_nic_names)); + goto err; + } + + /* 3. filter out by topo */ + if (ofi_ctx->mnic_type == ATL_MNIC_NONE) { + topo_prov_list[topo_prov_count] = fi_dupinfo(name_prov_list[0]); + topo_prov_count++; + } + else { + struct fid_nic* nic = nullptr; + for (idx = 0; idx < name_prov_count; idx++) { + prov_iter = name_prov_list[idx]; + LOG_DEBUG("topo filter: check nic ", atl_ofi_get_nic_name(prov_iter)); + nic = prov_iter->nic; + + LOG_DEBUG("topo filter: check nic ", + atl_ofi_get_nic_name(prov_iter), + ", has nic_attr ", + (nic != nullptr)); + + if (!atl_ofi_nic_already_used(prov_iter, topo_prov_list, topo_prov_count)) { + int is_local = atl_ofi_is_nic_local(prov_iter); + LOG_DEBUG( + "topo filter: nic ", atl_ofi_get_nic_name(prov_iter), ", is_local ", is_local); + if (ofi_ctx->mnic_type == ATL_MNIC_GLOBAL || + (ofi_ctx->mnic_type == ATL_MNIC_LOCAL && is_local)) { + LOG_DEBUG("topo filter: found suitable nic ", atl_ofi_get_nic_name(prov_iter)); + topo_prov_list[topo_prov_count] = fi_dupinfo(prov_iter); + topo_prov_count++; + } + } + else { + LOG_DEBUG("topo filter: nic ", atl_ofi_get_nic_name(prov_iter), " already used"); + } + } + } + + if (!topo_prov_count) { + LOG_ERROR("topo filter: can not find network providers, mnic_type ", ofi_ctx->mnic_type); + goto err; + } + + /* 4. filter out by count */ + for (idx = 0; idx < topo_prov_count; idx++) { + prov_iter = topo_prov_list[idx]; + LOG_DEBUG("count filter: check nic ", atl_ofi_get_nic_name(prov_iter)); + if (final_prov_count < ofi_ctx->mnic_count) { + LOG_DEBUG("count filter: found suitable nic ", + atl_ofi_get_nic_name(prov_iter), + ", nic idx ", + final_prov_count); + final_prov_list[final_prov_count] = fi_dupinfo(prov_iter); + final_prov_count++; + } + else { + break; + } + } + + if (!final_prov_count) { + LOG_ERROR("count filter: can not find network providers, mnic_count ", ofi_ctx->mnic_count); + goto err; + } + + /* 5. create network providers */ + LOG_INFO("found ", final_prov_count, " nic(s) according to all filters"); + ofi_ctx->nw_prov_count = final_prov_count; + for (idx = 0; idx < ofi_ctx->nw_prov_count; idx++) { + prov_idx = ofi_ctx->nw_prov_first_idx + idx; + prov = &ofi_ctx->provs[prov_idx]; + prov->idx = prov_idx; + prov->is_shm = 0; + ATL_CALL(atl_ofi_prov_init(ctx, final_prov_list[idx], prov, attr, pmi), goto err); + } + +exit: + for (idx = 0; idx < final_prov_count; idx++) { + if (final_prov_list[idx]) + fi_freeinfo(final_prov_list[idx]); + } + + for (idx = 0; idx < topo_prov_count; idx++) { + if (topo_prov_list[idx]) + fi_freeinfo(topo_prov_list[idx]); + } + + for (idx = 0; idx < name_prov_count; idx++) { + if (name_prov_list[idx]) + fi_freeinfo(name_prov_list[idx]); + } + + fi_freeinfo(prov_list); + + ofi_ctx->prov_count += ofi_ctx->nw_prov_count; + + return ret; + +err: + LOG_ERROR("can not open network providers"); + ret = ATL_STATUS_FAILURE; + goto exit; +} diff --git a/src/atl/ofi/atl_ofi_helper.hpp b/src/atl/ofi/atl_ofi_helper.hpp new file mode 100644 index 000000000..59e776e24 --- /dev/null +++ b/src/atl/ofi/atl_ofi_helper.hpp @@ -0,0 +1,312 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "atl.h" +#include "common/global/global.hpp" +#include "hwloc/hwloc_wrapper.hpp" +#ifdef CCL_ENABLE_OFI_HMEM +#include "sched/entry/gpu/ze_primitives.hpp" +#endif // CCL_ENABLE_OFI_HMEM + +#define ATL_OFI_BASE_PM_KEY "atl-ofi" +#define ATL_OFI_FI_ADDR_PM_KEY ATL_OFI_BASE_PM_KEY "-fiaddr" +#define ATL_OFI_HOSTNAME_PM_KEY ATL_OFI_BASE_PM_KEY "-hostname" + +#define ATL_OFI_MAJOR_VERSION "CCL_ATL_OFI_MAJOR_VERSION" +#define ATL_OFI_MINOR_VERSION "CCL_ATL_OFI_MINOR_VERSION" +#define ATL_OFI_TIMEOUT_SEC_ENV "CCL_ATL_OFI_TIMEOUT_SEC" +#define ATL_OFI_MAX_RETRY_COUNT_ENV "CCL_ATL_OFI_MAX_RETRY_COUNT" + +#define ATL_OFI_DEFAULT_TIMEOUT_SEC 60 +#define ATL_OFI_MAX_RETRY_COUNT 10000 +#define ATL_OFI_WAIT_SEC 10 +#define ATL_OFI_CQ_READ_ITERS 10000 +#define ATL_OFI_CQ_BUNCH_SIZE 8 + +#define ATL_OFI_MAX_PROV_ENV_LEN 128 +#define ATL_OFI_PMI_PROV_MULTIPLIER 100 +#define ATL_OFI_PMI_PROC_MULTIPLIER (ATL_OFI_PMI_PROV_MULTIPLIER * 10) +#define ATL_OFI_MAX_NW_PROV_COUNT 1024 +#define ATL_OFI_MAX_PROV_COUNT (ATL_OFI_MAX_NW_PROV_COUNT + 1) /* NW and SHM providers */ +#define ATL_OFI_MAX_ACTIVE_PROV_COUNT \ + 2 /* by current scheme each EP may use only SHM and 1 NW prov */ +#define ATL_OFI_SHM_PROV_NAME "shm" + +#define ATL_OFI_MAX_ZE_DEV_COUNT 1024 + +#ifndef PRId64 +#define PRId64 "lld" +#endif + +#define MAX(a, b) \ + ({ \ + __typeof__(a) _a = (a); \ + __typeof__(b) _b = (b); \ + _a > _b ? _a : _b; \ + }) + +#define MIN(a, b) \ + ({ \ + __typeof__(a) _a = (a); \ + __typeof__(b) _b = (b); \ + _a < _b ? _a : _b; \ + }) + +#define ATL_OFI_CALL(func, ret_val, err_action) \ + do { \ + (ret_val) = func; \ + if ((ret_val) != FI_SUCCESS) { \ + LOG_ERROR( \ + #func "\n fails with ret: ", ret_val, ", strerror: ", fi_strerror(-(ret_val))); \ + err_action; \ + } \ + } while (0) + +#define ATL_OFI_RETRY(func, ep, ret_val) \ + do { \ + atl_ctx_t* ctx_local = (ep)->ctx; \ + atl_ofi_ctx_t* ofi_ctx_local = container_of(ctx_local, atl_ofi_ctx_t, ctx); \ + size_t max_retry_count = ofi_ctx_local->max_retry_count; \ + size_t retry_count = 0; \ + do { \ + (ret_val) = func; \ + if ((ret_val) == FI_SUCCESS) \ + break; \ + if ((ret_val) != -FI_EAGAIN) { \ + LOG_ERROR(#func "\n fails with ret: ", \ + (ret_val), \ + ", strerror: ", \ + fi_strerror(-(ret_val))); \ + CCL_THROW("OFI function error"); \ + break; \ + } \ + (void)atl_ep_poll(ep); \ + retry_count++; \ + } while (((ret_val) == -FI_EAGAIN) && (retry_count < max_retry_count)); \ + } while (0) + +/* OFI returns 0 or -errno */ +#define RET2ATL(ret) \ + ({ \ + atl_status_t res; \ + if ((ret) == -FI_EAGAIN) \ + res = ATL_STATUS_AGAIN; \ + else \ + res = (ret) ? ATL_STATUS_FAILURE : ATL_STATUS_SUCCESS; \ + res; \ + }) + +inline long int safe_c_strtol(const char* str, char** endptr, int base) { + long int val = strtol(str, endptr, base); + if (val == 0) { + /* if a conversion error occurred, display a message and exit */ + if (errno == EINVAL) { + LOG_ERROR("conversion error occurred for string: ", str); + } + /* if the value provided was out of range, display a error message */ + if (errno == ERANGE) { + LOG_ERROR("the value provided was out of range, string: ", str); + } + } + return val; +} + +typedef enum { + ATL_OFI_COMP_POSTED, + ATL_OFI_COMP_COMPLETED, + ATL_OFI_COMP_PEEK_STARTED, + ATL_OFI_COMP_PEEK_FOUND, + ATL_OFI_COMP_PEEK_NOT_FOUND, +} atl_ofi_comp_state_t; + +typedef struct { + atl_mr_t mr; + struct fid_mr* fi_mr; +} atl_ofi_mr_t; + +typedef struct { + void* addr; + size_t len; +} atl_ofi_prov_ep_name_t; + +typedef struct { + struct fid_ep* tx; + struct fid_ep* rx; + struct fid_cq* cq; + atl_ofi_prov_ep_name_t name; +} atl_ofi_prov_ep_t; + +typedef struct { + size_t idx; + struct fi_info* info; + struct fid_fabric* fabric; + struct fid_domain* domain; + struct fid_av* av; + atl_ofi_prov_ep_t* eps; + + int is_shm; + size_t max_msg_size; + + /* used only in case of SEP supported */ + struct fid_ep* sep; + int rx_ctx_bits; + + /* table[0..proc_count][0..ep_count] */ + fi_addr_t* addr_table; + size_t addr_len; + int first_proc_idx; +} atl_ofi_prov_t; + +typedef struct { + atl_ep_t ep; + + /* used to make progressing only for really used providers */ + size_t active_prov_count; + size_t active_prov_idxs[ATL_OFI_MAX_ACTIVE_PROV_COUNT]; + +} atl_ofi_ep_t; + +typedef struct { + atl_ctx_t ctx; + pm_rt_desc_t* pm_rt; + atl_ofi_prov_t provs[ATL_OFI_MAX_PROV_COUNT]; + size_t prov_count; + size_t nw_prov_count; + size_t nw_prov_first_idx; + size_t shm_prov_idx; + size_t max_retry_count; + atl_progress_mode_t progress_mode; + atl_mnic_t mnic_type; + std::vector mnic_include_names; + std::vector mnic_exclude_names; + size_t mnic_count; + int enable_hmem; +} atl_ofi_ctx_t; + +typedef struct { + struct fi_context fi_ctx; + atl_ofi_prov_ep_t* prov_ep; + struct fid_ep* fi_ep; + atl_ofi_comp_state_t comp_state; + size_t recv_len; + struct fid_mr* mr; +} atl_ofi_req_t; + +#ifdef CCL_ENABLE_OFI_HMEM +typedef struct atl_ofi_ze_data { + ze_driver_handle_t driver; + ze_context_handle_t context; + uint32_t device_count; + ze_device_handle_t devices[ATL_OFI_MAX_ZE_DEV_COUNT]; + + atl_ofi_ze_data() : driver(nullptr), context(nullptr), device_count(0) {} + +} atl_ofi_ze_data_t; +#endif // CCL_ENABLE_OFI_HMEM + +typedef struct atl_ofi_global_data { + size_t ctx_count; + int is_env_inited; + void* dlhandle; + char prov_env_copy[ATL_OFI_MAX_PROV_ENV_LEN]; + + int fi_major_version; + int fi_minor_version; + +#ifdef CCL_ENABLE_OFI_HMEM + atl_ofi_ze_data ze_data; +#endif // CCL_ENABLE_OFI_HMEM + + atl_ofi_global_data() + : ctx_count(0), + is_env_inited(0), + dlhandle(nullptr), + prov_env_copy(), + fi_major_version(1), + fi_minor_version(10) { + memset(prov_env_copy, 0, sizeof(prov_env_copy)); + } +} atl_ofi_global_data_t; + +extern atl_ofi_global_data_t global_data; + +template +std::string vec_to_string(Container& elems); + +#ifdef CCL_ENABLE_OFI_HMEM +void atl_ofi_init_ze_data(); +#endif // CCL_ENABLE_OFI_HMEM + +void atl_ofi_print_coord(atl_proc_coord_t* coord); +std::string atl_ofi_get_short_nic_name(const struct fi_info* prov); +std::string atl_ofi_get_nic_name(const struct fi_info* prov); +atl_ofi_prov_t* atl_ofi_get_prov(atl_ep_t* ep, int peer_proc_idx, size_t msg_size); +fi_addr_t atl_ofi_get_addr(atl_ctx_t* ctx, atl_ofi_prov_t* prov, int proc_idx, size_t ep_idx); +atl_status_t atl_ofi_get_local_proc_coord(atl_ofi_ctx_t* ofi_ctx, std::unique_ptr& pmi); +atl_status_t atl_ofi_prov_update_addr_table(atl_ofi_ctx_t* ofi_ctx, + size_t prov_idx, + std::unique_ptr& pmi); +atl_status_t atl_ofi_prov_ep_get_name(atl_ofi_prov_t* prov, size_t ep_idx); +atl_status_t atl_ofi_prov_eps_connect(atl_ofi_ctx_t* ofi_ctx, + size_t prov_idx, + std::unique_ptr& pmi); +void atl_ofi_prov_ep_destroy(atl_ofi_prov_t* prov, atl_ofi_prov_ep_t* ep); +void atl_ofi_prov_destroy(atl_ctx_t* ctx, atl_ofi_prov_t* prov); +int atl_ofi_wait_cancel_cq(struct fid_cq* cq); +atl_status_t atl_ofi_prov_ep_init(atl_ofi_prov_t* prov, size_t ep_idx); +atl_status_t atl_ofi_try_to_drain_cq_err(struct fid_cq* cq); +int atl_ofi_try_to_drain_cq(struct fid_cq* cq); +void atl_ofi_reset(atl_ctx_t* ctx); +atl_status_t atl_ofi_adjust_env(const atl_attr_t& attr); +atl_status_t atl_ofi_set_env(const atl_attr_t& attr); +atl_status_t atl_ofi_get_prov_list(atl_ctx_t* ctx, + const char* prov_name, + struct fi_info* base_hints, + struct fi_info** out_prov_list); +atl_status_t atl_ofi_prov_init(atl_ctx_t* ctx, + struct fi_info* info, + atl_ofi_prov_t* prov, + atl_attr_t* attr, + std::unique_ptr& pmi); +atl_status_t atl_ofi_adjust_out_tag(atl_ofi_prov_t* prov, atl_attr_t* attr); +int atl_ofi_nic_already_used(const struct fi_info* prov, struct fi_info** others, size_t nic_count); +int atl_ofi_is_nic_local(struct fi_info* info); +atl_status_t atl_ofi_parse_mnic_name(atl_ctx_t* ctx, std::string str_to_parse); +int atl_ofi_is_allowed_nic_name(atl_ofi_ctx_t* ofi_ctx, struct fi_info* info); +atl_status_t atl_ofi_open_nw_provs(atl_ctx_t* ctx, + struct fi_info* base_hints, + atl_attr_t* attr, + std::unique_ptr& pmi); diff --git a/src/atl/ofi/atl_ofi_impl.cpp b/src/atl/ofi/atl_ofi_impl.cpp deleted file mode 100644 index ea34fef84..000000000 --- a/src/atl/ofi/atl_ofi_impl.cpp +++ /dev/null @@ -1,2142 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "atl.h" -#include "hwloc/hwloc_wrapper.h" - -#define ATL_OFI_BASE_PM_KEY "atl-ofi" -#define ATL_OFI_FI_ADDR_PM_KEY ATL_OFI_BASE_PM_KEY "-fiaddr" -#define ATL_OFI_HOSTNAME_PM_KEY ATL_OFI_BASE_PM_KEY "-hostname" - -#define ATL_OFI_TIMEOUT_SEC_ENV "ATL_OFI_TIMEOUT_SEC" -#define ATL_OFI_MAX_RETRY_COUNT_ENV "ATL_OFI_MAX_RETRY_COUNT" - -#define ATL_OFI_DEFAULT_TIMEOUT_SEC 60 -#define ATL_OFI_MAX_RETRY_COUNT 10000 -#define ATL_OFI_MAX_HOSTNAME_LEN 64 -#define ATL_OFI_WAIT_SEC 10 -#define ATL_OFI_CQ_READ_ITERS 10000 -#define ATL_OFI_CQ_BUNCH_SIZE 8 - -#define ATL_OFI_MAX_PROV_ENV_LEN 128 -#define ATL_OFI_PMI_PROV_MULTIPLIER 100 -#define ATL_OFI_PMI_PROC_MULTIPLIER (ATL_OFI_PMI_PROV_MULTIPLIER * 10) -#define ATL_OFI_MAX_NW_PROV_COUNT 32 -#define ATL_OFI_MAX_PROV_COUNT (ATL_OFI_MAX_NW_PROV_COUNT + 1) /* NW and SHM providers */ -#define ATL_OFI_MAX_ACTIVE_PROV_COUNT \ - 2 /* by current scheme each EP may use only SHM and 1 NW prov */ -#define ATL_OFI_SHM_PROV_NAME "shm" - -#ifndef PRId64 -#define PRId64 "lld" -#endif - -#define MAX(a, b) \ - ({ \ - __typeof__(a) _a = (a); \ - __typeof__(b) _b = (b); \ - _a > _b ? _a : _b; \ - }) - -#define MIN(a, b) \ - ({ \ - __typeof__(a) _a = (a); \ - __typeof__(b) _b = (b); \ - _a < _b ? _a : _b; \ - }) - -static inline atl_status_t atl_ofi_ep_poll(atl_ep_t* ep); - -#define ATL_OFI_CALL(func, ret_val, err_action) \ - do { \ - ret_val = func; \ - if (ret_val != FI_SUCCESS) { \ - LOG_ERROR(#func "\n fails with ret: ", ret_val, ", strerror:", fi_strerror(-ret_val)); \ - err_action; \ - } \ - } while (0) - -#define ATL_OFI_RETRY(func, ep, ret_val) \ - do { \ - atl_ctx_t* ctx = ep->ctx; \ - atl_ofi_ctx_t* ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); \ - size_t max_retry_count = ofi_ctx->max_retry_count; \ - size_t retry_count = 0; \ - do { \ - ret_val = func; \ - if (ret_val == FI_SUCCESS) \ - break; \ - if (ret_val != -FI_EAGAIN) { \ - LOG_ERROR( \ - #func "\n fails with ret: ", ret_val, ", strerror: ", fi_strerror(-ret_val)); \ - CCL_THROW("OFI function error"); \ - break; \ - } \ - (void)atl_ofi_ep_poll(ep); \ - retry_count++; \ - } while ((ret_val == -FI_EAGAIN) && (retry_count < max_retry_count)); \ - } while (0) - -/* OFI returns 0 or -errno */ -#define RET2ATL(ret) \ - ({ \ - atl_status_t res; \ - if (ret == -FI_EAGAIN) \ - res = ATL_STATUS_AGAIN; \ - else \ - res = (ret) ? ATL_STATUS_FAILURE : ATL_STATUS_SUCCESS; \ - res; \ - }) - -long int safe_c_strtol(const char* str, char** endptr, int base) { - long int val = strtol(str, endptr, base); - if (val == 0) { - /* if a conversion error occurred, display a message and exit */ - if (errno == EINVAL) { - LOG_ERROR("conversion error occurred for string: ", str); - } - /* if the value provided was out of range, display a error message */ - if (errno == ERANGE) { - LOG_ERROR("the value provided was out of range, string: ", str); - } - } - return val; -} - -typedef enum { - ATL_OFI_COMP_POSTED, - ATL_OFI_COMP_COMPLETED, - ATL_OFI_COMP_PEEK_STARTED, - ATL_OFI_COMP_PEEK_FOUND, - ATL_OFI_COMP_PEEK_NOT_FOUND, -} atl_ofi_comp_state_t; - -typedef struct { - atl_mr_t mr; - struct fid_mr* fi_mr; -} atl_ofi_mr_t; - -typedef struct { - void* addr; - size_t len; -} atl_ofi_prov_ep_name_t; - -typedef struct { - struct fid_ep* tx; - struct fid_ep* rx; - struct fid_cq* cq; - atl_ofi_prov_ep_name_t name; -} atl_ofi_prov_ep_t; - -typedef struct { - size_t idx; - struct fi_info* info; - struct fid_fabric* fabric; - struct fid_domain* domain; - struct fid_av* av; - atl_ofi_prov_ep_t* eps; - - int is_shm; - size_t max_msg_size; - - /* used only in case of SEP supported */ - struct fid_ep* sep; - int rx_ctx_bits; - - /* table[0..proc_count][0..ep_count] */ - fi_addr_t* addr_table; - size_t addr_len; - int first_proc_idx; -} atl_ofi_prov_t; - -typedef struct { - atl_ep_t ep; - - /* used to make progressing only for really used providers */ - size_t active_prov_count; - size_t active_prov_idxs[ATL_OFI_MAX_ACTIVE_PROV_COUNT]; - -} atl_ofi_ep_t; - -typedef struct { - atl_ctx_t ctx; - pm_rt_desc_t* pm_rt; - atl_ofi_prov_t provs[ATL_OFI_MAX_PROV_COUNT]; - size_t prov_count; - size_t nw_prov_count; - size_t nw_prov_first_idx; - size_t shm_prov_idx; - size_t max_retry_count; - atl_progress_mode_t progress_mode; - atl_mnic_t mnic_type; - size_t mnic_count; -} atl_ofi_ctx_t; - -typedef struct { - struct fi_context fi_ctx; - atl_ofi_prov_ep_t* prov_ep; - struct fid_ep* fi_ep; - atl_ofi_comp_state_t comp_state; - size_t recv_len; -} atl_ofi_req_t; - -typedef struct atl_ofi_global_data { - size_t ctx_count; - int is_env_inited; - void* dlhandle; - char prov_env_copy[ATL_OFI_MAX_PROV_ENV_LEN]; - - atl_ofi_global_data() : ctx_count(0), is_env_inited(0), dlhandle(NULL) { - memset(prov_env_copy, 0, sizeof(prov_env_copy)); - } -} atl_ofi_global_data_t; - -static atl_ofi_global_data_t global_data; - -static void atl_ofi_print_coord(atl_proc_coord_t* coord) { - LOG_DEBUG("coord: global [idx ", - coord->global_idx, - ", cnt ", - coord->global_count, - "], local [idx ", - coord->local_idx, - ", cnt ", - coord->local_count, - "]"); -} - -static std::string atl_ofi_get_nic_name(const struct fi_info* prov) { - std::stringstream ss; - //ss << prov->fabric_attr->prov_name << ":" << prov->fabric_attr->name << ":" << prov->domain_attr->name; - ss << prov->fabric_attr->prov_name << ":" << prov->domain_attr->name; - return ss.str(); -} - -static inline atl_ofi_prov_t* atl_ofi_get_prov(atl_ep_t* ep, int peer_proc_idx, size_t msg_size) { - size_t prov_idx; - atl_ofi_ctx_t* ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); - - CCL_THROW_IF_NOT(ofi_ctx->prov_count <= ATL_OFI_MAX_PROV_COUNT, - "unexpected prov_count ", - ofi_ctx->prov_count); - - atl_proc_coord_t* coord = &(ep->ctx->coord); - int my_node_idx = coord->global_idx / coord->local_count; - int peer_node_idx = peer_proc_idx / coord->local_count; - - int has_shm = (ofi_ctx->prov_count == ofi_ctx->nw_prov_count + 1) ? 1 : 0; - - if (has_shm && (my_node_idx == peer_node_idx) && - (msg_size <= ofi_ctx->provs[ofi_ctx->shm_prov_idx].max_msg_size)) { - prov_idx = ofi_ctx->shm_prov_idx; - } - else { - size_t nw_prov_offset = ep->idx % ofi_ctx->nw_prov_count; - prov_idx = ofi_ctx->nw_prov_first_idx + nw_prov_offset; - } - - LOG_DEBUG("get_prov: ep_idx ", - ep->idx, - ", prov_idx ", - prov_idx, - ", my_node_idx ", - my_node_idx, - ", peer_node_idx ", - peer_node_idx, - ", msg_size ", - msg_size, - ", has_shm ", - has_shm); - - /* TODO: add segmentation logic */ - CCL_THROW_IF_NOT(msg_size <= ofi_ctx->provs[prov_idx].max_msg_size, - "msg_size (", - msg_size, - ") is greater than max_msg_size (", - ofi_ctx->provs[prov_idx].max_msg_size, - "), prov_idx ", - prov_idx); - - return &(ofi_ctx->provs[prov_idx]); -} - -static inline fi_addr_t atl_ofi_get_addr(atl_ctx_t* ctx, - atl_ofi_prov_t* prov, - int proc_idx, - size_t ep_idx) { - return *(prov->addr_table + ((ctx->ep_count * (proc_idx - prov->first_proc_idx)) + ep_idx)); -} - -static atl_status_t atl_ofi_get_local_proc_coord(atl_ofi_ctx_t* ofi_ctx, ipmi* pmi) { - CCL_THROW_IF_NOT(ofi_ctx, "ofi_ctx is null"); - - atl_proc_coord_t* coord = &(ofi_ctx->ctx.coord); - - atl_status_t ret = ATL_STATUS_SUCCESS; - int i; - int local_idx = 0, local_count = 0; - char* all_hostnames = NULL; - char my_hostname[ATL_OFI_MAX_HOSTNAME_LEN] = { 0 }; - size_t my_hostname_len = 0; - int my_global_proc_idx = coord->global_idx; - - gethostname(my_hostname, ATL_OFI_MAX_HOSTNAME_LEN - 1); - my_hostname_len = strlen(my_hostname); - - CCL_THROW_IF_NOT(my_hostname_len < ATL_OFI_MAX_HOSTNAME_LEN, - "unexpected my_hostname_len ", - my_hostname_len, - ", expected max ", - (size_t)(ATL_OFI_MAX_HOSTNAME_LEN)); - - if (ATL_OFI_MAX_HOSTNAME_LEN - my_hostname_len <= 10) { - LOG_WARN("hostname is quite long, len: ", my_hostname_len, ", name: ", my_hostname); - } - - snprintf(my_hostname + my_hostname_len, - ATL_OFI_MAX_HOSTNAME_LEN - my_hostname_len, - "-%d", - my_global_proc_idx); - - ret = pmi->pmrt_kvs_put((char*)ATL_OFI_HOSTNAME_PM_KEY, - my_global_proc_idx * ATL_OFI_PMI_PROC_MULTIPLIER, - my_hostname, - ATL_OFI_MAX_HOSTNAME_LEN); - - if (ret) { - LOG_ERROR("pmrt_kvs_put: ret: ", ret); - goto fn_err; - } - - pmi->pmrt_barrier(); - - all_hostnames = (char*)calloc(1, coord->global_count * ATL_OFI_MAX_HOSTNAME_LEN); - if (!all_hostnames) { - LOG_ERROR("can't allocate all_hostnames"); - goto fn_err; - } - - for (i = 0; i < coord->global_count; i++) { - ret = pmi->pmrt_kvs_get((char*)ATL_OFI_HOSTNAME_PM_KEY, - i * ATL_OFI_PMI_PROC_MULTIPLIER, - all_hostnames + i * ATL_OFI_MAX_HOSTNAME_LEN, - ATL_OFI_MAX_HOSTNAME_LEN); - if (ret) { - LOG_ERROR("pmrt_kvs_get: ret: ", ret); - goto fn_err; - } - } - - for (i = 0; i < coord->global_count; i++) { - if (!strncmp(my_hostname, - all_hostnames + i * ATL_OFI_MAX_HOSTNAME_LEN, - my_hostname_len + 1 /* including "-" at the end */)) { - local_count++; - int peer_global_proc_idx; - sscanf(all_hostnames + i * ATL_OFI_MAX_HOSTNAME_LEN + my_hostname_len + 1, - "%d", - &peer_global_proc_idx); - if (my_global_proc_idx > peer_global_proc_idx) - local_idx++; - } - } - - coord->local_idx = local_idx; - coord->local_count = local_count; - -fn_exit: - free(all_hostnames); - return ret; - -fn_err: - ret = ATL_STATUS_FAILURE; - goto fn_exit; -} - -static atl_status_t atl_ofi_prov_update_addr_table(atl_ofi_ctx_t* ofi_ctx, - size_t prov_idx, - ipmi* pmi) { - CCL_THROW_IF_NOT(ofi_ctx, "ofi_ctx is null"); - - atl_ctx_t* ctx = &(ofi_ctx->ctx); - atl_ofi_prov_t* prov = &(ofi_ctx->provs[prov_idx]); - - atl_status_t ret = ATL_STATUS_SUCCESS; - int i; - size_t j; - int insert_count; - - size_t addr_idx = 0; - char* ep_names_table; - size_t ep_names_table_len; - - size_t named_ep_count = (prov->sep ? 1 : ctx->ep_count); - - int local_count = ctx->coord.local_count; - int node_idx = ctx->coord.global_idx / local_count; - int shm_start_idx = node_idx * local_count; - int shm_end_idx = (node_idx + 1) * local_count; - - LOG_DEBUG("shm_start_idx ", shm_start_idx, ", shm_end_idx ", shm_end_idx); - - int proc_count = prov->is_shm ? ctx->coord.local_count : ctx->coord.global_count; - - if (proc_count == 0) - return ATL_STATUS_SUCCESS; - - LOG_DEBUG("name ", - atl_ofi_get_nic_name(prov->info), - ", is_shm ", - prov->is_shm, - ", addr_len ", - prov->addr_len, - ", local_count ", - ctx->coord.local_count, - ", global_count ", - ctx->coord.global_count, - ", proc_count ", - proc_count); - - /* allocate OFI EP names table that will contain all published names */ - ep_names_table_len = prov->addr_len * named_ep_count * proc_count; - - if (ep_names_table_len == 0) { - LOG_ERROR("ep_names_table_len == 0, addr_len ", - prov->addr_len, - ", named_ep_count ", - named_ep_count, - ", proc_count ", - proc_count); - return ATL_STATUS_FAILURE; - } - - ep_names_table = (char*)calloc(1, ep_names_table_len); - if (!ep_names_table) { - LOG_ERROR("can't allocate epnames_table"); - return ATL_STATUS_FAILURE; - } - - pmi->pmrt_barrier(); - - /* retrieve all OFI EP names in order */ - for (i = 0; i < ctx->coord.global_count; i++) { - if (prov->is_shm) { - if (!(i >= shm_start_idx && i < shm_end_idx)) { - continue; - } - } - - for (j = 0; j < named_ep_count; j++) { - ret = pmi->pmrt_kvs_get( - (char*)ATL_OFI_FI_ADDR_PM_KEY, - i * ATL_OFI_PMI_PROC_MULTIPLIER + prov_idx * ATL_OFI_PMI_PROV_MULTIPLIER + j, - ep_names_table + addr_idx * prov->addr_len, - prov->addr_len); - - if (ret) { - LOG_ERROR("kvs_get error: ret ", - ret, - ", proc_idx ", - i, - ", ep_idx ", - j, - ", addr_idx ", - addr_idx); - goto err_ep_names; - } - - addr_idx++; - } - } - - LOG_DEBUG( - "kvs_get: ep_count ", named_ep_count, ", proc_count ", proc_count, ", got ", addr_idx); - - if (addr_idx != named_ep_count * proc_count) { - LOG_ERROR("unexpected kvs_get results: expected ", - named_ep_count * proc_count, - ", got ", - addr_idx); - - ret = ATL_STATUS_FAILURE; - goto err_addr_table; - } - - if (prov->addr_table != NULL) - free(prov->addr_table); - - prov->addr_table = (fi_addr_t*)calloc(1, ctx->ep_count * proc_count * sizeof(fi_addr_t)); - - if (!prov->addr_table) - goto err_ep_names; - - /* insert all the EP names into the AV */ - insert_count = fi_av_insert( - prov->av, ep_names_table, named_ep_count * proc_count, prov->addr_table, 0, NULL); - - LOG_DEBUG("av_insert: ep_count ", - named_ep_count, - ", proc_count ", - proc_count, - ", inserted ", - insert_count); - - if (insert_count != (int)(named_ep_count * proc_count)) { - LOG_ERROR("unexpected av_insert results: expected ", - named_ep_count * proc_count, - " got ", - insert_count); - ret = ATL_STATUS_FAILURE; - goto err_addr_table; - } - else { - ret = ATL_STATUS_SUCCESS; - } - - if (prov->sep) { - if (named_ep_count != 1) { - LOG_ERROR("unexpected named_ep_count ", named_ep_count); - goto err_addr_table; - } - - fi_addr_t* table; - table = (fi_addr_t*)calloc(1, proc_count * sizeof(fi_addr_t)); - if (table == NULL) { - LOG_ERROR("memory allocaion failed"); - ret = ATL_STATUS_FAILURE; - goto err_addr_table; - } - memcpy(table, prov->addr_table, proc_count * sizeof(fi_addr_t)); - - for (i = 0; i < proc_count; i++) { - for (j = 0; j < ctx->ep_count; j++) { - prov->addr_table[i * ctx->ep_count + j] = - fi_rx_addr(table[i], j, prov->rx_ctx_bits); - } - } - free(table); - } - - /* normal end of execution */ - free(ep_names_table); - return ret; - - /* abnormal end of execution */ -err_addr_table: - free(prov->addr_table); - -err_ep_names: - free(ep_names_table); - return ret; -} - -static atl_status_t atl_ofi_prov_ep_get_name(atl_ofi_prov_t* prov, size_t ep_idx) { - int ret; - - atl_ofi_prov_ep_t* ep = &(prov->eps[ep_idx]); - struct fid_ep* fi_ep = (prov->sep) ? prov->sep : ep->tx; - - ep->name.addr = NULL; - ep->name.len = 0; - - ret = fi_getname(&fi_ep->fid, ep->name.addr, &(ep->name.len)); - if ((ret != -FI_ETOOSMALL) || ep->name.len <= 0) - ep->name.len = FI_NAME_MAX; - - if (ep->name.addr) - free(ep->name.addr); - - ep->name.addr = calloc(1, ep->name.len); - - if (!(ep->name.addr)) { - LOG_ERROR("can't allocate addr"); - ret = ATL_STATUS_FAILURE; - goto err_addr; - } - - ret = fi_getname(&fi_ep->fid, ep->name.addr, &(ep->name.len)); - if (ret) { - LOG_ERROR("fi_getname error"); - goto err_getname; - } - - prov->addr_len = MAX(prov->addr_len, ep->name.len); - - return ATL_STATUS_SUCCESS; - -err_getname: - free(ep->name.addr); - ep->name.addr = NULL; - ep->name.len = 0; - -err_addr: - return RET2ATL(ret); -} - -static atl_status_t atl_ofi_prov_eps_connect(atl_ofi_ctx_t* ofi_ctx, size_t prov_idx, ipmi* pmi) { - int ret; - size_t ep_idx; - - atl_ctx_t* ctx = &(ofi_ctx->ctx); - atl_ofi_prov_t* prov = &(ofi_ctx->provs[prov_idx]); - size_t named_ep_count = (prov->sep ? 1 : ctx->ep_count); - atl_proc_coord_t* coord = &(ctx->coord); - - prov->addr_len = 0; - prov->first_proc_idx = - (prov->is_shm) ? ((coord->global_idx / coord->local_count) * coord->local_count) : 0; - - for (ep_idx = 0; ep_idx < ctx->ep_count; ep_idx++) { - ret = atl_ofi_prov_ep_get_name(prov, ep_idx); - if (ret) { - LOG_ERROR("atl_ofi_prov_ep_get_name error"); - return ATL_STATUS_FAILURE; - } - } - - for (ep_idx = 0; ep_idx < named_ep_count; ep_idx++) { - atl_ofi_prov_ep_t* ep = &(prov->eps[ep_idx]); - ret = pmi->pmrt_kvs_put((char*)ATL_OFI_FI_ADDR_PM_KEY, - coord->global_idx * ATL_OFI_PMI_PROC_MULTIPLIER + - prov_idx * ATL_OFI_PMI_PROV_MULTIPLIER + ep_idx, - ep->name.addr, - ep->name.len); - if (ret) { - LOG_ERROR("pmrt_kvs_put: ret: ", ret); - return ATL_STATUS_FAILURE; - } - } - - ret = atl_ofi_prov_update_addr_table(ofi_ctx, prov_idx, pmi); - - return RET2ATL(ret); -} - -static void atl_ofi_prov_ep_destroy(atl_ofi_prov_t* prov, atl_ofi_prov_ep_t* ep) { - if (ep->rx) - fi_close(&ep->rx->fid); - - if (prov->sep && ep->tx) - fi_close(&ep->tx->fid); - - if (ep->cq) - fi_close(&ep->cq->fid); - - if (ep->name.addr) - free(ep->name.addr); - - ep->rx = ep->tx = NULL; - ep->cq = NULL; - ep->name.addr = NULL; - ep->name.len = 0; -} - -static void atl_ofi_prov_destroy(atl_ctx_t* ctx, atl_ofi_prov_t* prov) { - size_t i; - - for (i = 0; i < ctx->ep_count; i++) { - atl_ofi_prov_ep_destroy(prov, &(prov->eps[i])); - } - - free(prov->eps); - free(prov->addr_table); - - if (prov->sep) - fi_close(&prov->sep->fid); - - if (prov->av) - fi_close(&prov->av->fid); - - if (prov->domain) - fi_close(&prov->domain->fid); - - if (prov->fabric) - fi_close(&prov->fabric->fid); - - if (prov->info) { - fi_freeinfo(prov->info); - } -} - -static atl_status_t atl_ofi_prov_ep_handle_cq_err(atl_ofi_prov_ep_t* ep) { - struct fi_cq_err_entry err_entry; - atl_ofi_req_t* ofi_req; - - int ret = fi_cq_readerr(ep->cq, &err_entry, 0); - if (ret != 1) { - CCL_THROW("unable to read error from cq"); - return ATL_STATUS_FAILURE; - } - else { - ofi_req = container_of(err_entry.op_context, atl_ofi_req_t, fi_ctx); - - if (err_entry.err == FI_ECANCELED) { - return ATL_STATUS_SUCCESS; - } - - if (err_entry.err == FI_ENOMSG && ofi_req->comp_state == ATL_OFI_COMP_PEEK_STARTED) { - ofi_req->comp_state = ATL_OFI_COMP_PEEK_NOT_FOUND; - } - else { - LOG_ERROR("fi_cq_readerr: err: ", - err_entry.err, - ", prov_err: ", - fi_cq_strerror(ep->cq, err_entry.prov_errno, err_entry.err_data, NULL, 0), - "(", - err_entry.prov_errno, - ")"); - return ATL_STATUS_FAILURE; - } - return ATL_STATUS_SUCCESS; - } -} - -static inline void atl_ofi_process_comps(struct fi_cq_tagged_entry* entries, ssize_t ret) { - ssize_t idx; - atl_ofi_req_t* comp_ofi_req; - for (idx = 0; idx < ret; idx++) { - comp_ofi_req = container_of(entries[idx].op_context, atl_ofi_req_t, fi_ctx); - switch (comp_ofi_req->comp_state) { - case ATL_OFI_COMP_POSTED: comp_ofi_req->comp_state = ATL_OFI_COMP_COMPLETED; break; - case ATL_OFI_COMP_COMPLETED: break; - case ATL_OFI_COMP_PEEK_STARTED: - comp_ofi_req->comp_state = ATL_OFI_COMP_PEEK_FOUND; - break; - default: CCL_THROW("unexpected completion state ", comp_ofi_req->comp_state); break; - } - - if (entries[idx].flags & FI_RECV) { - comp_ofi_req->recv_len = entries[idx].len; - } - } -} - -static int atl_ofi_wait_cancel_cq(struct fid_cq* cq) { - struct fi_cq_err_entry err_entry; - int ret, i; - struct fi_cq_tagged_entry entries[ATL_OFI_CQ_BUNCH_SIZE]; - - double time = 0; - clock_t start, end; - - while (time < ATL_OFI_WAIT_SEC) { - for (i = 0; i < ATL_OFI_CQ_READ_ITERS; i++) { - start = clock(); - ret = fi_cq_read(cq, entries, ATL_OFI_CQ_BUNCH_SIZE); - - if (ret < 0 && ret != -FI_EAGAIN) { - ret = fi_cq_readerr(cq, &err_entry, 0); - - if (err_entry.err != FI_ECANCELED) { - LOG_ERROR("fi_cq_readerr: err: ", - err_entry.err, - ", prov_err: ", - fi_cq_strerror(cq, err_entry.prov_errno, err_entry.err_data, NULL, 0), - "(", - err_entry.prov_errno, - ")"); - return ATL_STATUS_FAILURE; - } - return ATL_STATUS_SUCCESS; - } - } - end = clock(); - time += (double)(end - start) / CLOCKS_PER_SEC; - } - - LOG_ERROR("too long for cancel"); - - return ATL_STATUS_FAILURE; -} - -static atl_status_t atl_ofi_prov_ep_init(atl_ofi_prov_t* prov, size_t ep_idx) { - ssize_t ret = 0; - - struct fi_cq_attr cq_attr; - struct fi_tx_attr tx_attr; - struct fi_rx_attr rx_attr; - - atl_ofi_prov_ep_t* ep = &(prov->eps[ep_idx]); - - memset(&cq_attr, 0, sizeof(cq_attr)); - cq_attr.format = FI_CQ_FORMAT_TAGGED; - - ATL_OFI_CALL(fi_cq_open(prov->domain, &cq_attr, &ep->cq, NULL), ret, return ATL_STATUS_FAILURE); - - if (prov->sep) { - rx_attr = *prov->info->rx_attr; - rx_attr.caps |= FI_TAGGED; - - ATL_OFI_CALL(fi_rx_context(prov->sep, ep_idx, &rx_attr, &ep->rx, NULL), ret, goto err); - - ATL_OFI_CALL(fi_ep_bind(ep->rx, &ep->cq->fid, FI_RECV), ret, goto err); - - tx_attr = *prov->info->tx_attr; - tx_attr.caps |= FI_TAGGED; - - ATL_OFI_CALL(fi_tx_context(prov->sep, ep_idx, &tx_attr, &ep->tx, NULL), ret, goto err); - - ATL_OFI_CALL(fi_ep_bind(ep->tx, &ep->cq->fid, FI_SEND), ret, goto err); - - fi_enable(ep->rx); - fi_enable(ep->tx); - } - else { - struct fid_ep* endpoint; - - ATL_OFI_CALL(fi_endpoint(prov->domain, prov->info, &endpoint, NULL), ret, goto err); - - ep->tx = ep->rx = endpoint; - - ATL_OFI_CALL(fi_ep_bind(endpoint, &ep->cq->fid, FI_SEND | FI_RECV), ret, goto err); - - ATL_OFI_CALL(fi_ep_bind(endpoint, &prov->av->fid, 0), ret, goto err); - - fi_enable(endpoint); - } - - return ATL_STATUS_SUCCESS; - -err: - atl_ofi_prov_ep_destroy(prov, ep); - return ATL_STATUS_FAILURE; -} - -static atl_status_t atl_ofi_try_to_drain_cq_err(struct fid_cq* cq) { - struct fi_cq_err_entry err_entry; - int ret = fi_cq_readerr(cq, &err_entry, 0); - if (ret != 1) { - LOG_DEBUG("unable to fi_cq_readerr"); - return ATL_STATUS_FAILURE; - } - else { - if (err_entry.err != FI_ENOMSG && err_entry.err != FI_ECANCELED && - err_entry.err != FI_ETRUNC) { - LOG_ERROR("fi_cq_readerr: err: ", - err_entry.err, - ", prov_err: ", - fi_cq_strerror(cq, err_entry.prov_errno, err_entry.err_data, NULL, 0), - "(", - err_entry.prov_errno, - ")"); - return ATL_STATUS_FAILURE; - } - return ATL_STATUS_SUCCESS; - } -} - -static int atl_ofi_try_to_drain_cq(struct fid_cq* cq) { - int ret = -FI_EAGAIN, i; - double time = 0; - clock_t start, end; - struct fi_cq_tagged_entry entries[ATL_OFI_CQ_BUNCH_SIZE]; - - while (time < ATL_OFI_WAIT_SEC) { - start = clock(); - for (i = 0; i < ATL_OFI_CQ_READ_ITERS; i++) { - ret = fi_cq_read(cq, entries, ATL_OFI_CQ_BUNCH_SIZE); - - if (ret < 0 && ret != -FI_EAGAIN) { - atl_ofi_try_to_drain_cq_err(cq); - return ret; - } - - if (ret > 0) - return ret; - } - end = clock(); - time += (double)(end - start) / CLOCKS_PER_SEC; - } - - return ret; -} - -static void atl_ofi_reset(atl_ctx_t* ctx) { - atl_ofi_ctx_t* ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); - - int again = 1; - size_t prov_idx, ep_idx; - int recv_buf_len = sizeof(char); - char* recv_buf; - struct fi_context fi_ctx; - recv_buf = (char*)malloc(recv_buf_len); - for (prov_idx = 0; prov_idx < ofi_ctx->prov_count; prov_idx++) { - atl_ofi_prov_t* prov = &(ofi_ctx->provs[prov_idx]); - - for (ep_idx = 0; ep_idx < ctx->ep_count; ep_idx++) { - atl_ofi_prov_ep_t* ep = &(prov->eps[ep_idx]); - - /* complete active sends and receives */ - while (atl_ofi_try_to_drain_cq(ep->cq) != -FI_EAGAIN) { - } - - /* try to complete active incoming sends */ - while (again) { - again = 0; - /* post recv to complete incoming send */ - while (fi_trecv(ep->rx, - recv_buf, - recv_buf_len, - NULL, - FI_ADDR_UNSPEC, - 0, - UINTMAX_MAX, - &fi_ctx) == -FI_EAGAIN) { - } - - /* wait until recv will be completed or finished by timeout */ - while (atl_ofi_try_to_drain_cq(ep->cq) != -FI_EAGAIN) { - /* something is completed -> send queue not empty */ - again = 1; - } - } - - /* nothing to recv -> cancel last recv */ - fi_cancel(&ep->rx->fid, &fi_ctx); - - atl_ofi_wait_cancel_cq(ep->cq); - } - } - - free(recv_buf); -} - -static atl_status_t atl_ofi_adjust_env(const atl_attr_t& attr) { - char* prov_env = getenv("FI_PROVIDER"); - - if (prov_env && strlen(prov_env)) { - CCL_THROW_IF_NOT(strlen(prov_env) < sizeof(global_data.prov_env_copy), - "too long FI_PROVIDER value, max expected length ", - sizeof(global_data.prov_env_copy)); - memcpy(global_data.prov_env_copy, prov_env, strlen(prov_env)); - } - - if (attr.in.enable_shm) { - /* add shm provider in the list of allowed providers */ - if (prov_env && !strstr(prov_env, ATL_OFI_SHM_PROV_NAME)) { - /* whether single provider will be in the final env variable */ - int single_prov = (strlen(prov_env) == 0) ? 1 : 0; - - size_t prov_env_new_size = strlen(prov_env) + strlen(ATL_OFI_SHM_PROV_NAME) + - (single_prov ? 0 : 1) + /* for delimeter */ - 1; /* for terminating null symbol */ - - char* prov_env_new = (char*)calloc(prov_env_new_size, sizeof(char)); - if (prov_env_new == NULL) { - LOG_ERROR("memory allocaion failed"); - return ATL_STATUS_FAILURE; - } - - if (single_prov) - snprintf(prov_env_new, prov_env_new_size, "%s", ATL_OFI_SHM_PROV_NAME); - else { - snprintf(prov_env_new, prov_env_new_size, "%s,%s", prov_env, ATL_OFI_SHM_PROV_NAME); - } - - LOG_INFO("atl-ofi-shm is requested, modify FI_PROVIDER: old value: ", - prov_env, - ", new value: ", - prov_env_new); - - setenv("FI_PROVIDER", prov_env_new, 1); - - free(prov_env_new); - } - } - - return ATL_STATUS_SUCCESS; -} - -static atl_status_t atl_ofi_set_env(const atl_attr_t& attr) { - if (global_data.is_env_inited) { - return ATL_STATUS_SUCCESS; - } - - setenv("FI_PSM2_DELAY", "0", 0); - setenv("FI_PSM2_TIMEOUT", "0", 0); - setenv("FI_PSM2_LOCK_LEVEL", "1", 0); - setenv("FI_PSM2_NAME_SERVER", "0", 0); - setenv("HFI_NO_CPUAFFINITY", "1", 0); - setenv("PSM2_MULTI_EP", "1", 0); - - setenv("FI_PSM3_DELAY", "0", 0); - setenv("FI_PSM3_TIMEOUT", "0", 0); - setenv("FI_PSM3_LOCK_LEVEL", "1", 0); - setenv("FI_PSM3_NAME_SERVER", "0", 0); - setenv("PSM3_NO_CPUAFFINITY", "1", 0); - setenv("PSM3_RDMA", "2", 0); - setenv("PSM3_MR_CACHE_MODE", "0", 0); //TODO temporary - setenv("PSM3_MULTI_EP", "1", 0); - if (attr.in.mnic_type == ATL_MNIC_NONE) - setenv("PSM3_NIC", "any", 0); - - char* hydra_uuid_env = getenv("I_MPI_HYDRA_UUID"); - if (hydra_uuid_env) { - setenv("FI_PSM2_UUID", hydra_uuid_env, 0); - setenv("FI_PSM3_UUID", hydra_uuid_env, 0); - } - - setenv("FI_OFI_RXM_USE_HASH", "0", 0); - setenv("FI_OFI_RXM_RX_SIZE", "8192", 0); - setenv("FI_OFI_RXM_TX_SIZE", "8192", 0); - setenv("FI_OFI_RXM_MSG_RX_SIZE", "128", 0); - setenv("FI_OFI_RXM_MSG_TX_SIZE", "128", 0); - - setenv("FI_SHM_TX_SIZE", "8192", 0); - setenv("FI_SHM_RX_SIZE", "8192", 0); - -#ifdef CCL_ENABLE_SYCL - setenv("FI_SHM_DISABLE_CMA", "1", 0); -#endif /* CCL_ENABLE_SYCL */ - - atl_ofi_adjust_env(attr); - - /* - load libfabric symbols into global namespace - to workaround issue with undefined symbols - in case of out-of-tree providers, like OFI/PSM3 - */ - global_data.dlhandle = dlopen("libfabric.so", RTLD_GLOBAL | RTLD_NOW); - if (global_data.dlhandle == NULL) { - CCL_THROW("dlopen (libfabric.so): ", dlerror()); - } - - global_data.is_env_inited = 1; - - return ATL_STATUS_SUCCESS; -} - -static atl_status_t atl_ofi_finalize(atl_ctx_t* ctx) { - int ret = 0; - size_t idx; - - atl_ofi_ctx_t* ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); - - global_data.ctx_count--; - if (ctx->coord.global_idx == 0) { - LOG_INFO("finalize atl-ofi ctx, remaining ctx_count ", global_data.ctx_count); - } - - for (idx = 0; idx < ofi_ctx->prov_count; idx++) { - atl_ofi_prov_t* prov = &ofi_ctx->provs[idx]; - atl_ofi_prov_destroy(ctx, prov); - } - - for (idx = 0; idx < ctx->ep_count; idx++) { - atl_ofi_ep_t* ofi_ep = container_of(ctx->eps[idx], atl_ofi_ep_t, ep); - free(ofi_ep); - } - - if (global_data.ctx_count == 0) { - if (global_data.dlhandle) { - dlclose(global_data.dlhandle); - } - - if (hwloc_is_initialized()) { - CCL_THROW_IF_NOT(hwloc_finalize() == HWLOC_SUCCESS, "failed to finalize hwloc"); - } - - if (ctx->coord.global_idx == 0) { - LOG_INFO("finalized last atl-ofi ctx"); - } - } - - free(ctx->eps); - free(ofi_ctx); - - return RET2ATL(ret); -} - -static atl_status_t atl_ofi_mr_reg(atl_ctx_t* ctx, const void* buf, size_t len, atl_mr_t** mr) { - int ret; - atl_ofi_ctx_t* ofi_ctx; - ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); - atl_ofi_prov_t* prov = &(ofi_ctx->provs[0]); - - atl_ofi_mr_t* ofi_mr; - ofi_mr = (atl_ofi_mr_t*)calloc(1, sizeof(atl_ofi_mr_t)); - if (!ofi_mr) - return ATL_STATUS_FAILURE; - - ret = fi_mr_reg(prov->domain, - buf, - len, - FI_SEND | FI_RECV | FI_READ | FI_WRITE | FI_REMOTE_READ | FI_REMOTE_WRITE, - 0, - 0, - 0, - &ofi_mr->fi_mr, - NULL); - if (ret) - goto mr_reg_err; - - ofi_mr->mr.buf = (void*)buf; - ofi_mr->mr.len = len; - ofi_mr->mr.remote_key = (uintptr_t)fi_mr_key(ofi_mr->fi_mr); - ofi_mr->mr.local_key = (uintptr_t)fi_mr_desc(ofi_mr->fi_mr); - - *mr = &ofi_mr->mr; - return ATL_STATUS_SUCCESS; - -mr_reg_err: - free(ofi_mr); - return ATL_STATUS_FAILURE; -} - -static atl_status_t atl_ofi_mr_dereg(atl_ctx_t* ctx, atl_mr_t* mr) { - atl_ofi_mr_t* ofi_mr; - ofi_mr = container_of(mr, atl_ofi_mr_t, mr); - int ret = fi_close(&ofi_mr->fi_mr->fid); - free(ofi_mr); - return RET2ATL(ret); -} - -static atl_status_t atl_ofi_ep_wait(atl_ep_t* ep, atl_req_t* req); - -static atl_status_t atl_ofi_ep_send(atl_ep_t* ep, - const void* buf, - size_t len, - int dst_proc_idx, - uint64_t tag, - atl_req_t* req) { - ssize_t ret; - - atl_ofi_prov_t* prov; - atl_ofi_prov_ep_t* prov_ep; - atl_ofi_req_t* ofi_req; - - prov = atl_ofi_get_prov(ep, dst_proc_idx, len); - prov_ep = &(prov->eps[ep->idx]); - ofi_req = ((atl_ofi_req_t*)req->internal); - - req->tag = tag; - req->remote_proc_idx = dst_proc_idx; - ofi_req->comp_state = ATL_OFI_COMP_POSTED; - - ofi_req->prov_ep = prov_ep; - ofi_req->fi_ep = prov_ep->tx; - - ATL_OFI_RETRY(fi_tsend(prov_ep->tx, - buf, - len, - NULL, - atl_ofi_get_addr(ep->ctx, prov, dst_proc_idx, ep->idx), - tag, - &ofi_req->fi_ctx), - ep, - ret); - return RET2ATL(ret); -} - -static atl_status_t atl_ofi_ep_recv(atl_ep_t* ep, - void* buf, - size_t len, - int src_proc_idx, - uint64_t tag, - atl_req_t* req) { - ssize_t ret; - - atl_ofi_prov_t* prov; - atl_ofi_prov_ep_t* prov_ep; - atl_ofi_req_t* ofi_req; - - prov = atl_ofi_get_prov(ep, src_proc_idx, len); - prov_ep = &(prov->eps[ep->idx]); - ofi_req = ((atl_ofi_req_t*)req->internal); - - req->tag = tag; - req->remote_proc_idx = src_proc_idx; - ofi_req->comp_state = ATL_OFI_COMP_POSTED; - - ofi_req->prov_ep = prov_ep; - ofi_req->fi_ep = prov_ep->rx; - - ATL_OFI_RETRY(fi_trecv(prov_ep->rx, - buf, - len, - NULL, - atl_ofi_get_addr(ep->ctx, prov, src_proc_idx, ep->idx), - tag, - 0, - &ofi_req->fi_ctx), - ep, - ret); - return RET2ATL(ret); -} - -static atl_status_t atl_ofi_ep_probe(atl_ep_t* ep, - int src_proc_idx, - uint64_t tag, - int* found, - size_t* recv_len) { - atl_status_t ret; - atl_ofi_req_t reqs[ATL_OFI_MAX_PROV_COUNT]; - struct fi_msg_tagged msgs[ATL_OFI_MAX_PROV_COUNT]; - int flag, len; - ssize_t ofi_ret; - size_t idx; - int do_poll; - - atl_ofi_ctx_t* ofi_ctx; - - ret = ATL_STATUS_SUCCESS; - flag = 0; - len = 0; - ofi_ret = FI_SUCCESS; - do_poll = 1; - - ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); - - for (idx = 0; idx < ofi_ctx->prov_count; idx++) { - atl_ofi_prov_t* prov; - atl_ofi_prov_ep_t* prov_ep; - atl_ofi_req_t* req; - struct fi_msg_tagged* msg; - - prov = &(ofi_ctx->provs[idx]); - prov_ep = &(prov->eps[ep->idx]); - req = &(reqs[idx]); - msg = &(msgs[idx]); - - if (prov->is_shm && - ((src_proc_idx < prov->first_proc_idx) || - (src_proc_idx >= (prov->first_proc_idx + ep->ctx->coord.local_count)))) { - req->prov_ep = NULL; - continue; - } - - req->comp_state = ATL_OFI_COMP_PEEK_STARTED; - req->prov_ep = prov_ep; - req->fi_ep = prov_ep->rx; - - msg->msg_iov = NULL; - msg->desc = NULL; - msg->iov_count = 0; - msg->addr = atl_ofi_get_addr(ep->ctx, prov, src_proc_idx, ep->idx); - msg->tag = tag; - msg->ignore = 0; - msg->context = &(req->fi_ctx); - msg->data = 0; - - ATL_OFI_RETRY(fi_trecvmsg(prov_ep->rx, msg, FI_PEEK | FI_COMPLETION), ep, ofi_ret); - } - - do { - ret = atl_ofi_ep_poll(ep); - if (ret != ATL_STATUS_SUCCESS) - return ret; - - for (idx = 0; idx < ofi_ctx->prov_count; idx++) { - atl_ofi_req_t* req; - req = &(reqs[idx]); - - if (!req->prov_ep) - continue; - - if (req->comp_state != ATL_OFI_COMP_PEEK_STARTED) { - do_poll = 0; - - if (req->comp_state == ATL_OFI_COMP_PEEK_FOUND) { - flag = 1; - len = req->recv_len; - req->prov_ep = NULL; - } - else if (req->comp_state == ATL_OFI_COMP_PEEK_NOT_FOUND) { - req->prov_ep = NULL; - } - else { - CCL_THROW("unexpected completion state ", req->comp_state); - } - - break; - } - } - } while (do_poll); - - for (idx = 0; idx < ofi_ctx->prov_count; idx++) { - atl_ofi_req_t* req; - req = &(reqs[idx]); - - if (!req->prov_ep) - continue; - - if (fi_cancel(&req->fi_ep->fid, &req->fi_ctx) == 0) { - atl_ofi_wait_cancel_cq(req->prov_ep->cq); - } - } - - if (found) - *found = flag; - if (recv_len) - *recv_len = len; - - return RET2ATL(ofi_ret); -} - -static atl_status_t atl_ofi_ep_allgatherv(atl_ep_t* ep, - const void* send_buf, - size_t send_len, - void* recv_buf, - const int* recv_lens, - const int* offsets, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -static atl_status_t atl_ofi_ep_allreduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t count, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -static atl_status_t atl_ofi_ep_alltoall(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t len, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -static atl_status_t atl_ofi_ep_alltoallv(atl_ep_t* ep, - const void* send_buf, - const int* send_lens, - const int* send_offsets, - void* recv_buf, - const int* recv_lens, - const int* recv_offsets, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -static atl_status_t atl_ofi_ep_barrier(atl_ep_t* ep, atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -static atl_status_t atl_ofi_ep_bcast(atl_ep_t* ep, - void* buf, - size_t len, - int root, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -static atl_status_t atl_ofi_ep_reduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t count, - int root, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -static atl_status_t atl_ofi_ep_reduce_scatter(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t recv_count, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -static atl_status_t atl_ofi_ep_read(atl_ep_t* ep, - void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) { - ssize_t ret; - - atl_ofi_prov_t* prov; - atl_ofi_prov_ep_t* prov_ep; - atl_ofi_req_t* ofi_req; - - prov = atl_ofi_get_prov(ep, dst_proc_idx, len); - prov_ep = &(prov->eps[ep->idx]); - ofi_req = ((atl_ofi_req_t*)req->internal); - - req->tag = 0; - req->remote_proc_idx = dst_proc_idx; - ofi_req->comp_state = ATL_OFI_COMP_POSTED; - - ofi_req->prov_ep = prov_ep; - ofi_req->fi_ep = prov_ep->tx; - - ATL_OFI_RETRY(fi_read(prov_ep->tx, - buf, - len, - (void*)mr->local_key, - atl_ofi_get_addr(ep->ctx, prov, dst_proc_idx, ep->idx), - addr, - remote_key, - &ofi_req->fi_ctx), - ep, - ret); - return RET2ATL(ret); -} - -static atl_status_t atl_ofi_ep_write(atl_ep_t* ep, - const void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) { - ssize_t ret; - - atl_ofi_prov_t* prov; - atl_ofi_prov_ep_t* prov_ep; - atl_ofi_req_t* ofi_req; - - prov = atl_ofi_get_prov(ep, dst_proc_idx, len); - prov_ep = &(prov->eps[ep->idx]); - ofi_req = ((atl_ofi_req_t*)req->internal); - - req->tag = 0; - req->remote_proc_idx = dst_proc_idx; - ofi_req->comp_state = ATL_OFI_COMP_POSTED; - - ofi_req->prov_ep = prov_ep; - ofi_req->fi_ep = prov_ep->tx; - - ATL_OFI_RETRY(fi_write(prov_ep->tx, - buf, - len, - (void*)mr->local_key, - atl_ofi_get_addr(ep->ctx, prov, dst_proc_idx, ep->idx), - addr, - remote_key, - &ofi_req->fi_ctx), - ep, - ret); - return RET2ATL(ret); -} - -static atl_status_t atl_ofi_ep_wait(atl_ep_t* ep, atl_req_t* req) { - atl_status_t ret; - atl_ofi_req_t* ofi_req; - - ret = ATL_STATUS_SUCCESS; - ofi_req = ((atl_ofi_req_t*)req->internal); - - while ((ofi_req->comp_state != ATL_OFI_COMP_COMPLETED) && - ((ret = atl_ofi_ep_poll(ep)) == ATL_STATUS_SUCCESS)) - ; - - return ret; -} - -static atl_status_t atl_ofi_ep_wait_all(atl_ep_t* ep, atl_req_t* reqs, size_t count) { - size_t i; - atl_status_t ret; - - for (i = 0; i < count; i++) { - ret = atl_ofi_ep_wait(ep, &reqs[i]); - if (ret != ATL_STATUS_SUCCESS) - return ret; - } - - return ATL_STATUS_SUCCESS; -} - -static atl_status_t atl_ofi_ep_cancel(atl_ep_t* ep, atl_req_t* req) { - int ret; - atl_ofi_req_t* ofi_req; - - ret = ATL_STATUS_SUCCESS; - ofi_req = ((atl_ofi_req_t*)req->internal); - - ret = fi_cancel(&ofi_req->fi_ep->fid, &ofi_req->fi_ctx); - if (ret == 0) { - return RET2ATL(atl_ofi_wait_cancel_cq(ofi_req->prov_ep->cq)); - } - - return ATL_STATUS_SUCCESS; -} - -static inline atl_status_t atl_ofi_ep_progress(atl_ep_t* ep, atl_ofi_req_t* req /* unused */) { - ssize_t ret; - size_t idx; - struct fi_cq_tagged_entry entries[ATL_OFI_CQ_BUNCH_SIZE]; - atl_ofi_ep_t* ofi_ep = container_of(ep, atl_ofi_ep_t, ep); - atl_ofi_ctx_t* ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); - size_t ep_idx = ep->idx; - - /* ensure progress for all active providers */ - for (idx = 0; idx < ofi_ep->active_prov_count; idx++) { - atl_ofi_prov_ep_t* prov_ep; - prov_ep = &(ofi_ctx->provs[ofi_ep->active_prov_idxs[idx]].eps[ep_idx]); - do { - ret = fi_cq_read(prov_ep->cq, entries, ATL_OFI_CQ_BUNCH_SIZE); - if (ret > 0) - atl_ofi_process_comps(entries, ret); - else if (ret == -FI_EAGAIN) - break; - else - return atl_ofi_prov_ep_handle_cq_err(prov_ep); - } while (ret > 0); - } - - return ATL_STATUS_SUCCESS; -} - -static inline atl_status_t atl_ofi_ep_poll(atl_ep_t* ep) { - atl_ofi_ctx_t* ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); - if (ofi_ctx->progress_mode == ATL_PROGRESS_POLL) { - atl_ofi_ep_progress(ep, NULL /* ofi_req */); - } - return ATL_STATUS_SUCCESS; -} - -static atl_status_t atl_ofi_ep_check(atl_ep_t* ep, int* is_completed, atl_req_t* req) { - CCL_THROW_IF_NOT(is_completed); - - atl_status_t status; - atl_ofi_req_t* ofi_req; - atl_ofi_ctx_t* ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); - - status = ATL_STATUS_SUCCESS; - ofi_req = ((atl_ofi_req_t*)req->internal); - - *is_completed = (ofi_req->comp_state == ATL_OFI_COMP_COMPLETED); - if (*is_completed) { - return ATL_STATUS_SUCCESS; - } - - if (ofi_ctx->progress_mode == ATL_PROGRESS_CHECK) { - status = atl_ofi_ep_progress(ep, ofi_req); - *is_completed = (ofi_req->comp_state == ATL_OFI_COMP_COMPLETED); - } - - return status; -} - -static atl_ops_t atl_ofi_ops = { - .finalize = atl_ofi_finalize, -}; - -static atl_mr_ops_t atl_ofi_mr_ops = { - .mr_reg = atl_ofi_mr_reg, - .mr_dereg = atl_ofi_mr_dereg, -}; - -static atl_p2p_ops_t atl_ofi_ep_p2p_ops = { - .send = atl_ofi_ep_send, - .recv = atl_ofi_ep_recv, - .probe = atl_ofi_ep_probe, -}; - -static atl_coll_ops_t atl_ofi_ep_coll_ops = { - .allgatherv = atl_ofi_ep_allgatherv, - .allreduce = atl_ofi_ep_allreduce, - .alltoall = atl_ofi_ep_alltoall, - .alltoallv = atl_ofi_ep_alltoallv, - .barrier = atl_ofi_ep_barrier, - .bcast = atl_ofi_ep_bcast, - .reduce = atl_ofi_ep_reduce, - .reduce_scatter = atl_ofi_ep_reduce_scatter, -}; - -static atl_rma_ops_t atl_ofi_ep_rma_ops = { - .read = atl_ofi_ep_read, - .write = atl_ofi_ep_write, -}; - -static atl_comp_ops_t atl_ofi_ep_comp_ops = { .wait = atl_ofi_ep_wait, - .wait_all = atl_ofi_ep_wait_all, - .cancel = atl_ofi_ep_cancel, - .poll = atl_ofi_ep_poll, - .check = atl_ofi_ep_check }; - -static atl_status_t atl_ofi_get_prov_list(atl_ctx_t* ctx, - const char* prov_name, - struct fi_info* base_hints, - struct fi_info** out_prov_list) { - struct fi_info* hints = NULL; - struct fi_info* prov_list = NULL; - ssize_t ret = 0; - int fi_version = FI_VERSION(FI_MAJOR_VERSION, FI_MINOR_VERSION); - const char* prov_name_str = (prov_name) ? prov_name : ""; - - hints = fi_dupinfo(base_hints); - if (!hints) { - LOG_ERROR("fi_dupinfo error"); - goto err; - } - - *out_prov_list = NULL; - - LOG_DEBUG("request providers with name: ", prov_name_str); - - hints->fabric_attr->prov_name = (prov_name) ? strdup(prov_name) : NULL; - - ret = fi_getinfo(fi_version, NULL, NULL, 0ULL, hints, &prov_list); - if (ret || !prov_list) { - LOG_ERROR("fi_getinfo error: ret ", ret, ", providers ", (void*)prov_list); - goto err; - } - - if (prov_list->domain_attr->max_ep_tx_ctx > 1) { - hints->ep_attr->tx_ctx_cnt = ctx->ep_count; - hints->ep_attr->rx_ctx_cnt = ctx->ep_count; - } - else { - hints->ep_attr->tx_ctx_cnt = 1; - hints->ep_attr->rx_ctx_cnt = 1; - } - - fi_freeinfo(prov_list); - prov_list = NULL; - - ret = fi_getinfo(fi_version, NULL, NULL, 0ULL, hints, &prov_list); - if (ret || !prov_list) { - LOG_ERROR("fi_getinfo error, prov_name ", prov_name_str); - goto err; - } - - fi_freeinfo(hints); - hints = NULL; - - *out_prov_list = prov_list; - return ATL_STATUS_SUCCESS; - -err: - LOG_ERROR("can't create providers for name ", prov_name_str); - return ATL_STATUS_FAILURE; -} - -static atl_status_t atl_ofi_prov_init(atl_ctx_t* ctx, - struct fi_info* info, - atl_ofi_prov_t* prov, - ipmi* pmi) { - struct fi_av_attr av_attr; - size_t ep_idx = 0; - ssize_t ret = 0; - - memset(&av_attr, 0, sizeof(av_attr)); - - atl_ofi_ctx_t* ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); - - if (ctx->coord.global_idx == 0) { - LOG_INFO("provider: ", info->fabric_attr->prov_name); - LOG_INFO(" nic: ", atl_ofi_get_nic_name(info)); - LOG_INFO(" mr_mode: ", info->domain_attr->mr_mode); - LOG_INFO(" threading: ", info->domain_attr->threading); - LOG_INFO(" tx_ctx_cnt: ", info->domain_attr->tx_ctx_cnt); - LOG_INFO(" max_ep_tx_ctx: ", info->domain_attr->max_ep_tx_ctx); - LOG_INFO(" max_msg_size: ", info->ep_attr->max_msg_size); - } - - prov->info = fi_dupinfo(info); - - if (!prov->info) { - LOG_ERROR("fi_dupinfo error"); - goto err; - } - - prov->max_msg_size = info->ep_attr->max_msg_size; - - ATL_OFI_CALL(fi_fabric(info->fabric_attr, &prov->fabric, NULL), ret, goto err); - - ATL_OFI_CALL(fi_domain(prov->fabric, info, &prov->domain, NULL), ret, goto err); - - av_attr.type = FI_AV_TABLE; - av_attr.rx_ctx_bits = prov->rx_ctx_bits = (int)ceil(log2(prov->info->ep_attr->rx_ctx_cnt)); - - ATL_OFI_CALL(fi_av_open(prov->domain, &av_attr, &prov->av, NULL), ret, goto err); - - if (info->domain_attr->max_ep_tx_ctx > 1) { - ATL_OFI_CALL(fi_scalable_ep(prov->domain, info, &prov->sep, NULL), ret, goto err); - ATL_OFI_CALL(fi_scalable_ep_bind(prov->sep, &prov->av->fid, 0), ret, goto err); - } - - prov->eps = (atl_ofi_prov_ep_t*)calloc(1, sizeof(atl_ofi_prov_ep_t) * ctx->ep_count); - if (!prov->eps) { - LOG_ERROR("can't allocate prov->eps"); - goto err; - } - - for (ep_idx = 0; ep_idx < ctx->ep_count; ep_idx++) { - ret = atl_ofi_prov_ep_init(prov, ep_idx); - if (ret) { - LOG_ERROR("atl_ofi_prov_ep_init error"); - goto err; - } - } - - if (prov->sep) { - fi_enable(prov->sep); - } - - /* TODO: make separate function to be called on CCL comm creation */ - ret = atl_ofi_prov_eps_connect(ofi_ctx, prov->idx, pmi); - if (ret) { - LOG_ERROR("atl_ofi_prov_eps_connect error, prov_idx ", prov->idx); - goto err; - } - - return ATL_STATUS_SUCCESS; - -err: - LOG_ERROR("can't init provider ", atl_ofi_get_nic_name(info)); - return ATL_STATUS_FAILURE; -} - -/* determine if NIC has already been included in others */ -static int atl_ofi_nic_already_used(const struct fi_info* prov, - struct fi_info** others, - size_t nic_count) { - for (size_t i = 0; i < nic_count; i++) { - if (prov->nic->bus_attr->bus_type == FI_BUS_PCI && - others[i]->nic->bus_attr->bus_type == FI_BUS_PCI) { - struct fi_pci_attr pci = prov->nic->bus_attr->attr.pci; - struct fi_pci_attr other_pci = others[i]->nic->bus_attr->attr.pci; - LOG_DEBUG("compare nic ", - prov->fabric_attr->prov_name, - " pci ", - (int)pci.domain_id, - ":", - (int)pci.bus_id, - ":", - (int)pci.device_id, - ":", - (int)pci.function_id, - " with nic ", - others[i]->fabric_attr->prov_name, - " pci ", - (int)other_pci.domain_id, - ":", - (int)other_pci.bus_id, - ":", - (int)other_pci.device_id, - ":", - (int)other_pci.function_id); - if (pci.domain_id == other_pci.domain_id && pci.bus_id == other_pci.bus_id && - pci.device_id == other_pci.device_id && pci.function_id == other_pci.function_id) - return 1; - } - else { - LOG_DEBUG("compare nic ", - atl_ofi_get_nic_name(prov), - " with nic ", - atl_ofi_get_nic_name(others[i])); - if (!strcmp(prov->domain_attr->name, others[i]->domain_attr->name)) - return 1; - } - } - return 0; -} - -/* return true if the NIC is bound to the same socket as calling process */ -static int atl_ofi_is_nic_local(struct fi_info* info) { - if (info->nic->bus_attr->bus_type == FI_BUS_PCI) { - struct fi_pci_attr pci = info->nic->bus_attr->attr.pci; - return hwloc_is_dev_close_by_pci(pci.domain_id, pci.bus_id, pci.device_id, pci.function_id); - } - return 0; -} - -static atl_status_t atl_ofi_open_nw_provs(atl_ctx_t* ctx, struct fi_info* base_hints, ipmi* pmi) { - struct fi_info* prov_list = NULL; - size_t idx = 0, prov_idx = 0; - char* prov_name = NULL; - atl_ofi_prov_t* prov = NULL; - - atl_ofi_ctx_t* ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); - - if (strlen(global_data.prov_env_copy) && !strstr(global_data.prov_env_copy, ",")) - prov_name = global_data.prov_env_copy; - else - prov_name = NULL; - - ATL_CALL(atl_ofi_get_prov_list(ctx, prov_name, base_hints, &prov_list), goto err); - - if (ofi_ctx->mnic_type == ATL_MNIC_NONE) { - prov_idx = ofi_ctx->nw_prov_first_idx; - prov = &ofi_ctx->provs[prov_idx]; - prov->idx = prov_idx; - prov->is_shm = 0; - ATL_CALL(atl_ofi_prov_init(ctx, prov_list, prov, pmi), goto err); - ofi_ctx->nw_prov_count++; - } - else { - /* calculate the number of NICs */ - struct fi_info* prov_iter = prov_list; - struct fi_info* filtered_prov_list[ATL_OFI_MAX_NW_PROV_COUNT]; - size_t nic_count = 0; - struct fid_nic* nic = NULL; - - while (prov_iter && (nic_count < ofi_ctx->mnic_count)) { - nic = prov_iter->nic; - if (nic) { - LOG_DEBUG("check nic ", atl_ofi_get_nic_name(prov_iter)); - if (!atl_ofi_nic_already_used(prov_iter, filtered_prov_list, nic_count)) { - int is_local = atl_ofi_is_nic_local(prov_iter); - LOG_DEBUG("nic ", atl_ofi_get_nic_name(prov_iter), ", is_local ", is_local); - - if (ofi_ctx->mnic_type == ATL_MNIC_GLOBAL || - (ofi_ctx->mnic_type == ATL_MNIC_LOCAL && is_local)) { - LOG_INFO("found suitable nic ", atl_ofi_get_nic_name(prov_iter)); - filtered_prov_list[nic_count] = fi_dupinfo(prov_iter); - nic_count++; - } - } - else { - LOG_DEBUG("nic ", atl_ofi_get_nic_name(prov_iter), " already used"); - } - } - prov_iter = prov_iter->next; - } - - if (nic_count == 0) { - LOG_INFO("can not find nic(s) according to mnic_type ", - ofi_ctx->mnic_type, - ", use first available nic ", - atl_ofi_get_nic_name(prov_list)); - ofi_ctx->nw_prov_count = 1; - filtered_prov_list[0] = fi_dupinfo(prov_list); - } - else { - LOG_INFO("found ", nic_count, " nic(s) according to mnic_type ", ofi_ctx->mnic_type); - ofi_ctx->nw_prov_count = nic_count; - } - - for (idx = 0; idx < ofi_ctx->nw_prov_count; idx++) { - prov_idx = ofi_ctx->nw_prov_first_idx + idx; - prov = &ofi_ctx->provs[prov_idx]; - prov->idx = prov_idx; - prov->is_shm = 0; - ATL_CALL(atl_ofi_prov_init(ctx, filtered_prov_list[idx], prov, pmi), goto err); - } - - for (idx = 0; idx < ofi_ctx->nw_prov_count; idx++) { - fi_freeinfo(filtered_prov_list[idx]); - } - } - ofi_ctx->prov_count += ofi_ctx->nw_prov_count; - - fi_freeinfo(prov_list); - - return ATL_STATUS_SUCCESS; - -err: - LOG_ERROR("can not open network providers"); - return ATL_STATUS_FAILURE; -} - -static atl_status_t atl_ofi_init(int* argc, - char*** argv, - atl_attr_t* attr, - atl_ctx_t** out_ctx, - const char* main_addr, - ipmi* pmi) { - struct fi_info *prov_list = NULL, *base_hints = NULL, *prov_hints = NULL; - int fi_version; - ssize_t ret = 0; - size_t idx = 0, ep_idx = 0, prov_idx = 0; - char* prov_name = NULL; - atl_ofi_prov_t* prov = NULL; - char *max_retry_count_env = NULL, *progress_mode_env = NULL; - int open_nw_provs = 1; - int enable_shm = 0; - - CCL_THROW_IF_NOT((sizeof(atl_ofi_req_t) <= sizeof(atl_req_t) - offsetof(atl_req_t, internal)), - "unexpected offset: atl_ofi_request size ", - sizeof(atl_ofi_req_t), - ", atl_request size ", - sizeof(atl_req_t), - ", expected offset ", - offsetof(atl_req_t, internal)); - - if (global_data.ctx_count == 0) { - ret = atl_ofi_set_env(*attr); - if (ret != ATL_STATUS_SUCCESS) { - LOG_ERROR("atl_ofi_set_env error"); - return ATL_STATUS_FAILURE; - } - } - global_data.ctx_count++; - - atl_ofi_ctx_t* ofi_ctx; - ofi_ctx = (atl_ofi_ctx_t*)calloc(1, sizeof(atl_ofi_ctx_t)); - if (!ofi_ctx) - return ATL_STATUS_FAILURE; - - atl_ctx_t* ctx = &(ofi_ctx->ctx); - - ctx->ops = &atl_ofi_ops; - ctx->mr_ops = &atl_ofi_mr_ops; - ctx->ep_count = attr->in.ep_count; - ctx->eps = (atl_ep**)calloc(1, sizeof(void*) * attr->in.ep_count); - if (!ctx->eps) - goto err; - - ctx->coord.global_count = pmi->get_size(); - ctx->coord.global_idx = pmi->get_rank(); - - ret = atl_ofi_get_local_proc_coord(ofi_ctx, pmi); - if (ret) { - LOG_ERROR("atl_ofi_get_local_proc_coord error"); - goto err; - } - - atl_proc_coord_t* coord; - coord = &(ctx->coord); - - base_hints = fi_allocinfo(); - if (!base_hints) { - LOG_ERROR("can't alloc base_hints"); - goto err; - } - - base_hints->mode = FI_CONTEXT; - base_hints->ep_attr->type = FI_EP_RDM; - base_hints->domain_attr->resource_mgmt = FI_RM_ENABLED; - base_hints->domain_attr->control_progress = FI_PROGRESS_MANUAL; - base_hints->domain_attr->data_progress = FI_PROGRESS_MANUAL; - base_hints->caps = FI_TAGGED; - base_hints->caps |= FI_DIRECTED_RECV; - - fi_version = FI_VERSION(FI_MAJOR_VERSION, FI_MINOR_VERSION); - - if (coord->global_idx == 0) - LOG_INFO("libfabric version: ", fi_tostr("1" /* ignored */, FI_TYPE_VERSION)); - - char* prov_env; - prov_env = getenv("FI_PROVIDER"); - if (prov_env && !strcmp(prov_env, ATL_OFI_SHM_PROV_NAME)) { - if (coord->global_count != coord->local_count) { - LOG_ERROR("shm provider is requested as primary provider but global_count (", - coord->global_count, - ") != local_count (", - coord->local_count, - ")"); - goto err; - } - - if (!attr->in.enable_shm) { - LOG_ERROR( - "shm provider is requested through FI_PROVIDER but not requested from CCL level"); - goto err; - } - } - - atl_ofi_print_coord(coord); - - enable_shm = attr->in.enable_shm; - if (enable_shm) { - prov_hints = fi_dupinfo(base_hints); - prov_hints->fabric_attr->prov_name = strdup(ATL_OFI_SHM_PROV_NAME); - ret = fi_getinfo(fi_version, NULL, NULL, 0ULL, prov_hints, &prov_list); - if (ret || !prov_list) { - enable_shm = 0; - LOG_INFO("shm provider is requested but not available"); - } - else { - LOG_INFO("shm provider is requested and available"); - } - - fi_freeinfo(prov_list); - prov_list = NULL; - - fi_freeinfo(prov_hints); - prov_hints = NULL; - } - - ofi_ctx->prov_count = 0; - ofi_ctx->nw_prov_count = 0; - ofi_ctx->shm_prov_idx = 0; - ofi_ctx->nw_prov_first_idx = (enable_shm) ? 1 : 0; - ofi_ctx->mnic_type = attr->in.mnic_type; - ofi_ctx->mnic_count = std::min(attr->in.mnic_count, (size_t)(ATL_OFI_MAX_NW_PROV_COUNT)); - ofi_ctx->mnic_count = std::min(ofi_ctx->mnic_count, attr->in.ep_count); - ofi_ctx->mnic_count = std::max(ofi_ctx->mnic_count, (size_t)(1)); - - if ((ofi_ctx->mnic_type != ATL_MNIC_NONE) && !hwloc_is_initialized()) { - hwloc_status_t hwloc_status = hwloc_init(); - if (hwloc_status != HWLOC_SUCCESS) { - ofi_ctx->mnic_type = ATL_MNIC_NONE; - ofi_ctx->mnic_count = 1; - LOG_WARN("can't init hwloc, disable multi-nic") - } - } - - /* open SHM provider */ - if (enable_shm) { - prov_idx = ofi_ctx->shm_prov_idx; - prov_name = strdup(ATL_OFI_SHM_PROV_NAME); - prov = &ofi_ctx->provs[prov_idx]; - prov->idx = prov_idx; - prov->is_shm = 1; - ATL_CALL(atl_ofi_get_prov_list(ctx, prov_name, base_hints, &prov_list), goto err); - ATL_CALL(atl_ofi_prov_init(ctx, prov_list, prov, pmi), goto err); - free(prov_name); - fi_freeinfo(prov_list); - ofi_ctx->prov_count++; - } - - /* open NW provider(s) */ - if (prov_env && !strcmp(prov_env, ATL_OFI_SHM_PROV_NAME) && enable_shm) { - open_nw_provs = 0; - } - - if (open_nw_provs) { - ATL_CALL(atl_ofi_open_nw_provs(ctx, base_hints, pmi), goto err); - ofi_ctx->mnic_count = ofi_ctx->nw_prov_count; - } - - for (ep_idx = 0; ep_idx < ctx->ep_count; ep_idx++) { - atl_ofi_ep_t* ofi_ep; - ofi_ep = (atl_ofi_ep_t*)calloc(1, sizeof(atl_ofi_ep_t)); - if (!ofi_ep) { - LOG_ERROR("can't alloc ofi_ep, idx ", ep_idx); - goto err; - } - - atl_ep_t* ep; - ep = &(ofi_ep->ep); - ep->idx = ep_idx; - ep->ctx = ctx; - ep->p2p_ops = &atl_ofi_ep_p2p_ops; - ep->coll_ops = &atl_ofi_ep_coll_ops; - ep->rma_ops = &atl_ofi_ep_rma_ops; - ep->comp_ops = &atl_ofi_ep_comp_ops; - - ofi_ep->active_prov_count = 0; - if (enable_shm) { - ofi_ep->active_prov_idxs[ofi_ep->active_prov_count] = ofi_ctx->shm_prov_idx; - ofi_ep->active_prov_count++; - } - if (open_nw_provs) { - ofi_ep->active_prov_idxs[ofi_ep->active_prov_count] = - ofi_ctx->nw_prov_first_idx + ep_idx % ofi_ctx->nw_prov_count; - ofi_ep->active_prov_count++; - } - CCL_THROW_IF_NOT(ofi_ep->active_prov_count, "no active providers for ep_idx ", ep_idx); - - if (coord->global_idx == 0) { - std::stringstream ss; - for (idx = 0; idx < ofi_ep->active_prov_count; idx++) { - ss << ofi_ep->active_prov_idxs[idx] << " "; - } - LOG_INFO("ep_idx: ", ep_idx, ", active_prov_idxs: ", ss.str()); - } - - ctx->eps[ep_idx] = ep; - } - - pmi->pmrt_barrier(); - - max_retry_count_env = getenv(ATL_OFI_MAX_RETRY_COUNT_ENV); - if (max_retry_count_env) { - ofi_ctx->max_retry_count = safe_c_strtol(max_retry_count_env, NULL, 10); - } - else { - ofi_ctx->max_retry_count = ATL_OFI_MAX_RETRY_COUNT; - } - - if ((coord->global_count == coord->local_count) && (coord->global_count <= 4)) { - ofi_ctx->progress_mode = ATL_PROGRESS_CHECK; - } - else { - ofi_ctx->progress_mode = ATL_PROGRESS_POLL; - } - - progress_mode_env = getenv(ATL_PROGRESS_MODE_ENV); - if (progress_mode_env) { - ofi_ctx->progress_mode = static_cast(atoi(progress_mode_env)); - } - - if (coord->global_idx == 0) { - LOG_INFO("atl-ofi-ctx:"); - LOG_INFO(" new ctx_count: ", global_data.ctx_count); - LOG_INFO(" prov_count: ", ofi_ctx->prov_count); - LOG_INFO(" nw_prov_count: ", ofi_ctx->nw_prov_count); - LOG_INFO(" nw_prov_first_idx: ", ofi_ctx->nw_prov_first_idx); - LOG_INFO(" mnic_type: ", ofi_ctx->mnic_type); - if (ofi_ctx->mnic_type != ATL_MNIC_NONE) - LOG_INFO(" mnic_count: ", ofi_ctx->mnic_count); - LOG_INFO(" max_retry_count: ", ofi_ctx->max_retry_count); - LOG_INFO(" progress_mode: ", ofi_ctx->progress_mode); - } - - *out_ctx = ctx; - - fi_freeinfo(base_hints); - base_hints = NULL; - - /* report actual attributes back to upper level */ - attr->out.enable_shm = enable_shm; - attr->out.enable_rma = 0; - attr->out.enable_device_buf = 0; - attr->out.mnic_type = ofi_ctx->mnic_type; - attr->out.mnic_count = ofi_ctx->mnic_count; - attr->out.tag_bits = 64; - attr->out.max_tag = 0xFFFFFFFFFFFFFFFF; - attr->out.max_order_waw_size = 0; - - return ATL_STATUS_SUCCESS; - -err: - LOG_ERROR("can't find suitable provider"); - - if (prov_list) { - fi_freeinfo(prov_list); - } - - if (base_hints) { - fi_freeinfo(base_hints); - } - - if (prov_hints) { - fi_freeinfo(prov_hints); - } - - if (ctx != NULL) - atl_ofi_finalize(ctx); - - return ATL_STATUS_FAILURE; -} - -atl_status_t atl_ofi_main_addr_reserve(char* main_addr) { - return ATL_STATUS_UNSUPPORTED; -} diff --git a/src/atl/util/pm/codec/pm_rt_codec.h b/src/atl/util/pm/codec/pm_rt_codec.h index 270637338..5fcc09274 100644 --- a/src/atl/util/pm/codec/pm_rt_codec.h +++ b/src/atl/util/pm/codec/pm_rt_codec.h @@ -67,4 +67,4 @@ static inline int decode(const char *inval, void *outval, int outvallen) { return 0; } -#endif /* PMI_RT_CODEC_H */ +#endif // PMI_RT_CODEC_H diff --git a/src/atl/util/pm/pm_rt.h b/src/atl/util/pm/pm_rt.h index 8057df881..ac328fddc 100644 --- a/src/atl/util/pm/pm_rt.h +++ b/src/atl/util/pm/pm_rt.h @@ -189,4 +189,4 @@ class ipmi { virtual size_t get_ranks_per_process() = 0; }; #endif -#endif /* PM_RT_H */ +#endif // PM_RT_H diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/def.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/def.h index 3c03dab70..98b06efc6 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/def.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/def.h @@ -15,6 +15,15 @@ */ #pragma once +#include +#include +#include +#include +#include +#include + +#include "common/log/log.hpp" + //TODO: change exit to something more useful #define SET_STR(dst, size, ...) \ do { \ @@ -140,13 +149,6 @@ extern char my_hostname[MAX_KVS_VAL_LENGTH]; -#include -#include -#include -#include -#include -#include - void inline kvs_str_copy(char* dst, const char* src, size_t bytes) { strncpy(dst, src, bytes - 1); dst[bytes - 1] = '\0'; @@ -158,21 +160,20 @@ void inline kvs_str_copy_known_sizes(char* dst, const char* src, size_t bytes) { } long int inline safe_strtol(const char* str, char** endptr, int base) { + errno = 0; auto val = strtol(str, endptr, base); - if (val == 0) { - /* if a conversion error occurred, display a message and exit */ + + if (errno != 0) { if (errno == EINVAL) { - throw std::runtime_error( - std::string(__PRETTY_FUNCTION__) + - ": conversion error occurred from: " + std::to_string((int)val)); + CCL_THROW("conversion error occurred from: ", str); } - - /* if the value provided was out of range, display a warning message */ - if (errno == ERANGE) { - throw std::runtime_error( - std::string(__PRETTY_FUNCTION__) + - ": the value provided was out of range, value: " + std::to_string((int)val)); + else if (errno == ERANGE) { + CCL_THROW("the value provided was out of range: ", str); + } + else { + CCL_THROW("strtol error: ", strerror(errno), ", str: ", str); } } + return val; } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp index 1e693a9d9..3a69b0434 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp @@ -309,16 +309,13 @@ size_t internal_kvs::init_main_server_by_k8s() { char port_str[MAX_KVS_VAL_LENGTH]; request_k8s_kvs_init(); - SET_STR(port_str, INT_STR_SIZE, "%d", local_server_address.sin_port); + SET_STR(port_str, INT_STR_SIZE, "%d", local_server_address->get_sin_port()); request_k8s_kvs_get_master(local_host_ip, main_host_ip, port_str); main_port = safe_strtol(port_str, nullptr, 10); - main_server_address.sin_port = main_port; - if (inet_pton(AF_INET, main_host_ip, &(main_server_address.sin_addr)) <= 0) { - LOG_ERROR("invalid address/ address not supported: ", main_host_ip); - return 1; - } + main_server_address->set_sin_port(main_port); + main_server_address->set_sin_addr(main_host_ip); return 0; } @@ -345,32 +342,26 @@ size_t internal_kvs::init_main_server_by_env() { port++; main_port = safe_strtol(port, nullptr, 10); - main_server_address.sin_port = main_port; - - if (inet_pton(AF_INET, main_host_ip, &(main_server_address.sin_addr)) <= 0) { - LOG_ERROR("ivalid address / address not supported: ", main_host_ip); - return 1; - } + main_server_address->set_sin_port(main_port); + main_server_address->set_sin_addr(main_host_ip); return 0; } size_t internal_kvs::init_main_server_by_string(const char* main_addr) { char* port = nullptr; - local_server_address.sin_family = AF_INET; - local_server_address.sin_addr.s_addr = inet_addr(local_host_ip); - local_server_address.sin_port = default_start_port; + local_server_address->set_sin_addr(local_host_ip); - main_server_address.sin_family = AF_INET; - - if ((server_listen_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + if ((server_listen_sock = socket(address_family, SOCK_STREAM, 0)) < 0) { LOG_ERROR("init_main_server_by_string: server_listen_sock init"); exit(EXIT_FAILURE); } + size_t sin_port = local_server_address->get_sin_port(); while (bind(server_listen_sock, - (const struct sockaddr*)&local_server_address, - sizeof(local_server_address)) < 0) { - local_server_address.sin_port++; + local_server_address->get_sock_addr_ptr(), + local_server_address->size()) < 0) { + sin_port++; + local_server_address->set_sin_port(sin_port); } memset(main_host_ip, 0, CCL_IP_LEN); @@ -387,14 +378,8 @@ size_t internal_kvs::init_main_server_by_string(const char* main_addr) { port++; main_port = safe_strtol(port, nullptr, 10); - main_server_address.sin_port = main_port; - - if (inet_pton(AF_INET, main_host_ip, &(main_server_address.sin_addr)) <= 0) { - LOG_ERROR("init_main_server_by_string: invalid address / address not supported: ", - main_host_ip); - LOG_ERROR("init_main_server_by_string: inet_pton"); - return 1; - } + main_server_address->set_sin_port(main_port); + main_server_address->set_sin_addr(main_host_ip); return 0; } @@ -402,18 +387,28 @@ int internal_kvs::fill_local_host_ip() { struct ifaddrs *ifaddr, *ifa; int family = AF_UNSPEC; char local_ip[CCL_IP_LEN]; + bool is_supported_iface = false; if (getifaddrs(&ifaddr) < 0) { LOG_ERROR("fill_local_host_ip: can not get host IP"); return -1; } const char iface_name[] = "lo"; + char* iface_name_env = std::getenv(CCL_KVS_IFACE_ENV.c_str()); local_host_ips.clear(); + local_host_ipv6s.clear(); + local_host_ipv4s.clear(); for (ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) { if (ifa->ifa_addr == NULL) continue; - if (strstr(ifa->ifa_name, iface_name) == NULL) { + if (iface_name_env) { + is_supported_iface = strstr(ifa->ifa_name, iface_name_env); + } + else { + is_supported_iface = strstr(ifa->ifa_name, iface_name) == NULL; + } + if (is_supported_iface) { family = ifa->ifa_addr->sa_family; if (family == AF_INET || family == AF_INET6) { memset(local_ip, 0, CCL_IP_LEN); @@ -431,17 +426,58 @@ int internal_kvs::fill_local_host_ip() { LOG_ERROR(s.c_str()); return -1; } + local_host_ips.push_back(local_ip); + if (family == AF_INET6) { + char* scope_id_ptr = nullptr; + if ((scope_id_ptr = strchr(local_ip, SCOPE_ID_DELIM))) { + uint32_t scope_id = ((struct sockaddr_in6*)(ifa->ifa_addr))->sin6_scope_id; + sprintf(scope_id_ptr + 1, "%u", scope_id); + } + local_host_ipv6s.push_back(local_ip); + } + else { + local_host_ipv4s.push_back(local_ip); + } } } } if (local_host_ips.empty()) { - LOG_ERROR("fill_local_host_ip: can't find interface to get host IP"); + LOG_ERROR("fill_local_host_ip: can't find interface ", + iface_name_env ? iface_name_env : "", + " to get host IP"); return -1; } memset(local_host_ip, 0, CCL_IP_LEN); - kvs_str_copy(local_host_ip, local_host_ips.front().c_str(), CCL_IP_LEN); + + char* kvs_prefer_ipv6 = std::getenv(CCL_KVS_PREFER_IPV6_ENV.c_str()); + size_t is_kvs_prefer_ipv6 = kvs_prefer_ipv6 ? safe_strtol(kvs_prefer_ipv6, nullptr, 10) : 0; + + if (is_kvs_prefer_ipv6) { + if (!local_host_ipv6s.empty()) { + address_family = AF_INET6; + } + else { + LOG_WARN("ipv6 addresses are not found, fallback to ipv4"); + address_family = AF_INET; + } + } + else { + address_family = (!local_host_ipv4s.empty()) ? AF_INET : AF_INET6; + } + + if (address_family == AF_INET) { + main_server_address = std::shared_ptr(new sockaddr_v4()); + local_server_address = std::shared_ptr(new sockaddr_v4()); + kvs_str_copy(local_host_ip, local_host_ipv4s.front().c_str(), CCL_IP_LEN); + } + else { + main_server_address = std::shared_ptr(new sockaddr_v6()); + local_server_address = std::shared_ptr(new sockaddr_v6()); + kvs_str_copy(local_host_ip, local_host_ipv6s.front().c_str(), CCL_IP_LEN); + } + LOG_DEBUG("use ", address_family == AF_INET ? "ipv4" : "ipv6", ": ", local_host_ip); freeifaddrs(ifaddr); return 0; @@ -456,30 +492,29 @@ size_t internal_kvs::kvs_main_server_address_reserve(char* main_address) { exit(EXIT_FAILURE); } - if ((server_listen_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + if ((server_listen_sock = socket(address_family, SOCK_STREAM, 0)) < 0) { LOG_ERROR("reserve_main_address: server_listen_sock init"); exit(EXIT_FAILURE); } - main_server_address.sin_family = AF_INET; - main_server_address.sin_addr.s_addr = inet_addr(local_host_ip); - main_server_address.sin_port = default_start_port; - local_server_address.sin_family = AF_INET; - local_server_address.sin_addr.s_addr = inet_addr(local_host_ip); + main_server_address->set_sin_addr(local_host_ip); + local_server_address->set_sin_addr(local_host_ip); + size_t sin_port = main_server_address->get_sin_port(); while (bind(server_listen_sock, - (const struct sockaddr*)&main_server_address, - sizeof(main_server_address)) < 0) { - main_server_address.sin_port++; + main_server_address->get_sock_addr_ptr(), + main_server_address->size()) < 0) { + sin_port++; + main_server_address->set_sin_port(sin_port); } - local_server_address.sin_port = main_server_address.sin_port; + local_server_address->set_sin_port(main_server_address->get_sin_port()); memset(main_address, '\0', CCL_IP_LEN); snprintf(main_address, CCL_IP_LEN, "%s", local_host_ip); snprintf(main_address + strlen(local_host_ip), INT_STR_SIZE + 1, "_%d", - main_server_address.sin_port); + main_server_address->get_sin_port()); return 0; } @@ -487,10 +522,11 @@ size_t internal_kvs::kvs_main_server_address_reserve(char* main_address) { size_t internal_kvs::init_main_server_address(const char* main_addr) { char* ip_getting_type = std::getenv(CCL_KVS_IP_EXCHANGE_ENV.c_str()); - memset(local_host_ip, 0, CCL_IP_LEN); - if (fill_local_host_ip() < 0) { - LOG_ERROR("init_main_server_address: failed to get local host IP"); - exit(EXIT_FAILURE); + if (local_host_ips.empty()) { + if (fill_local_host_ip() < 0) { + LOG_ERROR("init_main_server_address: failed to get local host ip"); + exit(EXIT_FAILURE); + } } if (ip_getting_type) { @@ -518,13 +554,9 @@ size_t internal_kvs::init_main_server_address(const char* main_addr) { ip_getting_mode = IGT_ENV; } - local_server_address.sin_family = AF_INET; - local_server_address.sin_addr.s_addr = inet_addr(local_host_ip); - local_server_address.sin_port = default_start_port; - - main_server_address.sin_family = AF_INET; + local_server_address->set_sin_addr(local_host_ip); - if ((server_listen_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + if ((server_listen_sock = socket(address_family, SOCK_STREAM, 0)) < 0) { ; LOG_ERROR("init_main_server_address: server_listen_sock init"); exit(EXIT_FAILURE); @@ -532,13 +564,15 @@ size_t internal_kvs::init_main_server_address(const char* main_addr) { switch (ip_getting_mode) { case IGT_K8S: { + size_t sin_port = local_server_address->get_sin_port(); while (bind(server_listen_sock, - (const struct sockaddr*)&local_server_address, - sizeof(local_server_address)) < 0) { - local_server_address.sin_port++; + local_server_address->get_sock_addr_ptr(), + local_server_address->size()) < 0) { + sin_port++; + local_server_address->set_sin_port(sin_port); } - local_port = local_server_address.sin_port; + local_port = local_server_address->get_sin_port(); return init_main_server_by_k8s(); } case IGT_ENV: { @@ -558,32 +592,34 @@ size_t internal_kvs::init_main_server_address(const char* main_addr) { is_master_node = 1; memset(local_host_ip, 0, CCL_IP_LEN); kvs_str_copy_known_sizes(local_host_ip, main_host_ip, CCL_IP_LEN); - local_server_address.sin_addr.s_addr = inet_addr(local_host_ip); + local_server_address->set_sin_addr(local_host_ip); } } if (is_master_node) { if (bind(server_listen_sock, - (const struct sockaddr*)&main_server_address, - sizeof(main_server_address)) < 0) { - printf("port [%d] is busy\n", main_server_address.sin_port); + main_server_address->get_sock_addr_ptr(), + main_server_address->size()) < 0) { + LOG_INFO("port [", main_server_address->get_sin_port(), "] is busy"); + local_port = local_server_address->get_sin_port(); while (bind(server_listen_sock, - (const struct sockaddr*)&local_server_address, - sizeof(local_server_address)) < 0) { - local_server_address.sin_port++; + local_server_address->get_sock_addr_ptr(), + local_server_address->size()) < 0) { + local_port++; + local_server_address->set_sin_port(local_port); } - local_port = local_server_address.sin_port; } else { - local_port = main_server_address.sin_port; + local_port = main_server_address->get_sin_port(); } } else { + local_port = local_server_address->get_sin_port(); while (bind(server_listen_sock, - (const struct sockaddr*)&local_server_address, - sizeof(local_server_address)) < 0) { - local_server_address.sin_port++; + local_server_address->get_sock_addr_ptr(), + local_server_address->size()) < 0) { + local_port++; + local_server_address->set_sin_port(local_port); } - local_port = local_server_address.sin_port; } return res; @@ -598,24 +634,10 @@ size_t internal_kvs::init_main_server_address(const char* main_addr) { size_t internal_kvs::kvs_init(const char* main_addr) { int err; socklen_t len = 0; - struct sockaddr_in addr; + std::shared_ptr addr; + time_t start_time; time_t connection_time = 0; - memset(&addr, 0, sizeof(struct sockaddr_in)); - - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = inet_addr("127.0.0.1"); - addr.sin_port = default_start_port; - - if ((client_op_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - LOG_ERROR("kvs_init: client_op_sock init"); - return 1; - } - - if ((server_control_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - LOG_ERROR("kvs_init: server_control_sock init"); - return 1; - } if (init_main_server_address(main_addr)) { LOG_ERROR("kvs_init: init main server address error"); @@ -626,8 +648,29 @@ size_t internal_kvs::kvs_init(const char* main_addr) { return 1; } - while (bind(server_control_sock, (const struct sockaddr*)&addr, sizeof(addr)) < 0) { - addr.sin_port++; + if (address_family == AF_INET) { + addr = std::shared_ptr(new sockaddr_v4()); + addr->set_sin_addr("127.0.0.1"); + } + else { + addr = std::shared_ptr(new sockaddr_v6()); + addr->set_sin_addr("::1"); + } + + if ((client_op_sock = socket(address_family, SOCK_STREAM, 0)) < 0) { + LOG_ERROR("kvs_init: client_op_sock init"); + return 1; + } + + if ((server_control_sock = socket(address_family, SOCK_STREAM, 0)) < 0) { + LOG_ERROR("kvs_init: server_control_sock init"); + return 1; + } + + size_t sin_port = addr->get_sin_port(); + while (bind(server_control_sock, addr->get_sock_addr_ptr(), addr->size()) < 0) { + sin_port++; + addr->set_sin_port(sin_port); } if (listen(server_control_sock, 1) < 0) { @@ -635,9 +678,9 @@ size_t internal_kvs::kvs_init(const char* main_addr) { exit(EXIT_FAILURE); } - getsockname(server_control_sock, (struct sockaddr*)&addr, &len); + getsockname(server_control_sock, addr->get_sock_addr_ptr(), &len); server_args args; - args.args = &addr; + args.args = addr; args.sock_listener = server_listen_sock; err = pthread_create(&kvs_thread, nullptr, kvs_server_init, &args); if (err) { @@ -654,7 +697,7 @@ size_t internal_kvs::kvs_init(const char* main_addr) { start_time = time(nullptr); do { err = connect( - client_op_sock, (struct sockaddr*)&main_server_address, sizeof(main_server_address)); + client_op_sock, main_server_address->get_sock_addr_ptr(), main_server_address->size()); connection_time = time(nullptr) - start_time; } while ((err < 0) && (connection_time < CONNECTION_TIMEOUT)); @@ -718,7 +761,59 @@ size_t internal_kvs::kvs_finalize(void) { return 0; } + internal_kvs::~internal_kvs() { if (is_inited) kvs_finalize(); } + +void sockaddr_v4::set_sin_addr(const char* src) { + int ret = inet_pton(addr.sin_family, src, &(addr.sin_addr)); + if (ret <= 0) { + if (ret == 0) { + LOG_ERROR( + "inet_pton error - invalid network address, af: ", addr.sin_family, ", src: ", src); + } + else if (ret < 0) { + LOG_ERROR("inet_pton error - af: ", + addr.sin_family, + ", src: ", + src, + ", error: ", + strerror(errno)); + } + exit(1); + } +} + +void sockaddr_v6::set_sin_addr(const char* src) { + char src_copy[internal_kvs::CCL_IP_LEN] = { 0 }; + kvs_str_copy(src_copy, src, internal_kvs::CCL_IP_LEN); + + char* scope_id_ptr = nullptr; + if ((scope_id_ptr = strchr(src_copy, internal_kvs::SCOPE_ID_DELIM))) { + addr.sin6_scope_id = safe_strtol(scope_id_ptr + 1, nullptr, 10); + *scope_id_ptr = '\0'; + } + + int ret = inet_pton(addr.sin6_family, src_copy, &(addr.sin6_addr)); + if (ret <= 0) { + if (ret == 0) { + LOG_ERROR("inet_pton error - invalid network address, af: ", + addr.sin6_family, + ", src_copy: ", + src_copy); + } + else if (ret < 0) { + LOG_ERROR("inet_pton error - af: ", + addr.sin6_family, + ", src_copy: ", + src_copy, + ", error: ", + strerror(errno)); + } + exit(1); + } + + LOG_DEBUG("addr: ", src_copy, ", scope_id: ", addr.sin6_scope_id); +} diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h index cd54c77ee..7460426d2 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h @@ -19,9 +19,25 @@ #include #include #include +#include #include "ikvs_wrapper.h" +class isockaddr { +public: + virtual in_port_t get_sin_port() = 0; + virtual void set_sin_port(in_port_t) = 0; + virtual const void* get_sin_addr_ptr() = 0; + virtual void set_sin_addr(const char*) = 0; + virtual struct sockaddr* get_sock_addr_ptr() = 0; + virtual sa_family_t sin_family() = 0; + virtual size_t size() = 0; + virtual ~isockaddr() = default; + +protected: + const size_t default_start_port = 4096; +}; + class internal_kvs final : public ikvs_wrapper { public: size_t kvs_set_value(const char* kvs_name, const char* kvs_key, const char* kvs_val) override; @@ -60,8 +76,10 @@ class internal_kvs final : public ikvs_wrapper { server_address = server_addr; } -private: static const int CCL_IP_LEN = 128; + static const char SCOPE_ID_DELIM = '%'; + +private: size_t init_main_server_by_string(const char* main_addr); size_t init_main_server_by_env(); size_t init_main_server_by_k8s(); @@ -73,6 +91,8 @@ class internal_kvs final : public ikvs_wrapper { char main_host_ip[CCL_IP_LEN]; std::list local_host_ips; + std::list local_host_ipv4s; + std::list local_host_ipv6s; char local_host_ip[CCL_IP_LEN]; size_t main_port; @@ -80,8 +100,8 @@ class internal_kvs final : public ikvs_wrapper { size_t is_master = 0; std::mutex client_memory_mutex; - struct sockaddr_in main_server_address; - struct sockaddr_in local_server_address; + std::shared_ptr main_server_address; + std::shared_ptr local_server_address; int client_op_sock; /* used on client side to send commands and to recv result to/from server */ @@ -97,11 +117,81 @@ class internal_kvs final : public ikvs_wrapper { const std::string CCL_KVS_IP_PORT_ENV = "CCL_KVS_IP_PORT"; const std::string CCL_KVS_IP_EXCHANGE_ENV = "CCL_KVS_IP_EXCHANGE"; + const std::string CCL_KVS_PREFER_IPV6_ENV = "CCL_KVS_PREFER_IPV6"; + const std::string CCL_KVS_IFACE_ENV = "CCL_KVS_IFACE"; + const std::string CCL_KVS_IP_EXCHANGE_VAL_ENV = "env"; const std::string CCL_KVS_IP_EXCHANGE_VAL_K8S = "k8s"; const int CONNECTION_TIMEOUT = 120; + int server_listen_sock; /* used on server side to handle new incoming connect requests from clients */ std::string server_address{}; - const size_t default_start_port = 4096; + + sa_family_t address_family{ AF_UNSPEC }; +}; + +class sockaddr_v4 : public isockaddr { +public: + sockaddr_v4() { + memset(&addr, 0, sizeof(sockaddr_in)); + addr.sin_addr.s_addr = INADDR_ANY; + addr.sin_family = AF_INET; + addr.sin_port = default_start_port; + } + in_port_t get_sin_port() override { + return addr.sin_port; + } + void set_sin_port(in_port_t sin_port) override { + addr.sin_port = sin_port; + } + struct sockaddr* get_sock_addr_ptr() override { + return (struct sockaddr*)&addr; + } + const void* get_sin_addr_ptr() override { + return &(addr.sin_addr); + } + void set_sin_addr(const char* src) override; + sa_family_t sin_family() override { + return addr.sin_family; + } + size_t size() override { + return sizeof(addr); + } + ~sockaddr_v4() override = default; + +private: + struct sockaddr_in addr; +}; +class sockaddr_v6 : public isockaddr { +public: + sockaddr_v6() { + memset(&addr, 0, sizeof(sockaddr_in6)); + addr.sin6_addr = IN6ADDR_ANY_INIT; + addr.sin6_family = AF_INET6; + addr.sin6_port = default_start_port; + } + in_port_t get_sin_port() override { + return addr.sin6_port; + } + void set_sin_port(in_port_t sin_port) override { + addr.sin6_port = sin_port; + } + const void* get_sin_addr_ptr() override { + return &(addr.sin6_addr); + } + void set_sin_addr(const char* src) override; + struct sockaddr* get_sock_addr_ptr() override { + return (struct sockaddr*)&addr; + } + sa_family_t sin_family() override { + return addr.sin6_family; + } + size_t size() override { + return sizeof(addr); + } + ~sockaddr_v6() override = default; + +private: + struct sockaddr_in6 addr; }; diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.cpp index b53e00b40..7a4cba47f 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.cpp @@ -78,22 +78,26 @@ class server { std::map> requests; const int free_socket = -1; std::vector poll_fds; + + sa_family_t address_family{ AF_UNSPEC }; }; void server::try_to_connect_new() { if (poll_fds[FDI_LISTENER].revents != 0) { - struct sockaddr_in addr; + std::shared_ptr addr; - memset(&addr, 0, sizeof(addr)); + if (address_family == AF_INET) { + addr = std::shared_ptr(new sockaddr_v4()); + } + else { + addr = std::shared_ptr(new sockaddr_v6()); + } - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = INADDR_ANY; - addr.sin_port = 0; int new_socket; - socklen_t peer_addr_size = sizeof(addr); - if ((new_socket = accept( - poll_fds[FDI_LISTENER].fd, (struct sockaddr*)&addr, (socklen_t*)&peer_addr_size)) < - 0) { + socklen_t peer_addr_size = addr->size(); + if ((new_socket = accept(poll_fds[FDI_LISTENER].fd, + addr->get_sock_addr_ptr(), + (socklen_t*)&peer_addr_size)) < 0) { perror("server: server_listen_sock accept"); exit(EXIT_FAILURE); } @@ -379,19 +383,13 @@ bool server::check_finalize() { void server::run(void* args) { bool should_stop = false; int so_reuse = 1; - struct sockaddr_in addr; poll_fds.resize(client_count_increase); for (auto& it : poll_fds) { it.fd = free_socket; it.events = POLLIN; } poll_fds[FDI_LISTENER].fd = ((server_args_t*)args)->sock_listener; - - memset(&addr, 0, sizeof(addr)); - - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = INADDR_ANY; - addr.sin_port = 0; + address_family = ((server_args_t*)args)->args->sin_family(); #ifdef SO_REUSEPORT setsockopt(poll_fds[FDI_LISTENER].fd, SOL_SOCKET, SO_REUSEPORT, &so_reuse, sizeof(so_reuse)); @@ -404,14 +402,14 @@ void server::run(void* args) { exit(EXIT_FAILURE); } - if ((poll_fds[FDI_CONTROL].fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + if ((poll_fds[FDI_CONTROL].fd = socket(address_family, SOCK_STREAM, 0)) < 0) { perror("server: server_control_sock init"); exit(EXIT_FAILURE); } while (connect(poll_fds[FDI_CONTROL].fd, - (struct sockaddr*)(((server_args_t*)args)->args), - sizeof(addr)) < 0) { + ((server_args_t*)args)->args->get_sock_addr_ptr(), + ((server_args_t*)args)->args->size()) < 0) { } while (!should_stop || client_count > 0) { if (poll(poll_fds.data(), poll_fds.size(), -1) < 0) { diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.hpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.hpp index ccd2ae695..12590aa54 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.hpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.hpp @@ -15,6 +15,7 @@ */ #pragma once #include "util/pm/pmi_resizable_rt/pmi_resizable/def.h" +#include "internal_kvs.h" typedef enum kvs_access_mode { AM_PUT = 2, @@ -39,7 +40,7 @@ typedef struct kvs_request { typedef struct server_args { int sock_listener; - struct sockaddr_in* args; + std::shared_ptr args; } server_args_t; void* kvs_server_init(void* args); diff --git a/src/atl/util/pm/pmi_rt/pmi/mpl.h b/src/atl/util/pm/pmi_rt/pmi/mpl.h index 04b7344a6..fbdb67474 100644 --- a/src/atl/util/pm/pmi_rt/pmi/mpl.h +++ b/src/atl/util/pm/pmi_rt/pmi/mpl.h @@ -181,4 +181,4 @@ static inline int MPL_internal_error_printf(const char *str, ...) { return n; } -#endif /* MPL_H */ +#endif // MPL_H diff --git a/src/atl/util/pm/pmi_rt/pmi/simple_pmiutil.h b/src/atl/util/pm/pmi_rt/pmi/simple_pmiutil.h index 6c3a749ef..ce14d5cf9 100644 --- a/src/atl/util/pm/pmi_rt/pmi/simple_pmiutil.h +++ b/src/atl/util/pm/pmi_rt/pmi/simple_pmiutil.h @@ -34,7 +34,7 @@ #if defined HAVE_ARPA_INET_H #include -#endif /* HAVE_ARPA_INET_H */ +#endif // HAVE_ARPA_INET_H /* prototypes for PMIU routines */ void PMIU_Set_rank(int PMI_rank); diff --git a/src/ccl_api_functions.cpp b/src/ccl_api_functions.cpp index 280a720b0..4ae113e2b 100644 --- a/src/ccl_api_functions.cpp +++ b/src/ccl_api_functions.cpp @@ -25,7 +25,6 @@ #include "ccl_api_functions_generators.hpp" #include "common/global/global.hpp" -#include "ccl_gpu_module.hpp" namespace ccl { @@ -42,44 +41,9 @@ struct impl_dispatch { } }; -#ifdef MULTI_GPU_SUPPORT -/* register a gpu module */ -void register_gpu_module(std::string kernels_path) { - if (!kernels_path.empty()) { - if (*kernels_path.rbegin() != '/') { - kernels_path += '/'; - } - } - - LOG_INFO("SPIRV kernels directory: ", kernels_path); - - load_gpu_module( - kernels_path + "ring_allgatherv.spv", ccl::device_topology_type::ring, ccl_coll_allgatherv); - load_gpu_module( - kernels_path + "ring_allreduce.spv", ccl::device_topology_type::ring, ccl_coll_allreduce); - load_gpu_module( - kernels_path + "ring_alltoallv.spv", ccl::device_topology_type::ring, ccl_coll_alltoallv); - load_gpu_module( - kernels_path + "ring_bcast.spv", ccl::device_topology_type::ring, ccl_coll_bcast); - load_gpu_module( - kernels_path + "ring_reduce.spv", ccl::device_topology_type::ring, ccl_coll_reduce); - load_gpu_module(kernels_path + "ring_reduce_scatter.spv", - ccl::device_topology_type::ring, - ccl_coll_reduce_scatter); -} -#endif //MULTI_GPU_SUPPORT - void init(const init_attr& attr) { auto& env = detail::environment::instance(); (void)env; - -#ifdef MULTI_GPU_SUPPORT - const auto& env_object = ccl::global_data::env(); - //WA - if (!env_object.comm_kernels_path.empty()) { - register_gpu_module(env_object.comm_kernels_path); - } -#endif //MULTI_GPU_SUPPORT } /******************** ENVIRONMENT ********************/ @@ -127,31 +91,6 @@ stream create_stream() { return default_stream; } -#ifdef CCL_ENABLE_SYCL -communicator create_single_device_communicator(const int comm_size, - const int rank, - const cl::sycl::device& device, - const cl::sycl::context& context, - shared_ptr_class kvs) { - return detail::environment::instance().create_single_device_communicator( - comm_size, rank, device, context, kvs); -} -#endif // CCL_ENABLE_SYCL - -// communicator create_single_device_communicator(const size_t world_size, -// const int rank, -// cl::sycl::queue queue, -// shared_ptr_class kvs) const; - -// template -// communicator create_single_device_communicator(const size_t world_size, -// const int rank, -// const DeviceSelectorType& selector, -// shared_ptr_class kvs) const -// { -// return return detail::environment::instance().create_single_device_communicator(world_size, rank, cl::sycl::device(selector), kvs); -// } - } // namespace v1 namespace preview { diff --git a/src/ccl_cpp_environment.cpp b/src/ccl_cpp_environment.cpp index e69d61455..7c7297170 100644 --- a/src/ccl_cpp_environment.cpp +++ b/src/ccl_cpp_environment.cpp @@ -25,8 +25,6 @@ #include -#include "common/comm/single_device_communicator/single_device_communicator.hpp" - namespace ccl { namespace detail { @@ -102,50 +100,8 @@ size_t environment::get_datatype_size(ccl::datatype dtype) const { return ccl::global_data::get().dtypes->get(dtype).size(); } -/******************** STREAM ********************/ - -stream environment::create_stream(typename unified_device_type::ccl_native_t device) { - auto version = utils::get_library_version(); - return stream{ stream_provider_dispatcher::create(device, version) }; -} - -stream environment::create_stream(typename unified_device_type::ccl_native_t device, - typename unified_context_type::ccl_native_t context) { - auto version = utils::get_library_version(); - return stream{ stream_provider_dispatcher::create(device, context, version) }; -} - /******************** COMMUNICATOR ********************/ -#ifdef CCL_ENABLE_SYCL -communicator environment::create_single_device_communicator( - const int comm_size, - const int rank, - const cl::sycl::device& device, - const cl::sycl::context& context, - ccl::shared_ptr_class kvs) const { - LOG_TRACE("Create single device communicator from SYCL device"); - - std::shared_ptr kvs_wrapper(new users_kvs(kvs)); - std::shared_ptr atl = - std::shared_ptr(new atl_wrapper(comm_size, { rank }, kvs_wrapper)); - - ccl::communicator_interface_ptr impl = - ccl::communicator_interface::create_communicator_impl(device, - context, - rank, - comm_size, - create_comm_split_attr(), - atl, - ccl::group_split_type::single); - - //TODO use gpu_comm_attr to automatically visit() - auto single_dev_comm = std::dynamic_pointer_cast(impl); - //single_dev_comm->set_context(context); - return communicator(std::move(impl)); -} -#endif - communicator environment::create_communicator(const comm_attr& attr) const { return communicator::create_communicator(attr); } @@ -169,15 +125,11 @@ communicator environment::create_communicator(const size_t size, /******************** TypeGenerations ********************/ -CREATE_DEV_COMM_INSTANTIATION(ccl::device, ccl::context) -CREATE_DEV_COMM_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, - typename ccl::unified_context_type::ccl_native_t) -CREATE_DEV_COMM_INSTANTIATION(ccl::device_index_type, - typename ccl::unified_context_type::ccl_native_t) +CREATE_COMM_INSTANTIATION(ccl::device, ccl::context) +CREATE_COMM_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, + typename ccl::unified_context_type::ccl_native_t) +CREATE_COMM_INSTANTIATION(ccl::device_index_type, typename ccl::unified_context_type::ccl_native_t) CREATE_STREAM_INSTANTIATION(typename ccl::unified_stream_type::ccl_native_t) -CREATE_STREAM_EXT_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, - typename ccl::unified_context_type::ccl_native_t) - CREATE_CONTEXT_INSTANTIATION(typename ccl::unified_context_type::ccl_native_t) CREATE_DEVICE_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t) diff --git a/src/ccl_cpp_stream.cpp b/src/ccl_cpp_stream.cpp index 35c434c60..952e1fec6 100644 --- a/src/ccl_cpp_stream.cpp +++ b/src/ccl_cpp_stream.cpp @@ -44,10 +44,6 @@ CCL_API stream& stream::operator=(const stream& src) { return *this; } -CCL_API void stream::build_from_params() { - get_impl()->build_from_params(); -} - CCL_API stream::native_t& stream::get_native() { return const_cast(static_cast(this)->get_native()); } @@ -62,23 +58,5 @@ CCL_API const stream::native_t& stream::get_native() const { } // namespace ccl API_STREAM_CREATION_FORCE_INSTANTIATION(typename ccl::unified_stream_type::ccl_native_t) -API_STREAM_CREATION_EXT_FORCE_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, - typename ccl::unified_context_type::ccl_native_t) -#ifdef CCL_ENABLE_SYCL -API_STREAM_CREATION_FORCE_INSTANTIATION(cl_command_queue) -#else -//API_STREAM_CREATION_FORCE_INSTANTIATION(ccl::empty_t) -#endif - API_STREAM_FORCE_INSTANTIATION(ccl::stream_attr_id::version, ccl::library_version); -API_STREAM_FORCE_INSTANTIATION_GET( - ccl::stream_attr_id::native_handle); //, typename ccl::unified_stream_type::ccl_native_t); -API_STREAM_FORCE_INSTANTIATION_GET( - ccl::stream_attr_id::device); //, typename ccl::unified_device_type::ccl_native_t); -API_STREAM_FORCE_INSTANTIATION(ccl::stream_attr_id::context, - typename ccl::unified_context_type::ccl_native_t); -API_STREAM_FORCE_INSTANTIATION(ccl::stream_attr_id::ordinal, uint32_t); -API_STREAM_FORCE_INSTANTIATION(ccl::stream_attr_id::index, uint32_t); -API_STREAM_FORCE_INSTANTIATION(ccl::stream_attr_id::flags, size_t); -API_STREAM_FORCE_INSTANTIATION(ccl::stream_attr_id::mode, size_t); -API_STREAM_FORCE_INSTANTIATION(ccl::stream_attr_id::priority, size_t); +API_STREAM_FORCE_INSTANTIATION_GET(ccl::stream_attr_id::native_handle); diff --git a/src/ccl_gpu_modules.cpp b/src/ccl_gpu_modules.cpp deleted file mode 100644 index 0b400d359..000000000 --- a/src/ccl_gpu_modules.cpp +++ /dev/null @@ -1,95 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include -#include - -#include "ccl_gpu_module.hpp" - -#ifdef MULTI_GPU_SUPPORT - -#include "common/comm/l0/modules/specific_modules_source_data.hpp" -#include "common/comm/l0/device_group_routing_schema.hpp" -#include "coll/algorithms/algorithms_enum.hpp" - -ccl::status load_gpu_module(const std::string& path, - ccl::device_topology_type topo_type, - ccl_coll_type coll_type) { - char pwd[PATH_MAX]; - char* ret = getcwd(pwd, sizeof(pwd)); - (void)ret; - - LOG_INFO("loading GPU module for collective: \"", - ccl_coll_type_to_str(coll_type), - "\", topology: \"", - to_string(topo_type), - "\" by path: ", - path, - ", current directory is: ", - pwd); - - try { - if (path.empty()) { - throw std::runtime_error("path is empty"); - } - - switch (coll_type) { - case ccl_coll_allgatherv: - native::specific_modules_source_data_storage::instance() - .load_kernel_source(path, topo_type); - break; - case ccl_coll_allreduce: - native::specific_modules_source_data_storage::instance() - .load_kernel_source(path, topo_type); - break; - case ccl_coll_alltoallv: - native::specific_modules_source_data_storage::instance() - .load_kernel_source(path, topo_type); - break; - case ccl_coll_bcast: - native::specific_modules_source_data_storage::instance() - .load_kernel_source(path, topo_type); - break; - case ccl_coll_reduce: - native::specific_modules_source_data_storage::instance() - .load_kernel_source(path, topo_type); - break; - case ccl_coll_reduce_scatter: - native::specific_modules_source_data_storage::instance() - .load_kernel_source(path, topo_type); - break; - default: - throw std::runtime_error( - std::string(__PRETTY_FUNCTION__) + - " - unexpected collective type: " + std::to_string(coll_type)); - break; - } - } - catch (const std::exception& ex) { - LOG_ERROR("cannot load GPU module from: ", path, ", error: ", ex.what()); - CCL_ASSERT(false); - return ccl::status::runtime_error; - } - - LOG_INFO("GPU module for collective: \"", - ccl_coll_type_to_str(coll_type), - "\", topology: \"", - to_string(topo_type), - "\" loaded succesfully"); - - return ccl::status::success; -} - -#endif //MULTI_GPU_SUPPORT diff --git a/src/coll/algorithms/algorithm_utils.cpp b/src/coll/algorithms/algorithm_utils.cpp index 98214594a..48b5c00cf 100644 --- a/src/coll/algorithms/algorithm_utils.cpp +++ b/src/coll/algorithms/algorithm_utils.cpp @@ -36,6 +36,7 @@ const char* ccl_coll_type_to_str(ccl_coll_type type) { case ccl_coll_reduce_scatter: return "reduce_scatter"; case ccl_coll_sparse_allreduce: return "sparse_allreduce"; case ccl_coll_internal: return "internal"; + case ccl_coll_partial: return "partial"; default: return "unknown"; } return "unknown"; diff --git a/src/coll/algorithms/algorithms.hpp b/src/coll/algorithms/algorithms.hpp index 57b5e8456..712de99f8 100644 --- a/src/coll/algorithms/algorithms.hpp +++ b/src/coll/algorithms/algorithms.hpp @@ -38,6 +38,15 @@ ccl::status ccl_coll_build_scatter_ring_allgather_bcast(ccl_sched* sched, int root, ccl_comm* comm); +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) +ccl::status ccl_coll_build_gpu_bcast(ccl_sched* sched, + ccl_buffer buf, + size_t count, + const ccl_datatype& dtype, + int root, + ccl_comm* comm); +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT + ccl::status ccl_coll_build_dissemination_barrier(ccl_sched* sched, ccl_comm* comm); ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, @@ -49,6 +58,17 @@ ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, int root, ccl_comm* comm); +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) +ccl::status ccl_coll_build_gpu_reduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + int root, + ccl_comm* comm); +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT + ccl::status ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, ccl_buffer send_buf, ccl_buffer recv_buf, @@ -98,6 +118,16 @@ ccl::status ccl_coll_build_starlike_allreduce(ccl_sched* sched, ccl::reduction reduction, ccl_comm* comm); +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) +ccl::status ccl_coll_build_gpu_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm); +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT + ccl::status ccl_coll_build_naive_allgatherv(ccl_sched* sched, ccl_buffer send_buf, size_t send_count, @@ -186,6 +216,15 @@ ccl::status ccl_coll_build_ring_allgatherv(ccl_sched* sched, const ccl_datatype& dtype, ccl_comm* comm); +ccl::status ccl_coll_build_flat_allgatherv(ccl_master_sched* main_sched, + std::vector& scheds, + const ccl_coll_param& coll_param); + +ccl::status ccl_coll_build_multi_bcast_allgatherv(ccl_master_sched* main_sched, + std::vector& scheds, + const ccl_coll_param& coll_param, + size_t data_partition_count); + ccl::status ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, std::vector& scheds, const ccl_coll_param& coll_param); diff --git a/src/coll/algorithms/algorithms_enum.hpp b/src/coll/algorithms/algorithms_enum.hpp index 7d52cbc30..880991e44 100644 --- a/src/coll/algorithms/algorithms_enum.hpp +++ b/src/coll/algorithms/algorithms_enum.hpp @@ -24,16 +24,18 @@ ccl_coll_sparse_allreduce enum ccl_coll_allgatherv_algo { + ccl_coll_allgatherv_undefined = 0, + ccl_coll_allgatherv_direct, ccl_coll_allgatherv_naive, ccl_coll_allgatherv_ring, ccl_coll_allgatherv_flat, - ccl_coll_allgatherv_multi_bcast, - - ccl_coll_allgatherv_last_value + ccl_coll_allgatherv_multi_bcast }; enum ccl_coll_allreduce_algo { + ccl_coll_allreduce_undefined = 0, + ccl_coll_allreduce_direct, ccl_coll_allreduce_rabenseifner, ccl_coll_allreduce_starlike, @@ -42,66 +44,84 @@ enum ccl_coll_allreduce_algo { ccl_coll_allreduce_double_tree, ccl_coll_allreduce_recursive_doubling, ccl_coll_allreduce_2d, - - ccl_coll_allreduce_last_value + ccl_coll_allreduce_topo_ring }; enum ccl_coll_alltoall_algo { + ccl_coll_alltoall_undefined = 0, + ccl_coll_alltoall_direct, ccl_coll_alltoall_naive, ccl_coll_alltoall_scatter, - ccl_coll_alltoall_scatter_barrier, - - ccl_coll_alltoall_last_value + ccl_coll_alltoall_scatter_barrier }; enum ccl_coll_alltoallv_algo { + ccl_coll_alltoallv_undefined = 0, + ccl_coll_alltoallv_direct, ccl_coll_alltoallv_naive, ccl_coll_alltoallv_scatter, - ccl_coll_alltoallv_scatter_barrier, - - ccl_coll_alltoallv_last_value + ccl_coll_alltoallv_scatter_barrier }; enum ccl_coll_barrier_algo { - ccl_coll_barrier_direct, - ccl_coll_barrier_ring, + ccl_coll_barrier_undefined = 0, - ccl_coll_barrier_last_value + ccl_coll_barrier_direct, + ccl_coll_barrier_ring }; enum ccl_coll_bcast_algo { + ccl_coll_bcast_undefined = 0, + ccl_coll_bcast_direct, ccl_coll_bcast_ring, ccl_coll_bcast_double_tree, ccl_coll_bcast_naive, - - ccl_coll_bcast_last_value + ccl_coll_bcast_topo_ring }; enum ccl_coll_reduce_algo { + ccl_coll_reduce_undefined = 0, + ccl_coll_reduce_direct, ccl_coll_reduce_rabenseifner, ccl_coll_reduce_tree, ccl_coll_reduce_double_tree, - - ccl_coll_reduce_last_value + ccl_coll_reduce_topo_ring }; enum ccl_coll_reduce_scatter_algo { - ccl_coll_reduce_scatter_direct, - ccl_coll_reduce_scatter_ring, + ccl_coll_reduce_scatter_undefined = 0, - ccl_coll_reduce_scatter_last_value + ccl_coll_reduce_scatter_direct, + ccl_coll_reduce_scatter_ring }; enum ccl_coll_sparse_allreduce_algo { + ccl_coll_sparse_allreduce_undefined = 0, + ccl_coll_sparse_allreduce_ring, ccl_coll_sparse_allreduce_mask, - ccl_coll_sparse_allreduce_3_allgatherv, + ccl_coll_sparse_allreduce_3_allgatherv +}; - ccl_coll_sparse_allreduce_last_value +union ccl_coll_algo { + ccl_coll_allgatherv_algo allgatherv; + ccl_coll_allreduce_algo allreduce; + ccl_coll_alltoall_algo alltoall; + ccl_coll_alltoallv_algo alltoallv; + ccl_coll_barrier_algo barrier; + ccl_coll_bcast_algo bcast; + ccl_coll_reduce_algo reduce; + ccl_coll_reduce_scatter_algo reduce_scatter; + int value; + + ccl_coll_algo() : value(0) {} + bool has_value() const { + return (value != 0); + } }; enum ccl_coll_type { @@ -114,7 +134,10 @@ enum ccl_coll_type { ccl_coll_reduce, ccl_coll_reduce_scatter, ccl_coll_sparse_allreduce, + ccl_coll_last_regular = ccl_coll_sparse_allreduce, + ccl_coll_internal, + ccl_coll_partial, ccl_coll_last_value }; @@ -158,7 +181,9 @@ enum ccl_coll_reduction { ccl::reduction::sum, ccl::reduction::prod, ccl::reduction::min, \ ccl::reduction::max /*, ccl::reduction::custom*/ -using ccl_reductions = utils::enum_to_str(ccl::reduction::custom)>; +using ccl_reductions = + utils::enum_to_str::type>( + ccl::reduction::custom)>; inline const std::string reduction_to_str(ccl::reduction reduction_type) { return ccl_reductions({ "sum", "prod", "min", "max" }).choose(reduction_type, "INVALID_VALUE"); } diff --git a/src/coll/algorithms/allgatherv.cpp b/src/coll/algorithms/allgatherv.cpp index 329c50752..98d1bd160 100644 --- a/src/coll/algorithms/allgatherv.cpp +++ b/src/coll/algorithms/allgatherv.cpp @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include + #include "coll/algorithms/algorithms.hpp" +#include "sched/entry/coll/coll_entry_helper.hpp" #include "sched/entry/factory/chunked_entry_factory.hpp" #include "sched/entry/factory/entry_factory.hpp" @@ -131,3 +134,175 @@ ccl::status ccl_coll_build_ring_allgatherv(ccl_sched* sched, CCL_FREE(offsets); return status; } + +ccl::status ccl_coll_get_allgatherv_bufs_and_offsets(const ccl_coll_param& coll_param, + std::vector& recv_bufs, + std::vector& recv_offsets) { + int comm_size = coll_param.comm->size(); + size_t dtype_size = coll_param.dtype.size(); + + recv_bufs.resize(comm_size); + recv_offsets.resize(comm_size); + + if (coll_param.recv_bufs.size() > 1) { + CCL_THROW_IF_NOT((int)coll_param.recv_bufs.size() == comm_size, + "unexpected recv_bufs.size ", + coll_param.recv_bufs.size(), + ", expected ", + comm_size); + + for (int idx = 0; idx < comm_size; idx++) { + recv_bufs[idx].set(coll_param.get_recv_buf_ptr(idx), + coll_param.get_recv_count(idx) * dtype_size, + ccl_buffer_type::INDIRECT); + recv_offsets[idx] = 0; + } + } + else { + size_t offset = 0; + size_t dtype_size = coll_param.dtype.size(); + for (int idx = 0; idx < comm_size; idx++) { + size_t bytes = coll_param.get_recv_count(idx) * dtype_size; + recv_bufs[idx].set( + coll_param.get_recv_buf_ptr(), offset + bytes, offset, ccl_buffer_type::INDIRECT); + recv_offsets[idx] = offset; + offset += bytes; + } + } + + return ccl::status::success; +} + +ccl::status ccl_coll_build_flat_allgatherv(ccl_master_sched* main_sched, + std::vector& scheds, + const ccl_coll_param& coll_param) { + LOG_DEBUG("build flat allgatherv"); + + ccl_comm* comm = coll_param.comm; + const ccl_datatype& dtype = coll_param.dtype; + + int comm_rank = comm->rank(); + int comm_size = comm->size(); + size_t sched_count = scheds.size(); + size_t dtype_size = dtype.size(); + + bool inplace = coll_param.is_inplace(); + + std::vector recv_bufs; + std::vector recv_offsets; + ccl_coll_get_allgatherv_bufs_and_offsets(coll_param, recv_bufs, recv_offsets); + + auto send_seg = ccl_buffer(coll_param.get_send_buf_ptr(), + coll_param.get_send_count() * dtype_size, + ccl_buffer_type::INDIRECT); + + if (!inplace) { + entry_factory::make_entry(scheds[2 * comm_rank % sched_count], + ccl_buffer(coll_param.get_send_buf_ptr(), + coll_param.get_send_count() * dtype_size, + ccl_buffer_type::INDIRECT), + recv_bufs[comm_rank], + coll_param.get_recv_count(comm_rank), + dtype); + } + else { + size_t total_recv_bytes = + std::accumulate(coll_param.recv_counts.begin(), coll_param.recv_counts.end(), 0) * + dtype_size; + send_seg = ccl_buffer(coll_param.get_send_buf_ptr(), + total_recv_bytes, + recv_offsets[comm_rank], + ccl_buffer_type::INDIRECT); + } + + CCL_THROW_IF_NOT(static_cast(sched_count) == comm_size, + "unexpected sched_count ", + sched_count, + ", expected ", + comm_size); + + for (size_t idx = 0; idx < sched_count; idx++) { + if (static_cast(idx) == comm_rank) + continue; + + entry_factory::make_entry(scheds[(comm_rank + idx) % sched_count], + recv_bufs[idx], + coll_param.get_recv_count(idx), + dtype, + idx, + comm); + + entry_factory::make_entry(scheds[(comm_rank + idx) % sched_count], + send_seg, + coll_param.get_recv_count(comm_rank), + dtype, + idx, + comm); + } + main_sched->sync_partial_scheds(); + + return ccl::status::success; +} + +ccl::status ccl_coll_build_multi_bcast_allgatherv(ccl_master_sched* main_sched, + std::vector& scheds, + const ccl_coll_param& coll_param, + size_t data_partition_count) { + LOG_DEBUG("build multi_bcast allgatherv"); + + CCL_THROW_IF_NOT(data_partition_count > 0, "data_partition_count should be > 0 "); + + ccl_comm* comm = coll_param.comm; + const ccl_datatype& dtype = coll_param.dtype; + + int comm_rank = comm->rank(); + int comm_size = comm->size(); + size_t sched_count = scheds.size(); + size_t dtype_size = dtype.size(); + + bool inplace = coll_param.is_inplace(); + + std::vector recv_bufs; + std::vector recv_offsets; + ccl_coll_get_allgatherv_bufs_and_offsets(coll_param, recv_bufs, recv_offsets); + + if (!inplace) { + std::vector copy_counts(data_partition_count); + std::vector copy_offsets(data_partition_count); + for (size_t idx = 0; idx < data_partition_count; idx++) { + copy_counts[idx] = coll_param.get_recv_count(comm_rank) / data_partition_count; + copy_offsets[idx] = idx * copy_counts[idx] * dtype_size; + } + copy_counts[data_partition_count - 1] += + coll_param.get_recv_count(comm_rank) % data_partition_count; + + CCL_ASSERT(scheds.size() >= data_partition_count); + + for (size_t idx = 0; idx < data_partition_count; idx++) { + entry_factory::make_entry( + scheds[idx], + ccl_buffer(coll_param.get_send_buf_ptr(), + coll_param.get_send_count() * dtype_size, + copy_offsets[idx], + ccl_buffer_type::INDIRECT), + recv_bufs[comm_rank] + copy_offsets[idx], + copy_counts[idx], + dtype); + } + main_sched->sync_partial_scheds(); + } + + for (int idx = 0; idx < comm_size; idx++) { + ccl_coll_entry_param param{}; + param.ctype = ccl_coll_bcast; + param.recv_buf = recv_bufs[idx]; + param.count = coll_param.get_recv_count(idx); + param.dtype = dtype; + param.root = idx; + param.comm = comm; + param.stream = coll_param.stream; + coll_entry_helper::add_coll_entry(scheds[idx % sched_count], param); + } + + return ccl::status::success; +} diff --git a/src/coll/algorithms/allreduce/allreduce.cpp b/src/coll/algorithms/allreduce/allreduce.cpp index ab512a617..89c6d36fb 100644 --- a/src/coll/algorithms/allreduce/allreduce.cpp +++ b/src/coll/algorithms/allreduce/allreduce.cpp @@ -21,6 +21,8 @@ */ #include "coll/algorithms/algorithms.hpp" +#include "common/comm/host_communicator/host_communicator.hpp" +#include "sched/entry/coll/coll_entry_helper.hpp" #include "sched/entry/factory/chunked_entry_factory.hpp" #include "sched/entry/factory/entry_factory.hpp" @@ -497,3 +499,164 @@ ccl::status ccl_coll_build_ring_allreduce(ccl_sched* sched, return status; } + +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + +ccl::status ccl_coll_build_gpu_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm) { + LOG_DEBUG("build gpu allreduce"); + + const std::vector in_buffers{ + { send_buf.get_ptr(), ccl::ze::ipc_mem_type::memory }, // 0 + { recv_buf.get_ptr(), ccl::ze::ipc_mem_type::memory }, // 1 + }; + + ccl_coll_entry_param barrier_param{}; + barrier_param.ctype = ccl_coll_barrier; + barrier_param.comm = comm; + barrier_param.hint_algo.barrier = ccl_coll_barrier_ring; + + ccl_comm* pair_comm = comm->get_host_comm()->get_pair_comm().get()->get_ccl_comm().get(); + ccl_comm* even_comm = comm->get_host_comm()->get_even_comm().get()->get_ccl_comm().get(); + ccl_comm* node_comm = comm->get_host_comm()->get_node_comm().get()->get_ccl_comm().get(); + ccl_comm* r2r_comm = comm->get_host_comm()->get_r2r_comm().get()->get_ccl_comm().get(); + + int skip_rank = -1; + if (ccl::global_data::env().enable_kernel_1s_ipc_wa) { + skip_rank = ccl::global_data::env().kernel_1s_lead; + } + + if (sched->coll_attr.to_cache) { + sched->set_entry_exec_mode(ccl_sched_entry_exec_once); + entry_factory::make_entry( + sched, node_comm, in_buffers, skip_rank); + sched->add_barrier(); + sched->set_entry_exec_mode(ccl_sched_entry_exec_regular); + + // TODO: no need barrier for the first iteration where ze_handle_exchange_entry exists + // TODO: think about the right way + coll_entry_helper::add_coll_entry(sched, barrier_param); + } + else { + entry_factory::make_entry( + sched, node_comm, in_buffers, skip_rank); + } + + sched->add_barrier(); + + if (comm->size() == 4) { + LOG_DEBUG("node_comm: id: ", + node_comm->id(), + ", size: ", + node_comm->size(), + ", rank: ", + node_comm->rank()); + + if (node_comm->size() == 2) { + LOG_DEBUG("r2r_comm: id: ", + r2r_comm->id(), + ", size: ", + r2r_comm->size(), + ", rank: ", + r2r_comm->rank()); + + if (node_comm->rank() == ccl::global_data::env().kernel_1s_lead) { + entry_factory::make_entry( + sched, send_buf, recv_buf, count, dtype, op, node_comm->rank(), node_comm); + sched->add_barrier(); + ccl_buffer host_buf = sched->alloc_buffer(count * dtype.size()); + entry_factory::make_entry( + sched, recv_buf, host_buf, count, dtype, copy_attr(copy_direction::d2h)); + sched->add_barrier(); + ccl_coll_build_allreduce(sched, host_buf, host_buf, count, dtype, op, r2r_comm); + sched->add_barrier(); + entry_factory::make_entry( + sched, host_buf, recv_buf, count, dtype, copy_attr(copy_direction::h2d)); + sched->add_barrier(); + entry_factory::make_entry( + sched, + recv_buf, + ccl_buffer(), + count, + dtype, + copy_attr((node_comm->rank() + 1) % node_comm->size(), 1, copy_direction::d2d)); + sched->add_barrier(); + } + barrier_param.comm = comm; + coll_entry_helper::add_coll_entry(sched, barrier_param); + } + else if (node_comm->size() == 4) { + LOG_DEBUG("pair_comm: id: ", + pair_comm->id(), + ", size: ", + pair_comm->size(), + ", rank: ", + pair_comm->rank()); + + LOG_DEBUG("even_comm: id: ", + even_comm->id(), + ", size: ", + even_comm->size(), + ", rank: ", + even_comm->rank()); + + if (pair_comm->rank() == ccl::global_data::env().kernel_1s_lead) { + entry_factory::make_entry( + sched, send_buf, recv_buf, count, dtype, op, pair_comm->rank(), pair_comm); + sched->add_barrier(); + + barrier_param.comm = even_comm; + coll_entry_helper::add_coll_entry(sched, barrier_param); + sched->add_barrier(); + + if (even_comm->rank() == ccl::global_data::env().kernel_1s_lead) { + entry_factory::make_entry( + sched, recv_buf, recv_buf, count, dtype, op, even_comm); + sched->add_barrier(); + } + } + + barrier_param.comm = comm; + coll_entry_helper::add_coll_entry(sched, barrier_param); + sched->add_barrier(); + + if (pair_comm->rank() != ccl::global_data::env().kernel_1s_lead) { + entry_factory::make_entry( + sched, + ccl_buffer(), + recv_buf, + count, + dtype, + copy_attr((pair_comm->rank() + 1) % pair_comm->size(), + 1, + copy_direction::d2d, + pair_comm)); + sched->add_barrier(); + } + } + else { + CCL_THROW("unexpected node_comm size: ", node_comm->size()); + } + } + else if (comm->size() == 2) { + if (comm->rank() == ccl::global_data::env().kernel_1s_lead) { + entry_factory::make_entry( + sched, send_buf, recv_buf, count, dtype, op, comm); + sched->add_barrier(); + } + barrier_param.comm = comm; + coll_entry_helper::add_coll_entry(sched, barrier_param); + } + else { + CCL_THROW("unexpected comm size: ", comm->size()); + } + + return ccl::status::success; +} + +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT diff --git a/src/coll/algorithms/allreduce/allreduce_2d.cpp b/src/coll/algorithms/allreduce/allreduce_2d.cpp index 1a734e343..54e5b2719 100644 --- a/src/coll/algorithms/allreduce/allreduce_2d.cpp +++ b/src/coll/algorithms/allreduce/allreduce_2d.cpp @@ -54,12 +54,14 @@ ccl_allreduce_2d_builder::ccl_allreduce_2d_builder(size_t base_size, ((idx) ? " " : "") + std::to_string(second_dim_comm->get_global_rank(idx)); } - LOG_DEBUG("allreduce_2d:"); - LOG_DEBUG(" base_size: ", base_size); - LOG_DEBUG(" switch_dims: ", switch_dims); - LOG_DEBUG(" first_dim_comm: size ", first_dim_comm->size(), ", ranks ", first_dim_ranks); - LOG_DEBUG( - " second_dim_comm: size ", second_dim_comm->size(), ", ranks ", second_dim_ranks); + std::stringstream ss; + ss << "{" + << "base: " << base_size << ", switch: " << switch_dims + << ", 1st dim: {size:" << first_dim_comm->size() << ", ranks:" << first_dim_ranks << "}" + << ", 2nd dim: {size:" << second_dim_comm->size() << ", ranks:" << second_dim_ranks + << "}" + << "}"; + LOG_DEBUG(ss.str()); } } diff --git a/src/coll/algorithms/alltoallv.cpp b/src/coll/algorithms/alltoallv.cpp index a411440d7..caf063b4c 100644 --- a/src/coll/algorithms/alltoallv.cpp +++ b/src/coll/algorithms/alltoallv.cpp @@ -82,16 +82,22 @@ ccl::status ccl_coll_calculate_alltoallv_counts(const ccl_coll_param& coll_param size_t dtype_size = dtype.size(); if (coll_type == ccl_coll_alltoall) { - send_counts.resize(comm_size, coll_param.count); - recv_counts.resize(comm_size, coll_param.count); + send_counts.resize(comm_size, coll_param.get_send_count()); + recv_counts.resize(comm_size, coll_param.get_recv_count()); } else if (coll_type == ccl_coll_alltoallv) { - CCL_ASSERT(coll_param.send_counts); - CCL_ASSERT(coll_param.recv_counts); - send_counts.assign((size_t*)coll_param.send_counts, - (size_t*)coll_param.send_counts + comm_size); - recv_counts.assign((size_t*)coll_param.recv_counts, - (size_t*)coll_param.recv_counts + comm_size); + CCL_THROW_IF_NOT(static_cast(coll_param.send_counts.size()) == comm_size, + "unexpected send_counts size ", + coll_param.send_counts.size(), + ", expected ", + comm_size); + CCL_THROW_IF_NOT(static_cast(coll_param.recv_counts.size()) == comm_size, + "unexpected recv_counts size ", + coll_param.recv_counts.size(), + ", expected ", + comm_size); + send_counts = coll_param.send_counts; + recv_counts = coll_param.recv_counts; } send_offsets.resize(comm_size, 0); @@ -137,8 +143,7 @@ ccl::status ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, size_t total_send_count = 0, total_recv_count = 0; size_t total_send_bytes = 0, total_recv_bytes = 0; - bool inplace = - (coll_param.send_buf && (coll_param.send_buf == coll_param.recv_buf)) ? true : false; + bool inplace = coll_param.is_inplace(); ccl_coll_calculate_alltoallv_counts(coll_param, send_counts, @@ -153,11 +158,11 @@ ccl::status ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, if (!inplace && send_counts[comm_rank] && recv_counts[comm_rank]) { size_t sched_idx = (2 * comm_rank) % sched_count; entry_factory::make_entry(scheds[sched_idx], - ccl_buffer((void*)(&(coll_param.send_buf)), + ccl_buffer(coll_param.get_send_buf_ptr(), total_send_bytes, send_offsets[comm_rank], ccl_buffer_type::INDIRECT), - ccl_buffer((void*)(&(coll_param.recv_buf)), + ccl_buffer(coll_param.get_recv_buf_ptr(), total_recv_bytes, recv_offsets[comm_rank], ccl_buffer_type::INDIRECT), @@ -176,7 +181,7 @@ ccl::status ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, if (inplace) recv_buf = scheds[sched_idx]->alloc_buffer(recv_counts[idx] * dtype_size); else - recv_buf = ccl_buffer((void*)(&(coll_param.recv_buf)), + recv_buf = ccl_buffer(coll_param.get_recv_buf_ptr(), total_recv_bytes, recv_offsets[idx], ccl_buffer_type::INDIRECT); @@ -186,7 +191,7 @@ ccl::status ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, entry_factory::make_chunked_send_entry(scheds, sched_idx, - ccl_buffer((void*)(&(coll_param.send_buf)), + ccl_buffer(coll_param.get_send_buf_ptr(), total_send_bytes, send_offsets[idx], ccl_buffer_type::INDIRECT), @@ -199,7 +204,7 @@ ccl::status ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, scheds[sched_idx]->add_barrier(); entry_factory::make_entry(scheds[sched_idx], recv_buf, - ccl_buffer((void*)(&(coll_param.recv_buf)), + ccl_buffer(coll_param.get_recv_buf_ptr(), total_recv_bytes, recv_offsets[idx], ccl_buffer_type::INDIRECT), @@ -229,8 +234,7 @@ ccl::status ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, size_t total_send_count = 0, total_recv_count = 0; size_t total_send_bytes = 0, total_recv_bytes = 0; - bool inplace = - (coll_param.send_buf && (coll_param.send_buf == coll_param.recv_buf)) ? true : false; + bool inplace = coll_param.is_inplace(); std::vector recv_bufs; if (inplace) @@ -249,11 +253,11 @@ ccl::status ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, if (!inplace && send_counts[comm_rank] && recv_counts[comm_rank]) { size_t sched_idx = (2 * comm_rank) % sched_count; entry_factory::make_entry(scheds[sched_idx], - ccl_buffer((void*)(&(coll_param.send_buf)), + ccl_buffer(coll_param.get_send_buf_ptr(), total_send_bytes, send_offsets[comm_rank], ccl_buffer_type::INDIRECT), - ccl_buffer((void*)(&(coll_param.recv_buf)), + ccl_buffer(coll_param.get_recv_buf_ptr(), total_recv_bytes, recv_offsets[comm_rank], ccl_buffer_type::INDIRECT), @@ -276,7 +280,7 @@ ccl::status ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, recv_bufs[src] = recv_buf; } else - recv_buf = ccl_buffer((void*)(&(coll_param.recv_buf)), + recv_buf = ccl_buffer(coll_param.get_recv_buf_ptr(), total_recv_bytes, recv_offsets[src], ccl_buffer_type::INDIRECT); @@ -296,7 +300,7 @@ ccl::status ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, entry_factory::make_chunked_send_entry(scheds, sched_idx, - ccl_buffer((void*)(&(coll_param.send_buf)), + ccl_buffer(coll_param.get_send_buf_ptr(), total_send_bytes, send_offsets[dst], ccl_buffer_type::INDIRECT), @@ -319,7 +323,7 @@ ccl::status ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, entry_factory::make_entry(scheds[sched_idx], recv_bufs[idx], - ccl_buffer((void*)(&(coll_param.recv_buf)), + ccl_buffer(coll_param.get_recv_buf_ptr(), total_recv_bytes, recv_offsets[idx], ccl_buffer_type::INDIRECT), @@ -354,8 +358,7 @@ ccl::status ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sche } } - bool inplace = - (coll_param.send_buf && (coll_param.send_buf == coll_param.recv_buf)) ? true : false; + bool inplace = coll_param.is_inplace(); ccl_coll_calculate_alltoallv_counts(coll_param, send_counts, @@ -391,11 +394,11 @@ ccl::status ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sche if (!inplace && send_counts[comm_rank] && recv_counts[comm_rank]) { size_t sched_idx = (2 * comm_rank) % sched_count; entry_factory::make_entry(recv_scheds[sched_idx], - ccl_buffer((void*)(&(coll_param.send_buf)), + ccl_buffer(coll_param.get_send_buf_ptr(), total_send_bytes, send_offsets[comm_rank], ccl_buffer_type::INDIRECT), - ccl_buffer((void*)(&(coll_param.recv_buf)), + ccl_buffer(coll_param.get_recv_buf_ptr(), total_recv_bytes, recv_offsets[comm_rank], ccl_buffer_type::INDIRECT), @@ -421,7 +424,7 @@ ccl::status ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sche recv_bufs[src] = recv_buf; } else - recv_buf = ccl_buffer((void*)(&(coll_param.recv_buf)), + recv_buf = ccl_buffer(coll_param.get_recv_buf_ptr(), total_recv_bytes, recv_offsets[src], ccl_buffer_type::INDIRECT); @@ -441,7 +444,7 @@ ccl::status ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sche entry_factory::make_chunked_send_entry(send_scheds, sched_idx, - ccl_buffer((void*)(&(coll_param.send_buf)), + ccl_buffer(coll_param.get_send_buf_ptr(), total_send_bytes, send_offsets[dst], ccl_buffer_type::INDIRECT), @@ -464,7 +467,7 @@ ccl::status ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sche entry_factory::make_entry(scheds[sched_idx], recv_bufs[idx], - ccl_buffer((void*)(&(coll_param.recv_buf)), + ccl_buffer(coll_param.get_recv_buf_ptr(), total_recv_bytes, recv_offsets[idx], ccl_buffer_type::INDIRECT), diff --git a/src/coll/algorithms/bcast.cpp b/src/coll/algorithms/bcast.cpp index 59dcf76e0..c4ba99976 100644 --- a/src/coll/algorithms/bcast.cpp +++ b/src/coll/algorithms/bcast.cpp @@ -21,6 +21,7 @@ */ #include "coll/algorithms/algorithms.hpp" +#include "sched/entry/coll/coll_entry_helper.hpp" #include "sched/entry/factory/entry_factory.hpp" #define MIN(a, b) std::min(a, b) @@ -232,3 +233,50 @@ ccl::status ccl_coll_build_scatter_ring_allgather_bcast(ccl_sched* sched, fn_exit: return status; } + +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + +ccl::status ccl_coll_build_gpu_bcast(ccl_sched* sched, + ccl_buffer buf, + size_t count, + const ccl_datatype& dtype, + int root, + ccl_comm* comm) { + LOG_DEBUG("build gpu bcast"); + + const std::vector buffers{ + { buf.get_ptr(), ccl::ze::ipc_mem_type::memory }, // 0 + }; + LOG_DEBUG("BCAST buf = ", buf.get_ptr(), " and root = ", root); + + ccl_coll_entry_param barrier_param{}; + barrier_param.ctype = ccl_coll_barrier; + barrier_param.comm = comm; + barrier_param.hint_algo.barrier = ccl_coll_barrier_ring; + + if (sched->coll_attr.to_cache) { + sched->set_entry_exec_mode(ccl_sched_entry_exec_once); + entry_factory::make_entry(sched, comm, buffers); + sched->add_barrier(); + sched->set_entry_exec_mode(ccl_sched_entry_exec_regular); + + coll_entry_helper::add_coll_entry(sched, barrier_param); + } + else { + entry_factory::make_entry(sched, comm, buffers); + } + + sched->add_barrier(); + + if (comm->rank() != root) { + entry_factory::make_entry( + sched, ccl_buffer(), buf, count, dtype, copy_attr(root, 0, copy_direction::d2d)); + sched->add_barrier(); + } + + coll_entry_helper::add_coll_entry(sched, barrier_param); + + return ccl::status::success; +} + +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT diff --git a/src/coll/algorithms/reduce.cpp b/src/coll/algorithms/reduce.cpp index 65707afed..54a9a55d0 100644 --- a/src/coll/algorithms/reduce.cpp +++ b/src/coll/algorithms/reduce.cpp @@ -21,6 +21,7 @@ */ #include "coll/algorithms/algorithms.hpp" +#include "sched/entry/coll/coll_entry_helper.hpp" #include "sched/entry/factory/entry_factory.hpp" /* An implementation of Rabenseifner's reduce algorithm (see @@ -447,3 +448,56 @@ ccl::status ccl_coll_build_binomial_reduce(ccl_sched* sched, return status; } + +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + +ccl::status ccl_coll_build_gpu_reduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + int root, + ccl_comm* comm) { + LOG_DEBUG("build gpu reduce"); + + int skip_rank = -1; + + const std::vector in_buffers{ + { send_buf.get_ptr(), ccl::ze::ipc_mem_type::memory }, // 0 + }; + + ccl_coll_entry_param barrier_param{}; + barrier_param.ctype = ccl_coll_barrier; + barrier_param.comm = comm; + barrier_param.hint_algo.barrier = ccl_coll_barrier_ring; + + if (sched->coll_attr.to_cache) { + sched->set_entry_exec_mode(ccl_sched_entry_exec_once); + entry_factory::make_entry(sched, comm, in_buffers, skip_rank); + sched->add_barrier(); + sched->set_entry_exec_mode(ccl_sched_entry_exec_regular); + + // TODO: no need barrier for the first iteration where ze_handle_exchange_entry exists + // TODO: think about the right way + coll_entry_helper::add_coll_entry(sched, barrier_param); + } + else { + entry_factory::make_entry(sched, comm, in_buffers, skip_rank); + } + + sched->add_barrier(); + + if (comm->rank() == root) { + entry_factory::make_entry( + sched, send_buf, recv_buf, count, dtype, reduction, root, comm); + sched->add_barrier(); + } + + // TODO: think about the right way + coll_entry_helper::add_coll_entry(sched, barrier_param); + + return ccl::status::success; +} + +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT diff --git a/src/coll/algorithms/sparse_allreduce/sparse_allreduce.hpp b/src/coll/algorithms/sparse_allreduce/sparse_allreduce.hpp index 4e24e172a..32c341a7a 100644 --- a/src/coll/algorithms/sparse_allreduce/sparse_allreduce.hpp +++ b/src/coll/algorithms/sparse_allreduce/sparse_allreduce.hpp @@ -98,58 +98,57 @@ /* TODO: used for ring and mask, refactor to work with dst_ibuf, dst_vbuf */ #define CCL_SPARSE_ALLREDUCE_IF_SINGLE_RANK() \ ({ \ - if (sa_handler->comm_size == 1) { \ - *sa_handler->recv_icount = iv_map_cnt; \ - *sa_handler->recv_vcount = iv_map_cnt * sa_handler->val_dim_cnt; \ - *sa_handler->recv_ibuf = sa_handler->dst_buf; \ - *sa_handler->recv_vbuf = \ - (char*)sa_handler->dst_buf + sa_handler->itype_size * iv_map_cnt; \ + if (sa_hndl->comm_size == 1) { \ + *sa_hndl->recv_icount = iv_map_cnt; \ + *sa_hndl->recv_vcount = iv_map_cnt * sa_hndl->val_dim_cnt; \ + *sa_hndl->recv_ibuf = sa_hndl->dst_buf; \ + *sa_hndl->recv_vbuf = (char*)sa_hndl->dst_buf + sa_hndl->itype_size * iv_map_cnt; \ } \ }) #define CCL_SPARSE_ALLREDUCE_CREATE_HANDLER() \ do { \ /* create handler for sched function callbacks */ \ - sa_handler = static_cast( \ + sa_hndl = static_cast( \ sched->alloc_buffer(sizeof(ccl_sparse_allreduce_handler)).get_ptr()); \ \ - sa_handler->comm = comm; \ - sa_handler->comm_size = comm_size; \ - sa_handler->val_dim_cnt = val_dim_cnt; \ - sa_handler->itype_size = itype_size; \ - sa_handler->vtype_size = vtype_size; \ - sa_handler->index_dtype = index_dtype; \ - sa_handler->value_dtype = value_dtype; \ - sa_handler->op = op; \ - sa_handler->recv_ibuf = r_ind_buf; \ - sa_handler->recv_vbuf = r_val_buf; \ - sa_handler->recv_vcount = recv_val_count; \ - sa_handler->recv_icount = recv_ind_count; \ - sa_handler->sched = sched; \ + sa_hndl->comm = comm; \ + sa_hndl->comm_size = comm_size; \ + sa_hndl->val_dim_cnt = val_dim_cnt; \ + sa_hndl->itype_size = itype_size; \ + sa_hndl->vtype_size = vtype_size; \ + sa_hndl->index_dtype = index_dtype; \ + sa_hndl->value_dtype = value_dtype; \ + sa_hndl->op = op; \ + sa_hndl->recv_ibuf = r_ind_buf; \ + sa_hndl->recv_vbuf = r_val_buf; \ + sa_hndl->recv_vcount = recv_val_count; \ + sa_hndl->recv_icount = recv_ind_count; \ + sa_hndl->sched = sched; \ \ - sa_handler->size_per_rank = \ + sa_hndl->size_per_rank = \ static_cast(sched->alloc_buffer(sizeof(size_t) * comm_size).get_ptr()); \ \ for (int i = 0; i < comm_size; i++) \ - sa_handler->size_per_rank[i] = sizeof(size_t); \ + sa_hndl->size_per_rank[i] = sizeof(size_t); \ \ - sa_handler->send_ibuf = send_ind_buf.get_ptr(); \ - sa_handler->send_vbuf = send_val_buf.get_ptr(); \ + sa_hndl->send_ibuf = send_ind_buf.get_ptr(); \ + sa_hndl->send_vbuf = send_val_buf.get_ptr(); \ \ - sa_handler->send_count[0] = send_ind_count; \ - sa_handler->send_count[1] = send_val_count; \ + sa_hndl->send_count[0] = send_ind_count; \ + sa_hndl->send_count[1] = send_val_count; \ \ - if (sa_handler->sched->coll_attr.sparse_coalesce_mode == \ + if (sa_hndl->sched->coll_attr.sparse_coalesce_mode == \ ccl::sparse_coalesce_mode::keep_precision && \ - sa_handler->value_dtype.idx() == ccl::datatype::bfloat16) { \ - sa_handler->tmp = \ + sa_hndl->value_dtype.idx() == ccl::datatype::bfloat16) { \ + sa_hndl->tmp = \ static_cast(sched->alloc_buffer(sizeof(float) * val_dim_cnt).get_ptr()); \ - sa_handler->acc = \ + sa_hndl->acc = \ static_cast(sched->alloc_buffer(sizeof(float) * val_dim_cnt).get_ptr()); \ } \ else { \ - sa_handler->tmp = nullptr; \ - sa_handler->acc = nullptr; \ + sa_hndl->tmp = nullptr; \ + sa_hndl->acc = nullptr; \ } \ } while (0) @@ -157,10 +156,10 @@ do { \ ccl_coll_entry_param param_nnz{}; \ param_nnz.ctype = ccl_coll_allgatherv; \ - param_nnz.send_buf = ccl_buffer(sa_handler->send_count, sizeof(size_t)); \ - param_nnz.recv_buf = ccl_buffer(sa_handler->recv_counts, sizeof(size_t) * comm_size); \ + param_nnz.send_buf = ccl_buffer(sa_hndl->send_count, sizeof(size_t)); \ + param_nnz.recv_buf = ccl_buffer(sa_hndl->recv_counts, sizeof(size_t) * comm_size); \ param_nnz.send_count = sizeof(size_t); \ - param_nnz.recv_counts = sa_handler->size_per_rank; \ + param_nnz.recv_counts = sa_hndl->size_per_rank; \ param_nnz.dtype = ccl_datatype_int8; \ param_nnz.comm = comm; \ \ @@ -290,11 +289,11 @@ void sparse_coalesce(ccl_sparse_allreduce_handler* sah) { template ccl::status sparse_reduce_ring(const void* ctx) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; /* Having received the msg we should prepare it for further send operation to the next neighbour. - sa_handler->recv_counts contains all the nnz count for all the ranks. And every iteration - (sa_handler->iter) we need to take corresponding nnz count from recv_counts array, according to + sa_hndl->recv_counts contains all the nnz count for all the ranks. And every iteration + (sa_hndl->iter) we need to take corresponding nnz count from recv_counts array, according to the scheme rank id: 0 1 2 3 send_buf_id in iter 0: | 0 | -> | 1 | -> | 2 | -> | 3 | @@ -312,34 +311,34 @@ ccl::status sparse_reduce_ring(const void* ctx) { send_buf_id in iter 3: | 1 | -> | 2 | -> | 3 | -> | 0 | ↑__________________________| */ - sa_handler->send_count[0] = - sa_handler->recv_counts[(sa_handler->recv_from - sa_handler->iter + sa_handler->comm_size) % - sa_handler->comm_size]; - sa_handler->send_count[1] = sa_handler->send_count[0] * sa_handler->val_dim_cnt; + sa_hndl->send_count[0] = + sa_hndl->recv_counts[(sa_hndl->recv_from - sa_hndl->iter + sa_hndl->comm_size) % + sa_hndl->comm_size]; + sa_hndl->send_count[1] = sa_hndl->send_count[0] * sa_hndl->val_dim_cnt; - i_type* snd_i = (i_type*)(sa_handler->dst_buf); + i_type* snd_i = (i_type*)(sa_hndl->dst_buf); v_type* snd_v = - (v_type*)((char*)(sa_handler->dst_buf) + sa_handler->itype_size * sa_handler->dst_count[0]); + (v_type*)((char*)(sa_hndl->dst_buf) + sa_hndl->itype_size * sa_hndl->dst_count[0]); /* copy data from recv_buf so that it would be easier to identify unique indices */ - size_t idx_size = sa_handler->itype_size * sa_handler->send_count[0]; - i_type* rcv_i = (i_type*)sa_handler->recv_buf; - v_type* rcv_v = (v_type*)((char*)(sa_handler->recv_buf) + idx_size); + size_t idx_size = sa_hndl->itype_size * sa_hndl->send_count[0]; + i_type* rcv_i = (i_type*)sa_hndl->recv_buf; + v_type* rcv_v = (v_type*)((char*)(sa_hndl->recv_buf) + idx_size); std::vector unique_indices_ids; /* look at received indices and the ones we already have. Check if there are equal ones, then the values could be reduced right away. The indices left will be copied along with correspoinding values*/ - for (size_t idx = 0; idx < sa_handler->send_count[0]; idx++) { - auto it = sa_handler->iv_map->find(rcv_i[idx]); - if (it != sa_handler->iv_map->end()) { - ccl_comp_reduce(sa_handler->sched, - (void*)(rcv_v + idx * sa_handler->val_dim_cnt), - sa_handler->val_dim_cnt, + for (size_t idx = 0; idx < sa_hndl->send_count[0]; idx++) { + auto it = sa_hndl->iv_map->find(rcv_i[idx]); + if (it != sa_hndl->iv_map->end()) { + ccl_comp_reduce(sa_hndl->sched, + (void*)(rcv_v + idx * sa_hndl->val_dim_cnt), + sa_hndl->val_dim_cnt, snd_v + it->second[0], nullptr, - sa_handler->value_dtype, - sa_handler->op, + sa_hndl->value_dtype, + sa_hndl->op, nullptr, nullptr); } @@ -352,172 +351,161 @@ ccl::status sparse_reduce_ring(const void* ctx) { /* were there any unique indices? */ if (unique_indices_ids.size() > 0) { /* prepare buf for combined data */ - size_t merge_idx_len = sa_handler->iv_map->size() + unique_indices_ids.size(); + size_t merge_idx_len = sa_hndl->iv_map->size() + unique_indices_ids.size(); std::vector buf_i(merge_idx_len); - std::vector buf_v(merge_idx_len * sa_handler->val_dim_cnt); + std::vector buf_v(merge_idx_len * sa_hndl->val_dim_cnt); /* copy what we already have reduced*/ - ccl_comp_copy(snd_i, - buf_i.data(), - sa_handler->itype_size * sa_handler->dst_count[0], - ccl_datatype_int8); - ccl_comp_copy(snd_v, - buf_v.data(), - sa_handler->vtype_size * sa_handler->dst_count[1], - ccl_datatype_int8); + ccl_comp_copy( + snd_i, buf_i.data(), sa_hndl->itype_size * sa_hndl->dst_count[0], ccl_datatype_int8); + ccl_comp_copy( + snd_v, buf_v.data(), sa_hndl->vtype_size * sa_hndl->dst_count[1], ccl_datatype_int8); size_t idx_offset = 0; for (auto id : unique_indices_ids) { - buf_i[sa_handler->dst_count[0] + idx_offset] = rcv_i[id]; + buf_i[sa_hndl->dst_count[0] + idx_offset] = rcv_i[id]; - for (size_t k = 0; k < sa_handler->val_dim_cnt; k++) { - buf_v[sa_handler->dst_count[1] + idx_offset * sa_handler->val_dim_cnt + k] = - rcv_v[id * sa_handler->val_dim_cnt + k]; + for (size_t k = 0; k < sa_hndl->val_dim_cnt; k++) { + buf_v[sa_hndl->dst_count[1] + idx_offset * sa_hndl->val_dim_cnt + k] = + rcv_v[id * sa_hndl->val_dim_cnt + k]; } /* upd the map */ - std::vector tmp = { sa_handler->dst_count[1] + - idx_offset * sa_handler->val_dim_cnt }; + std::vector tmp = { sa_hndl->dst_count[1] + idx_offset * sa_hndl->val_dim_cnt }; tmp.reserve(CCL_COALESCE_RESERVE_SIZE); - sa_handler->iv_map->emplace(rcv_i[id], tmp); + sa_hndl->iv_map->emplace(rcv_i[id], tmp); idx_offset++; } /* we definitely have to increase the size of dst buffer because of the unique indices that came from our neighbour */ - size_t new_dst_size = merge_idx_len * (sa_handler->vtype_size * sa_handler->val_dim_cnt + - sa_handler->itype_size); - sa_handler->dst_buf = - (sa_handler->sched->update_buffer( - ccl_buffer(sa_handler->dst_buf, - sa_handler->dst_count[0] * sa_handler->itype_size + - sa_handler->dst_count[1] * sa_handler->vtype_size), - new_dst_size)) - .get_ptr(); + size_t new_dst_size = + merge_idx_len * (sa_hndl->vtype_size * sa_hndl->val_dim_cnt + sa_hndl->itype_size); + sa_hndl->dst_buf = (sa_hndl->sched->update_buffer( + ccl_buffer(sa_hndl->dst_buf, + sa_hndl->dst_count[0] * sa_hndl->itype_size + + sa_hndl->dst_count[1] * sa_hndl->vtype_size), + new_dst_size)) + .get_ptr(); ccl_comp_copy(buf_i.data(), - (i_type*)(sa_handler->dst_buf), - sa_handler->itype_size * merge_idx_len, + (i_type*)(sa_hndl->dst_buf), + sa_hndl->itype_size * merge_idx_len, ccl_datatype_int8); - ccl_comp_copy( - buf_v.data(), - (v_type*)((char*)(sa_handler->dst_buf) + sa_handler->itype_size * merge_idx_len), - sa_handler->vtype_size * merge_idx_len * sa_handler->val_dim_cnt, - ccl_datatype_int8); + ccl_comp_copy(buf_v.data(), + (v_type*)((char*)(sa_hndl->dst_buf) + sa_hndl->itype_size * merge_idx_len), + sa_hndl->vtype_size * merge_idx_len * sa_hndl->val_dim_cnt, + ccl_datatype_int8); - sa_handler->dst_count[0] = merge_idx_len; - sa_handler->dst_count[1] = merge_idx_len * sa_handler->val_dim_cnt; + sa_hndl->dst_count[0] = merge_idx_len; + sa_hndl->dst_count[1] = merge_idx_len * sa_hndl->val_dim_cnt; } // if unique_indices > 0 - ccl_comp_copy(sa_handler->recv_buf, - sa_handler->send_tmp_buf, - idx_size + sa_handler->send_count[1] * sa_handler->vtype_size, + ccl_comp_copy(sa_hndl->recv_buf, + sa_hndl->send_tmp_buf, + idx_size + sa_hndl->send_count[1] * sa_hndl->vtype_size, ccl_datatype_int8); - sa_handler->iter++; + sa_hndl->iter++; return ccl::status::success; } template ccl::status sparse_prepare_result_ring(const void* ctx) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; /* data should be returned as sorted in the result buffer */ - i_type* ibuf = (i_type*)(sa_handler->dst_buf); - v_type* vbuf = (v_type*)((i_type*)(sa_handler->dst_buf) + sa_handler->iv_map->size()); - std::vector tmp(vbuf, vbuf + sa_handler->iv_map->size() * sa_handler->val_dim_cnt); + i_type* ibuf = (i_type*)(sa_hndl->dst_buf); + v_type* vbuf = (v_type*)((i_type*)(sa_hndl->dst_buf) + sa_hndl->iv_map->size()); + std::vector tmp(vbuf, vbuf + sa_hndl->iv_map->size() * sa_hndl->val_dim_cnt); size_t idx_offset = 0; - for (auto& it : *sa_handler->iv_map) { + for (auto& it : *sa_hndl->iv_map) { ibuf[idx_offset] = it.first; std::copy(tmp.begin() + it.second[0], - tmp.begin() + it.second[0] + sa_handler->val_dim_cnt, - vbuf + idx_offset * sa_handler->val_dim_cnt); + tmp.begin() + it.second[0] + sa_hndl->val_dim_cnt, + vbuf + idx_offset * sa_hndl->val_dim_cnt); idx_offset++; } - *sa_handler->recv_icount = sa_handler->iv_map->size(); - *sa_handler->recv_vcount = *sa_handler->recv_icount * sa_handler->val_dim_cnt; + *sa_hndl->recv_icount = sa_hndl->iv_map->size(); + *sa_hndl->recv_vcount = *sa_hndl->recv_icount * sa_hndl->val_dim_cnt; - *sa_handler->recv_ibuf = sa_handler->dst_buf; - *sa_handler->recv_vbuf = - ((char*)sa_handler->dst_buf + sa_handler->itype_size * (*sa_handler->recv_icount)); + *sa_hndl->recv_ibuf = sa_hndl->dst_buf; + *sa_hndl->recv_vbuf = ((char*)sa_hndl->dst_buf + sa_hndl->itype_size * (*sa_hndl->recv_icount)); - sa_handler->iv_map->clear(); + sa_hndl->iv_map->clear(); return ccl::status::success; } ccl::status sparse_get_send_count_ring(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; size_t* cnt_ptr = (size_t*)field_ptr; - *cnt_ptr = sa_handler->send_count[0] * - (sa_handler->itype_size + sa_handler->val_dim_cnt * sa_handler->vtype_size); + *cnt_ptr = + sa_hndl->send_count[0] * (sa_hndl->itype_size + sa_hndl->val_dim_cnt * sa_hndl->vtype_size); return ccl::status::success; } ccl::status sparse_get_send_buf_ring(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; - buf_ptr->set(sa_handler->send_tmp_buf); + buf_ptr->set(sa_hndl->send_tmp_buf); return ccl::status::success; } ccl::status sparse_get_recv_count_ring(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; size_t* cnt_ptr = (size_t*)field_ptr; - size_t nnz = - sa_handler->recv_counts[(sa_handler->recv_from - sa_handler->iter + sa_handler->comm_size) % - sa_handler->comm_size]; + size_t nnz = sa_hndl->recv_counts[(sa_hndl->recv_from - sa_hndl->iter + sa_hndl->comm_size) % + sa_hndl->comm_size]; - *cnt_ptr = nnz * (sa_handler->itype_size + sa_handler->val_dim_cnt * sa_handler->vtype_size); + *cnt_ptr = nnz * (sa_hndl->itype_size + sa_hndl->val_dim_cnt * sa_hndl->vtype_size); return ccl::status::success; } ccl::status sparse_get_recv_buf_ring(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; - buf_ptr->set(sa_handler->recv_buf); + buf_ptr->set(sa_hndl->recv_buf); return ccl::status::success; } ccl::status sparse_set_max_buf_size_ring(const void* ctx) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; - size_t max_nnz = sa_handler->recv_counts[0]; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; + size_t max_nnz = sa_hndl->recv_counts[0]; - for (int i = 1; i < sa_handler->comm_size; i++) { - if (max_nnz < sa_handler->recv_counts[i]) { - max_nnz = sa_handler->recv_counts[i]; + for (int i = 1; i < sa_hndl->comm_size; i++) { + if (max_nnz < sa_hndl->recv_counts[i]) { + max_nnz = sa_hndl->recv_counts[i]; } } - size_t common_size_part = - sa_handler->itype_size + sa_handler->vtype_size * sa_handler->val_dim_cnt; + size_t common_size_part = sa_hndl->itype_size + sa_hndl->vtype_size * sa_hndl->val_dim_cnt; size_t max_size = max_nnz * common_size_part; - sa_handler->send_tmp_buf = sa_handler->sched->alloc_buffer(max_size).get_ptr(); - CCL_MEMCPY( - sa_handler->send_tmp_buf, sa_handler->dst_buf, sa_handler->dst_count[0] * common_size_part); - sa_handler->recv_buf = sa_handler->sched->alloc_buffer(max_size).get_ptr(); + sa_hndl->send_tmp_buf = sa_hndl->sched->alloc_buffer(max_size).get_ptr(); + CCL_MEMCPY(sa_hndl->send_tmp_buf, sa_hndl->dst_buf, sa_hndl->dst_count[0] * common_size_part); + sa_hndl->recv_buf = sa_hndl->sched->alloc_buffer(max_size).get_ptr(); return ccl::status::success; } template ccl::status sparse_coalesce_ring(const void* ctx) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; - sparse_coalesce(sa_handler); + sparse_coalesce(sa_hndl); - size_t iv_map_cnt = sa_handler->iv_map->size(); + size_t iv_map_cnt = sa_hndl->iv_map->size(); - sa_handler->send_count[0] = iv_map_cnt; /* index count */ - sa_handler->send_count[1] = iv_map_cnt * sa_handler->val_dim_cnt; /* value count */ - CCL_MEMCPY(&sa_handler->dst_count, &sa_handler->send_count, sizeof(size_t) * 2); + sa_hndl->send_count[0] = iv_map_cnt; /* index count */ + sa_hndl->send_count[1] = iv_map_cnt * sa_hndl->val_dim_cnt; /* value count */ + CCL_MEMCPY(&sa_hndl->dst_count, &sa_hndl->send_count, sizeof(size_t) * 2); CCL_SPARSE_ALLREDUCE_IF_SINGLE_RANK(); return ccl::status::success; @@ -555,7 +543,7 @@ ccl::status ccl_coll_build_sparse_allreduce_ring(ccl_sched* sched, void** r_ind_buf = recv_ind_buf; void** r_val_buf = recv_val_buf; - ccl_sparse_allreduce_handler* sa_handler; + ccl_sparse_allreduce_handler* sa_hndl; CCL_SPARSE_ALLREDUCE_CREATE_HANDLER(); /* send from left to right (ring)*/ @@ -565,45 +553,44 @@ ccl::status ccl_coll_build_sparse_allreduce_ring(ccl_sched* sched, /* send to the right neighbour */ int send_to = (rank + 1) % comm_size; - sa_handler->recv_from = recv_from; - sa_handler->iter = 0; + sa_hndl->recv_from = recv_from; + sa_hndl->iter = 0; - sa_handler->recv_counts = + sa_hndl->recv_counts = static_cast(sched->alloc_buffer(sizeof(size_t) * comm_size).get_ptr()); - entry_factory::make_entry( - sched, sparse_coalesce_ring, sa_handler); + entry_factory::make_entry(sched, sparse_coalesce_ring, sa_hndl); sched->add_barrier(); if (comm_size > 1) { CCL_SPARSE_ALLREDUCE_ADD_NNZ_ENTRY(); - entry_factory::make_entry(sched, sparse_set_max_buf_size_ring, sa_handler); + entry_factory::make_entry(sched, sparse_set_max_buf_size_ring, sa_hndl); sched->add_barrier(); for (int i = 0; i < comm_size - 1; i++) { /* send local data to the right neighbour */ send_entry* se = entry_factory::make_entry( sched, ccl_buffer(), 0, ccl_datatype_int8, send_to, comm); - se->set_field_fn(sparse_get_send_buf_ring, sa_handler); - se->set_field_fn(sparse_get_send_count_ring, sa_handler); + se->set_field_fn(sparse_get_send_buf_ring, sa_hndl); + se->set_field_fn(sparse_get_send_count_ring, sa_hndl); /* receive data from the left neighbour */ recv_entry* re = entry_factory::make_entry( sched, ccl_buffer(), 0, ccl_datatype_int8, recv_from, comm); - re->set_field_fn(sparse_get_recv_buf_ring, sa_handler); - re->set_field_fn(sparse_get_recv_count_ring, sa_handler); + re->set_field_fn(sparse_get_recv_buf_ring, sa_hndl); + re->set_field_fn(sparse_get_recv_count_ring, sa_hndl); sched->add_barrier(); /* reduce data */ entry_factory::make_entry( - sched, sparse_reduce_ring, sa_handler); + sched, sparse_reduce_ring, sa_hndl); sched->add_barrier(); } /* copy all reduced data to recv_buf */ entry_factory::make_entry( - sched, sparse_prepare_result_ring, sa_handler); + sched, sparse_prepare_result_ring, sa_hndl); sched->add_barrier(); } @@ -612,132 +599,130 @@ ccl::status ccl_coll_build_sparse_allreduce_ring(ccl_sched* sched, template ccl::status sparse_create_matrix_mask(const void* ctx) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; - LOG_TRACE("sa_handler: ", - sa_handler, - ", sa_handler->recv_buf_count: ", - sa_handler->recv_buf_count, - ", sa_handler->recv_buf: ", - sa_handler->recv_buf); + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; + LOG_TRACE("sa_hndl: ", + sa_hndl, + ", sa_hndl->recv_buf_count: ", + sa_hndl->recv_buf_count, + ", sa_hndl->recv_buf: ", + sa_hndl->recv_buf); /* get rid of the duplicates in allgathered indices list */ - std::set idx_set( - static_cast(sa_handler->recv_buf), - static_cast(sa_handler->recv_buf) + sa_handler->recv_buf_count); + std::set idx_set(static_cast(sa_hndl->recv_buf), + static_cast(sa_hndl->recv_buf) + sa_hndl->recv_buf_count); /* create a matrix expanded with zeros for indices that are not present in the unique indices list specified for this very process */ - size_t value_line_size = sa_handler->vtype_size * sa_handler->val_dim_cnt; + size_t value_line_size = sa_hndl->vtype_size * sa_hndl->val_dim_cnt; size_t idx_cnt = idx_set.size(); size_t matrix_size = idx_cnt * value_line_size; v_type* matrix = static_cast(CCL_MALLOC(matrix_size, "matrix")); v_type* values = - (v_type*)((char*)(sa_handler->dst_buf) + sa_handler->itype_size * sa_handler->dst_count[0]); + (v_type*)((char*)(sa_hndl->dst_buf) + sa_hndl->itype_size * sa_hndl->dst_count[0]); - v_type mask_value = get_mask(sa_handler->op); + v_type mask_value = get_mask(sa_hndl->op); size_t idx_offset = 0; for (typename std::set::iterator it = idx_set.begin(); it != idx_set.end(); ++it) { - auto elem = sa_handler->iv_map->find(*it); - if (elem != sa_handler->iv_map->end()) { + auto elem = sa_hndl->iv_map->find(*it); + if (elem != sa_hndl->iv_map->end()) { /* copy values from dst_buf to matrix */ - CCL_MEMCPY(matrix + idx_offset * sa_handler->val_dim_cnt, + CCL_MEMCPY(matrix + idx_offset * sa_hndl->val_dim_cnt, values + elem->second[0], value_line_size); } else { /* no index was found locally, fill the line with mask */ - std::fill(matrix + idx_offset * sa_handler->val_dim_cnt, - matrix + (idx_offset + 1) * sa_handler->val_dim_cnt, + std::fill(matrix + idx_offset * sa_hndl->val_dim_cnt, + matrix + (idx_offset + 1) * sa_hndl->val_dim_cnt, mask_value); } idx_offset++; } - sa_handler->dst_buf = - sa_handler->sched - ->find_and_realloc_buffer(sa_handler->dst_buf, - idx_cnt * sa_handler->itype_size + matrix_size, - sa_handler->itype_size * sa_handler->dst_count[0] + - sa_handler->vtype_size * sa_handler->dst_count[1]) + sa_hndl->dst_buf = + sa_hndl->sched + ->find_and_realloc_buffer(sa_hndl->dst_buf, + idx_cnt * sa_hndl->itype_size + matrix_size, + sa_hndl->itype_size * sa_hndl->dst_count[0] + + sa_hndl->vtype_size * sa_hndl->dst_count[1]) .get_ptr(); ccl_comp_copy(matrix, - (char*)sa_handler->dst_buf + idx_cnt * sa_handler->itype_size, + (char*)sa_hndl->dst_buf + idx_cnt * sa_hndl->itype_size, matrix_size, ccl_datatype_int8); CCL_FREE(matrix); - sa_handler->iv_map->clear(); - std::copy(idx_set.begin(), idx_set.end(), (i_type*)(sa_handler->dst_buf)); + sa_hndl->iv_map->clear(); + std::copy(idx_set.begin(), idx_set.end(), (i_type*)(sa_hndl->dst_buf)); - *sa_handler->recv_icount = idx_cnt; - *sa_handler->recv_vcount = idx_cnt * sa_handler->val_dim_cnt; + *sa_hndl->recv_icount = idx_cnt; + *sa_hndl->recv_vcount = idx_cnt * sa_hndl->val_dim_cnt; - *sa_handler->recv_ibuf = sa_handler->dst_buf; - *sa_handler->recv_vbuf = ((char*)sa_handler->dst_buf + sa_handler->itype_size * idx_cnt); + *sa_hndl->recv_ibuf = sa_hndl->dst_buf; + *sa_hndl->recv_vbuf = ((char*)sa_hndl->dst_buf + sa_hndl->itype_size * idx_cnt); return ccl::status::success; } ccl::status sparse_get_allreduce_buf_mask(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; - buf_ptr->set(*sa_handler->recv_vbuf); + buf_ptr->set(*sa_hndl->recv_vbuf); return ccl::status::success; } ccl::status sparse_get_allreduce_count_mask(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; size_t* cnt_ptr = (size_t*)field_ptr; - *cnt_ptr = *sa_handler->recv_vcount; + *cnt_ptr = *sa_hndl->recv_vcount; return ccl::status::success; } ccl::status sparse_nnz_per_rank_mask(const void* ctx) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; - sa_handler->recv_buf_count = 0; - for (int i = 0; i < sa_handler->comm_size; i++) { - sa_handler->recv_buf_count += sa_handler->recv_counts[i]; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; + sa_hndl->recv_buf_count = 0; + for (int i = 0; i < sa_hndl->comm_size; i++) { + sa_hndl->recv_buf_count += sa_hndl->recv_counts[i]; } - sa_handler->recv_buf = - sa_handler->sched->alloc_buffer(sa_handler->itype_size * sa_handler->recv_buf_count) - .get_ptr(); + sa_hndl->recv_buf = + sa_hndl->sched->alloc_buffer(sa_hndl->itype_size * sa_hndl->recv_buf_count).get_ptr(); return ccl::status::success; } ccl::status sparse_get_allgatherv_buf_mask(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; - buf_ptr->set(sa_handler->recv_buf); + buf_ptr->set(sa_hndl->recv_buf); return ccl::status::success; } ccl::status sparse_get_send_buf_mask(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; - buf_ptr->set(sa_handler->dst_buf); + buf_ptr->set(sa_hndl->dst_buf); return ccl::status::success; } ccl::status sparse_get_send_count_mask(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; size_t* count = (size_t*)field_ptr; - *count = sa_handler->dst_count[0]; + *count = sa_hndl->dst_count[0]; return ccl::status::success; } template ccl::status sparse_coalesce_mask(const void* ctx) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; - sparse_coalesce(sa_handler); + sparse_coalesce(sa_hndl); - size_t iv_map_cnt = sa_handler->iv_map->size(); + size_t iv_map_cnt = sa_hndl->iv_map->size(); - sa_handler->dst_count[0] = iv_map_cnt; - sa_handler->dst_count[1] = iv_map_cnt * sa_handler->val_dim_cnt; + sa_hndl->dst_count[0] = iv_map_cnt; + sa_hndl->dst_count[1] = iv_map_cnt * sa_hndl->val_dim_cnt; CCL_SPARSE_ALLREDUCE_IF_SINGLE_RANK(); return ccl::status::success; @@ -774,20 +759,19 @@ ccl::status ccl_coll_build_sparse_allreduce_mask(ccl_sched* sched, void** r_ind_buf = recv_ind_buf; void** r_val_buf = recv_val_buf; - ccl_sparse_allreduce_handler* sa_handler; + ccl_sparse_allreduce_handler* sa_hndl; CCL_SPARSE_ALLREDUCE_CREATE_HANDLER(); - sa_handler->recv_counts = + sa_hndl->recv_counts = static_cast(sched->alloc_buffer(sizeof(size_t) * comm_size).get_ptr()); - entry_factory::make_entry( - sched, sparse_coalesce_mask, sa_handler); + entry_factory::make_entry(sched, sparse_coalesce_mask, sa_hndl); sched->add_barrier(); if (comm_size > 1) { CCL_SPARSE_ALLREDUCE_ADD_NNZ_ENTRY(); - entry_factory::make_entry(sched, sparse_nnz_per_rank_mask, sa_handler); + entry_factory::make_entry(sched, sparse_nnz_per_rank_mask, sa_hndl); sched->add_barrier(); ccl_coll_entry_param param_allgatherv{}; @@ -795,19 +779,19 @@ ccl::status ccl_coll_build_sparse_allreduce_mask(ccl_sched* sched, param_allgatherv.send_buf = ccl_buffer(); param_allgatherv.recv_buf = ccl_buffer(); param_allgatherv.send_count = 0; - param_allgatherv.recv_counts = sa_handler->recv_counts; + param_allgatherv.recv_counts = sa_hndl->recv_counts; param_allgatherv.dtype = index_dtype; param_allgatherv.comm = comm; /* gather indices from all the processes */ coll_entry* e = entry_factory::make_entry(sched, param_allgatherv); - e->set_field_fn(sparse_get_send_buf_mask, sa_handler); - e->set_field_fn(sparse_get_allgatherv_buf_mask, sa_handler); - e->set_field_fn(sparse_get_send_count_mask, sa_handler); + e->set_field_fn(sparse_get_send_buf_mask, sa_hndl); + e->set_field_fn(sparse_get_allgatherv_buf_mask, sa_hndl); + e->set_field_fn(sparse_get_send_count_mask, sa_hndl); sched->add_barrier(); entry_factory::make_entry( - sched, sparse_create_matrix_mask, sa_handler); + sched, sparse_create_matrix_mask, sa_hndl); sched->add_barrier(); ccl_coll_entry_param param_allreduce{}; @@ -821,9 +805,9 @@ ccl::status ccl_coll_build_sparse_allreduce_mask(ccl_sched* sched, /* coll allreduce on matrix data */ coll_entry* ce = entry_factory::make_entry(sched, param_allreduce); - ce->set_field_fn(sparse_get_allreduce_buf_mask, sa_handler); - ce->set_field_fn(sparse_get_allreduce_buf_mask, sa_handler); - ce->set_field_fn(sparse_get_allreduce_count_mask, sa_handler); + ce->set_field_fn(sparse_get_allreduce_buf_mask, sa_hndl); + ce->set_field_fn(sparse_get_allreduce_buf_mask, sa_hndl); + ce->set_field_fn(sparse_get_allreduce_count_mask, sa_hndl); sched->add_barrier(); } @@ -831,119 +815,117 @@ ccl::status ccl_coll_build_sparse_allreduce_mask(ccl_sched* sched, } ccl::status sparse_alloc_result_buf_allgatherv(const void* ctx) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; - sa_handler->recv_buf_count = 0; - for (int i = 0; i < sa_handler->comm_size; i++) { - sa_handler->recv_buf_count += sa_handler->recv_counts[i]; + sa_hndl->recv_buf_count = 0; + for (int i = 0; i < sa_hndl->comm_size; i++) { + sa_hndl->recv_buf_count += sa_hndl->recv_counts[i]; } LOG_TRACE("sa_handle: ", - sa_handler, + sa_hndl, ",allocate all buffers - indices size: ", - sa_handler->recv_buf_count * sa_handler->itype_size, + sa_hndl->recv_buf_count * sa_hndl->itype_size, ", values size: ", - sa_handler->recv_buf_count * sa_handler->vtype_size * sa_handler->val_dim_cnt, - ", sa_handler->recv_counts: ", - sa_handler->recv_counts); + sa_hndl->recv_buf_count * sa_hndl->vtype_size * sa_hndl->val_dim_cnt, + ", sa_hndl->recv_counts: ", + sa_hndl->recv_counts); - ccl_sched* sched = sa_handler->sched; + ccl_sched* sched = sa_hndl->sched; if (sched->coll_attr.sparse_coalesce_mode == ccl::sparse_coalesce_mode::disable && sched->coll_attr.sparse_allreduce_alloc_fn) { /* with coalesce_disable the final buffers are allocated here, so use alloc_fn */ - sched->coll_attr.sparse_allreduce_alloc_fn( - sa_handler->recv_buf_count, - sa_handler->index_dtype.idx(), - sa_handler->recv_buf_count * sa_handler->val_dim_cnt, - sa_handler->value_dtype.idx(), - sched->coll_attr.sparse_allreduce_fn_ctx, - &sa_handler->all_idx_buf, - &sa_handler->all_val_buf); + sched->coll_attr.sparse_allreduce_alloc_fn(sa_hndl->recv_buf_count, + sa_hndl->index_dtype.idx(), + sa_hndl->recv_buf_count * sa_hndl->val_dim_cnt, + sa_hndl->value_dtype.idx(), + sched->coll_attr.sparse_allreduce_fn_ctx, + &sa_hndl->all_idx_buf, + &sa_hndl->all_val_buf); } else { - sa_handler->all_idx_buf = - sched->alloc_buffer(sa_handler->recv_buf_count * sa_handler->itype_size).get_ptr(); - sa_handler->all_val_buf = + sa_hndl->all_idx_buf = + sched->alloc_buffer(sa_hndl->recv_buf_count * sa_hndl->itype_size).get_ptr(); + sa_hndl->all_val_buf = sched - ->alloc_buffer(sa_handler->recv_buf_count * sa_handler->vtype_size * - sa_handler->val_dim_cnt) + ->alloc_buffer(sa_hndl->recv_buf_count * sa_hndl->vtype_size * sa_hndl->val_dim_cnt) .get_ptr(); } - CCL_THROW_IF_NOT(sa_handler->all_idx_buf && sa_handler->all_val_buf); + CCL_THROW_IF_NOT(sa_hndl->all_idx_buf && sa_hndl->all_val_buf); return ccl::status::success; } template ccl::status sparse_set_v_counts_allgatherv(const void* ctx) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; - size_t stride = stride_per_comm * sa_handler->comm_size; - for (int i = 0; i < sa_handler->comm_size; i++) { - sa_handler->recv_counts[i + stride] = sa_handler->recv_counts[i] * sa_handler->val_dim_cnt; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; + size_t stride = stride_per_comm * sa_hndl->comm_size; + for (int i = 0; i < sa_hndl->comm_size; i++) { + sa_hndl->recv_counts[i + stride] = sa_hndl->recv_counts[i] * sa_hndl->val_dim_cnt; } return ccl::status::success; } ccl::status sparse_return_gathered_allgatherv(const void* ctx) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; - *sa_handler->recv_icount = sa_handler->recv_buf_count; - *sa_handler->recv_vcount = sa_handler->recv_buf_count * sa_handler->val_dim_cnt; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; + *sa_hndl->recv_icount = sa_hndl->recv_buf_count; + *sa_hndl->recv_vcount = sa_hndl->recv_buf_count * sa_hndl->val_dim_cnt; - *sa_handler->recv_ibuf = sa_handler->all_idx_buf; - *sa_handler->recv_vbuf = sa_handler->all_val_buf; + *sa_hndl->recv_ibuf = sa_hndl->all_idx_buf; + *sa_hndl->recv_vbuf = sa_hndl->all_val_buf; return ccl::status::success; } template ccl::status sparse_reduce_gathered_allgatherv(const void* ctx) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; - i_type* indices = static_cast(sa_handler->all_idx_buf); - v_type* values = static_cast(sa_handler->all_val_buf); + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; + i_type* indices = static_cast(sa_hndl->all_idx_buf); + v_type* values = static_cast(sa_hndl->all_val_buf); std::unique_ptr iv_map(new idx_offset_map); - for (size_t i = 0; i < sa_handler->recv_buf_count; i++) { + for (size_t i = 0; i < sa_hndl->recv_buf_count; i++) { auto it = iv_map->find(indices[i]); if (it == iv_map->end()) { - std::vector tmp = { i * sa_handler->val_dim_cnt }; + std::vector tmp = { i * sa_hndl->val_dim_cnt }; tmp.reserve(CCL_COALESCE_RESERVE_SIZE); iv_map->emplace(indices[i], tmp); } else { - it->second.push_back(i * sa_handler->val_dim_cnt); + it->second.push_back(i * sa_hndl->val_dim_cnt); } } size_t idx_cnt = iv_map->size(); - size_t i_new_size = sa_handler->itype_size * idx_cnt; - size_t v_new_size = sa_handler->vtype_size * idx_cnt * sa_handler->val_dim_cnt; + size_t i_new_size = sa_hndl->itype_size * idx_cnt; + size_t v_new_size = sa_hndl->vtype_size * idx_cnt * sa_hndl->val_dim_cnt; i_type* i_recv = nullptr; v_type* v_recv = nullptr; - ccl_sched* sched = sa_handler->sched; + ccl_sched* sched = sa_hndl->sched; if (sched->coll_attr.sparse_allreduce_alloc_fn) { sched->coll_attr.sparse_allreduce_alloc_fn(idx_cnt, - sa_handler->index_dtype.idx(), - idx_cnt * sa_handler->val_dim_cnt, - sa_handler->value_dtype.idx(), + sa_hndl->index_dtype.idx(), + idx_cnt * sa_hndl->val_dim_cnt, + sa_hndl->value_dtype.idx(), sched->coll_attr.sparse_allreduce_fn_ctx, - &sa_handler->dst_ibuf, - &sa_handler->dst_vbuf); + &sa_hndl->dst_ibuf, + &sa_hndl->dst_vbuf); - i_recv = (i_type*)sa_handler->dst_ibuf; - v_recv = (v_type*)sa_handler->dst_vbuf; + i_recv = (i_type*)sa_hndl->dst_ibuf; + v_recv = (v_type*)sa_hndl->dst_vbuf; } else { - sa_handler->dst_ibuf = sched->alloc_buffer(i_new_size).get_ptr(); - sa_handler->dst_vbuf = sched->alloc_buffer(v_new_size).get_ptr(); + sa_hndl->dst_ibuf = sched->alloc_buffer(i_new_size).get_ptr(); + sa_hndl->dst_vbuf = sched->alloc_buffer(v_new_size).get_ptr(); - i_recv = (i_type*)sa_handler->dst_ibuf; - v_recv = (v_type*)sa_handler->dst_vbuf; + i_recv = (i_type*)sa_hndl->dst_ibuf; + v_recv = (v_type*)sa_hndl->dst_vbuf; } CCL_THROW_IF_NOT(i_recv && v_recv); @@ -953,9 +935,9 @@ ccl::status sparse_reduce_gathered_allgatherv(const void* ctx) { for (auto& it : *iv_map) { i_recv[idx_offset] = it.first; - val_offset = idx_offset * sa_handler->val_dim_cnt; + val_offset = idx_offset * sa_hndl->val_dim_cnt; std::copy(values + it.second[0], - values + it.second[0] + sa_handler->val_dim_cnt, + values + it.second[0] + sa_hndl->val_dim_cnt, v_recv + val_offset); it.second[0] = val_offset; @@ -963,70 +945,70 @@ ccl::status sparse_reduce_gathered_allgatherv(const void* ctx) { if (it.second.size() > 1) { ccl_comp_batch_reduce(values, it.second, - sa_handler->val_dim_cnt, + sa_hndl->val_dim_cnt, v_recv + val_offset, nullptr, - sa_handler->value_dtype, - sa_handler->op, + sa_hndl->value_dtype, + sa_hndl->op, nullptr, nullptr, sched->coll_attr.sparse_coalesce_mode == ccl::sparse_coalesce_mode::keep_precision && - sa_handler->value_dtype.idx() == ccl::datatype::bfloat16, - sa_handler->tmp, - sa_handler->acc); + sa_hndl->value_dtype.idx() == ccl::datatype::bfloat16, + sa_hndl->tmp, + sa_hndl->acc); } idx_offset++; } iv_map->clear(); - *sa_handler->recv_icount = idx_cnt; - *sa_handler->recv_vcount = idx_cnt * sa_handler->val_dim_cnt; + *sa_hndl->recv_icount = idx_cnt; + *sa_hndl->recv_vcount = idx_cnt * sa_hndl->val_dim_cnt; - *sa_handler->recv_ibuf = i_recv; - *sa_handler->recv_vbuf = v_recv; + *sa_hndl->recv_ibuf = i_recv; + *sa_hndl->recv_vbuf = v_recv; return ccl::status::success; } ccl::status sparse_get_i_recv_allgatherv(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; - buf_ptr->set(sa_handler->all_idx_buf); + buf_ptr->set(sa_hndl->all_idx_buf); return ccl::status::success; } ccl::status sparse_get_i_send_allgatherv(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; - buf_ptr->set(sa_handler->dst_ibuf); + buf_ptr->set(sa_hndl->dst_ibuf); return ccl::status::success; } template ccl::status sparse_get_send_count_allgatherv(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; size_t* send_buf_count = (size_t*)field_ptr; - *send_buf_count = sa_handler->send_count[send_count_src_index]; + *send_buf_count = sa_hndl->send_count[send_count_src_index]; return ccl::status::success; } ccl::status sparse_get_v_recv_allgatherv(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; - buf_ptr->set(sa_handler->all_val_buf); + buf_ptr->set(sa_hndl->all_val_buf); return ccl::status::success; } ccl::status sparse_get_v_send_allgatherv(const void* ctx, void* field_ptr) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; - if (sa_handler->sched->coll_attr.sparse_coalesce_mode == ccl::sparse_coalesce_mode::disable) { - buf_ptr->set(sa_handler->send_vbuf); + if (sa_hndl->sched->coll_attr.sparse_coalesce_mode == ccl::sparse_coalesce_mode::disable) { + buf_ptr->set(sa_hndl->send_vbuf); } else { - buf_ptr->set(sa_handler->dst_vbuf); + buf_ptr->set(sa_hndl->dst_vbuf); } return ccl::status::success; @@ -1034,20 +1016,20 @@ ccl::status sparse_get_v_send_allgatherv(const void* ctx, void* field_ptr) { template ccl::status sparse_coalesce_allgatherv(const void* ctx) { - ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; + ccl_sparse_allreduce_handler* sa_hndl = (ccl_sparse_allreduce_handler*)ctx; - sparse_coalesce(sa_handler); + sparse_coalesce(sa_hndl); - size_t iv_map_cnt = sa_handler->iv_map->size(); - sa_handler->iv_map->clear(); - sa_handler->send_count[0] = iv_map_cnt; - sa_handler->send_count[1] = iv_map_cnt * sa_handler->val_dim_cnt; + size_t iv_map_cnt = sa_hndl->iv_map->size(); + sa_hndl->iv_map->clear(); + sa_hndl->send_count[0] = iv_map_cnt; + sa_hndl->send_count[1] = iv_map_cnt * sa_hndl->val_dim_cnt; - if (sa_handler->comm_size == 1) { - *sa_handler->recv_icount = iv_map_cnt; - *sa_handler->recv_vcount = iv_map_cnt * sa_handler->val_dim_cnt; - *sa_handler->recv_ibuf = sa_handler->dst_ibuf; - *sa_handler->recv_vbuf = sa_handler->dst_vbuf; + if (sa_hndl->comm_size == 1) { + *sa_hndl->recv_icount = iv_map_cnt; + *sa_hndl->recv_vcount = iv_map_cnt * sa_hndl->val_dim_cnt; + *sa_hndl->recv_ibuf = sa_hndl->dst_ibuf; + *sa_hndl->recv_vbuf = sa_hndl->dst_vbuf; } return ccl::status::success; @@ -1084,41 +1066,40 @@ ccl::status ccl_coll_build_sparse_allreduce_3_allgatherv(ccl_sched* sched, void** r_ind_buf = recv_ind_buf; void** r_val_buf = recv_val_buf; - ccl_sparse_allreduce_handler* sa_handler; + ccl_sparse_allreduce_handler* sa_hndl; CCL_SPARSE_ALLREDUCE_CREATE_HANDLER(); constexpr size_t parallel_requests_count = 2; //indices + values - sa_handler->recv_counts = static_cast( + sa_hndl->recv_counts = static_cast( sched->alloc_buffer(sizeof(size_t) * comm_size * parallel_requests_count).get_ptr()); - LOG_DEBUG("sa_handler: ", - sa_handler, - ", sa_handler->recv_ibuf: ", - sa_handler->recv_ibuf, - ", sa_handler->recv_vbuf: ", - sa_handler->recv_vbuf, - ", sa_handler->val_dim_cnt: ", - sa_handler->val_dim_cnt, - ", sa_handler->recv_counts: ", - sa_handler->recv_counts); + LOG_DEBUG("sa_hndl: ", + sa_hndl, + ", sa_hndl->recv_ibuf: ", + sa_hndl->recv_ibuf, + ", sa_hndl->recv_vbuf: ", + sa_hndl->recv_vbuf, + ", sa_hndl->val_dim_cnt: ", + sa_hndl->val_dim_cnt, + ", sa_hndl->recv_counts: ", + sa_hndl->recv_counts); if (sched->coll_attr.sparse_coalesce_mode != ccl::sparse_coalesce_mode::disable) { entry_factory::make_entry( - sched, sparse_coalesce_allgatherv, sa_handler); + sched, sparse_coalesce_allgatherv, sa_hndl); sched->add_barrier(); if (comm_size == 1) return status; } else { - sa_handler->dst_ibuf = sa_handler->send_ibuf; - sa_handler->dst_vbuf = sa_handler->send_vbuf; + sa_hndl->dst_ibuf = sa_hndl->send_ibuf; + sa_hndl->dst_vbuf = sa_hndl->send_vbuf; } CCL_SPARSE_ALLREDUCE_ADD_NNZ_ENTRY(); - entry_factory::make_entry( - sched, sparse_alloc_result_buf_allgatherv, sa_handler); + entry_factory::make_entry(sched, sparse_alloc_result_buf_allgatherv, sa_hndl); sched->add_barrier(); // allgather indices @@ -1128,16 +1109,16 @@ ccl::status ccl_coll_build_sparse_allreduce_3_allgatherv(ccl_sched* sched, param_i.send_buf = ccl_buffer(); param_i.recv_buf = ccl_buffer(); param_i.send_count = 0; - param_i.recv_counts = sa_handler->recv_counts; + param_i.recv_counts = sa_hndl->recv_counts; param_i.dtype = index_dtype; param_i.comm = comm; coll_entry* ce = entry_factory::make_entry(sched, param_i, parallel_request_index); - ce->set_field_fn(sparse_get_i_send_allgatherv, sa_handler); - ce->set_field_fn(sparse_get_i_recv_allgatherv, sa_handler); + ce->set_field_fn(sparse_get_i_send_allgatherv, sa_hndl); + ce->set_field_fn(sparse_get_i_recv_allgatherv, sa_hndl); ce->set_field_fn(sparse_get_send_count_allgatherv<0>, - sa_handler); - entry_factory::make_entry(sched, sparse_set_v_counts_allgatherv<1>, sa_handler); + sa_hndl); + entry_factory::make_entry(sched, sparse_set_v_counts_allgatherv<1>, sa_hndl); // allgather values parallel_request_index++; @@ -1146,24 +1127,24 @@ ccl::status ccl_coll_build_sparse_allreduce_3_allgatherv(ccl_sched* sched, param_v.send_buf = ccl_buffer(); param_v.recv_buf = ccl_buffer(); param_v.send_count = 0; - param_v.recv_counts = &sa_handler->recv_counts[comm_size]; + param_v.recv_counts = &sa_hndl->recv_counts[comm_size]; param_v.dtype = value_dtype; param_v.comm = comm; ce = entry_factory::make_entry(sched, param_v, parallel_request_index); - ce->set_field_fn(sparse_get_v_send_allgatherv, sa_handler); - ce->set_field_fn(sparse_get_v_recv_allgatherv, sa_handler); + ce->set_field_fn(sparse_get_v_send_allgatherv, sa_hndl); + ce->set_field_fn(sparse_get_v_recv_allgatherv, sa_hndl); ce->set_field_fn(sparse_get_send_count_allgatherv<1>, - sa_handler); + sa_hndl); sched->add_barrier(); if (sched->coll_attr.sparse_coalesce_mode == ccl::sparse_coalesce_mode::disable) { entry_factory::make_entry( - sched, sparse_return_gathered_allgatherv, sa_handler); + sched, sparse_return_gathered_allgatherv, sa_hndl); } else { entry_factory::make_entry( - sched, sparse_reduce_gathered_allgatherv, sa_handler); + sched, sparse_reduce_gathered_allgatherv, sa_hndl); } sched->add_barrier(); diff --git a/src/coll/coll.cpp b/src/coll/coll.cpp index fe1622f9e..0dca06c03 100644 --- a/src/coll/coll.cpp +++ b/src/coll/coll.cpp @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include - #include "oneapi/ccl/types.hpp" #include "oneapi/ccl/aliases.hpp" @@ -47,6 +45,7 @@ #include "coll/ccl_reduce_op_attr.hpp" #include "coll/ccl_reduce_scatter_op_attr.hpp" #include "coll/ccl_sparse_allreduce_op_attr.hpp" +#include "coll/coll_check.hpp" #include "coll/coll_param.hpp" #include "common/global/global.hpp" @@ -60,193 +59,29 @@ #include "fusion/fusion.hpp" #include "unordered_coll/unordered_coll.hpp" -#define COPY_COMMON_OP_ATTRS(from, to) \ - to->prologue_fn = nullptr; /*from.get().get();*/ \ - to->epilogue_fn = nullptr; /*from.get().get();*/ \ - to->priority = from.get(); \ - to->synchronous = from.get(); \ - to->to_cache = (from.get().length()) \ - ? from.get() \ - : false; \ - to->match_id = from.get(); \ - if (to->to_cache != from.get()) \ - LOG_INFO("collective caching is requested but no match_id is provided, disable caching"); - -//TODO temporary solution for type convertation, ccl_coll_attr would be depreacated -ccl_coll_attr::ccl_coll_attr(const ccl::allgatherv_attr& attr) { - COPY_COMMON_OP_ATTRS(attr, this); -} - -ccl_coll_attr::ccl_coll_attr(const ccl::allreduce_attr& attr) { - COPY_COMMON_OP_ATTRS(attr, this); - - reduction_fn = attr.get().get(); -} - -ccl_coll_attr::ccl_coll_attr(const ccl::alltoall_attr& attr) { - COPY_COMMON_OP_ATTRS(attr, this); -} - -ccl_coll_attr::ccl_coll_attr(const ccl::alltoallv_attr& attr) { - COPY_COMMON_OP_ATTRS(attr, this); -} - -ccl_coll_attr::ccl_coll_attr(const ccl::barrier_attr& attr) { - COPY_COMMON_OP_ATTRS(attr, this); -} - -ccl_coll_attr::ccl_coll_attr(const ccl::broadcast_attr& attr) { - COPY_COMMON_OP_ATTRS(attr, this); -} - -ccl_coll_attr::ccl_coll_attr(const ccl::reduce_attr& attr) { - COPY_COMMON_OP_ATTRS(attr, this); - - reduction_fn = attr.get().get(); -} - -ccl_coll_attr::ccl_coll_attr(const ccl::reduce_scatter_attr& attr) { - COPY_COMMON_OP_ATTRS(attr, this); - - reduction_fn = attr.get().get(); -} - -ccl_coll_attr::ccl_coll_attr(const ccl::sparse_allreduce_attr& attr) { - COPY_COMMON_OP_ATTRS(attr, this); - - sparse_allreduce_completion_fn = attr.get().get(); - sparse_allreduce_alloc_fn = attr.get().get(); - sparse_allreduce_fn_ctx = attr.get(); - sparse_coalesce_mode = attr.get(); -} - -static void ccl_coll_validate_and_adjust(const ccl_coll_param& param) { - // not SYCL, don't need validation - if (param.stream == nullptr) { - return; - } - - // skip validation if it was requested explicitly (e.g. for sycl::buffer) - if (param.skip_validation) { - return; - } - -#ifdef CCL_ENABLE_SYCL - std::vector bufs = {}; - - switch (param.ctype) { - case ccl_coll_alltoallv: { - // if the sum of the counts is 0 this means that the buf pointer could be anything, - // including nullptr and invalid pointer. We should neither validate nor dereference it. - // TODO: make const void* - if (std::accumulate(param.send_counts, param.send_counts + param.comm->size(), 0) > 0) { - bufs.push_back((void*)(param.send_buf)); - } - - if (std::accumulate(param.recv_counts, param.recv_counts + param.comm->size(), 0) > 0) { - bufs.push_back((void*)(param.recv_buf)); - } - break; - } - case ccl_coll_allreduce: - case ccl_coll_allgatherv: - case ccl_coll_alltoall: - case ccl_coll_reduce: - case ccl_coll_reduce_scatter: - bufs = { (void*)param.send_buf, (void*)param.recv_buf }; - break; - case ccl_coll_bcast: bufs = { (void*)param.recv_buf }; break; - case ccl_coll_sparse_allreduce: - bufs = { (void*)param.sparse_param.send_ind_buf, - (void*)param.sparse_param.send_val_buf, - (void*)param.sparse_param.recv_ind_buf, - (void*)param.sparse_param.recv_val_buf }; - break; - default: - // everything that is not a collective, i.e. barrier doesn't require validation - return; - } - - auto q = param.stream->get_native_stream(); - CCL_THROW_IF_NOT( - native::detail::check_assoc_device_memory(bufs, q.get_device(), q.get_context()) != - native::detail::usm_support_mode::prohibited, - "unsupported usm type"); -#endif /* CCL_ENABLE_SYCL */ -} - /* param is not const because param.comm can be updated for unordered colls */ -static ccl_request* ccl_coll_create(ccl_coll_param& param, const ccl_coll_attr& attr) { - // perform a validation and adjustion if necessary - ccl_coll_validate_and_adjust(param); - - ccl::global_data& data = ccl::global_data::get(); - - /* 1. decide whether schedule should be postponed (this includes caching and staring) */ - bool postpone_schedule = false; - if (ccl::global_data::env().enable_unordered_coll) { - if (!attr.match_id.empty()) { - auto comm = - param.comm->unordered_coll_manager->get_comm(std::string(attr.match_id)).get(); - if (!comm) { - if (attr.synchronous) { - CCL_THROW("unsupported collective (synchronous && unordered && !communicator)"); - } - LOG_DEBUG("didn't find comm for match_id ", attr.match_id, ", postpone schedule"); - postpone_schedule = true; - } - else { - LOG_DEBUG("found comm ", comm->id(), " for match_id ", attr.match_id); - param.comm = comm; - } - } - else { - /* use comm provided by user, it is ordered collective */ - } - } - - /* 2. create or get schedule */ - ccl_master_sched* sched = ccl_master_sched::create(param, attr); - - /* 3. fuse schedule */ - if (!postpone_schedule && ccl::global_data::env().enable_fusion) { - if (data.fusion_manager->add(sched)) { - LOG_DEBUG("sched ", - sched, - ", ctype ", - ccl_coll_type_to_str(sched->coll_param.ctype), - " will be fused"); - return sched; - } - } - - /* 4. parallelize schedule */ - sched->commit(data.parallelizer.get()); +static ccl_request* ccl_coll_create(ccl_coll_param& param, const ccl_coll_attr& in_attr) { + ccl_coll_attr& attr = const_cast(in_attr); - /* 5. postpone unordered coll schedule */ - if (postpone_schedule) { - /* - user has provided match_id that has not been resolved yet. - schedule will be postponed until comm resolution - */ - return param.comm->unordered_coll_manager->postpone(sched); - } +#ifdef CCL_ENABLE_SYCL + if (ccl::global_data::env().enable_op_sync) + attr.synchronous = 1; +#endif // CCL_ENABLE_SYCL - /* 6. regular schedule execution */ - ccl_request* request = sched->start(data.executor.get()); - if (sched->coll_attr.synchronous) { - ccl_wait_impl(data.executor.get(), request); - request = nullptr; - } + LOG_DEBUG("\n{\n", + " param: ", + param.to_string(), + "\n" + " attr: ", + attr.to_string(), + "\n" + "}"); - return request; -} + ccl_coll_validate_user_input(param, attr); -//TODO duplicated code - make `ccl_coll_create` templated -static ccl_request* ccl_gpu_coll_create(ccl_coll_param& param, const ccl_coll_attr& attr) { ccl::global_data& data = ccl::global_data::get(); - /* 1. decide whether schedule should be postponed */ + /* 1. decide whether schedule should be postponed (this includes caching and starting) */ bool postpone_schedule = false; if (ccl::global_data::env().enable_unordered_coll) { if (!attr.match_id.empty()) { @@ -277,7 +112,7 @@ static ccl_request* ccl_gpu_coll_create(ccl_coll_param& param, const ccl_coll_at if (data.fusion_manager->add(sched)) { LOG_DEBUG("sched ", sched, - ", ctype ", + ", coll ", ccl_coll_type_to_str(sched->coll_param.ctype), " will be fused"); return sched; @@ -320,7 +155,8 @@ ccl::status ccl_coll_build_allgatherv(ccl_sched* sched, param.recv_counts = recv_counts; param.dtype = dtype; param.comm = comm; - param.vector_buf = sched->coll_attr.vector_buf; + param.is_vector_buf = sched->coll_attr.is_vector_buf; + param.hint_algo = sched->hint_algo; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -341,6 +177,7 @@ ccl::status ccl_coll_build_allgatherv(ccl_sched* sched, CCL_FATAL("unexpected allgatherv_algo ", ccl_coll_algorithm_to_str(algo)); return ccl::status::invalid_arguments; } + return status; } @@ -351,6 +188,7 @@ ccl::status ccl_coll_build_allreduce(ccl_sched* sched, const ccl_datatype& dtype, ccl::reduction reduction, ccl_comm* comm) { + CCL_ASSERT(sched != nullptr && comm != nullptr); ccl::status status = ccl::status::success; ccl_selector_param param; @@ -358,6 +196,12 @@ ccl::status ccl_coll_build_allreduce(ccl_sched* sched, param.count = count; param.dtype = dtype; param.comm = comm; + param.stream = sched->coll_param.stream; + param.buf = send_buf.get_ptr(); +#ifdef CCL_ENABLE_SYCL + param.is_sycl_buf = sched->coll_attr.is_sycl_buf; +#endif // CCL_ENABLE_SYCL + param.hint_algo = sched->hint_algo; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -401,6 +245,12 @@ ccl::status ccl_coll_build_allreduce(ccl_sched* sched, CCL_CALL(comm->allreduce_2d_builder->build( sched, send_buf, recv_buf, count, dtype, reduction)); break; +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + case ccl_coll_allreduce_topo_ring: + CCL_CALL(ccl_coll_build_gpu_allreduce( + sched, send_buf, recv_buf, count, dtype, reduction, comm)); + break; +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT default: CCL_FATAL("unexpected allreduce_algo ", ccl_coll_algorithm_to_str(algo)); return ccl::status::invalid_arguments; @@ -422,6 +272,7 @@ ccl::status ccl_coll_build_alltoall(ccl_sched* sched, param.count = count; param.dtype = dtype; param.comm = comm; + param.hint_algo = sched->hint_algo; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -450,6 +301,7 @@ ccl::status ccl_coll_build_alltoallv(ccl_sched* sched, param.ctype = ccl_coll_alltoallv; param.dtype = dtype; param.comm = comm; + param.hint_algo = sched->hint_algo; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -474,6 +326,7 @@ ccl::status ccl_coll_build_barrier(ccl_sched* sched, ccl_comm* comm) { param.count = 0; param.dtype = ccl_datatype_int8; param.comm = comm; + param.hint_algo = sched->hint_algo; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -503,6 +356,11 @@ ccl::status ccl_coll_build_bcast(ccl_sched* sched, param.count = count; param.dtype = dtype; param.comm = comm; + param.stream = sched->coll_param.stream; +#ifdef CCL_ENABLE_SYCL + param.is_sycl_buf = sched->coll_attr.is_sycl_buf; +#endif // CCL_ENABLE_SYCL + param.hint_algo = sched->hint_algo; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -529,6 +387,11 @@ ccl::status ccl_coll_build_bcast(ccl_sched* sched, case ccl_coll_bcast_naive: CCL_CALL(ccl_coll_build_naive_bcast(sched, buf, count, dtype, root, comm)); break; +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + case ccl_coll_bcast_topo_ring: + CCL_CALL(ccl_coll_build_gpu_bcast(sched, buf, count, dtype, root, comm)); + break; +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT default: CCL_FATAL("unexpected bcast_algo ", ccl_coll_algorithm_to_str(algo)); return ccl::status::invalid_arguments; @@ -551,6 +414,8 @@ ccl::status ccl_coll_build_reduce(ccl_sched* sched, param.count = count; param.dtype = dtype; param.comm = comm; + param.stream = sched->coll_param.stream; + param.hint_algo = sched->hint_algo; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -579,6 +444,12 @@ ccl::status ccl_coll_build_reduce(ccl_sched* sched, root == 0 ? comm->dtree() : comm->dtree().copy_with_new_root(root), comm)); break; +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + case ccl_coll_reduce_topo_ring: + CCL_CALL(ccl_coll_build_gpu_reduce( + sched, send_buf, recv_buf, count, dtype, reduction, root, comm)); + break; +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT default: CCL_FATAL("unexpected reduce_algo ", ccl_coll_algorithm_to_str(algo)); return ccl::status::invalid_arguments; @@ -602,6 +473,7 @@ ccl::status ccl_coll_build_reduce_scatter(ccl_sched* sched, param.count = count; param.dtype = dtype; param.comm = comm; + param.hint_algo = sched->hint_algo; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -752,20 +624,9 @@ ccl_request* ccl_allgatherv_impl(const void* send_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation) { - ccl_coll_param param{}; - - param.ctype = ccl_coll_allgatherv; - param.send_buf = send_buf; - param.recv_buf = recv_buf; - param.send_count = send_count; - param.recv_counts = recv_counts; - param.dtype = ccl::global_data::get().dtypes->get(dtype); - param.stream = stream; - param.comm = comm; - param.skip_validation = skip_validation; - copy_deps(deps, param.deps); + const std::vector& deps) { + ccl_coll_param param = ccl_coll_param::create_allgatherv_param( + send_buf, send_count, recv_buf, recv_counts, dtype, attr, comm, stream, deps); auto req = ccl_coll_create(param, attr); LOG_DEBUG("coll ", ccl_coll_type_to_str(param.ctype), " created, req ", req); @@ -780,20 +641,9 @@ ccl_request* ccl_allreduce_impl(const void* send_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation) { - ccl_coll_param param{}; - - param.ctype = ccl_coll_allreduce; - param.send_buf = send_buf; - param.recv_buf = recv_buf; - param.count = count; - param.dtype = ccl::global_data::get().dtypes->get(dtype); - param.reduction = reduction; - param.stream = stream; - param.comm = comm; - param.skip_validation = skip_validation; - copy_deps(deps, param.deps); + const std::vector& deps) { + ccl_coll_param param = ccl_coll_param::create_allreduce_param( + send_buf, recv_buf, count, dtype, reduction, attr, comm, stream, deps); auto req = ccl_coll_create(param, attr); LOG_DEBUG("coll ", ccl_coll_type_to_str(param.ctype), " created, req ", req, " count ", count); @@ -807,19 +657,9 @@ ccl_request* ccl_alltoall_impl(const void* send_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation) { - ccl_coll_param param{}; - - param.ctype = ccl_coll_alltoall; - param.send_buf = send_buf; - param.recv_buf = recv_buf; - param.count = count; - param.dtype = ccl::global_data::get().dtypes->get(dtype); - param.stream = stream; - param.comm = comm; - param.skip_validation = skip_validation; - copy_deps(deps, param.deps); + const std::vector& deps) { + ccl_coll_param param = ccl_coll_param::create_alltoall_param( + send_buf, recv_buf, count, dtype, attr, comm, stream, deps); auto req = ccl_coll_create(param, attr); LOG_DEBUG("coll ", ccl_coll_type_to_str(param.ctype), " created, req ", req, " count ", count); @@ -834,66 +674,19 @@ ccl_request* ccl_alltoallv_impl(const void* send_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation) { - ccl_coll_param param{}; - - param.ctype = ccl_coll_alltoallv; - param.send_buf = send_buf; - param.send_counts = send_counts; - param.recv_buf = recv_buf; - param.recv_counts = recv_counts; - param.dtype = ccl::global_data::get().dtypes->get(dtype); - param.stream = stream; - param.comm = comm; - param.skip_validation = skip_validation; - copy_deps(deps, param.deps); + const std::vector& deps) { + ccl_coll_param param = ccl_coll_param::create_alltoallv_param( + send_buf, send_counts, recv_buf, recv_counts, dtype, attr, comm, stream, deps); auto req = ccl_coll_create(param, attr); LOG_DEBUG("coll ", ccl_coll_type_to_str(param.ctype), " created, req ", req); return req; } -/* Unused function */ -ccl_request* ccl_allreduce_gpu_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl_coll_attr& attr, - ccl_comm* comm, - const ccl_stream* stream, - const std::vector& deps) { - ccl_coll_param param{}; - - param.ctype = ccl_coll_allreduce; - param.send_buf = send_buf; - param.recv_buf = recv_buf; - param.count = count; - param.dtype = ccl::global_data::get().dtypes->get(dtype); - param.reduction = reduction; - param.stream = stream; - param.comm = comm; - copy_deps(deps, param.deps); - - auto req = ccl_gpu_coll_create(param, attr); - LOG_DEBUG( - "GPU coll ", ccl_coll_type_to_str(param.ctype), " created, req ", req, " count ", count); - return req; -} - void ccl_barrier_impl(ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation) { - ccl_coll_param param{}; - - param.ctype = ccl_coll_barrier; - param.dtype = ccl_datatype_int8; - param.stream = stream; - param.comm = comm; - param.skip_validation = skip_validation; - copy_deps(deps, param.deps); + const std::vector& deps) { + ccl_coll_param param = ccl_coll_param::create_barrier_param(comm, stream, deps); ccl_coll_attr attr{}; attr.synchronous = 1; @@ -915,19 +708,9 @@ ccl_request* ccl_broadcast_impl(void* buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation) { - ccl_coll_param param{}; - - param.ctype = ccl_coll_bcast; - param.send_buf = param.recv_buf = buf; - param.count = count; - param.dtype = ccl::global_data::get().dtypes->get(dtype); - param.root = root; - param.stream = stream; - param.comm = comm; - param.skip_validation = skip_validation; - copy_deps(deps, param.deps); + const std::vector& deps) { + ccl_coll_param param = + ccl_coll_param::create_broadcast_param(buf, count, dtype, root, attr, comm, stream, deps); auto req = ccl_coll_create(param, attr); LOG_DEBUG("coll ", ccl_coll_type_to_str(param.ctype), " created, req ", req); @@ -943,21 +726,9 @@ ccl_request* ccl_reduce_impl(const void* send_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation) { - ccl_coll_param param{}; - - param.ctype = ccl_coll_reduce; - param.send_buf = send_buf; - param.recv_buf = recv_buf; - param.count = count; - param.dtype = ccl::global_data::get().dtypes->get(dtype); - param.reduction = reduction; - param.root = root; - param.stream = stream; - param.comm = comm; - param.skip_validation = skip_validation; - copy_deps(deps, param.deps); + const std::vector& deps) { + ccl_coll_param param = ccl_coll_param::create_reduce_param( + send_buf, recv_buf, count, dtype, reduction, root, attr, comm, stream, deps); auto req = ccl_coll_create(param, attr); LOG_DEBUG("coll ", ccl_coll_type_to_str(param.ctype), " created, req ", req); @@ -972,20 +743,9 @@ ccl_request* ccl_reduce_scatter_impl(const void* send_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation) { - ccl_coll_param param{}; - - param.ctype = ccl_coll_reduce_scatter; - param.send_buf = send_buf; - param.recv_buf = recv_buf; - param.count = recv_count; - param.dtype = ccl::global_data::get().dtypes->get(dtype); - param.reduction = reduction; - param.stream = stream; - param.comm = comm; - param.skip_validation = skip_validation; - copy_deps(deps, param.deps); + const std::vector& deps) { + ccl_coll_param param = ccl_coll_param::create_reduce_scatter_param( + send_buf, recv_buf, recv_count, dtype, reduction, attr, comm, stream, deps); auto req = ccl_coll_create(param, attr); LOG_DEBUG("coll ", ccl_coll_type_to_str(param.ctype), " created, req ", req); @@ -1006,8 +766,9 @@ ccl_request* ccl_sparse_allreduce_impl(const void* send_ind_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation) { + const std::vector& deps) { + CCL_THROW("unsupported path"); + ccl_coll_param param{}; param.ctype = ccl_coll_sparse_allreduce; @@ -1022,10 +783,9 @@ ccl_request* ccl_sparse_allreduce_impl(const void* send_ind_buf, param.dtype = ccl::global_data::get().dtypes->get(value_dtype); param.sparse_param.itype = ccl::global_data::get().dtypes->get(index_dtype); param.reduction = reduction; - param.stream = stream; + param.stream = (ccl_stream*)stream; param.comm = comm; - param.skip_validation = skip_validation; - copy_deps(deps, param.deps); + param.copy_deps(deps); ccl_coll_attr internal_attr(attr); internal_attr.to_cache = 0; /* skip to_cache flag, unsupported yet */ diff --git a/src/coll/coll.hpp b/src/coll/coll.hpp index 76601f01a..b455eefe7 100644 --- a/src/coll/coll.hpp +++ b/src/coll/coll.hpp @@ -109,8 +109,7 @@ ccl_request* ccl_allgatherv_impl(const void* send_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation = false); + const std::vector& deps); ccl_request* ccl_allreduce_impl(const void* send_buf, void* recv_buf, @@ -120,18 +119,7 @@ ccl_request* ccl_allreduce_impl(const void* send_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation = false); -template -ccl_request* ccl_allreduce_gpu_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl_coll_attr& attr, - ccl_comm* comm, - const ccl_stream* stream, - const std::vector& deps); + const std::vector& deps); ccl_request* ccl_alltoall_impl(const void* send_buf, void* recv_buf, @@ -140,8 +128,7 @@ ccl_request* ccl_alltoall_impl(const void* send_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation = false); + const std::vector& deps); ccl_request* ccl_alltoallv_impl(const void* send_buf, const size_t* send_counts, @@ -151,13 +138,11 @@ ccl_request* ccl_alltoallv_impl(const void* send_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation = false); + const std::vector& deps); void ccl_barrier_impl(ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation = false); + const std::vector& deps); ccl_request* ccl_broadcast_impl(void* buf, size_t count, @@ -166,8 +151,7 @@ ccl_request* ccl_broadcast_impl(void* buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation = false); + const std::vector& deps); ccl_request* ccl_reduce_impl(const void* send_buf, void* recv_buf, @@ -178,8 +162,7 @@ ccl_request* ccl_reduce_impl(const void* send_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation = false); + const std::vector& deps); ccl_request* ccl_reduce_scatter_impl(const void* send_buf, void* recv_buf, @@ -189,8 +172,7 @@ ccl_request* ccl_reduce_scatter_impl(const void* send_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation = false); + const std::vector& deps); ccl_request* ccl_sparse_allreduce_impl(const void* send_ind_buf, size_t send_ind_count, @@ -206,5 +188,4 @@ ccl_request* ccl_sparse_allreduce_impl(const void* send_ind_buf, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream, - const std::vector& deps, - bool skip_validation = false); + const std::vector& deps); diff --git a/src/coll/coll_check.cpp b/src/coll/coll_check.cpp new file mode 100644 index 000000000..86d2c3e40 --- /dev/null +++ b/src/coll/coll_check.cpp @@ -0,0 +1,138 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include +#include + +#include "coll/coll.hpp" +#include "coll/coll_check.hpp" +#include "common/env/env.hpp" +#include "common/global/global.hpp" +#include "common/utils/sycl_utils.hpp" + +#ifdef CCL_ENABLE_SYCL +void ccl_check_usm_pointers(const ccl_coll_param& param) { + auto bufs = param.get_all_non_zero_bufs(); + if (bufs.empty()) { + return; + } + + auto dev = param.stream->get_native_stream().get_device(); + auto ctx = param.stream->get_native_stream().get_context(); + + std::set usm_types; + for (size_t idx = 0; idx < bufs.size(); idx++) { + usm_types.insert(sycl::get_pointer_type(bufs[idx], ctx)); + } + + if (usm_types.size() != 1) { + auto first_usm_type = *usm_types.begin(); + auto second_usm_type = *(++usm_types.begin()); + CCL_THROW("coll: ", + ccl_coll_type_to_str(param.ctype), + ", mixed usm pointer types (", + ccl::utils::usm_type_to_str(first_usm_type), + ", ", + ccl::utils::usm_type_to_str(second_usm_type), + ") within single operation are not supported, ", + "device type: ", + ccl::utils::sycl_device_to_str(dev)); + } + + sycl::usm::alloc usm_type = *usm_types.begin(); + bool is_valid_type = true; + + if ((usm_type == sycl::usm::alloc::host) && (dev.is_gpu() || dev.is_accelerator())) + is_valid_type = false; + + if ((usm_type == sycl::usm::alloc::device) && !(dev.is_gpu() || dev.is_accelerator())) + is_valid_type = false; + + if (usm_type == sycl::usm::alloc::unknown) + is_valid_type = false; + + LOG_DEBUG("coll: ", + ccl_coll_type_to_str(param.ctype), + ", usm pointer type: ", + ccl::utils::usm_type_to_str(usm_type), + ", device type: ", + ccl::utils::sycl_device_to_str(dev)); + + CCL_THROW_IF_NOT(is_valid_type, + "coll: ", + ccl_coll_type_to_str(param.ctype), + " - invalid usm pointer type: ", + ccl::utils::usm_type_to_str(usm_type), + " for device type: ", + ccl::utils::sycl_device_to_str(dev)); +} +#endif // CCL_ENABLE_SYCL + +void ccl_coll_validate_user_input(const ccl_coll_param& param, const ccl_coll_attr& attr) { + CCL_THROW_IF_NOT(ccl::global_data::env().atl_transport == ccl_atl_ofi || !(attr.reduction_fn), + "custom reduction is supported for OFI transport only"); + + CCL_THROW_IF_NOT(ccl_datatype_storage::is_predefined_datatype(param.dtype.idx()) || + ccl::global_data::env().atl_transport == ccl_atl_ofi, + "custom datatype is supported for OFI transport only"); + + CCL_THROW_IF_NOT((param.ctype != ccl_coll_allreduce && param.ctype != ccl_coll_reduce && + param.ctype != ccl_coll_sparse_allreduce) || + ccl_datatype_storage::is_predefined_datatype(param.dtype.idx()) || + attr.reduction_fn, + "custom datatype requires custom reduction"); + + CCL_THROW_IF_NOT(param.ctype == ccl_coll_allreduce || + !(attr.prologue_fn || attr.epilogue_fn || attr.reduction_fn), + "prologue/epilogue/custom reduction is supported for allreduce only"); + + CCL_THROW_IF_NOT(param.ctype == ccl_coll_allgatherv || !(attr.is_vector_buf), + "vector buffer is supported for allgatherv only"); + + if (param.ctype == ccl_coll_sparse_allreduce) { + CCL_THROW_IF_NOT( + ccl::global_data::env().sparse_allreduce_algo_raw != "mask" || !(attr.reduction_fn), + "mask algorithm for sparse_allreduce does not support custom reduction"); + + CCL_THROW_IF_NOT( + (attr.sparse_allreduce_completion_fn || attr.sparse_allreduce_alloc_fn) && + !(reinterpret_cast(attr.sparse_allreduce_completion_fn) & + reinterpret_cast(attr.sparse_allreduce_alloc_fn)), + "sparse_allreduce requires completion callback only or allocation callback only"); + } + + if (param.ctype == ccl_coll_bcast || param.ctype == ccl_coll_reduce) { + CCL_THROW_IF_NOT(param.root < param.comm->size(), + "unexpected root ", + param.root, + ", comm size ", + param.comm->size()); + } + + if (param.stream) { +#ifdef CCL_ENABLE_SYCL + /* SYCL specific validation */ + + /* TODO: compare stream dev/ctx and comm dev/ctx */ + // sycl::device stream_dev = param.stream->get_native().get_context(); + // sycl::device stream_ctx = param.stream->get_native().get_device(); + + if (!attr.is_sycl_buf) { + /* check whether USM pointers have expected type */ + ccl_check_usm_pointers(param); + } +#endif // CCL_ENABLE_SYCL + } +} diff --git a/src/ccl_gpu_module.hpp b/src/coll/coll_check.hpp similarity index 59% rename from src/ccl_gpu_module.hpp rename to src/coll/coll_check.hpp index 80cc430a3..c7462f66b 100644 --- a/src/ccl_gpu_module.hpp +++ b/src/coll/coll_check.hpp @@ -15,12 +15,16 @@ */ #pragma once -#include "oneapi/ccl/types.hpp" -#include "coll/algorithms/algorithms_enum.hpp" -#include "internal_types.hpp" +#include "coll/coll_param.hpp" -#ifdef MULTI_GPU_SUPPORT -ccl::status load_gpu_module(const std::string& path, - ccl::device_topology_type topo_type, - ccl_coll_type coll_type); -#endif //MULTI_GPU_SUPPORT +#ifdef CCL_ENABLE_SYCL +#include +#endif // CCL_ENABLE_SYCL + +#ifdef CCL_ENABLE_SYCL +void ccl_check_usm_pointers(const std::vector& ptrs, + const sycl::device& dev, + const sycl::context& ctx); +#endif // CCL_ENABLE_SYCL + +void ccl_coll_validate_user_input(const ccl_coll_param& param, const ccl_coll_attr& attr); diff --git a/src/coll/coll_param.cpp b/src/coll/coll_param.cpp index a050cdfd1..d3cb9b54b 100644 --- a/src/coll/coll_param.cpp +++ b/src/coll/coll_param.cpp @@ -13,7 +13,66 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include + #include "coll/coll_param.hpp" +#include "common/global/global.hpp" + +#define COPY_COMMON_OP_ATTRS(from, to) \ + to->prologue_fn = nullptr; /*from.get().get();*/ \ + to->epilogue_fn = nullptr; /*from.get().get();*/ \ + to->priority = from.get(); \ + to->synchronous = from.get(); \ + to->to_cache = (from.get().length()) \ + ? from.get() \ + : false; \ + to->match_id = from.get(); \ + if (to->to_cache != from.get()) \ + LOG_INFO("collective caching is requested but no match_id is provided, disable caching"); + +ccl_coll_attr::ccl_coll_attr(const ccl::allgatherv_attr& attr) { + COPY_COMMON_OP_ATTRS(attr, this); +} + +ccl_coll_attr::ccl_coll_attr(const ccl::allreduce_attr& attr) { + COPY_COMMON_OP_ATTRS(attr, this); + + reduction_fn = attr.get().get(); +} + +ccl_coll_attr::ccl_coll_attr(const ccl::alltoall_attr& attr) { + COPY_COMMON_OP_ATTRS(attr, this); +} + +ccl_coll_attr::ccl_coll_attr(const ccl::alltoallv_attr& attr) { + COPY_COMMON_OP_ATTRS(attr, this); +} + +ccl_coll_attr::ccl_coll_attr(const ccl::barrier_attr& attr) { + COPY_COMMON_OP_ATTRS(attr, this); +} + +ccl_coll_attr::ccl_coll_attr(const ccl::broadcast_attr& attr) { + COPY_COMMON_OP_ATTRS(attr, this); +} + +ccl_coll_attr::ccl_coll_attr(const ccl::reduce_attr& attr) { + COPY_COMMON_OP_ATTRS(attr, this); + reduction_fn = attr.get().get(); +} + +ccl_coll_attr::ccl_coll_attr(const ccl::reduce_scatter_attr& attr) { + COPY_COMMON_OP_ATTRS(attr, this); + reduction_fn = attr.get().get(); +} + +ccl_coll_attr::ccl_coll_attr(const ccl::sparse_allreduce_attr& attr) { + COPY_COMMON_OP_ATTRS(attr, this); + sparse_allreduce_completion_fn = attr.get().get(); + sparse_allreduce_alloc_fn = attr.get().get(); + sparse_allreduce_fn_ctx = attr.get(); + sparse_coalesce_mode = attr.get(); +} bool operator==(const coll_param_gpu& lhs, const coll_param_gpu& rhs) { CCL_ASSERT((lhs.is_reduction() && rhs.is_reduction()) || @@ -29,40 +88,568 @@ bool operator==(const coll_param_gpu& lhs, const coll_param_gpu& rhs) { return res; } -void copy_deps(const std::vector& in, std::vector& out) { +std::string ccl_coll_attr::to_string() const { + std::stringstream ss; + + ss << "{ " + << "priority: " << priority << ", sync: " << synchronous << ", to_cache: " << to_cache + << ", match_id: " << (!match_id.empty() ? match_id : ""); + + if (is_vector_buf) { + ss << ", vector_buf"; + } + #ifdef CCL_ENABLE_SYCL - out.clear(); - for (size_t idx = 0; idx < in.size(); idx++) { - try { - auto sycl_event = in[idx].get_native(); - out.push_back(ccl::create_event(sycl_event)); - } - catch (ccl::exception&) { - } + if (is_sycl_buf) { + ss << ", sycl_buf"; } -#else /* CCL_ENABLE_SYCL */ - CCL_THROW_IF_NOT(in.size() == 0, "host deps are not supported yet"); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL + + ss << " }"; + + return ss.str(); +} + +ccl_coll_param::ccl_coll_param() { + ctype = ccl_coll_last_value; + send_bufs.reserve(1); + recv_bufs.reserve(1); + send_counts.reserve(1); + recv_counts.reserve(1); + stream = nullptr; + comm = nullptr; } ccl_coll_param::ccl_coll_param(const ccl_coll_param& other) { ctype = other.ctype; - send_buf = other.send_buf; - recv_buf = other.recv_buf; - count = other.count; - send_count = other.send_count; + send_bufs = other.send_bufs; + recv_bufs = other.recv_bufs; + device_send_bufs = other.device_send_bufs; + device_recv_bufs = other.device_recv_bufs; send_counts = other.send_counts; recv_counts = other.recv_counts; dtype = other.dtype; reduction = other.reduction; root = other.root; - stream = other.stream; - copy_deps(other.deps, deps); comm = other.comm; + stream = other.stream; + copy_deps(other.deps); sparse_param = other.sparse_param; + validate(); +} + +std::string ccl_coll_param::to_string() const { + std::stringstream ss; + + ss << "{ "; + ss << "coll: " << ccl_coll_type_to_str(ctype); + + if (!send_bufs.empty()) + ss << ", sb: " << get_send_buf() << ", sc: " << get_send_count(); + + if (!recv_bufs.empty()) + ss << ", rb: " << get_recv_buf() << ", rc: " << get_recv_count(); + + if (ctype != ccl_coll_barrier) + ss << ", dt: " << ccl::global_data::get().dtypes->name(dtype); + + if (ctype == ccl_coll_allreduce || ctype == ccl_coll_reduce || + ctype == ccl_coll_reduce_scatter) { + ss << ", rt: " << ccl_reduction_to_str(reduction); + } + + if (ctype == ccl_coll_bcast || ctype == ccl_coll_reduce) { + ss << ", root: " << root; + } + + ss << ", comm: "; + if (comm) + ss << "{ rank: " << comm->rank() << ", size: " << comm->size() << " }"; + else + ss << "null"; + +#ifdef CCL_ENABLE_SYCL + if (stream) + ss << ", stream: " << stream->to_string(); +#endif // CCL_ENABLE_SYCL + + if (!deps.empty()) + ss << ", deps: " << deps.size(); + + ss << " }"; + + return ss.str(); +} + +void* ccl_coll_param::get_send_buf(size_t idx, ccl_coll_param::buf_type type) const { + auto& vec = (type == ccl_coll_param::buf_type::regular) ? send_bufs : device_send_bufs; + CCL_THROW_IF_NOT(idx < vec.size(), "coll ", ctype, ", unexpected idx ", idx); + return vec[idx]; +} + +void* ccl_coll_param::get_recv_buf(size_t idx, ccl_coll_param::buf_type type) const { + auto& vec = (type == ccl_coll_param::buf_type::regular) ? recv_bufs : device_recv_bufs; + CCL_THROW_IF_NOT(idx < vec.size(), "coll ", ctype, ", unexpected idx ", idx); + return vec[idx]; +} + +void* ccl_coll_param::get_send_buf_ptr(size_t idx, ccl_coll_param::buf_type type) const { + auto& vec = (type == ccl_coll_param::buf_type::regular) ? send_bufs : device_send_bufs; + CCL_THROW_IF_NOT(idx < vec.size(), "coll ", ctype, ", unexpected idx ", idx); + void* res = (void*)(&vec[idx]); + return res; +} + +void* ccl_coll_param::get_recv_buf_ptr(size_t idx, ccl_coll_param::buf_type type) const { + auto& vec = (type == ccl_coll_param::buf_type::regular) ? recv_bufs : device_recv_bufs; + CCL_THROW_IF_NOT(idx < vec.size(), "coll ", ctype, ", unexpected idx ", idx); + void* res = (void*)(&vec[idx]); + return res; +} + +size_t ccl_coll_param::get_send_count(size_t idx) const { + CCL_THROW_IF_NOT(idx < send_counts.size(), "coll ", ctype, ", unexpected idx ", idx); + return send_counts[idx]; +} + +size_t ccl_coll_param::get_recv_count(size_t idx) const { + CCL_THROW_IF_NOT(idx < recv_counts.size(), "coll ", ctype, ", unexpected idx ", idx); + return recv_counts[idx]; +} + +bool ccl_coll_param::is_inplace(buf_type type) const { + if (ctype == ccl_coll_barrier || ctype == ccl_coll_bcast) { + return true; + } + + void* send_buf = get_send_buf(0, type); + void* recv_buf = get_recv_buf(0, type); + return (send_buf && (send_buf == recv_buf)) ? true : false; +} + +std::vector ccl_coll_param::get_all_non_zero_bufs() const { + std::vector bufs; + switch (ctype) { + case ccl_coll_alltoallv: { + /* + if the sum of the counts is 0 this means that the buf pointer could be anything, + including nullptr and invalid pointer + don't validate nor dereference it + */ + if (std::accumulate(send_counts.begin(), send_counts.end(), 0) > 0) { + bufs.push_back(get_send_buf()); + } + + if (std::accumulate(recv_counts.begin(), recv_counts.end(), 0) > 0) { + bufs.push_back(get_recv_buf()); + } + break; + } + case ccl_coll_allgatherv: { + if (get_send_count()) { + bufs.push_back(get_send_buf()); + } + + if (std::accumulate(recv_counts.begin(), recv_counts.end(), 0) > 0) { + if (recv_bufs.size() == 1) { + bufs.push_back(get_recv_buf()); + } + else { + for (size_t idx = 0; idx < recv_counts.size(); idx++) { + if (recv_counts[idx]) + bufs.push_back(get_recv_buf(idx)); + } + } + } + break; + } + case ccl_coll_allreduce: + case ccl_coll_alltoall: + case ccl_coll_bcast: + case ccl_coll_reduce: + case ccl_coll_reduce_scatter: + if (get_send_count()) { + bufs.push_back(get_send_buf()); + } + + if (get_recv_count()) { + bufs.push_back(get_recv_buf()); + } + break; + case ccl_coll_sparse_allreduce: + bufs = { (void*)sparse_param.send_ind_buf, + (void*)sparse_param.send_val_buf, + (void*)sparse_param.recv_ind_buf, + (void*)sparse_param.recv_val_buf }; + break; + default: break; + } + return bufs; +} + +void ccl_coll_param::validate() const { + if (ctype > ccl_coll_last_regular) { + return; + } + + LOG_TRACE("validate coll_param, ctype: ", ccl_coll_type_to_str(ctype)); + CCL_THROW_IF_NOT(!send_counts.empty(), "empty send_counts"); + CCL_THROW_IF_NOT(!recv_counts.empty(), "empty recv_counts"); + CCL_THROW_IF_NOT(comm, "null comm"); + + if (ctype == ccl_coll_barrier) { + return; + } + + CCL_THROW_IF_NOT(!send_bufs.empty(), "empty send_bufs"); + CCL_THROW_IF_NOT(!recv_bufs.empty(), "empty recv_bufs"); + + switch (ctype) { + case ccl_coll_alltoallv: { + CCL_THROW_IF_NOT( + (send_bufs.size() == 1) || (static_cast(send_bufs.size()) == comm->size()), + "send_bufs size ", + send_bufs.size(), + ", comm size ", + comm->size()); + + CCL_THROW_IF_NOT( + (recv_bufs.size() == 1) || (static_cast(recv_bufs.size()) == comm->size()), + "recv_bufs size ", + recv_bufs.size(), + ", comm size ", + comm->size()); + + CCL_THROW_IF_NOT(send_counts[comm->rank()] == recv_counts[comm->rank()], + "send_count[rank] ", + send_counts[comm->rank()], + ", recv_counts[rank] ", + recv_counts[comm->rank()]); + + if (send_counts.size() > 1) { + CCL_THROW_IF_NOT(static_cast(send_counts.size()) == comm->size(), + "send_counts size ", + send_counts.size(), + ", comm size ", + comm->size()); + } + + if (recv_counts.size() > 1) { + CCL_THROW_IF_NOT(static_cast(recv_counts.size()) == comm->size(), + "recv_counts size ", + recv_counts.size(), + ", comm size ", + comm->size()); + } + break; + } + case ccl_coll_allgatherv: { + CCL_THROW_IF_NOT( + (recv_bufs.size() == 1) || (static_cast(recv_bufs.size()) == comm->size()), + "recv_bufs size ", + recv_bufs.size(), + ", comm size ", + comm->size()); + + CCL_THROW_IF_NOT( + send_counts.size() == 1, "unexpected send_counts size ", send_counts.size()); + + CCL_THROW_IF_NOT(get_send_count() == recv_counts[comm->rank()], + "send_count ", + get_send_count(), + ", recv_counts[rank] ", + recv_counts[comm->rank()]); + + if (recv_counts.size() > 1) { + CCL_THROW_IF_NOT(static_cast(recv_counts.size()) == comm->size(), + "recv_counts size ", + recv_counts.size(), + ", comm size ", + comm->size()); + } + break; + } + case ccl_coll_allreduce: + case ccl_coll_alltoall: + case ccl_coll_bcast: + case ccl_coll_reduce: + case ccl_coll_reduce_scatter: + CCL_THROW_IF_NOT(send_bufs.size() == send_counts.size(), + "send_bufs size ", + send_bufs.size(), + ", send_counts size ", + send_counts.size()); + + CCL_THROW_IF_NOT(recv_bufs.size() == recv_counts.size(), + "recv_bufs size ", + recv_bufs.size(), + ", recv_counts size ", + recv_counts.size()); + + if (ctype == ccl_coll_bcast) { + CCL_THROW_IF_NOT(get_send_buf() == get_recv_buf(), + "send_buf ", + get_send_buf(), + ", recv_buf ", + get_recv_buf()); + } + + CCL_THROW_IF_NOT( + send_counts.size() == 1, "unexpected send_counts size ", send_counts.size()); + + if (ctype == ccl_coll_reduce_scatter) { + CCL_THROW_IF_NOT(get_send_count() == get_recv_count() * comm->size(), + "send_count ", + get_send_count(), + ", recv_count * comm_size ", + get_recv_count() * comm->size()); + } + else { + CCL_THROW_IF_NOT(get_send_count() == get_recv_count(), + "send_count ", + get_send_count(), + ", recv_count ", + get_recv_count()); + } + break; + default: break; + } +} + +// Optional extra event(from submit_barrier call) to add to our deps list +void ccl_coll_param::copy_deps(const std::vector& d, ccl::event* extra) { +#ifdef CCL_ENABLE_SYCL + deps.clear(); + for (size_t idx = 0; idx < d.size(); idx++) { + try { + auto sycl_event = d[idx].get_native(); + deps.push_back(ccl::create_event(sycl_event)); + } + catch (ccl::exception&) { + } + } + + if (extra) { + try { + auto sycl_event = extra->get_native(); + deps.push_back(ccl::create_event(sycl_event)); + } + catch (ccl::exception&) { + } + } +#else // CCL_ENABLE_SYCL + CCL_THROW_IF_NOT(d.size() == 0, "host deps are not supported yet"); +#endif // CCL_ENABLE_SYCL +} +void ccl_coll_param::set_common_fields(ccl::datatype d, + ccl_comm* c, + const ccl_stream* s, + const std::vector& ds) { + dtype = ccl::global_data::get().dtypes->get(d); + comm = c; + stream = (ccl_stream*)s; + + sync_deps(s, ds); +} + +// Submit a barrier if necessary to sync queue. The event from the barrier is added +// to other deps +void ccl_coll_param::sync_deps(const ccl_stream* s, const std::vector& ds) { #ifdef CCL_ENABLE_SYCL - device_send_buf = other.device_send_buf; - device_recv_buf = other.device_recv_buf; -#endif /* CCL_ENABLE_SYCL */ + // The main purpose of the barrier is to sync user's in-order queue with our out-of-order + // queue, so we don't execute anything before the user's tasks are completed. + // We don't really need anything like this for the case when user has out-of-order queue as + // there is no ordering requirement unless dependencies are explicitly provided and which we + // handle as well. + if (s != nullptr && s->is_sycl_device_stream() && s->get_native_stream().is_in_order()) { + // TODO: it would be nice to pass here all the dependencies as parameters to submit_barrier + // and get a single event to use later. Note: submit_barrier with empty event vector doesn't + // do anything and just return an empty event as opposed to submit_barrier without paramers + // which submits a full queue barrier. And there is a bug which leads to a crash if + // empty sycl event is passed to the function. + auto sycl_ev = s->get_native_stream().submit_barrier(); + auto e = ccl::create_event(sycl_ev); + copy_deps(ds, &e); + return; + } +#endif // CCL_ENABLE_SYCL + copy_deps(ds); +} + +ccl_coll_param ccl_coll_param::create_allgatherv_param(const void* send_buf, + size_t send_count, + void* recv_buf, + const size_t* recv_counts, + ccl::datatype dtype, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps) { + ccl_coll_param param; + + param.ctype = ccl_coll_allgatherv; + param.send_bufs.push_back((void*)send_buf); + param.send_counts.push_back(send_count); + if (attr.is_vector_buf) { + param.recv_bufs.assign((void**)recv_buf, (void**)recv_buf + comm->size()); + } + else { + param.recv_bufs.push_back(recv_buf); + } + param.recv_counts.assign((size_t*)recv_counts, (size_t*)recv_counts + comm->size()); + param.set_common_fields(dtype, comm, stream, deps); + param.validate(); + + return param; +} + +ccl_coll_param ccl_coll_param::create_allreduce_param(const void* send_buf, + void* recv_buf, + size_t count, + ccl::datatype dtype, + ccl::reduction reduction, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps) { + ccl_coll_param param; + + param.ctype = ccl_coll_allreduce; + param.send_bufs.push_back((void*)send_buf); + param.send_counts.push_back(count); + param.recv_bufs.push_back(recv_buf); + param.recv_counts.push_back(count); + param.reduction = reduction; + param.set_common_fields(dtype, comm, stream, deps); + param.validate(); + + return param; +} + +ccl_coll_param ccl_coll_param::create_alltoall_param(const void* send_buf, + void* recv_buf, + size_t count, + ccl::datatype dtype, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps) { + ccl_coll_param param; + + param.ctype = ccl_coll_alltoall; + param.send_bufs.push_back((void*)send_buf); + param.send_counts.push_back(count); + param.recv_bufs.push_back(recv_buf); + param.recv_counts.push_back(count); + param.set_common_fields(dtype, comm, stream, deps); + param.validate(); + + return param; +} + +ccl_coll_param ccl_coll_param::create_alltoallv_param(const void* send_buf, + const size_t* send_counts, + void* recv_buf, + const size_t* recv_counts, + ccl::datatype dtype, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps) { + ccl_coll_param param; + + param.ctype = ccl_coll_alltoallv; + param.send_bufs.push_back((void*)send_buf); + param.send_counts.assign((size_t*)send_counts, (size_t*)send_counts + comm->size()); + param.recv_bufs.push_back(recv_buf); + param.recv_counts.assign((size_t*)recv_counts, (size_t*)recv_counts + comm->size()); + param.set_common_fields(dtype, comm, stream, deps); + param.validate(); + + return param; +} + +ccl_coll_param ccl_coll_param::create_barrier_param(ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps) { + ccl_coll_param param; + + param.ctype = ccl_coll_barrier; + param.send_counts.push_back(0); + param.recv_counts.push_back(0); + param.set_common_fields(ccl::datatype::int8, comm, stream, deps); + param.validate(); + + return param; +} + +ccl_coll_param ccl_coll_param::create_broadcast_param(void* buf, + size_t count, + ccl::datatype dtype, + int root, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps) { + ccl_coll_param param; + + param.ctype = ccl_coll_bcast; + param.send_bufs.push_back(buf); + param.send_counts.push_back(count); + param.recv_bufs.push_back(buf); + param.recv_counts.push_back(count); + param.root = root; + param.set_common_fields(dtype, comm, stream, deps); + param.validate(); + + return param; +} + +ccl_coll_param ccl_coll_param::create_reduce_param(const void* send_buf, + void* recv_buf, + size_t count, + ccl::datatype dtype, + ccl::reduction reduction, + int root, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps) { + ccl_coll_param param; + + param.ctype = ccl_coll_reduce; + param.send_bufs.push_back((void*)send_buf); + param.send_counts.push_back(count); + param.recv_bufs.push_back(recv_buf); + param.recv_counts.push_back(count); + param.reduction = reduction; + param.root = root; + param.set_common_fields(dtype, comm, stream, deps); + param.validate(); + + return param; +} + +ccl_coll_param ccl_coll_param::create_reduce_scatter_param(const void* send_buf, + void* recv_buf, + size_t recv_count, + ccl::datatype dtype, + ccl::reduction reduction, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps) { + ccl_coll_param param; + + param.ctype = ccl_coll_reduce_scatter; + param.send_bufs.push_back((void*)send_buf); + param.send_counts.push_back(comm->size() * recv_count); + param.recv_bufs.push_back(recv_buf); + param.recv_counts.push_back(recv_count); + param.reduction = reduction; + param.set_common_fields(dtype, comm, stream, deps); + param.validate(); + + return param; } diff --git a/src/coll/coll_param.hpp b/src/coll/coll_param.hpp index 927e0ca3f..049f5c56a 100644 --- a/src/coll/coll_param.hpp +++ b/src/coll/coll_param.hpp @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "coll/algorithms/algorithms_enum.hpp" #include "common/datatype/datatype.hpp" #include "oneapi/ccl.hpp" @@ -23,7 +25,6 @@ class ccl_comm; #ifdef CCL_ENABLE_SYCL #include -typedef cl::sycl::buffer ccl_sycl_buffer_t; template using ccl_sycl_typed_buffer_t = cl::sycl::buffer; @@ -41,7 +42,7 @@ using ccl_sycl_buffer_one_dim_types = std::tuple ccl_sycl_typed_buffer_t, ccl_sycl_typed_buffer_t, ccl_sycl_typed_buffer_t>; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL #define CCL_INVALID_PROC_IDX (-1) @@ -50,7 +51,6 @@ struct ccl_coll_attr { ccl_coll_attr(const ccl_coll_attr&) = default; ccl_coll_attr& operator=(const ccl_coll_attr&) = default; - //TODO temporary solution for type convertation, ccl_coll_attr would be depreacated ccl_coll_attr(const ccl::allgatherv_attr& attr); ccl_coll_attr(const ccl::allreduce_attr& attr); ccl_coll_attr(const ccl::alltoall_attr& attr); @@ -64,6 +64,8 @@ struct ccl_coll_attr { ccl_coll_attr(ccl_coll_attr&&) = default; ccl_coll_attr& operator=(ccl_coll_attr&&) = default; + std::string to_string() const; + ccl::prologue_fn prologue_fn = nullptr; ccl::epilogue_fn epilogue_fn = nullptr; ccl::reduction_fn reduction_fn = nullptr; @@ -71,9 +73,15 @@ struct ccl_coll_attr { size_t priority = 0; int synchronous = 0; int to_cache = 0; - int vector_buf = 0; std::string match_id{}; + /* change how user-supplied buffers have to be interpreted */ + int is_vector_buf = 0; + +#ifdef CCL_ENABLE_SYCL + int is_sycl_buf = 0; +#endif // CCL_ENABLE_SYCL + ccl::sparse_allreduce_completion_fn sparse_allreduce_completion_fn = nullptr; ccl::sparse_allreduce_alloc_fn sparse_allreduce_alloc_fn = nullptr; const void* sparse_allreduce_fn_ctx = nullptr; @@ -92,32 +100,133 @@ struct ccl_coll_sparse_param { ccl_datatype itype; }; -void copy_deps(const std::vector& in, std::vector& out); - struct ccl_coll_param { + enum class buf_type { regular, device }; + ccl_coll_type ctype; - const void* send_buf; - void* recv_buf; - size_t count; - size_t send_count; - const size_t* send_counts; - const size_t* recv_counts; + + std::vector send_bufs; + std::vector recv_bufs; + + /* + filled if pre-post copy is used + to keep original send/recv buffers + send_buf and recv_buf fields are replaced by staging buffers + */ + std::vector device_send_bufs; + std::vector device_recv_bufs; + + std::vector send_counts; + std::vector recv_counts; + ccl_datatype dtype; ccl::reduction reduction; int root; - const ccl_stream* stream; - std::vector deps; + ccl_stream* stream; ccl_comm* comm; - ccl_coll_sparse_param sparse_param; - bool skip_validation; + std::vector deps; -#ifdef CCL_ENABLE_SYCL - ccl_sycl_buffer_t* device_send_buf; - ccl_sycl_buffer_t* device_recv_buf; -#endif /* CCL_ENABLE_SYCL */ + ccl_coll_sparse_param sparse_param; - ccl_coll_param() {} + ccl_coll_param(); ccl_coll_param(const ccl_coll_param& other); + + std::string to_string() const; + + void* get_send_buf(size_t idx = 0, buf_type type = buf_type::regular) const; + void* get_recv_buf(size_t idx = 0, buf_type type = buf_type::regular) const; + + void* get_send_buf_ptr(size_t idx = 0, buf_type type = buf_type::regular) const; + void* get_recv_buf_ptr(size_t idx = 0, buf_type type = buf_type::regular) const; + + size_t get_send_count(size_t idx = 0) const; + size_t get_recv_count(size_t idx = 0) const; + + bool is_inplace(buf_type type = buf_type::regular) const; + + std::vector get_all_non_zero_bufs() const; + + void validate() const; + + void copy_deps(const std::vector& d, ccl::event* extra = nullptr); + void set_common_fields(ccl::datatype dtype, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps); + void sync_deps(const ccl_stream* s, const std::vector& ds); + + static ccl_coll_param create_allgatherv_param(const void* send_buf, + size_t send_count, + void* recv_buf, + const size_t* recv_counts, + ccl::datatype dtype, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps = {}); + + static ccl_coll_param create_allreduce_param(const void* send_buf, + void* recv_buf, + size_t count, + ccl::datatype dtype, + ccl::reduction reduction, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps = {}); + + static ccl_coll_param create_alltoall_param(const void* send_buf, + void* recv_buf, + size_t count, + ccl::datatype dtype, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps = {}); + + static ccl_coll_param create_alltoallv_param(const void* send_buf, + const size_t* send_counts, + void* recv_buf, + const size_t* recv_counts, + ccl::datatype dtype, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps = {}); + + static ccl_coll_param create_barrier_param(ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps = {}); + + static ccl_coll_param create_broadcast_param(void* buf, + size_t count, + ccl::datatype dtype, + int root, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps = {}); + + static ccl_coll_param create_reduce_param(const void* send_buf, + void* recv_buf, + size_t count, + ccl::datatype dtype, + ccl::reduction reduction, + int root, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps = {}); + + static ccl_coll_param create_reduce_scatter_param(const void* send_buf, + void* recv_buf, + size_t recv_count, + ccl::datatype dtype, + ccl::reduction reduction, + const ccl_coll_attr& attr, + ccl_comm* comm, + const ccl_stream* stream, + const std::vector& deps = {}); }; class coll_param_gpu { @@ -160,18 +269,3 @@ class coll_param_gpu { }; bool operator==(const coll_param_gpu& lhs, const coll_param_gpu& rhs); - -/* - explicitly split coll_param and coll_param_copy - to separate coll_param structure which is used for interaction between different modules - and coll_param_copy which is used as storage for user options -*/ -struct ccl_coll_param_copy { - /* keep copy of user options which can be invalidated after collective call */ - - std::vector ag_recv_bufs; - std::vector ag_recv_counts; - - std::vector a2av_send_counts; - std::vector a2av_recv_counts; -}; diff --git a/src/coll/selection/selection.cpp b/src/coll/selection/selection.cpp new file mode 100644 index 000000000..a90c83600 --- /dev/null +++ b/src/coll/selection/selection.cpp @@ -0,0 +1,185 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "coll/selection/selection.hpp" +#include "common/global/global.hpp" + +bool ccl_is_direct_algo(const ccl_selector_param& param) { + bool res = false; + + auto& selector = ccl::global_data::get().algorithm_selector; + + if (param.ctype == ccl_coll_allgatherv) { + res = (selector->get(param) == ccl_coll_allgatherv_direct); + } + else if (param.ctype == ccl_coll_allreduce) { + res = (selector->get(param) == ccl_coll_allreduce_direct); + } + else if (param.ctype == ccl_coll_alltoall) { + res = (selector->get(param) == ccl_coll_alltoall_direct); + } + else if (param.ctype == ccl_coll_alltoallv) { + res = (selector->get(param) == ccl_coll_alltoallv_direct); + } + else if (param.ctype == ccl_coll_barrier) { + res = (selector->get(param) == ccl_coll_barrier_direct); + } + else if (param.ctype == ccl_coll_bcast) { + res = (selector->get(param) == ccl_coll_bcast_direct); + } + else if (param.ctype == ccl_coll_reduce) { + res = (selector->get(param) == ccl_coll_reduce_direct); + } + else if (param.ctype == ccl_coll_reduce_scatter) { + res = (selector->get(param) == ccl_coll_reduce_scatter_direct); + } + + return res; +} + +static bool ccl_is_device_side_algo(ccl_coll_algo algo, const ccl_selector_param& param) { + if (param.ctype == ccl_coll_allreduce) { + return algo.allreduce == ccl_coll_allreduce_topo_ring; + } + else if (param.ctype == ccl_coll_reduce) { + return algo.reduce == ccl_coll_reduce_topo_ring; + } + else if (param.ctype == ccl_coll_bcast) { + return algo.bcast == ccl_coll_bcast_topo_ring; + } + + return false; +} + +bool ccl_can_use_datatype(ccl_coll_algo algo, const ccl_selector_param& param) { + // A regular type, so we don't need to check for an additional support + if (param.dtype.idx() != ccl::datatype::bfloat16 && + param.dtype.idx() != ccl::datatype::float16) { + return true; + } + + bool can_use = true; + + bool device_side_algo = ccl_is_device_side_algo(algo, param); + + // Algorithms running on GPU device support both fp16 and bf16, so we don't need to require their + // support on the host. + if (!device_side_algo) { + if (param.dtype.idx() == ccl::datatype::bfloat16) { + bool bf16_hw_support = + ccl::global_data::env().bf16_impl_type != ccl_bf16_no_hardware_support; + bool bf16_compiler_support = + ccl::global_data::env().bf16_impl_type != ccl_bf16_no_compiler_support; + + can_use = bf16_compiler_support && bf16_hw_support; + + if (!can_use) { + LOG_DEBUG("BF16 datatype is requested for ", + ccl_coll_type_to_str(param.ctype), + " running on CPU but not fully supported: hw: ", + bf16_hw_support, + " compiler: ", + bf16_compiler_support); + } + } + else if (param.dtype.idx() == ccl::datatype::float16) { + bool fp16_hw_support = + ccl::global_data::env().fp16_impl_type != ccl_fp16_no_hardware_support; + bool fp16_compiler_support = + ccl::global_data::env().fp16_impl_type != ccl_fp16_no_compiler_support; + + can_use = fp16_hw_support && fp16_compiler_support; + + if (!can_use) { + LOG_DEBUG("FP16 datatype is requested for ", + ccl_coll_type_to_str(param.ctype), + " running on CPU but not fully supported: hw: ", + fp16_hw_support, + " compiler: ", + fp16_compiler_support); + } + } + } + + return can_use; +} + +bool ccl_is_topo_ring_algo(const ccl_selector_param& param) { +#ifndef CCL_ENABLE_SYCL + return false; +#endif // CCL_ENABLE_SYCL + + if ((param.ctype != ccl_coll_allreduce) && (param.ctype != ccl_coll_bcast) && + (param.ctype != ccl_coll_reduce)) { + return false; + } + + bool res = false; + + auto& selector = ccl::global_data::get().algorithm_selector; + + if (param.ctype == ccl_coll_allreduce) { + res = (selector->get(param) == ccl_coll_allreduce_topo_ring); + } + else if (param.ctype == ccl_coll_bcast) { + res = (selector->get(param) == ccl_coll_bcast_topo_ring); + } + else if (param.ctype == ccl_coll_reduce) { + res = (selector->get(param) == ccl_coll_reduce_topo_ring); + } + + return res; +} + +bool ccl_can_use_topo_ring_algo(const ccl_selector_param& param) { + if ((param.ctype != ccl_coll_allreduce) && (param.ctype != ccl_coll_bcast) && + (param.ctype != ccl_coll_reduce)) { + return false; + } + + bool is_sycl_buf = false; + bool is_device_buf = true; + bool is_l0_backend = false; + + size_t local_proc_count = ccl::global_data::get().executor->get_local_proc_count(); + +#ifdef CCL_ENABLE_SYCL + is_sycl_buf = param.is_sycl_buf; + if (param.buf && param.stream) { + auto ctx = param.stream->get_native_stream().get_context(); + is_device_buf = + (sycl::get_pointer_type(param.buf, ctx) == sycl::usm::alloc::device) ? true : false; + } +#ifdef MULTI_GPU_SUPPORT + if (param.stream && param.stream->get_backend() == sycl::backend::level_zero) { + is_l0_backend = true; + } +#endif // MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL + + if ((param.comm->size() != 2 && param.comm->size() != 4) || + (param.comm->size() == 2 && param.comm->size() != static_cast(local_proc_count)) || + (param.comm->size() == 4 && local_proc_count != 2 && local_proc_count != 4) || + (param.comm->size() != 2 && (ccl::global_data::env().atl_transport == ccl_atl_mpi)) || + !param.stream || (param.stream->get_type() != stream_type::gpu) || is_sycl_buf || + !is_device_buf || !is_l0_backend || ccl::global_data::env().enable_fusion || + ccl::global_data::env().enable_unordered_coll || + (ccl::global_data::env().priority_mode != ccl_priority_none) || + (ccl::global_data::env().worker_count != 1)) { + return false; + } + + return true; +} diff --git a/src/coll/selection/selection.hpp b/src/coll/selection/selection.hpp index 527e9c1ff..2a8fe2a28 100644 --- a/src/coll/selection/selection.hpp +++ b/src/coll/selection/selection.hpp @@ -15,49 +15,10 @@ */ #pragma once -#include "coll/selection/selector.hpp" -#include "common/utils/tuple.hpp" +#include "coll/selection/selector_wrapper.hpp" -#include +bool ccl_can_use_datatype(ccl_coll_algo algo, const ccl_selector_param& param); -template -class ccl_algorithm_selector_wrapper { -public: - struct selector_init_functor { - template - void operator()(T& t) const { - t.init(); - } - }; - - struct selector_print_functor { - template - void operator()(T& t) const { - t.print(); - } - }; - - void init() { - ccl_tuple_for_each(selectors, selector_init_functor()); - } - - void print() { - ccl_tuple_for_each(selectors, selector_print_functor()); - } - - template - typename ccl_algorithm_selector::type get(const ccl_selector_param& param) const { - CCL_THROW_IF_NOT(coll_id == param.ctype); - return std::get(selectors).get(param); - } - - template - bool is_direct(const ccl_selector_param& param) const { - CCL_THROW_IF_NOT(coll_id == param.ctype); - return std::get(selectors).is_direct(param); - } - -private: - using algo_selectors = std::tuple...>; - algo_selectors selectors; -}; +bool ccl_is_direct_algo(const ccl_selector_param& param); +bool ccl_is_topo_ring_algo(const ccl_selector_param& param); +bool ccl_can_use_topo_ring_algo(const ccl_selector_param& param); diff --git a/src/coll/selection/selector.hpp b/src/coll/selection/selector.hpp index c781c07ef..31a25d019 100644 --- a/src/coll/selection/selector.hpp +++ b/src/coll/selection/selector.hpp @@ -36,14 +36,22 @@ enum ccl_selection_border_type { }; struct ccl_selector_param { - ccl_coll_type ctype; - size_t count; - ccl_datatype dtype; - ccl_comm* comm; + ccl_coll_type ctype = ccl_coll_last_value; + size_t count = 0; + ccl_datatype dtype = ccl_datatype_int8; + ccl_comm* comm = nullptr; + ccl_stream* stream = nullptr; + void* buf = nullptr; - const size_t* send_counts; - const size_t* recv_counts; - int vector_buf; + const size_t* send_counts = nullptr; + const size_t* recv_counts = nullptr; + int is_vector_buf = 0; + +#ifdef CCL_ENABLE_SYCL + int is_sycl_buf = 0; +#endif // CCL_ENABLE_SYCL + + ccl_coll_algo hint_algo = {}; /* tmp fields to avoid selection of algorithms which don't support all coalesce modes or alloc_fn */ ccl::sparse_coalesce_mode sparse_coalesce_mode; @@ -73,7 +81,6 @@ using ccl_selection_table_iter_t = typename ccl_selection_table_t::ccl_algorithm_selector() { insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allgatherv_flat); } -template <> -bool ccl_algorithm_selector_helper::is_direct( - ccl_coll_allgatherv_algo algo) { - return (algo == ccl_coll_allgatherv_direct) ? true : false; -} - template <> bool ccl_algorithm_selector_helper::can_use( ccl_coll_allgatherv_algo algo, @@ -52,7 +46,7 @@ bool ccl_algorithm_selector_helper::can_use( const ccl_selection_table_t& table) { bool can_use = true; - if (param.vector_buf && algo != ccl_coll_allgatherv_flat && + if (param.is_vector_buf && algo != ccl_coll_allgatherv_flat && algo != ccl_coll_allgatherv_multi_bcast) can_use = false; else if (ccl::global_data::env().atl_transport == ccl_atl_mpi && diff --git a/src/coll/selection/selector_allreduce.cpp b/src/coll/selection/selector_allreduce.cpp index 101adaf5f..2ca2aa831 100644 --- a/src/coll/selection/selector_allreduce.cpp +++ b/src/coll/selection/selector_allreduce.cpp @@ -14,7 +14,6 @@ limitations under the License. */ #include "coll/selection/selection.hpp" -#include "exec/exec.hpp" template <> std::map @@ -26,10 +25,14 @@ std::map std::make_pair(ccl_coll_allreduce_ring_rma, "ring_rma"), std::make_pair(ccl_coll_allreduce_double_tree, "double_tree"), std::make_pair(ccl_coll_allreduce_recursive_doubling, "recursive_doubling"), - std::make_pair(ccl_coll_allreduce_2d, "2d") + std::make_pair(ccl_coll_allreduce_2d, "2d"), + std::make_pair(ccl_coll_allreduce_topo_ring, "topo_ring") }; ccl_algorithm_selector::ccl_algorithm_selector() { +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allreduce_topo_ring); +#else // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT if (ccl::global_data::env().atl_transport == ccl_atl_ofi) { insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allreduce_ring); insert(main_table, 0, CCL_ALLREDUCE_SHORT_MSG_SIZE, ccl_coll_allreduce_recursive_doubling); @@ -37,19 +40,13 @@ ccl_algorithm_selector::ccl_algorithm_selector() { CCL_ALLREDUCE_SHORT_MSG_SIZE + 1, CCL_ALLREDUCE_MEDIUM_MSG_SIZE, ccl_coll_allreduce_starlike); - insert( - fallback_table, 0, CCL_ALLREDUCE_SHORT_MSG_SIZE, ccl_coll_allreduce_recursive_doubling); } else if (ccl::global_data::env().atl_transport == ccl_atl_mpi) insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allreduce_direct); +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allreduce_ring); -} - -template <> -bool ccl_algorithm_selector_helper::is_direct( - ccl_coll_allreduce_algo algo) { - return (algo == ccl_coll_allreduce_direct) ? true : false; + insert(fallback_table, 0, CCL_ALLREDUCE_SHORT_MSG_SIZE, ccl_coll_allreduce_recursive_doubling); } template <> @@ -59,7 +56,12 @@ bool ccl_algorithm_selector_helper::can_use( const ccl_selection_table_t& table) { bool can_use = true; - if (algo == ccl_coll_allreduce_rabenseifner && (int)param.count < param.comm->pof2()) + ccl_coll_algo algo_param; + algo_param.allreduce = algo; + can_use = ccl_can_use_datatype(algo_param, param); + + if (algo == ccl_coll_allreduce_rabenseifner && + static_cast(param.count) < param.comm->pof2()) can_use = false; else if (algo == ccl_coll_allreduce_ring_rma && !atl_wrapper::attr.out.enable_rma) can_use = false; @@ -71,6 +73,8 @@ bool ccl_algorithm_selector_helper::can_use( else if (algo == ccl_coll_allreduce_direct && (ccl::global_data::env().atl_transport == ccl_atl_ofi)) can_use = false; + else if (algo == ccl_coll_allreduce_topo_ring && !ccl_can_use_topo_ring_algo(param)) + can_use = false; return can_use; } diff --git a/src/coll/selection/selector_alltoall.cpp b/src/coll/selection/selector_alltoall.cpp index 70e92a7e9..a01290798 100644 --- a/src/coll/selection/selector_alltoall.cpp +++ b/src/coll/selection/selector_alltoall.cpp @@ -40,11 +40,6 @@ ccl_algorithm_selector::ccl_algorithm_selector() { insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_alltoall_naive); } -template <> -bool ccl_algorithm_selector_helper::is_direct(ccl_coll_alltoall_algo algo) { - return (algo == ccl_coll_alltoall_direct) ? true : false; -} - template <> bool ccl_algorithm_selector_helper::can_use( ccl_coll_alltoall_algo algo, diff --git a/src/coll/selection/selector_alltoallv.cpp b/src/coll/selection/selector_alltoallv.cpp index e09d1fcfb..c56e1baa8 100644 --- a/src/coll/selection/selector_alltoallv.cpp +++ b/src/coll/selection/selector_alltoallv.cpp @@ -37,13 +37,7 @@ ccl_algorithm_selector::ccl_algorithm_selector() { CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_alltoallv_scatter_barrier); - insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_alltoallv_naive); -} - -template <> -bool ccl_algorithm_selector_helper::is_direct( - ccl_coll_alltoallv_algo algo) { - return (algo == ccl_coll_alltoallv_direct) ? true : false; + insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_alltoallv_scatter_barrier); } template <> @@ -53,7 +47,10 @@ bool ccl_algorithm_selector_helper::can_use( const ccl_selection_table_t& table) { bool can_use = true; - if (algo == ccl_coll_alltoallv_direct && (ccl::global_data::env().atl_transport == ccl_atl_ofi)) + if (param.is_vector_buf && algo != ccl_coll_alltoallv_scatter_barrier) + can_use = false; + else if (algo == ccl_coll_alltoallv_direct && + (ccl::global_data::env().atl_transport == ccl_atl_ofi)) can_use = false; return can_use; diff --git a/src/coll/selection/selector_barrier.cpp b/src/coll/selection/selector_barrier.cpp index cf47cd7e1..f18a12db1 100644 --- a/src/coll/selection/selector_barrier.cpp +++ b/src/coll/selection/selector_barrier.cpp @@ -31,11 +31,6 @@ ccl_algorithm_selector::ccl_algorithm_selector() { insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_barrier_ring); } -template <> -bool ccl_algorithm_selector_helper::is_direct(ccl_coll_barrier_algo algo) { - return (algo == ccl_coll_barrier_direct) ? true : false; -} - template <> bool ccl_algorithm_selector_helper::can_use( ccl_coll_barrier_algo algo, diff --git a/src/coll/selection/selector_bcast.cpp b/src/coll/selection/selector_bcast.cpp index aecd3e985..786bc22e1 100644 --- a/src/coll/selection/selector_bcast.cpp +++ b/src/coll/selection/selector_bcast.cpp @@ -21,25 +21,26 @@ std::map std::make_pair(ccl_coll_bcast_direct, "direct"), std::make_pair(ccl_coll_bcast_ring, "ring"), std::make_pair(ccl_coll_bcast_double_tree, "double_tree"), - std::make_pair(ccl_coll_bcast_naive, "naive") + std::make_pair(ccl_coll_bcast_naive, "naive"), + std::make_pair(ccl_coll_bcast_topo_ring, "topo_ring") }; ccl_algorithm_selector::ccl_algorithm_selector() { +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_bcast_topo_ring); +#else // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT if (ccl::global_data::env().atl_transport == ccl_atl_ofi) { insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_bcast_naive); insert(main_table, 0, CCL_BCAST_SHORT_MSG_SIZE, ccl_coll_bcast_double_tree); } - else if (ccl::global_data::env().atl_transport == ccl_atl_mpi) + else if (ccl::global_data::env().atl_transport == ccl_atl_mpi) { insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_bcast_direct); + } +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_bcast_naive); } -template <> -bool ccl_algorithm_selector_helper::is_direct(ccl_coll_bcast_algo algo) { - return (algo == ccl_coll_bcast_direct) ? true : false; -} - template <> bool ccl_algorithm_selector_helper::can_use( ccl_coll_bcast_algo algo, @@ -47,6 +48,10 @@ bool ccl_algorithm_selector_helper::can_use( const ccl_selection_table_t& table) { bool can_use = true; + ccl_coll_algo algo_param; + algo_param.bcast = algo; + can_use = ccl_can_use_datatype(algo_param, param); + if (ccl::global_data::env().enable_unordered_coll && algo == ccl_coll_bcast_double_tree) { /* TODO: stabilize double_tree bcast for unordered_coll case */ can_use = false; @@ -54,6 +59,8 @@ bool ccl_algorithm_selector_helper::can_use( else if (algo == ccl_coll_bcast_direct && (ccl::global_data::env().atl_transport == ccl_atl_ofi)) can_use = false; + else if (algo == ccl_coll_bcast_topo_ring && !ccl_can_use_topo_ring_algo(param)) + can_use = false; return can_use; } diff --git a/src/coll/selection/selector_helper.hpp b/src/coll/selection/selector_helper.hpp index c452d4074..bee95ae2d 100644 --- a/src/coll/selection/selector_helper.hpp +++ b/src/coll/selection/selector_helper.hpp @@ -19,13 +19,13 @@ #include #include "coll/algorithms/algorithms.hpp" +#include "exec/exec.hpp" template struct ccl_algorithm_selector_helper { static bool can_use(algo_group_type algo, const ccl_selector_param& param, const ccl_selection_table_t& table); - static bool is_direct(algo_group_type algo); static const std::string& get_str_to_parse(); static ccl_coll_type get_coll_id(); static size_t get_count(const ccl_selector_param& param); diff --git a/src/coll/selection/selector_impl.hpp b/src/coll/selection/selector_impl.hpp index e80319a13..92bbe443f 100644 --- a/src/coll/selection/selector_impl.hpp +++ b/src/coll/selection/selector_impl.hpp @@ -54,9 +54,6 @@ void ccl_algorithm_selector_base::init() { const std::string& str_to_parse = ccl_algorithm_selector_helper::get_str_to_parse(); - if (!str_to_parse.length()) - return; - size_t elem_size; algo_group_type elem_algo; ccl_selection_border_type elem_border; @@ -231,13 +228,7 @@ void ccl_algorithm_selector_base::print() const { << "]: " << ccl_coll_algorithm_to_str(elem_algo) << std::endl; } } - LOG_TRACE(str.str()); -} - -template -bool ccl_algorithm_selector_base::is_direct( - const ccl_selector_param& param) const { - return ccl_algorithm_selector_helper::is_direct(get(param)); + LOG_DEBUG(str.str()); } template @@ -248,6 +239,30 @@ algo_group_type ccl_algorithm_selector_base::get( ccl_selection_border_type elem_border; size_t count = ccl_algorithm_selector_helper::get_count(param); + + if (param.hint_algo.has_value()) { + elem_algo = static_cast(param.hint_algo.value); + if (!ccl_algorithm_selector_helper::can_use( + elem_algo, param, main_table)) { + LOG_DEBUG("can't select hint algorithm: coll ", + ccl_coll_type_to_str(param.ctype), + ", count ", + count, + ", algo ", + ccl_coll_algorithm_to_str(elem_algo), + ", switch to regular selection"); + } + else { + LOG_DEBUG("selected hint algo: coll ", + ccl_coll_type_to_str(param.ctype), + ", count ", + count, + ", algo ", + ccl_coll_algorithm_to_str(elem_algo)); + return elem_algo; + } + } + size_t size = count * param.dtype.size(); auto lower_bound = main_table.lower_bound(size); ccl_selection_unpack_elem(elem_size, elem_algo, elem_border, lower_bound, main_table); @@ -257,17 +272,19 @@ algo_group_type ccl_algorithm_selector_base::get( lower_bound = fallback_table.lower_bound(size); ccl_selection_unpack_elem(elem_size, elem_algo, elem_border, lower_bound, fallback_table); CCL_THROW_IF_NOT(lower_bound != fallback_table.end(), - "can't select algorithm: coll_type ", + "can't select algorithm: coll ", ccl_coll_type_to_str(param.ctype), - ", selection_count ", + ", count ", count); CCL_THROW_IF_NOT(ccl_algorithm_selector_helper::can_use( - elem_algo, param, fallback_table)); + elem_algo, param, fallback_table), + "can't select algorithm in fallback_table: coll ", + ccl_coll_type_to_str(param.ctype)); } - LOG_DEBUG("selected algo: coll_type ", + LOG_DEBUG("selected algo: coll ", ccl_coll_type_to_str(param.ctype), - ", selection_count ", + ", count ", count, ", algo ", ccl_coll_algorithm_to_str(elem_algo)); diff --git a/src/coll/selection/selector_reduce.cpp b/src/coll/selection/selector_reduce.cpp index 68cbe0ca9..4c13ea035 100644 --- a/src/coll/selection/selector_reduce.cpp +++ b/src/coll/selection/selector_reduce.cpp @@ -21,23 +21,25 @@ std::map std::make_pair(ccl_coll_reduce_direct, "direct"), std::make_pair(ccl_coll_reduce_rabenseifner, "rabenseifner"), std::make_pair(ccl_coll_reduce_tree, "tree"), - std::make_pair(ccl_coll_reduce_double_tree, "double_tree") + std::make_pair(ccl_coll_reduce_double_tree, "double_tree"), + std::make_pair(ccl_coll_reduce_topo_ring, "topo_ring") }; ccl_algorithm_selector::ccl_algorithm_selector() { - if (ccl::global_data::env().atl_transport == ccl_atl_ofi) +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_topo_ring); +#else // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT + if (ccl::global_data::env().atl_transport == ccl_atl_ofi) { insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_tree); - else if (ccl::global_data::env().atl_transport == ccl_atl_mpi) + } + else if (ccl::global_data::env().atl_transport == ccl_atl_mpi) { insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_direct); + } +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_tree); } -template <> -bool ccl_algorithm_selector_helper::is_direct(ccl_coll_reduce_algo algo) { - return (algo == ccl_coll_reduce_direct) ? true : false; -} - template <> bool ccl_algorithm_selector_helper::can_use( ccl_coll_reduce_algo algo, @@ -45,13 +47,17 @@ bool ccl_algorithm_selector_helper::can_use( const ccl_selection_table_t& table) { bool can_use = true; + ccl_coll_algo algo_param; + algo_param.reduce = algo; + can_use = ccl_can_use_datatype(algo_param, param); + if (algo == ccl_coll_reduce_rabenseifner && (int)param.count < param.comm->pof2()) can_use = false; else if (algo == ccl_coll_reduce_direct && (ccl::global_data::env().atl_transport == ccl_atl_ofi)) can_use = false; - - return can_use; + else if (algo == ccl_coll_reduce_topo_ring && !ccl_can_use_topo_ring_algo(param)) + can_use = false; return can_use; } diff --git a/src/coll/selection/selector_reduce_scatter.cpp b/src/coll/selection/selector_reduce_scatter.cpp index 12825c795..3d8f67e01 100644 --- a/src/coll/selection/selector_reduce_scatter.cpp +++ b/src/coll/selection/selector_reduce_scatter.cpp @@ -31,12 +31,6 @@ ccl_algorithm_selector::ccl_algorithm_selector() { insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_scatter_ring); } -template <> -bool ccl_algorithm_selector_helper::is_direct( - ccl_coll_reduce_scatter_algo algo) { - return (algo == ccl_coll_reduce_scatter_direct) ? true : false; -} - template <> bool ccl_algorithm_selector_helper::can_use( ccl_coll_reduce_scatter_algo algo, diff --git a/src/coll/selection/selector_sparse_allreduce.cpp b/src/coll/selection/selector_sparse_allreduce.cpp index db02d88fd..449eb512b 100644 --- a/src/coll/selection/selector_sparse_allreduce.cpp +++ b/src/coll/selection/selector_sparse_allreduce.cpp @@ -35,12 +35,6 @@ ccl_algorithm_selector::ccl_algorithm_selector() { } } -template <> -bool ccl_algorithm_selector_helper::is_direct( - ccl_coll_sparse_allreduce_algo algo) { - return false; -} - template <> bool ccl_algorithm_selector_helper::can_use( ccl_coll_sparse_allreduce_algo algo, diff --git a/src/coll/selection/selector_wrapper.hpp b/src/coll/selection/selector_wrapper.hpp new file mode 100644 index 000000000..6b962021d --- /dev/null +++ b/src/coll/selection/selector_wrapper.hpp @@ -0,0 +1,58 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "coll/selection/selector.hpp" +#include "common/utils/tuple.hpp" + +#include + +template +class ccl_algorithm_selector_wrapper { +public: + struct selector_init_functor { + template + void operator()(T& t) const { + t.init(); + } + }; + + struct selector_print_functor { + template + void operator()(T& t) const { + t.print(); + } + }; + + void init() { + ccl_tuple_for_each(selectors, selector_init_functor()); + } + + void print() { + ccl_tuple_for_each(selectors, selector_print_functor()); + } + + template + typename ccl_algorithm_selector::type get(const ccl_selector_param& param) const { + CCL_THROW_IF_NOT( + coll_id == param.ctype, "expected coll_id ", coll_id, ", got ", param.ctype); + return std::get(selectors).get(param); + } + +private: + using algo_selectors = std::tuple...>; + algo_selectors selectors; +}; diff --git a/src/common/comm/atl_tag.cpp b/src/common/comm/atl_tag.cpp index 504983b80..2bca7ba61 100644 --- a/src/common/comm/atl_tag.cpp +++ b/src/common/comm/atl_tag.cpp @@ -24,35 +24,26 @@ void ccl_atl_tag::print() { LOG_INFO(" pof2: ", ccl_pof2(max_tag)); } -uint64_t ccl_atl_tag::create(ccl_comm_id_t comm_id, - int rank, +uint64_t ccl_atl_tag::create(int rank, + ccl_comm_id_t comm_id, ccl_sched_id_t sched_id, ccl_op_id_t op_id) { uint64_t tag = 0; - if (tag_bits == 32) { - tag |= (((uint64_t)op_id) << op_id_shift) & op_id_mask; - tag |= (((uint64_t)sched_id) << sched_id_shift) & sched_id_mask; - } - else if (tag_bits == 64) { - tag |= (((uint64_t)op_id) << op_id_shift) & op_id_mask; - tag |= (((uint64_t)sched_id) << sched_id_shift) & sched_id_mask; - tag |= (((uint64_t)rank) << rank_shift) & rank_mask; - tag |= (((uint64_t)comm_id) << comm_id_shift) & comm_id_mask; - } - else { - CCL_ASSERT(0); - } + tag |= (((uint64_t)op_id) << op_id_shift) & op_id_mask; + tag |= (((uint64_t)sched_id) << sched_id_shift) & sched_id_mask; + tag |= (((uint64_t)comm_id) << comm_id_shift) & comm_id_mask; + tag |= (((uint64_t)rank) << rank_shift) & rank_mask; if (tag > max_tag) tag &= max_tag_mask; LOG_DEBUG("tag ", tag, - " (comm_id: ", - comm_id, - ", rank ", + " (rank ", rank, + ", comm_id: ", + comm_id, ", sched_id: ", sched_id, ", op_id: ", @@ -64,10 +55,10 @@ uint64_t ccl_atl_tag::create(ccl_comm_id_t comm_id, tag, ", max_tag ", max_tag, - " (comm_id: ", - comm_id, - ", rank ", + " (rank ", rank, + ", comm_id: ", + comm_id, ", sched_id: ", sched_id, ", op_id: ", diff --git a/src/common/comm/atl_tag.hpp b/src/common/comm/atl_tag.hpp index ea4570121..4c9a46cfc 100644 --- a/src/common/comm/atl_tag.hpp +++ b/src/common/comm/atl_tag.hpp @@ -25,6 +25,8 @@ using ccl_comm_id_t = uint16_t; class ccl_atl_tag { public: ccl_atl_tag(size_t tag_bits, size_t max_tag) : tag_bits(tag_bits), max_tag(max_tag) { + CCL_THROW_IF_NOT(tag_bits >= 32, "unexpected tag_bits ", tag_bits); + if (max_tag == ccl_pof2(max_tag) * 2 - 1) max_tag_mask = max_tag; else @@ -43,20 +45,21 @@ class ccl_atl_tag { /** * Generates the tag to be used by ATL communication operations + * @param rank identifier of the rank within communicator * @param comm_id identifier of the communicator - * @param sched_id identifier if the schedule - * @param op_id local operation ID. Used to generate unique ATL tag when the rest of input parameters do not change + * @param sched_id identifier of the schedule within communicator + * @param op_id local operation id, used as sub-schedule identifier * @return ATL communication tag */ - uint64_t create(ccl_comm_id_t comm_id, int rank, ccl_sched_id_t sched_id, ccl_op_id_t op_id); + uint64_t create(int rank, ccl_comm_id_t comm_id, ccl_sched_id_t sched_id, ccl_op_id_t op_id); private: /********************************************************************************** * atl tag layout * * ******************************************************************************** - * 01234567 01234567 | 01234567 01234567 01234567 | 01234567 01234567 | 01234567 | - * | | | | - * comm_id | rank | sched_id(per comm) | op_id | + * 01234567 01234567 01234567 | 01234567 01234567 | 01234567 01234567 | 01234567 | + * | | | | + * rank | comm_id | sched_id | op_id | *********************************************************************************/ size_t tag_bits; @@ -65,11 +68,11 @@ class ccl_atl_tag { const int op_id_shift = 0; const int sched_id_shift = 8; - const int rank_shift = 24; - const int comm_id_shift = 48; + const int comm_id_shift = 24; + const int rank_shift = 40; const uint64_t op_id_mask = 0x00000000000000FF; const uint64_t sched_id_mask = 0x0000000000FFFF00; - const uint64_t rank_mask = 0x0000FFFFFF000000; - const uint64_t comm_id_mask = 0xFFFF000000000000; + const uint64_t comm_id_mask = 0x000000FFFF000000; + const uint64_t rank_mask = 0xFFFFFF0000000000; }; diff --git a/src/common/comm/comm.cpp b/src/common/comm/comm.cpp index 7c4f5c88b..fe6e8062e 100644 --- a/src/common/comm/comm.cpp +++ b/src/common/comm/comm.cpp @@ -16,6 +16,7 @@ #include "atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.h" #include "exec/exec.hpp" #include "common/comm/comm.hpp" +#include "common/comm/host_communicator/host_communicator.hpp" #include "common/global/global.hpp" #include "sched/sched.hpp" #include "oneapi/ccl/types.hpp" @@ -43,21 +44,30 @@ ccl_comm::ccl_comm(int rank, int size, ccl_comm_id_storage::comm_id&& id, std::shared_ptr atl, - bool share_resources) - : ccl_comm(rank, size, std::move(id), ccl_rank2rank_map{}, atl, share_resources) {} + bool share_resources, + ccl::host_communicator* host_comm) + : ccl_comm(rank, + size, + std::move(id), + ccl_rank2rank_map{}, + atl, + share_resources, + host_comm) {} ccl_comm::ccl_comm(int rank, int size, ccl_comm_id_storage::comm_id&& id, ccl_rank2rank_map&& rank_map, std::shared_ptr atl, - bool share_resources) + bool share_resources, + ccl::host_communicator* host_comm) : atl(atl), m_id(std::move(id)), m_local2global_map(std::move(rank_map)), m_dtree(size, rank), thread_number(1), - on_process_ranks_number(1) { + on_process_ranks_number(1), + host_comm(host_comm) { reset(rank, size); if (!share_resources) { @@ -79,10 +89,12 @@ ccl_comm::ccl_comm(const std::vector& local_ranks, int comm_size, std::shared_ptr kvs_instance, ccl_comm_id_storage::comm_id&& id, - bool share_resources) + bool share_resources, + ccl::host_communicator* host_comm) : m_id(std::move(id)), m_local2global_map(), - m_dtree(local_ranks.size(), comm_size) { + m_dtree(local_ranks.size(), comm_size), + host_comm(host_comm) { std::shared_ptr kvs_wrapper(new users_kvs(kvs_instance)); atl = std::shared_ptr(new atl_wrapper(comm_size, local_ranks, kvs_wrapper)); @@ -148,8 +160,13 @@ ccl_comm* ccl_comm::create_with_colors(const std::vector& colors, std::shared_ptr ccl_comm::clone_with_new_id(ccl_comm_id_storage::comm_id&& id) { ccl_rank2rank_map rank_map{ m_local2global_map }; - return std::make_shared( - m_rank, m_size, std::move(id), std::move(rank_map), atl, true /*share_resources*/); + return std::make_shared(m_rank, + m_size, + std::move(id), + std::move(rank_map), + atl, + true /*share_resources*/, + get_host_comm()); } int ccl_comm::get_global_rank(int rank) const { diff --git a/src/common/comm/comm.hpp b/src/common/comm/comm.hpp index 5a9e2072a..77505c705 100644 --- a/src/common/comm/comm.hpp +++ b/src/common/comm/comm.hpp @@ -31,6 +31,7 @@ using ccl_rank2rank_map = std::vector; namespace ccl { +class host_communicator; namespace v1 { class kvs_interface; } @@ -38,7 +39,12 @@ class kvs_interface; class alignas(CACHELINE_SIZE) ccl_comm { public: - //TODO + static constexpr int invalid_rank = -1; + + ccl::host_communicator* get_host_comm() { + return host_comm; + } + static void ccl_comm_reset_thread_barrier(); ccl_comm() = delete; ccl_comm(const ccl_comm& other) = delete; @@ -48,13 +54,16 @@ class alignas(CACHELINE_SIZE) ccl_comm { int size, ccl_comm_id_storage::comm_id&& id, std::shared_ptr atl, - bool share_resources = false); + bool share_resources = false, + ccl::host_communicator* host_comm = nullptr); + ccl_comm(int rank, int size, ccl_comm_id_storage::comm_id&& id, ccl_rank2rank_map&& ranks, std::shared_ptr atl, - bool share_resources = false); + bool share_resources = false, + ccl::host_communicator* host_comm = nullptr); //TODO non-implemented //1) cluster_devices_count (devices 1000) -> (processes 10) @@ -70,7 +79,8 @@ class alignas(CACHELINE_SIZE) ccl_comm { int comm_size, std::shared_ptr kvs_instance, ccl_comm_id_storage::comm_id&& id, - bool share_resources = false); + bool share_resources = false, + ccl::host_communicator* host_comm = nullptr); ~ccl_comm() = default; @@ -178,4 +188,5 @@ class alignas(CACHELINE_SIZE) ccl_comm { size_t thread_number; size_t on_process_ranks_number; + ccl::host_communicator* host_comm; }; diff --git a/src/common/comm/comm_interface.hpp b/src/common/comm/comm_interface.hpp index 02d30b0e8..a72256413 100644 --- a/src/common/comm/comm_interface.hpp +++ b/src/common/comm/comm_interface.hpp @@ -161,6 +161,6 @@ struct communicator_interface : public communicator_interface_dispatcher { COMM_INTERFACE_COLL_METHODS(DECLARATION); #ifdef CCL_ENABLE_SYCL SYCL_COMM_INTERFACE_COLL_METHODS(DECLARATION); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL }; } // namespace ccl diff --git a/src/common/comm/compiler_comm_interface_dispatcher.cpp b/src/common/comm/compiler_comm_interface_dispatcher.cpp index 6734322bf..3c1373925 100644 --- a/src/common/comm/compiler_comm_interface_dispatcher.cpp +++ b/src/common/comm/compiler_comm_interface_dispatcher.cpp @@ -29,19 +29,11 @@ #include "comm_split_attr_impl.hpp" #include "common/global/global.hpp" -//#include "native_device_api/compiler_ccl_wrappers_dispatcher.hpp" + #ifdef MULTI_GPU_SUPPORT -#include "common/comm/l0/communicator/device_group/device_ring_communicator.hpp" -#include "common/comm/l0/communicator/device_group/device_a2a_communicator.hpp" -#include "common/comm/l0/communicator/thread_group/thread_ring_communicator.hpp" -#include "common/comm/l0/communicator/thread_group/thread_a2a_communicator.hpp" -#include "common/comm/l0/communicator/process_group/process_ring_communicator.hpp" -#include "common/comm/l0/communicator/process_group/process_a2a_communicator.hpp" #include "supported_topologies.hpp" - #endif -#include "common/comm/single_device_communicator/single_device_communicator.hpp" #include "common/comm/host_communicator/host_communicator_impl.hpp" namespace ccl { @@ -112,8 +104,6 @@ communicator_interface_ptr communicator_interface_dispatcher::create_communicato atl, preferred_topology_group); #else - //. static_assert(std::is_same::value, - // "Unsupported 'DeviceType'"); return communicator_interface_dispatcher::create_communicator_from_unified_device( unified_device_type(device_id), unified_context_type(context), @@ -134,111 +124,33 @@ communicator_interface_dispatcher::create_communicator_from_unified_device( const ccl::comm_split_attr& attr, std::shared_ptr atl, ccl::group_split_type preferred_topology_group /* = ccl::group_split_type::undetermined */) { - // TODO ring by default at now. Choose preferred a2a if availbale - ccl::device_topology_type preferred_topology_class = ccl::device_topology_type::ring; - - // Use process class if not specified otherwise - // TODO: implement a proper dispatching for other types if (preferred_topology_group == ccl::group_split_type::undetermined) { preferred_topology_group = ccl::group_split_type::cluster; } - // read comm split attributes - // TODO: we don't have support for communicator splitting yet, there is chance that - // we might not need the attr here as the split routine will be moved into a separate - // function if (attr.is_valid()) { throw ccl::exception(std::string(__FUNCTION__) + " - not implemented for 'group'"); if (attr.is_valid()) { throw ccl::exception(std::string( - "Invalid `comm_split_attr`: both `color` and `group` set. Only one is supported")); + "invalid `comm_split_attr`: both `color` and `group` set, only one is supported")); } } else if (attr.is_valid()) { throw ccl::exception(std::string(__FUNCTION__) + " - not implemented for 'color'"); } - // TODO creation host communicator from device - // if (device is host ?) - // return new host_communicator(atl); - - //TODO check device_id or sycl device validity before communicator creation - switch (preferred_topology_class) { - case device_topology_type::ring: { - switch (preferred_topology_group) { + switch (preferred_topology_group) { #if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) - case ccl::group_split_type::single: { - auto comm_impl = new single_device_communicator( - std::move(device_id), std::move(context), thread_idx, process_idx, attr); - ccl::global_data& data = ccl::global_data::get(); - auto comm = std::shared_ptr( - new ccl_comm(thread_idx, process_idx, data.comm_ids->acquire(), atl)); - comm_impl->set_ccl_comm(std::move(comm)); - return communicator_interface_ptr(comm_impl); - } -#endif -#ifdef MULTI_GPU_SUPPORT - case ccl::group_split_type::thread: - return communicator_interface_ptr(new device_group_ring_communicator( - std::move(device_id), std::move(context), thread_idx, process_idx, attr)); - case ccl::group_split_type::process: - return communicator_interface_ptr(new thread_device_group_ring_communicator( - std::move(device_id), std::move(context), thread_idx, process_idx, attr)); - case ccl::group_split_type::cluster: - return communicator_interface_ptr(new process_ring_communicator( - std::move(device_id), std::move(context), thread_idx, process_idx, attr)); -#endif //MULTI_GPU_SUPPORT - default: - throw ccl::exception( - std::string( - "Invalid `comm_split_attr` value for `ccl_device_preferred_group`: ") + - ::to_string(preferred_topology_group)); - } - break; + case ccl::group_split_type::single: { + return communicator_interface_ptr( + new host_communicator(std::move(device_id), std::move(context), atl)); } - case device_topology_type::a2a: { - switch (preferred_topology_group) { -#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) - case ccl::group_split_type::single: - return communicator_interface_ptr(new single_device_communicator( - std::move(device_id), std::move(context), thread_idx, process_idx, attr)); #endif -#ifdef MULTI_GPU_SUPPORT - case ccl::group_split_type::thread: - return communicator_interface_ptr(new device_group_a2a_communicator( - std::move(device_id), std::move(context), thread_idx, process_idx, attr)); - case ccl::group_split_type::process: - return communicator_interface_ptr(new thread_device_group_a2a_communicator( - std::move(device_id), std::move(context), thread_idx, process_idx, attr)); - case ccl::group_split_type::cluster: - return communicator_interface_ptr(new process_a2a_communicator( - std::move(device_id), std::move(context), thread_idx, process_idx, attr)); -#endif - default: - throw ccl::exception( - std::string( - "Invalid `comm_split_attr` value for `ccl_device_preferred_group`: ") + - ::to_string(preferred_topology_group)); - } - break; - } -#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) - case device_topology_type::undetermined: { - auto comm_impl = new single_device_communicator( - std::move(device_id), std::move(context), thread_idx, process_idx, attr); - ccl::global_data& data = ccl::global_data::get(); - auto comm = std::shared_ptr( - new ccl_comm(thread_idx, process_idx, data.comm_ids->acquire(), atl)); - comm_impl->set_ccl_comm(std::move(comm)); - return communicator_interface_ptr(comm_impl); - } -#endif //MULTI_GPU_SUPPORT - default: { + default: throw ccl::exception( - std::string( - "Invalid `comm_split_attr` value for `ccl_device_preferred_topology_class`: ") + - ::to_string(preferred_topology_class)); - } + std::string("Invalid `comm_split_attr` value for `ccl_device_preferred_group`: ") + + ::to_string(preferred_topology_group)); + break; } return std::unique_ptr(); diff --git a/src/common/comm/compiler_comm_interface_dispatcher.hpp b/src/common/comm/compiler_comm_interface_dispatcher.hpp index 89a7ad66b..ad643f0a9 100644 --- a/src/common/comm/compiler_comm_interface_dispatcher.hpp +++ b/src/common/comm/compiler_comm_interface_dispatcher.hpp @@ -41,6 +41,8 @@ struct communicator_interface_dispatcher { using device_t = typename ccl::unified_device_type::ccl_native_t; using context_t = typename ccl::unified_context_type::ccl_native_t; + virtual ~communicator_interface_dispatcher() = default; + #ifdef MULTI_GPU_SUPPORT virtual void visit(ccl::gpu_comm_attr& comm_attr) = 0; #endif //MULTI_GPU_SUPPORT diff --git a/src/common/comm/host_communicator/host_communicator.cpp b/src/common/comm/host_communicator/host_communicator.cpp index 0f7c5cd9d..b0883492d 100644 --- a/src/common/comm/host_communicator/host_communicator.cpp +++ b/src/common/comm/host_communicator/host_communicator.cpp @@ -38,10 +38,17 @@ namespace ccl { using ccl::preview::create_comm_split_attr; -host_communicator::host_communicator() : comm_attr(create_comm_split_attr()) {} +host_communicator::host_communicator() + : device(ccl::device_index_type(ccl::unused_index_value, + ccl::unused_index_value, + ccl::unused_index_value)), + comm_attr(create_comm_split_attr()) {} host_communicator::host_communicator(int size, shared_ptr_class kvs) - : comm_attr(create_comm_split_attr()), + : device(ccl::device_index_type(ccl::unused_index_value, + ccl::unused_index_value, + ccl::unused_index_value)), + comm_attr(create_comm_split_attr()), comm_rank(0), comm_size(size) { if (size <= 0) { @@ -50,7 +57,10 @@ host_communicator::host_communicator(int size, shared_ptr_class kv } host_communicator::host_communicator(int size, int rank, shared_ptr_class kvs) - : comm_attr(create_comm_split_attr()), + : device(ccl::device_index_type(ccl::unused_index_value, + ccl::unused_index_value, + ccl::unused_index_value)), + comm_attr(create_comm_split_attr()), comm_rank(rank), comm_size(size) { if (rank > size || size <= 0) { @@ -62,34 +72,52 @@ host_communicator::host_communicator(int size, int rank, shared_ptr_class atl_tmp = std::shared_ptr(new atl_wrapper(size, { rank }, kvs)); - comm_impl = - std::shared_ptr(new ccl_comm(rank, size, data.comm_ids->acquire(), atl_tmp)); + comm_impl = std::shared_ptr( + new ccl_comm(rank, size, data.comm_ids->acquire(), atl_tmp, false, this)); + create_sub_comms(atl_tmp); } +host_communicator::host_communicator(ccl::unified_device_type&& d, + ccl::unified_context_type&& c, + std::shared_ptr atl) + : host_communicator(atl) {} + host_communicator::host_communicator(std::shared_ptr atl) - : comm_attr(create_comm_split_attr()), + : device(ccl::device_index_type(ccl::unused_index_value, + ccl::unused_index_value, + ccl::unused_index_value)), + comm_attr(create_comm_split_attr()), comm_rank(atl->get_rank()), comm_size(atl->get_size()) { int rank = atl->get_rank(); int size = atl->get_size(); if (rank > size || size <= 0) { - throw ccl::exception("Incorrect rank or size value when creating \ - a host communicator: rank" + - std::to_string(rank) + " size: " + std::to_string(size)); + throw ccl::exception("incorrect rank or size when creating \ + a host communicator: rank: " + + std::to_string(rank) + ", size: " + std::to_string(size)); } LOG_DEBUG("ctor"); ccl::global_data& data = ccl::global_data::get(); - comm_impl = std::shared_ptr(new ccl_comm(rank, size, data.comm_ids->acquire(), atl)); + comm_impl = std::shared_ptr( + new ccl_comm(rank, size, data.comm_ids->acquire(), atl, false, this)); + create_sub_comms(atl); } -host_communicator::host_communicator(std::shared_ptr impl) +host_communicator::host_communicator(std::shared_ptr impl, bool is_sub_communicator) : comm_impl(impl), + device(ccl::device_index_type(ccl::unused_index_value, + ccl::unused_index_value, + ccl::unused_index_value)), comm_attr(create_comm_split_attr()), comm_rank(impl->rank()), - comm_size(impl->size()) {} + comm_size(impl->size()) { + if (!is_sub_communicator) { + create_sub_comms(comm_impl.get()->atl); + } +} int host_communicator::rank() const { return comm_rank; @@ -133,6 +161,39 @@ void host_communicator::exchange_colors(std::vector& colors) { .wait(); } +void host_communicator::create_sub_comms(std::shared_ptr atl) { + bool is_sub_comm = true; + if (ccl::global_data::env().atl_transport == ccl_atl_mpi) { + r2r_comm = + std::shared_ptr(new host_communicator(comm_impl, is_sub_comm)); + node_comm = + std::shared_ptr(new host_communicator(comm_impl, is_sub_comm)); + pair_comm = + std::shared_ptr(new host_communicator(comm_impl, is_sub_comm)); + even_comm = + std::shared_ptr(new host_communicator(comm_impl, is_sub_comm)); + } + else { + ccl::global_data& data = ccl::global_data::get(); + r2r_comm = std::shared_ptr( + new host_communicator(std::shared_ptr(this->create_with_color( + atl->get_r2r_color(), data.comm_ids.get(), comm_impl.get())), + is_sub_comm)); + node_comm = std::shared_ptr( + new host_communicator(std::shared_ptr(this->create_with_color( + atl->get_host_color(), data.comm_ids.get(), comm_impl.get())), + is_sub_comm)); + even_comm = std::shared_ptr(new host_communicator( + std::shared_ptr(this->create_with_color( + atl->get_host_color() + atl->get_rank() % 2, data.comm_ids.get(), comm_impl.get())), + is_sub_comm)); + pair_comm = std::shared_ptr(new host_communicator( + std::shared_ptr(this->create_with_color( + atl->get_host_color() + atl->get_rank() / 2, data.comm_ids.get(), comm_impl.get())), + is_sub_comm)); + } +} + ccl_comm* host_communicator::create_with_color(int color, ccl_comm_id_storage* comm_ids, const ccl_comm* parent_comm) { @@ -165,22 +226,17 @@ ccl::communicator_interface_ptr host_communicator::split(const comm_split_attr& new host_communicator(std::shared_ptr(new_comm))); } -ccl::event host_communicator::barrier(const ccl::stream::impl_value_t& op_stream, +ccl::event host_communicator::barrier(const ccl::stream::impl_value_t& stream, const ccl::barrier_attr& attr, const ccl::vector_class& deps) { - return get_impl()->barrier_impl(op_stream, attr, deps); + return get_impl()->barrier_impl(stream, attr, deps); } -ccl::event host_communicator::barrier_impl(const ccl::stream::impl_value_t& op_stream, +ccl::event host_communicator::barrier_impl(const ccl::stream::impl_value_t& stream, const ccl::barrier_attr& attr, const ccl::vector_class& deps) { - // TODO what exactly we need to do with 'attr' here? - - ccl_barrier_impl(comm_impl.get(), op_stream.get(), deps); - - // TODO what exactly we need to return here? ccl_barrier_impl() is void func - ccl_request* req = nullptr; - return std::unique_ptr(new ccl::host_event_impl(req)); + ccl_barrier_impl(comm_impl.get(), stream.get(), deps); + return std::unique_ptr(new ccl::host_event_impl(nullptr)); } /* allgatherv */ @@ -199,7 +255,7 @@ ccl::event host_communicator::allgatherv_impl(const void* send_buf, dtype, attr, comm_impl.get(), - nullptr, + get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); @@ -214,7 +270,7 @@ ccl::event host_communicator::allgatherv_impl(const void* send_buf, const ccl::allgatherv_attr& attr, const ccl::vector_class& deps) { ccl_coll_attr internal_attr(attr); - internal_attr.vector_buf = 1; + internal_attr.is_vector_buf = 1; ccl_request* req = ccl_allgatherv_impl(reinterpret_cast(send_buf), send_count, @@ -223,7 +279,7 @@ ccl::event host_communicator::allgatherv_impl(const void* send_buf, dtype, internal_attr, comm_impl.get(), - nullptr, + get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); @@ -238,8 +294,15 @@ ccl::event host_communicator::allreduce_impl(const void* send_buf, const ccl::stream::impl_value_t& stream, const ccl::allreduce_attr& attr, const ccl::vector_class& deps) { - ccl_request* req = ccl_allreduce_impl( - send_buf, recv_buf, count, dtype, reduction, attr, comm_impl.get(), nullptr, deps); + ccl_request* req = ccl_allreduce_impl(send_buf, + recv_buf, + count, + dtype, + reduction, + attr, + comm_impl.get(), + get_stream_ptr(stream), + deps); return std::unique_ptr(new ccl::host_event_impl(req)); } @@ -252,8 +315,8 @@ ccl::event host_communicator::alltoall_impl(const void* send_buf, const ccl::stream::impl_value_t& stream, const ccl::alltoall_attr& attr, const ccl::vector_class& deps) { - ccl_request* req = - ccl_alltoall_impl(send_buf, recv_buf, count, dtype, attr, comm_impl.get(), nullptr, deps); + ccl_request* req = ccl_alltoall_impl( + send_buf, recv_buf, count, dtype, attr, comm_impl.get(), get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); } @@ -286,7 +349,7 @@ ccl::event host_communicator::alltoallv_impl(const void* send_buf, dtype, attr, comm_impl.get(), - nullptr, + get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); @@ -313,8 +376,8 @@ ccl::event host_communicator::broadcast_impl(void* buf, const ccl::stream::impl_value_t& stream, const ccl::broadcast_attr& attr, const ccl::vector_class& deps) { - ccl_request* req = - ccl_broadcast_impl(buf, count, dtype, root, attr, comm_impl.get(), nullptr, deps); + ccl_request* req = ccl_broadcast_impl( + buf, count, dtype, root, attr, comm_impl.get(), get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); } @@ -329,8 +392,16 @@ ccl::event host_communicator::reduce_impl(const void* send_buf, const ccl::stream::impl_value_t& stream, const ccl::reduce_attr& attr, const ccl::vector_class& deps) { - ccl_request* req = ccl_reduce_impl( - send_buf, recv_buf, count, dtype, reduction, root, attr, comm_impl.get(), nullptr, deps); + ccl_request* req = ccl_reduce_impl(send_buf, + recv_buf, + count, + dtype, + reduction, + root, + attr, + comm_impl.get(), + get_stream_ptr(stream), + deps); return std::unique_ptr(new ccl::host_event_impl(req)); } @@ -344,8 +415,15 @@ ccl::event host_communicator::reduce_scatter_impl(const void* send_buf, const ccl::stream::impl_value_t& stream, const ccl::reduce_scatter_attr& attr, const ccl::vector_class& deps) { - ccl_request* req = ccl_reduce_scatter_impl( - send_buf, recv_buf, recv_count, dtype, reduction, attr, comm_impl.get(), nullptr, deps); + ccl_request* req = ccl_reduce_scatter_impl(send_buf, + recv_buf, + recv_count, + dtype, + reduction, + attr, + comm_impl.get(), + get_stream_ptr(stream), + deps); return std::unique_ptr(new ccl::host_event_impl(req)); } @@ -378,7 +456,7 @@ ccl::event host_communicator::sparse_allreduce_impl(const void* send_ind_buf, reduction, attr, comm_impl.get(), - nullptr, + get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); @@ -388,6 +466,26 @@ std::shared_ptr host_communicator::get_atl() { return comm_impl->atl; } +std::shared_ptr host_communicator::get_r2r_comm() { + return r2r_comm; +} + +std::shared_ptr host_communicator::get_node_comm() { + return node_comm; +} + +std::shared_ptr host_communicator::get_pair_comm() { + return pair_comm; +} + +std::shared_ptr host_communicator::get_even_comm() { + return even_comm; +} + +std::shared_ptr host_communicator::get_ccl_comm() { + return comm_impl; +} + std::string host_communicator::to_string() const { return std::string("host communicator, rank (") + std::to_string(rank()) + "/" + std::to_string(size()); @@ -396,6 +494,6 @@ std::string host_communicator::to_string() const { COMM_INTERFACE_COLL_INSTANTIATION(host_communicator); #ifdef CCL_ENABLE_SYCL SYCL_COMM_INTERFACE_COLL_INSTANTIATION(host_communicator); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL } // namespace ccl diff --git a/src/common/comm/host_communicator/host_communicator.hpp b/src/common/comm/host_communicator/host_communicator.hpp index 08b604b2a..53bb642ce 100644 --- a/src/common/comm/host_communicator/host_communicator.hpp +++ b/src/common/comm/host_communicator/host_communicator.hpp @@ -15,7 +15,9 @@ */ #pragma once +#include "atl/atl_wrapper.h" #include "common/comm/comm.hpp" +#include "common/stream/stream.hpp" #include "oneapi/ccl/types.hpp" #include "oneapi/ccl/types_policy.hpp" #include "oneapi/ccl/comm_split_attr_ids.hpp" @@ -24,7 +26,6 @@ #include "oneapi/ccl/types.hpp" #include "oneapi/ccl/type_traits.hpp" #include "oneapi/ccl/types_policy.hpp" - #include "oneapi/ccl/event.hpp" #include "oneapi/ccl/coll_attr_ids.hpp" #include "oneapi/ccl/coll_attr_ids_traits.hpp" @@ -33,11 +34,17 @@ #include "common/comm/communicator_traits.hpp" #include "common/comm/comm_interface.hpp" #include "types_generator_defines.hpp" -#include "atl/atl_wrapper.h" class ikvs_wrapper; namespace ccl { +inline ccl_stream* get_stream_ptr(const ccl::stream::impl_value_t& stream) { + if (stream.get() && stream->is_sycl_device_stream()) + return stream.get(); + else + return nullptr; +} + class host_communicator : public ccl::communicator_interface { public: using traits = ccl::host_communicator_traits; @@ -111,7 +118,7 @@ class host_communicator : public ccl::communicator_interface { COMM_INTERFACE_COLL_METHODS(DEFINITION); #ifdef CCL_ENABLE_SYCL SYCL_COMM_INTERFACE_COLL_METHODS(DEFINITION); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL COMM_IMPL_DECLARATION; COMM_IMPL_CLASS_DECLARATION @@ -121,33 +128,49 @@ class host_communicator : public ccl::communicator_interface { host_communicator(); host_communicator(int size, shared_ptr_class kvs); host_communicator(int size, int rank, shared_ptr_class kvs); + host_communicator(ccl::unified_device_type&& device, + ccl::unified_context_type&& context, + std::shared_ptr atl); host_communicator(std::shared_ptr atl); - host_communicator(std::shared_ptr impl); + host_communicator(std::shared_ptr impl, bool is_sub_communicator = false); host_communicator(host_communicator& src) = delete; host_communicator(host_communicator&& src) = default; host_communicator& operator=(host_communicator& src) = delete; host_communicator& operator=(host_communicator&& src) = default; ~host_communicator() = default; std::shared_ptr get_atl(); + std::shared_ptr get_r2r_comm(); + std::shared_ptr get_node_comm(); + std::shared_ptr get_even_comm(); + std::shared_ptr get_pair_comm(); + std::shared_ptr get_ccl_comm(); // troubleshooting std::string to_string() const; private: friend struct group_context; + std::shared_ptr comm_impl; + + ccl::unified_device_type device; + //ccl::unified_context_type context; + + std::shared_ptr r2r_comm; + std::shared_ptr node_comm; + std::shared_ptr even_comm; + std::shared_ptr pair_comm; ccl::comm_split_attr comm_attr; int comm_rank; int comm_size; ccl::group_unique_key owner_id; - // ccl::unified_device_type device; - // ccl::unified_context_type context; host_communicator* get_impl() { return this; } void exchange_colors(std::vector& colors); + void create_sub_comms(std::shared_ptr atl); ccl_comm* create_with_color(int color, ccl_comm_id_storage* comm_ids, const ccl_comm* parent_comm); diff --git a/src/common/comm/host_communicator/host_communicator_impl.hpp b/src/common/comm/host_communicator/host_communicator_impl.hpp index a958a117a..00d8bd879 100644 --- a/src/common/comm/host_communicator/host_communicator_impl.hpp +++ b/src/common/comm/host_communicator/host_communicator_impl.hpp @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once + #include "common/comm/host_communicator/host_communicator.hpp" #include "oneapi/ccl/native_device_api/interop_utils.hpp" @@ -42,7 +43,7 @@ ccl::event host_communicator::allgatherv_impl(const buffer_type* send_buf, ccl::native_type_info::dtype, attr, comm_impl.get(), - nullptr, + get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); @@ -57,7 +58,7 @@ ccl::event host_communicator::allgatherv_impl(const buffer_type* send_buf, const ccl::allgatherv_attr& attr, const ccl::vector_class& deps) { ccl_coll_attr internal_attr(attr); - internal_attr.vector_buf = 1; + internal_attr.is_vector_buf = 1; ccl_request* req = ccl_allgatherv_impl(reinterpret_cast(send_buf), send_count, @@ -66,7 +67,7 @@ ccl::event host_communicator::allgatherv_impl(const buffer_type* send_buf, ccl::native_type_info::dtype, internal_attr, comm_impl.get(), - nullptr, + get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); @@ -80,23 +81,47 @@ ccl::event host_communicator::allgatherv_impl(const buffer_type& send_buf, const ccl::stream::impl_value_t& stream, const ccl::allgatherv_attr& attr, const ccl::vector_class& deps) { - // TODO not implemented - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; + ccl_coll_attr internal_attr(attr); +#ifdef CCL_ENABLE_SYCL + internal_attr.is_sycl_buf = 1; +#endif // CCL_ENABLE_SYCL + ccl_request* req = ccl_allgatherv_impl(reinterpret_cast(&send_buf), + send_count, + reinterpret_cast(&recv_buf), + recv_counts.data(), + ccl::native_type_info::dtype, + internal_attr, + comm_impl.get(), + get_stream_ptr(stream), + deps); + return std::unique_ptr(new ccl::host_event_impl(req)); } template ccl::event host_communicator::allgatherv_impl( const buffer_type& send_buf, size_t send_count, - ccl::vector_class>& recv_buf, + ccl::vector_class>& recv_bufs, const ccl::vector_class& recv_counts, const ccl::stream::impl_value_t& stream, const ccl::allgatherv_attr& attr, const ccl::vector_class& deps) { - // TODO not implemented - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; + ccl_coll_attr internal_attr(attr); + internal_attr.is_vector_buf = 1; +#ifdef CCL_ENABLE_SYCL + internal_attr.is_sycl_buf = 1; +#endif // CCL_ENABLE_SYCL + ccl_request* req = ccl_allgatherv_impl(reinterpret_cast(&send_buf), + send_count, + (void*)(recv_bufs.data()), + recv_counts.data(), + ccl::native_type_info::dtype, + internal_attr, + comm_impl.get(), + get_stream_ptr(stream), + deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); } /* allreduce */ @@ -115,7 +140,7 @@ ccl::event host_communicator::allreduce_impl(const buffer_type* send_buf, reduction, attr, comm_impl.get(), - nullptr, + get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); @@ -129,9 +154,21 @@ ccl::event host_communicator::allreduce_impl(const buffer_type& send_buf, const ccl::stream::impl_value_t& stream, const ccl::allreduce_attr& attr, const ccl::vector_class& deps) { - // TODO not implemented - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; + ccl_coll_attr internal_attr(attr); +#ifdef CCL_ENABLE_SYCL + internal_attr.is_sycl_buf = 1; +#endif // CCL_ENABLE_SYCL + ccl_request* req = ccl_allreduce_impl(reinterpret_cast(&send_buf), + reinterpret_cast(&recv_buf), + count, + ccl::native_type_info::dtype, + reduction, + internal_attr, + comm_impl.get(), + get_stream_ptr(stream), + deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); } /* alltoall */ @@ -148,7 +185,7 @@ ccl::event host_communicator::alltoall_impl(const buffer_type* send_buf, ccl::native_type_info::dtype, attr, comm_impl.get(), - nullptr, + get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); @@ -161,7 +198,6 @@ ccl::event host_communicator::alltoall_impl(const ccl::vector_class& deps) { - // TODO not implemented throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); return {}; } @@ -173,9 +209,20 @@ ccl::event host_communicator::alltoall_impl(const buffer_type& send_buf, const ccl::stream::impl_value_t& stream, const ccl::alltoall_attr& attr, const ccl::vector_class& deps) { - // TODO not implemented - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; + ccl_coll_attr internal_attr(attr); +#ifdef CCL_ENABLE_SYCL + internal_attr.is_sycl_buf = 1; +#endif // CCL_ENABLE_SYCL + ccl_request* req = ccl_alltoall_impl(reinterpret_cast(&send_buf), + reinterpret_cast(&recv_buf), + count, + ccl::native_type_info::dtype, + internal_attr, + comm_impl.get(), + get_stream_ptr(stream), + deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); } template @@ -186,7 +233,6 @@ ccl::event host_communicator::alltoall_impl( const ccl::stream::impl_value_t& stream, const ccl::alltoall_attr& attr, const ccl::vector_class& dep) { - // TODO not implemented throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); return {}; } @@ -207,7 +253,7 @@ ccl::event host_communicator::alltoallv_impl(const buffer_type* send_buf, ccl::native_type_info::dtype, attr, comm_impl.get(), - nullptr, + get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); @@ -221,7 +267,6 @@ ccl::event host_communicator::alltoallv_impl(const ccl::vector_class& dep) { - // TODO not implemented throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); return {}; } @@ -234,10 +279,23 @@ ccl::event host_communicator::alltoallv_impl(const buffer_type& send_buf, const ccl::stream::impl_value_t& stream, const ccl::alltoallv_attr& attr, const ccl::vector_class& deps) { - // TODO not implemented - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; + ccl_coll_attr internal_attr(attr); +#ifdef CCL_ENABLE_SYCL + internal_attr.is_sycl_buf = 1; +#endif // CCL_ENABLE_SYCL + ccl_request* req = ccl_alltoallv_impl(reinterpret_cast(&send_buf), + send_counts.data(), + reinterpret_cast(&recv_buf), + recv_counts.data(), + ccl::native_type_info::dtype, + internal_attr, + comm_impl.get(), + get_stream_ptr(stream), + deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); } + template ccl::event host_communicator::alltoallv_impl( const ccl::vector_class>& send_buf, @@ -247,7 +305,6 @@ ccl::event host_communicator::alltoallv_impl( const ccl::stream::impl_value_t& stream, const ccl::alltoallv_attr& attr, const ccl::vector_class& dep) { - // TODO not implemented throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); return {}; } @@ -266,7 +323,7 @@ ccl::event host_communicator::broadcast_impl(buffer_type* buf, root, attr, comm_impl.get(), - nullptr, + get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); @@ -279,9 +336,20 @@ ccl::event host_communicator::broadcast_impl(buffer_type& buf, const ccl::stream::impl_value_t& stream, const ccl::broadcast_attr& attr, const ccl::vector_class& deps) { - // TODO not implemented - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; + ccl_coll_attr internal_attr(attr); +#ifdef CCL_ENABLE_SYCL + internal_attr.is_sycl_buf = 1; +#endif // CCL_ENABLE_SYCL + ccl_request* req = ccl_broadcast_impl(reinterpret_cast(&buf), + count, + ccl::native_type_info::dtype, + root, + internal_attr, + comm_impl.get(), + get_stream_ptr(stream), + deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); } /* reduce */ @@ -302,7 +370,7 @@ ccl::event host_communicator::reduce_impl(const buffer_type* send_buf, root, attr, comm_impl.get(), - nullptr, + get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); @@ -317,9 +385,22 @@ ccl::event host_communicator::reduce_impl(const buffer_type& send_buf, const ccl::stream::impl_value_t& stream, const ccl::reduce_attr& attr, const ccl::vector_class& deps) { - // TODO not implemented - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; + ccl_coll_attr internal_attr(attr); +#ifdef CCL_ENABLE_SYCL + internal_attr.is_sycl_buf = 1; +#endif // CCL_ENABLE_SYCL + ccl_request* req = ccl_reduce_impl(reinterpret_cast(&send_buf), + reinterpret_cast(&recv_buf), + count, + ccl::native_type_info::dtype, + reduction, + root, + internal_attr, + comm_impl.get(), + get_stream_ptr(stream), + deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); } /* reduce_scatter */ @@ -338,7 +419,7 @@ ccl::event host_communicator::reduce_scatter_impl(const buffer_type* send_buf, reduction, attr, comm_impl.get(), - nullptr, + get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); @@ -352,9 +433,21 @@ ccl::event host_communicator::reduce_scatter_impl(const buffer_type& send_buf, const ccl::stream::impl_value_t& stream, const ccl::reduce_scatter_attr& attr, const ccl::vector_class& deps) { - // TODO not implemented - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; + ccl_coll_attr internal_attr(attr); +#ifdef CCL_ENABLE_SYCL + internal_attr.is_sycl_buf = 1; +#endif // CCL_ENABLE_SYCL + ccl_request* req = ccl_reduce_scatter_impl(reinterpret_cast(&send_buf), + reinterpret_cast(&recv_buf), + recv_count, + ccl::native_type_info::dtype, + reduction, + internal_attr, + comm_impl.get(), + get_stream_ptr(stream), + deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); } /* sparse_allreduce */ @@ -384,7 +477,7 @@ ccl::event host_communicator::sparse_allreduce_impl(const index_buffer_type* sen reduction, attr, comm_impl.get(), - nullptr, + get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); @@ -403,7 +496,6 @@ ccl::event host_communicator::sparse_allreduce_impl(const index_buffer_container const ccl::stream::impl_value_t& stream, const ccl::sparse_allreduce_attr& attr, const ccl::vector_class& deps) { - // TODO not implemented throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); return {}; } diff --git a/src/common/comm/l0/comm_context_impl.hpp b/src/common/comm/l0/comm_context_impl.hpp index f69645f42..f35795bd1 100644 --- a/src/common/comm/l0/comm_context_impl.hpp +++ b/src/common/comm/l0/comm_context_impl.hpp @@ -49,7 +49,7 @@ ccl::communicator_interface_ptr ccl::comm_group::create_communicator_from_group( auto host_comm = pimpl->get_host_communicator(); - if (device_count_per_process == 1 && !ccl::global_data::env().enable_comm_kernels) { + if (device_count_per_process == 1 /*&& !ccl::global_data::env().enable_comm_kernels*/) { /* special single device case */ LOG_TRACE("create single device communicator from SYCL device"); //TODO diff --git a/src/common/comm/l0/comm_context_storage.cpp b/src/common/comm/l0/comm_context_storage.cpp index cafd14344..1c31b1e4a 100644 --- a/src/common/comm/l0/comm_context_storage.cpp +++ b/src/common/comm/l0/comm_context_storage.cpp @@ -31,7 +31,7 @@ group_context::comm_group_t group_context::group_by_kvs( const std::vector& local_thread_device_group_ranks, int cluster_device_group_size, std::shared_ptr kvs) { - LOG_INFO("thread acquire by barrier"); + LOG_DEBUG("thread acquire by barrier"); std::shared_ptr atl = std::shared_ptr( new atl_wrapper(cluster_device_group_size, local_thread_device_group_ranks, kvs)); @@ -40,42 +40,43 @@ group_context::comm_group_t group_context::group_by_kvs( * Most of the cases are handled in communicator_impl_details.hpp, but here we check the case * when we have multiple threads and each of them has 1 device. And we don't know the total number * of ranks in the process until we sync them above */ - if (atl->get_ranks_per_process() > 1 && !ccl::global_data::env().enable_comm_kernels) { + if (atl->get_ranks_per_process() > 1 /* && !ccl::global_data::env().enable_comm_kernels*/) { throw ccl::unimplemented("API", "create_communicators", "for multiple devices"); } - LOG_INFO("thread released by barrier"); - LOG_INFO("cluster_device_group size: ", - cluster_device_group_size, - "\nThread device group ranks size: ", - local_thread_device_group_ranks.size()); + LOG_DEBUG("thread released by barrier"); + LOG_DEBUG("cluster_device_group size: ", + cluster_device_group_size, + "\nthread device group ranks size: ", + local_thread_device_group_ranks.size()); for (size_t i = 0; i < local_thread_device_group_ranks.size(); i++) { - LOG_INFO("\nLocal thread device group ranks: ", local_thread_device_group_ranks[i]); + LOG_DEBUG("\nlocal thread device group ranks: ", local_thread_device_group_ranks[i]); } + // register group slot in global context table, based on communicator id comm_group_t group = group_context::group_by_comm(atl); - // sync existing group: blocking operation - wait for all groups - LOG_INFO("group thread barrier acquired: ", static_cast(group.get())); - group->sync_group_size(local_thread_device_group_ranks.size()); - LOG_INFO("group thread barrier released: ", static_cast(group.get())); + // if (ccl::global_data::env().enable_comm_kernels) { + // // sync existing group: blocking operation - wait for all groups + // LOG_DEBUG("group thread barrier acquired: ", static_cast(group.get())); + // group->sync_group_size(local_thread_device_group_ranks.size()); + // LOG_DEBUG("group thread barrier released: ", static_cast(group.get())); + // } + return group; } group_context::comm_group_t group_context::group_by_comm(std::shared_ptr atl) { - LOG_INFO("\n", - "\nATL info:", - "\n threads per process: ", - atl->get_threads_per_process(), - "\n ranks per process: ", - atl->get_ranks_per_process(), - "\n atl size: ", - atl->get_size(), - "\n rank: ", - atl->get_rank(), - "\n unique id of atl: ", - atl->get_id(), - "\n") + std::stringstream ss; + ss << "\n{\n" + << " ATL info:\n" + << " rank: " << atl->get_rank() << "\n" + << " size: " << atl->get_size() << "\n" + << " id: " << atl->get_id() << "\n" + << " ranks per process: " << atl->get_ranks_per_process() << "\n" + << " threads per process: " << atl->get_threads_per_process() << "\n" + << "}"; + LOG_INFO(ss.str()); comm_group_t group; { @@ -91,21 +92,21 @@ group_context::comm_group_t group_context::group_by_comm(std::shared_ptr(group.get()), - " has been created for unique_id: ", - unique_id, - ", threads per process: ", - threads_per_process, - ", ranks per process: ", - ranks_per_process); + LOG_DEBUG("comm group: ", + static_cast(group.get()), + " has been created for unique_id: ", + unique_id, + ", threads per process: ", + threads_per_process, + ", ranks per process: ", + ranks_per_process); } else { group = ctx_it->second; - LOG_INFO("get existing comm group: ", - static_cast(group.get()), - " for unique_id: ", - unique_id); + LOG_DEBUG("get existing comm group: ", + static_cast(group.get()), + " for unique_id: ", + unique_id); } } return group; diff --git a/src/common/comm/l0/communicator/device_group/device_a2a_communicator.cpp b/src/common/comm/l0/communicator/device_group/device_a2a_communicator.cpp deleted file mode 100644 index 42b8afe38..000000000 --- a/src/common/comm/l0/communicator/device_group/device_a2a_communicator.cpp +++ /dev/null @@ -1,272 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" -#include "common/comm/l0/communicator/device_group/device_a2a_communicator_impl.hpp" -#include "common/comm/l0/gpu_comm_attr.hpp" -#include "common/comm/l0/context/thread_group_ctx.hpp" -#include "common/comm/l0/context/process_group_ctx.hpp" - -using namespace ccl; - -device_group_a2a_communicator::device_group_a2a_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& ctx, - size_t thread_idx, - size_t process_idx, - const ccl::comm_split_attr& attr) - : base_t(std::move(device), std::move(ctx), thread_idx, process_idx /*, comm_attr*/, attr) { -} - -void device_group_a2a_communicator::visit(ccl::gpu_comm_attr& comm_attr) { - auto process_ctx = comm_attr.get_process_context(); - auto thread_ctx = process_ctx->get_thread_context(process_id); - auto device_ctx = thread_ctx->get_device_group_ctx(thread_id); - - ctx = device_ctx; - - //get rank & size - this->initialize_comm_addr(get_device_path(), - ctx->get_group_topology()); - - this->set_comm_group_id(comm_attr.get_unique_id()); -} - -ccl::event device_group_a2a_communicator::barrier(const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented yet"); -} - -/* allgatherv */ -ccl::event device_group_a2a_communicator::allgatherv_impl( - const void* send_buf, - size_t send_count, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event device_group_a2a_communicator::allgatherv_impl( - const void* send_buf, - size_t send_count, - const ccl::vector_class& recv_bufs, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* allreduce */ -ccl::event device_group_a2a_communicator::allreduce_impl( - const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - int comm_rank = rank(); - LOG_DEBUG("communicator for device idx: ", get_device_path(), ", rank idx: ", comm_rank); - - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - count * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - ccl_buffer recv_entry_buffer( - &recv_buf, count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(); - - const auto& in_process_gpu_storage = community->get_devices(); - const auto& virtual_process_gpu_storage = community->get_devices(); - ; - - device_group_scheduler::schedule_ptr schedule; - - //source for collective operation is real gpu or virtual gpu - auto real_device_it = in_process_gpu_storage.find(comm_rank); - if (real_device_it != in_process_gpu_storage.end()) { - LOG_DEBUG("Invoke: ", real_device_it->second->to_string()); - - /* TODO - - using gpu_allreduce_entry = l0_allreduce_typed_entry; - - schedule = - ctx->scheduler_impl->submit_entry(*device_community_impl, - real_device_it->second,send_entry_buffer, - recv_entry_buffer, - count, - reduction); - */ - } - else { - auto virtual_device_it = virtual_process_gpu_storage.find(comm_rank); - if (virtual_device_it != virtual_process_gpu_storage.end()) { - LOG_DEBUG("Invoke: ", virtual_device_it->second->to_string()); - /* TODO - - using gpu_allreduce_entry = l0_allreduce_typed_entry; - - schedule = - ctx->scheduler_impl->submit_entry(*device_community_impl, - virtual_device_it->second,send_entry_buffer, - recv_entry_buffer, - count, - reduction); - */ - } - } - - //if sched is not ready - send NULL - if (schedule) { - LOG_DEBUG("Device group finalized"); - } - return std::unique_ptr(new ccl::gpu_event_impl(std::move(schedule))); -} - -/* alltoall */ -ccl::event device_group_a2a_communicator::alltoall_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event device_group_a2a_communicator::alltoall_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -ccl::event device_group_a2a_communicator::alltoallv_impl( - const void* send_buf, - const ccl::vector_class& send_counts, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event device_group_a2a_communicator::alltoallv_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - ccl::vector_class recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -ccl::event device_group_a2a_communicator::broadcast_impl( - void* buf, - size_t count, - ccl::datatype dtype, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce */ -ccl::event device_group_a2a_communicator::reduce_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce_scatter */ -ccl::event device_group_a2a_communicator::reduce_scatter_impl( - const void* send_buf, - void* recv_buf, - size_t recv_count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* sparse_allreduce */ -ccl::event device_group_a2a_communicator::sparse_allreduce_impl( - const void* send_ind_buf, - size_t send_ind_count, - const void* send_val_buf, - size_t send_val_count, - void* recv_ind_buf, - size_t recv_ind_count, - void* recv_val_buf, - size_t recv_val_count, - ccl::datatype index_dtype, - ccl::datatype value_dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -COMM_INTERFACE_COLL_INSTANTIATION(device_group_a2a_communicator); -#ifdef CCL_ENABLE_SYCL -SYCL_COMM_INTERFACE_COLL_INSTANTIATION(device_group_a2a_communicator); -#endif /* CCL_ENABLE_SYCL */ diff --git a/src/common/comm/l0/communicator/device_group/device_a2a_communicator.hpp b/src/common/comm/l0/communicator/device_group/device_a2a_communicator.hpp deleted file mode 100644 index f9520c0e1..000000000 --- a/src/common/comm/l0/communicator/device_group/device_a2a_communicator.hpp +++ /dev/null @@ -1,63 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/l0/communicator/typed_base_communicator.hpp" - -namespace native { -struct device_group_context; -} - -class device_group_a2a_communicator : public typed_base_communicator { -public: - using base_t = typed_base_communicator; - - using communication_devices_t = native::device_variant_t, - native::ccl_numa_proxy*/>; - - device_group_a2a_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& ctx, - size_t thread_idx, - size_t proces_idx, - const ccl::comm_split_attr& attr); - - void visit(ccl::gpu_comm_attr& comm_attr) override; - - ccl::event barrier(const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) override; - - COMM_IMPL_DECLARATION - COMM_IMPL_CLASS_DECLARATION - COMM_IMPL_SPARSE_DECLARATION - COMM_IMPL_SPARSE_CLASS_DECLARATION - - communication_devices_t& get_communication_device() { - return communication_device; - } - -private: - std::shared_ptr ctx; - communication_devices_t communication_device; -}; diff --git a/src/common/comm/l0/communicator/device_group/device_a2a_communicator_impl.hpp b/src/common/comm/l0/communicator/device_group/device_a2a_communicator_impl.hpp deleted file mode 100644 index f509c5d4e..000000000 --- a/src/common/comm/l0/communicator/device_group/device_a2a_communicator_impl.hpp +++ /dev/null @@ -1,322 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/l0/communicator/device_group/device_a2a_communicator.hpp" -#include "common/comm/l0/communicator/typed_base_communicator_impl.hpp" - -#include "common/comm/l0/devices/devices_declaration.hpp" -#include "common/comm/l0/device_community.hpp" -#include "common/comm/l0/context/device_group_ctx.hpp" -#include "common/comm/l0/scheduler/device_group_scheduler.hpp" -#include "common/event/impls/gpu_event.hpp" - -#include "common/comm/l0/communicator/device_group/device_communicator_utils.hpp" - -/* allgatherv */ -template -ccl::event device_group_a2a_communicator::allgatherv_impl( - const buffer_type* send_buf, - size_t send_count, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event device_group_a2a_communicator::allgatherv_impl( - const buffer_type* send_buf, - size_t send_count, - ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event device_group_a2a_communicator::allgatherv_impl( - const buffer_type& send_buf, - size_t send_count, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event device_group_a2a_communicator::allgatherv_impl( - const buffer_type& send_buf, - size_t send_count, - ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* allreduce */ -template -ccl::event device_group_a2a_communicator::allreduce_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - return allreduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - ccl::native_type_info::dtype, - reduction, - stream, - attr, - deps); -} - -template -ccl::event device_group_a2a_communicator::allreduce_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoall */ -template -ccl::event device_group_a2a_communicator::alltoall_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event device_group_a2a_communicator::alltoall_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event device_group_a2a_communicator::alltoall_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event device_group_a2a_communicator::alltoall_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class>& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -template -ccl::event device_group_a2a_communicator::alltoallv_impl( - const buffer_type* send_buf, - const ccl::vector_class& send_counts, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event device_group_a2a_communicator::alltoallv_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event device_group_a2a_communicator::alltoallv_impl( - const buffer_type& send_buf, - const ccl::vector_class& send_counts, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event device_group_a2a_communicator::alltoallv_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -template -ccl::event device_group_a2a_communicator::broadcast_impl( - buffer_type* buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event device_group_a2a_communicator::broadcast_impl( - buffer_type& buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce */ -template -ccl::event device_group_a2a_communicator::reduce_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event device_group_a2a_communicator::reduce_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce_scatter */ -template -ccl::event device_group_a2a_communicator::reduce_scatter_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event device_group_a2a_communicator::reduce_scatter_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* sparse_allreduce */ -template -ccl::event device_group_a2a_communicator::sparse_allreduce_impl( - const index_buffer_type* send_ind_buf, - size_t send_ind_count, - const value_buffer_type* send_val_buf, - size_t send_val_count, - index_buffer_type* recv_ind_buf, - size_t recv_ind_count, - value_buffer_type* recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event device_group_a2a_communicator::sparse_allreduce_impl( - const index_buffer_container_type& send_ind_buf, - size_t send_ind_count, - const value_buffer_container_type& send_val_buf, - size_t send_val_count, - index_buffer_container_type& recv_ind_buf, - size_t recv_ind_count, - value_buffer_container_type& recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} diff --git a/src/common/comm/l0/communicator/device_group/device_communicator_utils.hpp b/src/common/comm/l0/communicator/device_group/device_communicator_utils.hpp deleted file mode 100644 index 93ae547cc..000000000 --- a/src/common/comm/l0/communicator/device_group/device_communicator_utils.hpp +++ /dev/null @@ -1,68 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/l0/devices/devices_declaration.hpp" -#include "common/comm/l0/device_community.hpp" - -template - class algorithm> -struct communication_device_expander { - template - void operator()(native::device_t_ptr& comm_device, - std::shared_ptr& ctx, - typename native::device_community_container::element_type community, - Args&&... args) { - if (comm_device) { - LOG_DEBUG("Invoke: ", comm_device->to_string()); - - using gpu_entry = algorithm; - - schedule = ctx->scheduler_impl - ->submit_entry( - *community, comm_device, std::forward(args)...); - } - } - - std::unique_ptr schedule; -}; - -template - class algorithm, - class... Args> -std::unique_ptr do_collective_op( - native::device_variant_t& - communication_device, - std::shared_ptr& ctx, - typename native::device_community_container::element_type community, - native::ccl_driver_context_ptr native_context, - Args&&... args) { - communication_device_expander expander; - ccl_tuple_for_each_args(communication_device, - expander, - ctx, - community, - native_context, - std::forward(args)...); - if (expander.schedule) { - LOG_DEBUG("Device group finalized"); - } - return std::unique_ptr( - new ccl::gpu_shared_event_impl(std::move(expander.schedule))); -} diff --git a/src/common/comm/l0/communicator/device_group/device_ring_communicator.cpp b/src/common/comm/l0/communicator/device_group/device_ring_communicator.cpp deleted file mode 100644 index 95ed26cfe..000000000 --- a/src/common/comm/l0/communicator/device_group/device_ring_communicator.cpp +++ /dev/null @@ -1,250 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" -#include "common/comm/l0/communicator/device_group/device_ring_communicator_impl.hpp" -#include "common/comm/l0/gpu_comm_attr.hpp" -#include "common/comm/l0/context/thread_group_ctx.hpp" -#include "common/comm/l0/context/process_group_ctx.hpp" - -using namespace ccl; - -device_group_ring_communicator::device_group_ring_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& ctx, - size_t thread_idx, - size_t process_idx, - const ccl::comm_split_attr& attr) - : base_t(std::move(device), std::move(ctx), thread_idx, process_idx /*, comm_attr*/, attr) { -} - -void device_group_ring_communicator::visit(ccl::gpu_comm_attr& comm_attr) { - auto process_ctx = comm_attr.get_process_context(); - auto thread_ctx = process_ctx->get_thread_context(process_id); - auto device_ctx = thread_ctx->get_device_group_ctx(thread_id); - - ctx = device_ctx; - - //get rank & size - - this->initialize_comm_addr(get_device_path(), - ctx->get_group_topology()); - - this->set_comm_group_id(comm_attr.get_unique_id()); -} - -ccl::event device_group_ring_communicator::barrier(const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented yet"); -} - -/* allgatherv */ -ccl::event device_group_ring_communicator::allgatherv_impl( - const void* send_buf, - size_t send_count, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event device_group_ring_communicator::allgatherv_impl( - const void* send_buf, - size_t send_count, - const ccl::vector_class& recv_bufs, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* allreduce */ -ccl::event device_group_ring_communicator::allreduce_impl( - const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - size_t ring_index = 0; - - int comm_rank = rank(); - LOG_DEBUG("communicator for device idx: ", - get_device_path(), - ", rank idx: , ring_index: ", - comm_rank, - ring_index); - - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - count * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - ccl_buffer recv_entry_buffer( - &recv_buf, count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(ring_index); - - const coll_param_gpu params(ccl_coll_allreduce, dtype, reduction); - - return do_collective_op( - communication_device, - ctx, - community, - this->get_native_context(), - send_entry_buffer, - recv_entry_buffer, - count, - params, - stream); -} - -/* alltoall */ -ccl::event device_group_ring_communicator::alltoall_impl( - const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event device_group_ring_communicator::alltoall_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -ccl::event device_group_ring_communicator::alltoallv_impl( - const void* send_buf, - const ccl::vector_class& send_counts, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event device_group_ring_communicator::alltoallv_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - ccl::vector_class recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -ccl::event device_group_ring_communicator::broadcast_impl( - void* buf, - size_t count, - ccl::datatype dtype, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce */ -ccl::event device_group_ring_communicator::reduce_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce_scatter */ -ccl::event device_group_ring_communicator::reduce_scatter_impl( - const void* send_buf, - void* recv_buf, - size_t recv_count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* sparse_allreduce */ -ccl::event device_group_ring_communicator::sparse_allreduce_impl( - const void* send_ind_buf, - size_t send_ind_count, - const void* send_val_buf, - size_t send_val_count, - void* recv_ind_buf, - size_t recv_ind_count, - void* recv_val_buf, - size_t recv_val_count, - ccl::datatype index_dtype, - ccl::datatype value_dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -COMM_INTERFACE_COLL_INSTANTIATION(device_group_ring_communicator); -#ifdef CCL_ENABLE_SYCL -SYCL_COMM_INTERFACE_COLL_INSTANTIATION(device_group_ring_communicator); -#endif /* CCL_ENABLE_SYCL */ diff --git a/src/common/comm/l0/communicator/device_group/device_ring_communicator.hpp b/src/common/comm/l0/communicator/device_group/device_ring_communicator.hpp deleted file mode 100644 index 93255e666..000000000 --- a/src/common/comm/l0/communicator/device_group/device_ring_communicator.hpp +++ /dev/null @@ -1,64 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/l0/communicator/typed_base_communicator.hpp" - -namespace native { -struct device_group_context; -} - -class device_group_ring_communicator - : public typed_base_communicator { -public: - using base_t = typed_base_communicator; - - using communication_devices_t = native::device_variant_t, - native::ccl_numa_proxy*/>; - - device_group_ring_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& ctx, - size_t thread_idx, - size_t proces_idx, - const ccl::comm_split_attr& attr); - - void visit(ccl::gpu_comm_attr& comm_attr) override; - - ccl::event barrier(const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) override; - - COMM_IMPL_DECLARATION - COMM_IMPL_CLASS_DECLARATION - COMM_IMPL_SPARSE_DECLARATION - COMM_IMPL_SPARSE_CLASS_DECLARATION - - communication_devices_t& get_communication_device() { - return communication_device; - } - -private: - std::shared_ptr ctx; - communication_devices_t communication_device; -}; diff --git a/src/common/comm/l0/communicator/device_group/device_ring_communicator_impl.hpp b/src/common/comm/l0/communicator/device_group/device_ring_communicator_impl.hpp deleted file mode 100644 index b6dfaea7a..000000000 --- a/src/common/comm/l0/communicator/device_group/device_ring_communicator_impl.hpp +++ /dev/null @@ -1,328 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/l0/communicator/device_group/device_ring_communicator.hpp" -#include "common/comm/l0/communicator/typed_base_communicator_impl.hpp" - -#include "common/comm/l0/devices/devices_declaration.hpp" -#include "common/comm/l0/device_community.hpp" -#include "common/comm/l0/context/device_group_ctx.hpp" -#include "common/comm/l0/scheduler/device_group_scheduler.hpp" -#include "common/event/impls/gpu_event.hpp" - -#include "common/comm/l0/communicator/device_group/device_communicator_utils.hpp" - -/* allgatherv */ -template -ccl::event device_group_ring_communicator::allgatherv_impl( - const buffer_type* send_buf, - size_t send_count, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event device_group_ring_communicator::allgatherv_impl( - const buffer_type* send_buf, - size_t send_count, - ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event device_group_ring_communicator::allgatherv_impl( - const buffer_type& send_buf, - size_t send_count, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event device_group_ring_communicator::allgatherv_impl( - const buffer_type& send_buf, - size_t send_count, - ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* allreduce */ -template -ccl::event device_group_ring_communicator::allreduce_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - return allreduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - ccl::native_type_info::dtype, - reduction, - stream, - attr, - deps); -} - -template -ccl::event device_group_ring_communicator::allreduce_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoall */ -template -ccl::event device_group_ring_communicator::alltoall_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event device_group_ring_communicator::alltoall_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event device_group_ring_communicator::alltoall_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event device_group_ring_communicator::alltoall_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class>& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -template -ccl::event device_group_ring_communicator::alltoallv_impl( - const buffer_type* send_buf, - const ccl::vector_class& send_counts, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event device_group_ring_communicator::alltoallv_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event device_group_ring_communicator::alltoallv_impl( - const buffer_type& send_buf, - const ccl::vector_class& send_counts, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event device_group_ring_communicator::alltoallv_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -template -ccl::event device_group_ring_communicator::broadcast_impl( - buffer_type* buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event device_group_ring_communicator::broadcast_impl( - buffer_type& buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce */ -template -ccl::event device_group_ring_communicator::reduce_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event device_group_ring_communicator::reduce_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -/* reduce_scatter */ -template -ccl::event device_group_ring_communicator::reduce_scatter_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event device_group_ring_communicator::reduce_scatter_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* sparse_allreduce */ -template -ccl::event device_group_ring_communicator::sparse_allreduce_impl( - const index_buffer_type* send_ind_buf, - size_t send_ind_count, - const value_buffer_type* send_val_buf, - size_t send_val_count, - index_buffer_type* recv_ind_buf, - size_t recv_ind_count, - value_buffer_type* recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event device_group_ring_communicator::sparse_allreduce_impl( - const index_buffer_container_type& send_ind_buf, - size_t send_ind_count, - const value_buffer_container_type& send_val_buf, - size_t send_val_count, - index_buffer_container_type& recv_ind_buf, - size_t recv_ind_count, - value_buffer_container_type& recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} diff --git a/src/common/comm/l0/communicator/process_group/process_a2a_communicator.cpp b/src/common/comm/l0/communicator/process_group/process_a2a_communicator.cpp deleted file mode 100644 index 4f37872e9..000000000 --- a/src/common/comm/l0/communicator/process_group/process_a2a_communicator.cpp +++ /dev/null @@ -1,315 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" -#include "common/comm/l0/communicator/process_group/process_a2a_communicator_impl.hpp" -#include "common/comm/l0/gpu_comm_attr.hpp" -#include "common/comm/l0/context/process_group_ctx.hpp" - -using namespace ccl; - -process_a2a_communicator::process_a2a_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& ctx, - size_t thread_idx, - size_t process_idx, - const ccl::comm_split_attr& attr) - : base_t(std::move(device), std::move(ctx), thread_idx, process_idx, /*comm_attr, */ attr) { -} - -void process_a2a_communicator::visit(ccl::gpu_comm_attr& comm_attr) { - ctx = comm_attr.get_process_context(); - - //get rank & size - auto topology = ctx->get_process_topology(process_id, thread_id); - this->initialize_comm_addr(get_device_path(), topology); - - this->set_comm_group_id(comm_attr.get_unique_id()); -} - -ccl::event process_a2a_communicator::barrier(const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented yet"); -} - -/* allgatherv */ -ccl::event process_a2a_communicator::allgatherv_impl(const void* send_buf, - size_t send_count, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event process_a2a_communicator::allgatherv_impl(const void* send_buf, - size_t send_count, - const ccl::vector_class& recv_bufs, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* allreduce */ -ccl::event process_a2a_communicator::allreduce_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - LOG_DEBUG("communicator for device idx: ", get_device_path(), ", rank idx: ", comm_rank); - - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - count * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - ccl_buffer recv_entry_buffer( - &recv_buf, count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(); - - const auto& in_process_gpu_storage = community->get_devices(); - const auto& virtual_process_gpu_storage = community->get_devices(); - - auto& ipc_gpu_storage = community->get_devices(); - (void)ipc_gpu_storage; - auto& in_process_ipc_source_real_gpu_storage = - community->get_devices>(); - auto& in_process_ipc_source_virtual_gpu_storage = - community->get_devices>(); - - allied_process_group_scheduler::thread_schedule_ptr schedule; - //source for collective operation is ipc sources, real gpu or virtual gpu - auto ipc_src_real_it = in_process_ipc_source_real_gpu_storage.find(comm_rank); - if (ipc_src_real_it != in_process_ipc_source_real_gpu_storage.end()) { - LOG_DEBUG("Invoke: ", ipc_src_real_it->second->to_string()); - /* - using gpu_allreduce_entry = l0_allreduce_typed_entry, - group_id>; - - schedule = - ctx->scheduler_impl->submit_entry_ipc(process_id, - thread_id, - *device_community_impl, - ipc_src_real_it->second, - send_entry_buffer, - recv_entry_buffer, - count, - dtype, - reduction); - */ - } - else { - auto ipc_src_virt_it = in_process_ipc_source_virtual_gpu_storage.find(comm_rank); - if (ipc_src_virt_it != in_process_ipc_source_virtual_gpu_storage.end()) { - LOG_DEBUG("Invoke: ", ipc_src_virt_it->second->to_string()); - /* - using gpu_allreduce_entry = l0_allreduce_typed_entry, - group_id>; - - schedule = - ctx->scheduler_impl->submit_entry_ipc(process_id, - thread_id, - *device_community_impl, - ipc_src_virt_it->second, - send_entry_buffer, - recv_entry_buffer, - count, - dtype, - reduction); - */ - } - else { - auto real_device_it = in_process_gpu_storage.find(comm_rank); - if (real_device_it != in_process_gpu_storage.end()) { - LOG_DEBUG("Invoke: ", real_device_it->second->to_string()); - /* - using gpu_allreduce_entry = l0_allreduce_typed_entry; - - schedule = - ctx->scheduler_impl->submit_entry(process_id, - thread_id, - *device_community_impl, - real_device_it->second,send_entry_buffer, - recv_entry_buffer, - count, - dtype, - reduction); - */ - } - else { - auto virtual_device_it = virtual_process_gpu_storage.find(comm_rank); - if (virtual_device_it != virtual_process_gpu_storage.end()) { - LOG_DEBUG("Invoke: ", virtual_device_it->second->to_string()); - /* - using gpu_allreduce_entry = l0_allreduce_typed_entry; - - schedule = - ctx->scheduler_impl->submit_entry(process_id, - thread_id, - *device_community_impl, - virtual_device_it->second,send_entry_buffer, - recv_entry_buffer, - count, - dtype, - reduction); - */ - } - } - } - } - - //if sched is not ready - send NULL - if (schedule) { - LOG_DEBUG("Device group finalized"); - } - return std::unique_ptr(new ccl::gpu_shared_event_impl(std::move(schedule))); -} - -/* alltoall */ -ccl::event process_a2a_communicator::alltoall_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event process_a2a_communicator::alltoall_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -ccl::event process_a2a_communicator::alltoallv_impl(const void* send_buf, - const ccl::vector_class& send_counts, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event process_a2a_communicator::alltoallv_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - ccl::vector_class recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -ccl::event process_a2a_communicator::broadcast_impl(void* buf, - size_t count, - ccl::datatype dtype, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce */ -ccl::event process_a2a_communicator::reduce_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce_scatter */ -ccl::event process_a2a_communicator::reduce_scatter_impl( - const void* send_buf, - void* recv_buf, - size_t recv_count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* sparse_allreduce */ -ccl::event process_a2a_communicator::sparse_allreduce_impl( - const void* send_ind_buf, - size_t send_ind_count, - const void* send_val_buf, - size_t send_val_count, - void* recv_ind_buf, - size_t recv_ind_count, - void* recv_val_buf, - size_t recv_val_count, - ccl::datatype index_dtype, - ccl::datatype value_dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -COMM_INTERFACE_COLL_INSTANTIATION(process_a2a_communicator); -#ifdef CCL_ENABLE_SYCL -SYCL_COMM_INTERFACE_COLL_INSTANTIATION(process_a2a_communicator); -#endif /* CCL_ENABLE_SYCL */ diff --git a/src/common/comm/l0/communicator/process_group/process_a2a_communicator.hpp b/src/common/comm/l0/communicator/process_group/process_a2a_communicator.hpp deleted file mode 100644 index c266ebd8f..000000000 --- a/src/common/comm/l0/communicator/process_group/process_a2a_communicator.hpp +++ /dev/null @@ -1,65 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/l0/communicator/typed_base_communicator.hpp" - -namespace native { -struct process_group_context; -} - -class process_a2a_communicator : public typed_base_communicator { -public: - using base_t = typed_base_communicator; - - using communication_devices_t = native::device_variant_t, - native::ccl_ipc_source_gpu_comm - /*, TODO disabled t now - native::ccl_numa_proxy, - native::ccl_numa_proxy*/>; - - process_a2a_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& ctx, - size_t thread_idx, - size_t proces_idx, - const ccl::comm_split_attr& attr); - - void visit(ccl::gpu_comm_attr& comm_attr) override; - - ccl::event barrier(const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) override; - - COMM_IMPL_DECLARATION - COMM_IMPL_CLASS_DECLARATION - COMM_IMPL_SPARSE_DECLARATION - COMM_IMPL_SPARSE_CLASS_DECLARATION - - communication_devices_t& get_communication_device() { - return communication_device; - } - -private: - std::shared_ptr ctx; - communication_devices_t communication_device; -}; diff --git a/src/common/comm/l0/communicator/process_group/process_a2a_communicator_impl.hpp b/src/common/comm/l0/communicator/process_group/process_a2a_communicator_impl.hpp deleted file mode 100644 index 3f06af5e6..000000000 --- a/src/common/comm/l0/communicator/process_group/process_a2a_communicator_impl.hpp +++ /dev/null @@ -1,314 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/l0/communicator/process_group/process_a2a_communicator.hpp" -#include "common/comm/l0/communicator/typed_base_communicator_impl.hpp" - -#include "common/comm/l0/devices/devices_declaration.hpp" -#include "common/comm/l0/device_community.hpp" -#include "common/comm/l0/context/process_group_ctx.hpp" -#include "common/comm/l0/scheduler/allied_process_group_scheduler.hpp" -#include "common/event/impls/gpu_event.hpp" - -#include "common/comm/l0/communicator/process_group/process_communicator_utils.hpp" -/* allgatherv */ -template -ccl::event process_a2a_communicator::allgatherv_impl(const buffer_type* send_buf, - size_t send_count, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event process_a2a_communicator::allgatherv_impl(const buffer_type* send_buf, - size_t send_count, - ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event process_a2a_communicator::allgatherv_impl(const buffer_type& send_buf, - size_t send_count, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event process_a2a_communicator::allgatherv_impl( - const buffer_type& send_buf, - size_t send_count, - ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* allreduce */ -template -ccl::event process_a2a_communicator::allreduce_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - return allreduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - ccl::native_type_info::dtype, - reduction, - stream, - attr, - deps); -} - -template -ccl::event process_a2a_communicator::allreduce_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoall */ -template -ccl::event process_a2a_communicator::alltoall_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event process_a2a_communicator::alltoall_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event process_a2a_communicator::alltoall_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event process_a2a_communicator::alltoall_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class>& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -template -ccl::event process_a2a_communicator::alltoallv_impl(const buffer_type* send_buf, - const ccl::vector_class& send_counts, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event process_a2a_communicator::alltoallv_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event process_a2a_communicator::alltoallv_impl(const buffer_type& send_buf, - const ccl::vector_class& send_counts, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event process_a2a_communicator::alltoallv_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -template -ccl::event process_a2a_communicator::broadcast_impl(buffer_type* buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event process_a2a_communicator::broadcast_impl(buffer_type& buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce */ -template -ccl::event process_a2a_communicator::reduce_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event process_a2a_communicator::reduce_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -/* reduce_scatter */ -template -ccl::event process_a2a_communicator::reduce_scatter_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event process_a2a_communicator::reduce_scatter_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* sparse_allreduce */ -template -ccl::event process_a2a_communicator::sparse_allreduce_impl( - const index_buffer_type* send_ind_buf, - size_t send_ind_count, - const value_buffer_type* send_val_buf, - size_t send_val_count, - index_buffer_type* recv_ind_buf, - size_t recv_ind_count, - value_buffer_type* recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event process_a2a_communicator::sparse_allreduce_impl( - const index_buffer_container_type& send_ind_buf, - size_t send_ind_count, - const value_buffer_container_type& send_val_buf, - size_t send_val_count, - index_buffer_container_type& recv_ind_buf, - size_t recv_ind_count, - value_buffer_container_type& recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} diff --git a/src/common/comm/l0/communicator/process_group/process_communicator_utils.hpp b/src/common/comm/l0/communicator/process_group/process_communicator_utils.hpp deleted file mode 100644 index 0984c5911..000000000 --- a/src/common/comm/l0/communicator/process_group/process_communicator_utils.hpp +++ /dev/null @@ -1,87 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/l0/devices/devices_declaration.hpp" -#include "common/comm/l0/device_community.hpp" - -template - class algorithm> -struct communication_process_device_expander { - template - void operator()(native::device_t_ptr& comm_device, - std::shared_ptr& ctx, - typename native::device_community_container::element_type community, - size_t process_id, - size_t thread_id, - Args&&... args) { - if (comm_device) { - LOG_DEBUG("Invoke: ", comm_device->to_string()); - - using gpu_entry = algorithm; - - schedule = ctx->scheduler_impl - ->submit_entry( - process_id, - thread_id, - *community, - comm_device, - std::forward(args)...); - } - } - - std::shared_ptr schedule; -}; - -template - class algorithm, - class... Args> -std::unique_ptr do_collective_op( - // TODO: can we avoid using device_variant here? Because it creates an instantiation of entry for each device which - // makes it slow to compile - native::device_variant_t, - native::ccl_ipc_source_gpu_comm, - native::ccl_numa_proxy, - native::ccl_numa_proxy, - native::ccl_scaleout_proxy, - native::ccl_scaleout_proxy>& - communication_device, - std::shared_ptr& ctx, - typename native::device_community_container::element_type community, - size_t process_id, - size_t thread_id, - native::ccl_driver_context_ptr native_context, - Args&&... args) { - communication_process_device_expander expander; - ccl_tuple_for_each_args(communication_device, - expander, - ctx, - community, - process_id, - thread_id, - native_context, - std::forward(args)...); - if (expander.schedule) { - LOG_DEBUG("Device group finalized"); - } - return std::unique_ptr( - new ccl::gpu_shared_event_impl(std::move(expander.schedule))); -} diff --git a/src/common/comm/l0/communicator/process_group/process_ring_communicator.cpp b/src/common/comm/l0/communicator/process_group/process_ring_communicator.cpp deleted file mode 100644 index 4c79f883e..000000000 --- a/src/common/comm/l0/communicator/process_group/process_ring_communicator.cpp +++ /dev/null @@ -1,464 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "oneapi/ccl.hpp" -#include "common/comm/l0/communicator/process_group/process_ring_communicator_impl.hpp" - -#include "common/comm/l0/gpu_comm_attr.hpp" - -using namespace ccl; - -process_ring_communicator::process_ring_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& ctx, - size_t thread_idx, - size_t process_idx, - const ccl::comm_split_attr& attr) - : base_t(std::move(device), std::move(ctx), thread_idx, process_idx, /*comm_attr,*/ attr) {} - -void process_ring_communicator::visit(ccl::gpu_comm_attr& comm_attr) { - ctx = comm_attr.get_process_context(); - - //get rank & size - auto topology = ctx->get_process_topology(process_id, thread_id); - this->initialize_comm_addr(get_device_path(), topology); - - this->set_comm_group_id(comm_attr.get_unique_id()); -} -/* -size_t process_ring_communicator::group_size() const -{ - return get_device_count() + - get_device_count>() + - / * get_device_count() + do no participate in group communication* / - get_device_count(); - -} -*/ - -ccl::event process_ring_communicator::barrier(const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented yet"); -} - -/* allgatherv */ -ccl::event process_ring_communicator::allgatherv_impl(const void* send_buf, - size_t send_count, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - size_t ring_index = 0; - LOG_DEBUG("communicator for device idx: ", - get_device_path(), - ", rank idx: ", - comm_rank, - ", ring_index :", - ring_index); - - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - send_count * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - ccl_buffer recv_entry_buffer( - &recv_buf, send_count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(ring_index); - - const coll_param_gpu params(ccl_coll_allgatherv, dtype); - - return do_collective_op( - communication_device, - ctx, - community, - process_id, - thread_id, - this->get_native_context(), - send_entry_buffer, - send_count, - recv_entry_buffer, - recv_counts.data(), - params, - stream); -} -ccl::event process_ring_communicator::allgatherv_impl(const void* send_buf, - size_t send_count, - const ccl::vector_class& recv_bufs, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* allreduce */ -ccl::event process_ring_communicator::allreduce_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - size_t ring_index = 0; - LOG_DEBUG("communicator for device idx: ", - get_device_path(), - ", rank idx: ", - comm_rank, - ", ring_index: ", - ring_index); - - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - count * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - ccl_buffer recv_entry_buffer( - &recv_buf, count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(ring_index); - - // TODO: we can get dtype value from buffer_type template, no need to introduce a new parameter - const coll_param_gpu params(ccl_coll_allreduce, dtype, reduction); - - return do_collective_op( - communication_device, - ctx, - community, - process_id, - thread_id, - this->get_native_context(), - send_entry_buffer, - recv_entry_buffer, - count, - params, - stream); -} - -/* alltoall */ -ccl::event process_ring_communicator::alltoall_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event process_ring_communicator::alltoall_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -ccl::event process_ring_communicator::alltoallv_impl(const void* send_buf, - const ccl::vector_class& send_counts, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - size_t ring_index = 0; - LOG_DEBUG("communicator for device idx: ", - get_device_path(), - ", rank idx: ", - comm_rank, - ", ring_index :", - ring_index); - size_t total_send_counts = std::accumulate(std::begin(send_counts), std::end(send_counts), 0); - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - total_send_counts * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - - size_t total_recv_counts = std::accumulate(std::begin(recv_counts), std::end(recv_counts), 0); - ccl_buffer recv_entry_buffer( - &recv_buf, total_recv_counts * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(ring_index); - - const coll_param_gpu params(ccl_coll_alltoallv, dtype); - - return do_collective_op( - communication_device, - ctx, - community, - process_id, - thread_id, - this->get_native_context(), - send_entry_buffer, - send_counts.data(), - total_send_counts, - recv_entry_buffer, - recv_counts.data(), - total_recv_counts, - params, - stream); -} -ccl::event process_ring_communicator::alltoallv_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - ccl::vector_class recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -ccl::event process_ring_communicator::broadcast_impl(void* buf, - size_t count, - ccl::datatype dtype, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - size_t ring_index = 0; - LOG_DEBUG("communicator for device idx: ", - get_device_path(), - ", rank idx: ", - comm_rank, - ", ring_index :", - ring_index); - - //TODO make const! - ccl_buffer entry_buffer( - &buf, count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(ring_index); - - const coll_param_gpu params(ccl_coll_bcast, dtype); - - return do_collective_op(communication_device, - ctx, - community, - process_id, - thread_id, - this->get_native_context(), - entry_buffer, - count, - root, - params, - stream); -} - -/* reduce */ -ccl::event process_ring_communicator::reduce_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - size_t ring_index = 0; - LOG_DEBUG("communicator for device idx: ", - get_device_path(), - ", rank idx: ", - comm_rank, - ", ring_index :", - ring_index); - - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - count * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - ccl_buffer recv_entry_buffer( - &recv_buf, count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(ring_index); - - const coll_param_gpu params(ccl_coll_allreduce, dtype, reduction); - - return do_collective_op(communication_device, - ctx, - community, - process_id, - thread_id, - this->get_native_context(), - send_entry_buffer, - recv_entry_buffer, - count, - reduction, - root, - params, - stream); -} - -/* reduce_scatter */ -ccl::event process_ring_communicator::reduce_scatter_impl( - const void* send_buf, - void* recv_buf, - size_t recv_count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - size_t ring_index = 0; - LOG_DEBUG("communicator for device idx: ", - get_device_path(), - ", rank idx: ", - comm_rank, - ", ring_index :", - ring_index); - - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - recv_count * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - ccl_buffer recv_entry_buffer( - &recv_buf, recv_count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(ring_index); - - const coll_param_gpu params(ccl_coll_reduce_scatter, dtype, reduction); - - return do_collective_op( - communication_device, - ctx, - community, - process_id, - thread_id, - this->get_native_context(), - send_entry_buffer, - recv_entry_buffer, - recv_count, - params, - stream); -} - -/* sparse_allreduce */ -ccl::event process_ring_communicator::sparse_allreduce_impl( - const void* send_ind_buf, - size_t send_ind_count, - const void* send_val_buf, - size_t send_val_count, - void* recv_ind_buf, - size_t recv_ind_count, - void* recv_val_buf, - size_t recv_val_count, - ccl::datatype index_dtype, - ccl::datatype value_dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -COMM_INTERFACE_COLL_INSTANTIATION(process_ring_communicator); -#ifdef CCL_ENABLE_SYCL -SYCL_COMM_INTERFACE_COLL_INSTANTIATION(process_ring_communicator); -#endif /* CCL_ENABLE_SYCL */ diff --git a/src/common/comm/l0/communicator/process_group/process_ring_communicator.hpp b/src/common/comm/l0/communicator/process_group/process_ring_communicator.hpp deleted file mode 100644 index efbe5c801..000000000 --- a/src/common/comm/l0/communicator/process_group/process_ring_communicator.hpp +++ /dev/null @@ -1,69 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/l0/communicator/typed_base_communicator.hpp" -#include "common/comm/usm_visitor/usm_visitors.hpp" - -namespace native { -struct process_group_context; -} - -class process_ring_communicator : public typed_base_communicator { -public: - using base_t = typed_base_communicator; - - using communication_devices_t = - native::device_variant_t, - native::ccl_ipc_source_gpu_comm, - native::ccl_numa_proxy, - native::ccl_numa_proxy, - native::ccl_scaleout_proxy, - native::ccl_scaleout_proxy>; - using coll_request_t = ccl::event; - - process_ring_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& ctx, - size_t thread_idx, - size_t process_idx, - const ccl::comm_split_attr& attr); - - void visit(ccl::gpu_comm_attr& comm_attr) override; - - ccl::event barrier(const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) override; - - COMM_IMPL_DECLARATION - COMM_IMPL_CLASS_DECLARATION - COMM_IMPL_SPARSE_DECLARATION - COMM_IMPL_SPARSE_CLASS_DECLARATION - - communication_devices_t& get_communication_device() { - return communication_device; - } - -private: - std::shared_ptr ctx; - communication_devices_t communication_device; -}; diff --git a/src/common/comm/l0/communicator/process_group/process_ring_communicator_impl.hpp b/src/common/comm/l0/communicator/process_group/process_ring_communicator_impl.hpp deleted file mode 100644 index 889cb32a3..000000000 --- a/src/common/comm/l0/communicator/process_group/process_ring_communicator_impl.hpp +++ /dev/null @@ -1,347 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/l0/communicator/process_group/process_ring_communicator.hpp" -#include "common/comm/l0/communicator/typed_base_communicator_impl.hpp" - -#include "common/comm/l0/devices/devices_declaration.hpp" -#include "common/comm/l0/device_community.hpp" -#include "common/comm/l0/context/process_group_ctx.hpp" -#include "common/comm/l0/scheduler/allied_process_group_scheduler.hpp" -#include "common/event/impls/gpu_event.hpp" -#include "common/comm/l0/communicator/process_group/process_communicator_utils.hpp" - -/* allgatherv */ -template -ccl::event process_ring_communicator::allgatherv_impl(const buffer_type* send_buf, - size_t send_count, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - return allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - ccl::native_type_info::dtype, - stream, - attr, - deps); -} - -template -ccl::event process_ring_communicator::allgatherv_impl(const buffer_type* send_buf, - size_t send_count, - ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event process_ring_communicator::allgatherv_impl(const buffer_type& send_buf, - size_t send_count, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event process_ring_communicator::allgatherv_impl( - const buffer_type& send_buf, - size_t send_count, - ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* allreduce */ -template -ccl::event process_ring_communicator::allreduce_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - return allreduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - ccl::native_type_info::dtype, - reduction, - stream, - attr, - deps); -} - -template -ccl::event process_ring_communicator::allreduce_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoall */ -template -ccl::event process_ring_communicator::alltoall_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event process_ring_communicator::alltoall_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event process_ring_communicator::alltoall_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event process_ring_communicator::alltoall_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class>& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -template -ccl::event process_ring_communicator::alltoallv_impl(const buffer_type* send_buf, - const ccl::vector_class& send_counts, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - return alltoallv_impl(static_cast(send_buf), - send_counts, - static_cast(recv_buf), - recv_counts, - ccl::native_type_info::dtype, - stream, - attr, - deps); -} - -template -ccl::event process_ring_communicator::alltoallv_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event process_ring_communicator::alltoallv_impl(const buffer_type& send_buf, - const ccl::vector_class& send_counts, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event process_ring_communicator::alltoallv_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -template -ccl::event process_ring_communicator::broadcast_impl(buffer_type* buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - return broadcast_impl(static_cast(buf), - count, - ccl::native_type_info::dtype, - root, - stream, - attr, - deps); -} - -template -ccl::event process_ring_communicator::broadcast_impl(buffer_type& buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce */ -template -ccl::event process_ring_communicator::reduce_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - return reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - ccl::native_type_info::dtype, - reduction, - root, - stream, - attr, - deps); -} - -template -ccl::event process_ring_communicator::reduce_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -/* reduce_scatter */ -template -ccl::event process_ring_communicator::reduce_scatter_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - return reduce_scatter_impl(static_cast(send_buf), - static_cast(recv_buf), - recv_count, - ccl::native_type_info::dtype, - reduction, - stream, - attr, - deps); -} -template -ccl::event process_ring_communicator::reduce_scatter_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* sparse_allreduce */ -template -ccl::event process_ring_communicator::sparse_allreduce_impl( - const index_buffer_type* send_ind_buf, - size_t send_ind_count, - const value_buffer_type* send_val_buf, - size_t send_val_count, - index_buffer_type* recv_ind_buf, - size_t recv_ind_count, - value_buffer_type* recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event process_ring_communicator::sparse_allreduce_impl( - const index_buffer_container_type& send_ind_buf, - size_t send_ind_count, - const value_buffer_container_type& send_val_buf, - size_t send_val_count, - index_buffer_container_type& recv_ind_buf, - size_t recv_ind_count, - value_buffer_container_type& recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} diff --git a/src/common/comm/l0/communicator/thread_group/thread_a2a_communicator.cpp b/src/common/comm/l0/communicator/thread_group/thread_a2a_communicator.cpp deleted file mode 100644 index 52b714ecd..000000000 --- a/src/common/comm/l0/communicator/thread_group/thread_a2a_communicator.cpp +++ /dev/null @@ -1,290 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" -#include "common/comm/l0/communicator/thread_group/thread_a2a_communicator_impl.hpp" -#include "common/comm/l0/gpu_comm_attr.hpp" -#include "common/comm/l0/context/process_group_ctx.hpp" - -using namespace ccl; - -thread_device_group_a2a_communicator::thread_device_group_a2a_communicator( - ccl::unified_device_type&& device, - ccl::unified_context_type&& ctx, - size_t thread_idx, - size_t process_idx, - const ccl::comm_split_attr& attr) - : base_t(std::move(device), std::move(ctx), thread_idx, process_idx, /*comm_attr, */ attr) { -} - -void thread_device_group_a2a_communicator::visit(ccl::gpu_comm_attr& comm_attr) { - auto process_ctx = comm_attr.get_process_context(); - auto thread_ctx = process_ctx->get_thread_context(process_id); - auto device_ctx = thread_ctx->get_device_group_ctx(thread_id); - (void)device_ctx; - - ctx = thread_ctx; - - //get rank & size - auto topology = ctx->get_thread_topology(thread_id); - this->initialize_comm_addr(get_device_path(), topology); - - this->set_comm_group_id(comm_attr.get_unique_id()); -} -/* -size_t thread_device_group_ring_communicator::group_size() const -{ - return get_device_count() + - / * get_device_count>() + Will add further* / - get_device_count(); - -} -*/ -ccl::event thread_device_group_a2a_communicator::barrier( - const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented yet"); -} - -/* allgatherv */ -ccl::event thread_device_group_a2a_communicator::allgatherv_impl( - const void* send_buf, - size_t send_count, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event thread_device_group_a2a_communicator::allgatherv_impl( - const void* send_buf, - size_t send_count, - const ccl::vector_class& recv_bufs, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* allreduce */ -ccl::event thread_device_group_a2a_communicator::allreduce_impl( - const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - LOG_DEBUG("communicator for device idx: ", get_device_path(), ", rank idx: ", comm_rank); - - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - count * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - ccl_buffer recv_entry_buffer( - &recv_buf, count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(); - - const auto& in_process_gpu_storage = community->get_devices(); - const auto& virtual_process_gpu_storage = community->get_devices(); - - auto& ipc_gpu_storage = community->get_devices(); - (void)ipc_gpu_storage; - - thread_group_scheduler::thread_schedule_ptr schedule; - //source for collective operation is real gpu or virtual gpu - auto real_device_it = in_process_gpu_storage.find(comm_rank); - if (real_device_it != in_process_gpu_storage.end()) { - LOG_DEBUG("Invoke: ", real_device_it->second->to_string()); - /* - using gpu_allreduce_entry = l0_allreduce_typed_entry; - - schedule = - ctx->scheduler_impl->submit_entry(thread_id, - *ctx->get_thread_topology(thread_id), - real_device_it->second,send_entry_buffer, - recv_entry_buffer, - count, - reduction); - */ - } - else { - auto virtual_device_it = virtual_process_gpu_storage.find(comm_rank); - if (virtual_device_it != virtual_process_gpu_storage.end()) { - LOG_DEBUG("Invoke: ", virtual_device_it->second->to_string()); - /* - using gpu_allreduce_entry = l0_allreduce_typed_entry; - - - schedule = - ctx->scheduler_impl->submit_entry(thread_id, - *ctx->get_thread_topology(thread_id), - virtual_device_it->second, send_entry_buffer, - recv_entry_buffer, - count, - reduction); - */ - } - } - - //if sched is not ready - send NULL - if (schedule) { - LOG_DEBUG("Device group finalized"); - } - return std::unique_ptr(new ccl::gpu_shared_event_impl(std::move(schedule))); -} - -/* alltoall */ -ccl::event thread_device_group_a2a_communicator::alltoall_impl( - const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event thread_device_group_a2a_communicator::alltoall_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -ccl::event thread_device_group_a2a_communicator::alltoallv_impl( - const void* send_buf, - const ccl::vector_class& send_counts, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event thread_device_group_a2a_communicator::alltoallv_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - ccl::vector_class recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -ccl::event thread_device_group_a2a_communicator::broadcast_impl( - void* buf, - size_t count, - ccl::datatype dtype, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce */ -ccl::event thread_device_group_a2a_communicator::reduce_impl( - const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce_scatter */ -ccl::event thread_device_group_a2a_communicator::reduce_scatter_impl( - const void* send_buf, - void* recv_buf, - size_t recv_count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* sparse_allreduce */ -ccl::event thread_device_group_a2a_communicator::sparse_allreduce_impl( - const void* send_ind_buf, - size_t send_ind_count, - const void* send_val_buf, - size_t send_val_count, - void* recv_ind_buf, - size_t recv_ind_count, - void* recv_val_buf, - size_t recv_val_count, - ccl::datatype index_dtype, - ccl::datatype value_dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -COMM_INTERFACE_COLL_INSTANTIATION(thread_device_group_a2a_communicator); -#ifdef CCL_ENABLE_SYCL -SYCL_COMM_INTERFACE_COLL_INSTANTIATION(thread_device_group_a2a_communicator); -#endif /* CCL_ENABLE_SYCL */ diff --git a/src/common/comm/l0/communicator/thread_group/thread_a2a_communicator.hpp b/src/common/comm/l0/communicator/thread_group/thread_a2a_communicator.hpp deleted file mode 100644 index 2454bc670..000000000 --- a/src/common/comm/l0/communicator/thread_group/thread_a2a_communicator.hpp +++ /dev/null @@ -1,66 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/l0/communicator/typed_base_communicator.hpp" - -namespace native { -struct thread_group_context; -} - -class thread_device_group_a2a_communicator - : public typed_base_communicator { -public: - using base_t = typed_base_communicator; - - using communication_devices_t = native::device_variant_t, - native::ccl_numa_proxy*/>; - - thread_device_group_a2a_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& ctx, - size_t thread_idx, - size_t proces_idx, - const ccl::comm_split_attr& attr); - - void visit(ccl::gpu_comm_attr& comm_attr) override; - - ccl::event barrier(const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) override; - - COMM_IMPL_DECLARATION - COMM_IMPL_CLASS_DECLARATION - COMM_IMPL_SPARSE_DECLARATION - COMM_IMPL_SPARSE_CLASS_DECLARATION - - communication_devices_t& get_communication_device() { - return communication_device; - } - -private: - std::shared_ptr ctx; - communication_devices_t communication_device; -}; - -//size_t group_size() const; diff --git a/src/common/comm/l0/communicator/thread_group/thread_a2a_communicator_impl.hpp b/src/common/comm/l0/communicator/thread_group/thread_a2a_communicator_impl.hpp deleted file mode 100644 index 54bf28168..000000000 --- a/src/common/comm/l0/communicator/thread_group/thread_a2a_communicator_impl.hpp +++ /dev/null @@ -1,329 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/l0/communicator/thread_group/thread_a2a_communicator.hpp" -#include "common/comm/l0/communicator/typed_base_communicator_impl.hpp" - -#include "common/comm/l0/devices/devices_declaration.hpp" -#include "common/comm/l0/device_community.hpp" -#include "common/comm/l0/context/thread_group_ctx.hpp" -#include "common/comm/l0/scheduler/thread_group_scheduler.hpp" -#include "common/event/impls/gpu_event.hpp" -#include "common/comm/l0/communicator/thread_group/thread_communicator_utils.hpp" - -/* allgatherv */ -template -ccl::event thread_device_group_a2a_communicator::allgatherv_impl( - const buffer_type& send_buf, - size_t send_count, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event thread_device_group_a2a_communicator::allgatherv_impl( - const buffer_type* send_buf, - size_t send_count, - ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event thread_device_group_a2a_communicator::allgatherv_impl( - const buffer_type* send_buf, - size_t send_count, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event thread_device_group_a2a_communicator::allgatherv_impl( - const buffer_type& send_buf, - size_t send_count, - ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* allreduce */ -template -ccl::event thread_device_group_a2a_communicator::allreduce_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event thread_device_group_a2a_communicator::allreduce_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - return allreduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - ccl::native_type_info::dtype, - reduction, - stream, - attr, - deps); -} - -/* alltoall */ -template -ccl::event thread_device_group_a2a_communicator::alltoall_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event thread_device_group_a2a_communicator::alltoall_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event thread_device_group_a2a_communicator::alltoall_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event thread_device_group_a2a_communicator::alltoall_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class>& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -template -ccl::event thread_device_group_a2a_communicator::alltoallv_impl( - const buffer_type* send_buf, - const ccl::vector_class& send_counts, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event thread_device_group_a2a_communicator::alltoallv_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event thread_device_group_a2a_communicator::alltoallv_impl( - const buffer_type& send_buf, - const ccl::vector_class& send_counts, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event thread_device_group_a2a_communicator::alltoallv_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -template -ccl::event thread_device_group_a2a_communicator::broadcast_impl( - buffer_type* buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event thread_device_group_a2a_communicator::broadcast_impl( - buffer_type& buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce */ -template -ccl::event thread_device_group_a2a_communicator::reduce_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event thread_device_group_a2a_communicator::reduce_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -/* reduce_scatter */ -template -ccl::event thread_device_group_a2a_communicator::reduce_scatter_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event thread_device_group_a2a_communicator::reduce_scatter_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* sparse_allreduce */ -template -ccl::event thread_device_group_a2a_communicator::sparse_allreduce_impl( - const index_buffer_type* send_ind_buf, - size_t send_ind_count, - const value_buffer_type* send_val_buf, - size_t send_val_count, - index_buffer_type* recv_ind_buf, - size_t recv_ind_count, - value_buffer_type* recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event thread_device_group_a2a_communicator::sparse_allreduce_impl( - const index_buffer_container_type& send_ind_buf, - size_t send_ind_count, - const value_buffer_container_type& send_val_buf, - size_t send_val_count, - index_buffer_container_type& recv_ind_buf, - size_t recv_ind_count, - value_buffer_container_type& recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} diff --git a/src/common/comm/l0/communicator/thread_group/thread_communicator_utils.hpp b/src/common/comm/l0/communicator/thread_group/thread_communicator_utils.hpp deleted file mode 100644 index 1e8123435..000000000 --- a/src/common/comm/l0/communicator/thread_group/thread_communicator_utils.hpp +++ /dev/null @@ -1,71 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/l0/devices/devices_declaration.hpp" -#include "common/comm/l0/device_community.hpp" - -template - class algorithm> -struct communication_thread_device_expander { - template - void operator()(native::device_t_ptr& comm_device, - std::shared_ptr& ctx, - typename native::device_community_container::element_type community, - size_t thread_id, - Args&&... args) { - if (comm_device) { - LOG_DEBUG("Invoke: ", comm_device->to_string()); - - using gpu_entry = algorithm; - - schedule = ctx->scheduler_impl - ->submit_entry( - thread_id, *community, comm_device, std::forward(args)...); - } - } - - std::shared_ptr schedule; -}; - -template - class algorithm, - class... Args> -std::unique_ptr do_collective_op( - native::device_variant_t& - communication_device, - std::shared_ptr& ctx, - typename native::device_community_container::element_type community, - size_t thread_id, - native::ccl_driver_context_ptr native_context, - Args&&... args) { - communication_thread_device_expander expander; - ccl_tuple_for_each_args(communication_device, - expander, - ctx, - community, - thread_id, - native_context, - std::forward(args)...); - if (expander.schedule) { - LOG_DEBUG("Device group finalized"); - } - return std::unique_ptr( - new ccl::gpu_shared_event_impl(std::move(expander.schedule))); -} diff --git a/src/common/comm/l0/communicator/thread_group/thread_ring_communicator.cpp b/src/common/comm/l0/communicator/thread_group/thread_ring_communicator.cpp deleted file mode 100644 index a64625156..000000000 --- a/src/common/comm/l0/communicator/thread_group/thread_ring_communicator.cpp +++ /dev/null @@ -1,475 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" -#include "common/comm/l0/communicator/thread_group/thread_ring_communicator_impl.hpp" -#include "common/comm/l0/gpu_comm_attr.hpp" -#include "common/comm/l0/context/process_group_ctx.hpp" - -using namespace ccl; - -thread_device_group_ring_communicator::thread_device_group_ring_communicator( - ccl::unified_device_type&& device, - ccl::unified_context_type&& ctx, - size_t thread_idx, - size_t process_idx, - const ccl::comm_split_attr& attr) - : base_t(std::move(device), std::move(ctx), thread_idx, process_idx, /*comm_attr, */ attr) { -} - -void thread_device_group_ring_communicator::visit(ccl::gpu_comm_attr& comm_attr) { - auto process_ctx = comm_attr.get_process_context(); - auto thread_ctx = process_ctx->get_thread_context(process_id); - auto device_ctx = thread_ctx->get_device_group_ctx(thread_id); - (void)device_ctx; - - ctx = thread_ctx; - - //get rank & size - auto topology = ctx->get_thread_topology(thread_id); - this->initialize_comm_addr(get_device_path(), topology); - - this->set_comm_group_id(comm_attr.get_unique_id()); -} - -/* -size_t thread_device_group_ring_communicator::group_size() const -{ - return get_device_count() + - / * get_device_count>() + Will add further* / - get_device_count(); - -} -*/ -ccl::event thread_device_group_ring_communicator::barrier( - const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented yet"); -} - -/* allgatherv */ -ccl::event thread_device_group_ring_communicator::allgatherv_impl( - const void* send_buf, - size_t send_count, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - size_t ring_index = 0; - LOG_DEBUG("communicator for device idx: ", - get_device_path(), - ", rank idx: ", - comm_rank, - ", ring_index :", - ring_index); - - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - send_count * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - ccl_buffer recv_entry_buffer( - &recv_buf, send_count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(ring_index); - - coll_param_gpu params(ccl_coll_allgatherv, dtype); - - return do_collective_op( - communication_device, - ctx, - community, - thread_id, - this->get_native_context(), - send_entry_buffer, - send_count, - recv_entry_buffer, - recv_counts.data(), - params, - stream); -} -ccl::event thread_device_group_ring_communicator::allgatherv_impl( - const void* send_buf, - size_t send_count, - const ccl::vector_class& recv_bufs, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* allreduce */ -ccl::event thread_device_group_ring_communicator::allreduce_impl( - const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - size_t ring_index = 0; - LOG_DEBUG("communicator for device idx: ", - get_device_path(), - ", rank idx: ", - comm_rank, - ", ring_index :", - ring_index); - - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - count * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - ccl_buffer recv_entry_buffer( - &recv_buf, count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(ring_index); - - const coll_param_gpu params(ccl_coll_allreduce, dtype, reduction); - - return do_collective_op( - communication_device, - ctx, - community, - thread_id, - this->get_native_context(), - send_entry_buffer, - recv_entry_buffer, - count, - params, - stream); -} - -/* alltoall */ -ccl::event thread_device_group_ring_communicator::alltoall_impl( - const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -ccl::event thread_device_group_ring_communicator::alltoall_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -ccl::event thread_device_group_ring_communicator::alltoallv_impl( - const void* send_buf, - const ccl::vector_class& send_counts, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - size_t ring_index = 0; - LOG_DEBUG("communicator for device idx: ", - get_device_path(), - ", rank idx: ", - comm_rank, - ", ring_index :", - ring_index); - - size_t total_send_counts = std::accumulate(std::begin(send_counts), std::end(send_counts), 0); - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - total_send_counts * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - - size_t total_recv_counts = std::accumulate(std::begin(recv_counts), std::end(recv_counts), 0); - ccl_buffer recv_entry_buffer( - &recv_buf, total_recv_counts * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(ring_index); - - coll_param_gpu params(ccl_coll_alltoallv, dtype); - - return do_collective_op( - communication_device, - ctx, - community, - thread_id, - this->get_native_context(), - send_entry_buffer, - send_counts.data(), - total_send_counts, - recv_entry_buffer, - recv_counts.data(), - total_recv_counts, - params, - stream); -} -ccl::event thread_device_group_ring_communicator::alltoallv_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - ccl::vector_class recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -ccl::event thread_device_group_ring_communicator::broadcast_impl( - void* buf, - size_t count, - ccl::datatype dtype, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - size_t ring_index = 0; - LOG_DEBUG("communicator for device idx: ", - get_device_path(), - ", rank idx: ", - comm_rank, - ", ring_index :", - ring_index); - - //TODO make const! - ccl_buffer entry_buffer( - &buf, count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(ring_index); - - coll_param_gpu params(ccl_coll_bcast, dtype); - - return do_collective_op(communication_device, - ctx, - community, - thread_id, - this->get_native_context(), - entry_buffer, - count, - root, - params, - stream); -} - -/* reduce */ -ccl::event thread_device_group_ring_communicator::reduce_impl( - const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - size_t ring_index = 0; - LOG_DEBUG("communicator for device idx: ", - get_device_path(), - ", rank idx: ", - comm_rank, - ", ring_index :", - ring_index); - - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - count * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - ccl_buffer recv_entry_buffer( - &recv_buf, count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(ring_index); - - coll_param_gpu params(ccl_coll_reduce, dtype, reduction); - - return do_collective_op(communication_device, - ctx, - community, - thread_id, - this->get_native_context(), - send_entry_buffer, - recv_entry_buffer, - count, - reduction, - root, - params, - stream); -} - -/* reduce_scatter */ -ccl::event thread_device_group_ring_communicator::reduce_scatter_impl( - const void* send_buf, - void* recv_buf, - size_t recv_count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - using namespace native; - - static constexpr ccl::group_split_type group_id = base_t::topology_type(); - static constexpr ccl::device_topology_type class_id = base_t::topology_class(); - - if (!is_ready()) { - throw ccl::exception(std::string( - "Device communicator for group_id: " + ::to_string(group_id) + - " is not ready yet. Not all сommunicators are created in group. Please create them before usage")); - } - - int comm_rank = rank(); - size_t ring_index = 0; - LOG_DEBUG("communicator for device idx: ", - get_device_path(), - ", rank idx: ", - comm_rank, - ", ring_index :", - ring_index); - - //TODO make const! - ccl_buffer send_entry_buffer(const_cast(&send_buf), - recv_count * ccl::get_datatype_size(dtype), - 0, - ccl_buffer_type::INDIRECT); - ccl_buffer recv_entry_buffer( - &recv_buf, recv_count * ccl::get_datatype_size(dtype), 0, ccl_buffer_type::INDIRECT); - - using community_t = typename device_community_container::element_type; - community_t community = device_community_impl.get_topology(ring_index); - - coll_param_gpu params(ccl_coll_reduce_scatter, dtype, reduction); - - return do_collective_op( - communication_device, - ctx, - community, - thread_id, - this->get_native_context(), - send_entry_buffer, - recv_entry_buffer, - recv_count, - params, - stream); -} - -/* sparse_allreduce */ -ccl::event thread_device_group_ring_communicator::sparse_allreduce_impl( - const void* send_ind_buf, - size_t send_ind_count, - const void* send_val_buf, - size_t send_val_count, - void* recv_ind_buf, - size_t recv_ind_count, - void* recv_val_buf, - size_t recv_val_count, - ccl::datatype index_dtype, - ccl::datatype value_dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -COMM_INTERFACE_COLL_INSTANTIATION(thread_device_group_ring_communicator); -#ifdef CCL_ENABLE_SYCL -SYCL_COMM_INTERFACE_COLL_INSTANTIATION(thread_device_group_ring_communicator); -#endif /* CCL_ENABLE_SYCL */ diff --git a/src/common/comm/l0/communicator/thread_group/thread_ring_communicator.hpp b/src/common/comm/l0/communicator/thread_group/thread_ring_communicator.hpp deleted file mode 100644 index 04b93440a..000000000 --- a/src/common/comm/l0/communicator/thread_group/thread_ring_communicator.hpp +++ /dev/null @@ -1,66 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "common/comm/l0/communicator/typed_base_communicator.hpp" -#include "common/comm/usm_visitor/usm_visitors.hpp" - -namespace native { -struct thread_group_context; -} - -class thread_device_group_ring_communicator - : public typed_base_communicator { -public: - using base_t = typed_base_communicator; - - using communication_devices_t = native::device_variant_t, - native::ccl_numa_proxy*/>; - - thread_device_group_ring_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& ctx, - size_t thread_idx, - size_t process_idx, - const ccl::comm_split_attr& attr); - - void visit(ccl::gpu_comm_attr& comm_attr) override; - - ccl::event barrier(const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) override; - - COMM_IMPL_DECLARATION - COMM_IMPL_CLASS_DECLARATION - COMM_IMPL_SPARSE_DECLARATION - COMM_IMPL_SPARSE_CLASS_DECLARATION - - communication_devices_t& get_communication_device() { - return communication_device; - } - -private: - std::shared_ptr ctx; - communication_devices_t communication_device; -}; diff --git a/src/common/comm/l0/communicator/thread_group/thread_ring_communicator_impl.hpp b/src/common/comm/l0/communicator/thread_group/thread_ring_communicator_impl.hpp deleted file mode 100644 index c2b646d4f..000000000 --- a/src/common/comm/l0/communicator/thread_group/thread_ring_communicator_impl.hpp +++ /dev/null @@ -1,362 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "common/comm/l0/communicator/thread_group/thread_ring_communicator.hpp" -#include "common/comm/l0/communicator/typed_base_communicator_impl.hpp" - -#include "common/comm/l0/devices/devices_declaration.hpp" -#include "common/comm/l0/device_community.hpp" -#include "common/comm/l0/context/thread_group_ctx.hpp" -// TODO: try to move to cpp file as we now only reference l0_entries from there -#include "common/comm/l0/scheduler/thread_group_scheduler.hpp" -#include "common/event/impls/gpu_event.hpp" -#include "common/comm/l0/communicator/thread_group/thread_communicator_utils.hpp" - -/* allgatherv */ -template -ccl::event thread_device_group_ring_communicator::allgatherv_impl( - const buffer_type& send_buf, - size_t send_count, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event thread_device_group_ring_communicator::allgatherv_impl( - const buffer_type* send_buf, - size_t send_count, - ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event thread_device_group_ring_communicator::allgatherv_impl( - const buffer_type& send_buf, - size_t send_count, - ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event thread_device_group_ring_communicator::allgatherv_impl( - const buffer_type* send_buf, - size_t send_count, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - return allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - ccl::native_type_info::dtype, - stream, - attr, - deps); -} - -/* allreduce */ -template -ccl::event thread_device_group_ring_communicator::allreduce_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - return allreduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - ccl::native_type_info::dtype, - reduction, - stream, - attr, - deps); -} - -template -ccl::event thread_device_group_ring_communicator::allreduce_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoall */ -template -ccl::event thread_device_group_ring_communicator::alltoall_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event thread_device_group_ring_communicator::alltoall_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event thread_device_group_ring_communicator::alltoall_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event thread_device_group_ring_communicator::alltoall_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class>& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -template -ccl::event thread_device_group_ring_communicator::alltoallv_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event thread_device_group_ring_communicator::alltoallv_impl( - const buffer_type& send_buf, - const ccl::vector_class& send_counts, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} -template -ccl::event thread_device_group_ring_communicator::alltoallv_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event thread_device_group_ring_communicator::alltoallv_impl( - const buffer_type* send_buf, - const ccl::vector_class& send_counts, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - return alltoallv_impl(static_cast(send_buf), - send_counts, - static_cast(recv_buf), - recv_counts, - ccl::native_type_info::dtype, - stream, - attr, - deps); -} - -/* bcast */ -template -ccl::event thread_device_group_ring_communicator::broadcast_impl( - buffer_type* buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - return broadcast_impl(static_cast(buf), - count, - ccl::native_type_info::dtype, - root, - stream, - attr, - deps); -} - -template -ccl::event thread_device_group_ring_communicator::broadcast_impl( - buffer_type& buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce */ -template -ccl::event thread_device_group_ring_communicator::reduce_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - return reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - ccl::native_type_info::dtype, - reduction, - root, - stream, - attr, - deps); -} - -template -ccl::event thread_device_group_ring_communicator::reduce_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* reduce_scatter */ -template -ccl::event thread_device_group_ring_communicator::reduce_scatter_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - return reduce_scatter_impl(static_cast(send_buf), - static_cast(recv_buf), - recv_count, - ccl::native_type_info::dtype, - reduction, - stream, - attr, - deps); -} - -template -ccl::event thread_device_group_ring_communicator::reduce_scatter_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* sparse_allreduce */ -template -ccl::event thread_device_group_ring_communicator::sparse_allreduce_impl( - const index_buffer_type* send_ind_buf, - size_t send_ind_count, - const value_buffer_type* send_val_buf, - size_t send_val_count, - index_buffer_type* recv_ind_buf, - size_t recv_ind_count, - value_buffer_type* recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event thread_device_group_ring_communicator::sparse_allreduce_impl( - const index_buffer_container_type& send_ind_buf, - size_t send_ind_count, - const value_buffer_container_type& send_val_buf, - size_t send_val_count, - index_buffer_container_type& recv_ind_buf, - size_t recv_ind_count, - value_buffer_container_type& recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} diff --git a/src/common/comm/l0/communicator/typed_base_communicator.hpp b/src/common/comm/l0/communicator/typed_base_communicator.hpp index 6249e4a42..4bed58c10 100644 --- a/src/common/comm/l0/communicator/typed_base_communicator.hpp +++ b/src/common/comm/l0/communicator/typed_base_communicator.hpp @@ -79,7 +79,7 @@ class typed_base_communicator : public base_communicator { COMM_INTERFACE_COLL_METHODS(DEFINITION); #ifdef CCL_ENABLE_SYCL SYCL_COMM_INTERFACE_COLL_METHODS(DEFINITION); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL // Device community interface /* template diff --git a/src/common/comm/l0/communicator/typed_base_communicator_impl.hpp b/src/common/comm/l0/communicator/typed_base_communicator_impl.hpp index 26e97ecd4..1f9d9e0fc 100644 --- a/src/common/comm/l0/communicator/typed_base_communicator_impl.hpp +++ b/src/common/comm/l0/communicator/typed_base_communicator_impl.hpp @@ -41,18 +41,18 @@ typed_base_communicator::typed_base_communicator( process_idx /*, comm_attr*/, attr) { try { - LOG_INFO("sheduled for create, device id: ", - device.get_id(), - ", thread_id: ", - thread_idx, - ", process id:", - process_idx); + LOG_DEBUG("sheduled for create, device id: ", + device.get_id(), + ", thread_id: ", + thread_idx, + ", process id:", + process_idx); } catch (...) { - LOG_INFO("sheduled for create single device communicator , thread_id: ", - thread_idx, - ", process id:", - process_idx); + LOG_DEBUG("sheduled for create single device communicator , thread_id: ", + thread_idx, + ", process id:", + process_idx); } } @@ -83,18 +83,18 @@ void typed_base_communicator::initialize_comm_addr( comm_rank = registered_addr.comm_rank; comm_size = registered_addr.comm_size; - LOG_INFO("Communicator finalized. Rank (", - comm_rank, - "/", - comm_size, - ") on {dev: ", - device_id, - ", thr: ", - thread_id, - ", proc: ", - process_id, - "} on device:\n", - p.to_string()); + LOG_DEBUG("Communicator finalized. Rank (", + comm_rank, + "/", + comm_size, + ") on {dev: ", + device_id, + ", thr: ", + thread_id, + ", proc: ", + process_id, + "} on device:\n", + p.to_string()); } template diff --git a/src/common/comm/l0/context/process_group_ctx.cpp b/src/common/comm/l0/context/process_group_ctx.cpp index 7c7d00c5d..e30ac8448 100644 --- a/src/common/comm/l0/context/process_group_ctx.cpp +++ b/src/common/comm/l0/context/process_group_ctx.cpp @@ -101,11 +101,11 @@ bool process_group_context::sync_barrier(const ccl::device_indices_type& thread_ const ccl::process_device_indices_type& thread_indices = thread_group_ctx->get_thread_group_device_indices(); - LOG_INFO("Process (", - process_idx, - "/", - process_count, - ") reached process group communicator barrier"); + LOG_DEBUG("Process (", + process_idx, + "/", + process_count, + ") reached process group communicator barrier"); ccl::device_indices_type process_aggregated_device_indices = std::accumulate(thread_indices.begin(), @@ -120,22 +120,23 @@ bool process_group_context::sync_barrier(const ccl::device_indices_type& thread_ //iterate over allied processes(on the same host) //find possible IPC device with P2P capability - LOG_INFO("Process (", process_idx, "/", process_count, ") starts hardware topologies creation"); + LOG_DEBUG( + "Process (", process_idx, "/", process_count, ") starts hardware topologies creation"); /* TODO -S- enable it later cluster_group_device_creator ally_process_topology( process_idx, process_count, *this, *gpu_device_storage); */ { - LOG_INFO("TODO - Limitation on node processes considered!!!\n" - "process_idx: ", - process_idx, - ", process_count: ", - process_count, - ", cluster_device_rank_offset: ", - cluster_device_rank_offset, - ", cluster_device_size: ", - cluster_device_size); + LOG_DEBUG("TODO - Limitation on node processes considered!!!\n" + "process_idx: ", + process_idx, + ", process_count: ", + process_count, + ", cluster_device_rank_offset: ", + cluster_device_rank_offset, + ", cluster_device_size: ", + cluster_device_size); //TODO -S- Temporary solution for IPC topology allied_process_group_ring_topology ally_process_topology(process_idx, process_count, @@ -164,32 +165,32 @@ bool process_group_context::sync_barrier(const ccl::device_indices_type& thread_ { //TODO Create A2A topology - LOG_INFO("Process Context Topologies A2A TODO"); + LOG_DEBUG("Process Context Topologies A2A TODO"); } // create scheduler - LOG_INFO("Create scheduler"); + LOG_DEBUG("Create scheduler"); scheduler_impl.reset(new allied_process_group_scheduler( process_count, comm_addr.thread_count, ccl_communicator, *gpu_device_storage)); // initialize observer contexts - LOG_INFO("Sync communicator barrier"); + LOG_DEBUG("Sync communicator barrier"); ccl_communicator->barrier({}, ccl::default_barrier_attr); - LOG_INFO("initialize IPC context"); + LOG_DEBUG("initialize IPC context"); get_ipc_ctx().initialize_ctx(ccl_communicator); - LOG_INFO("initialize SCALE-OUT context"); + LOG_DEBUG("initialize SCALE-OUT context"); get_scaleout_ctx().initialize_ctx(ccl_communicator); // dump topology std::stringstream out; dump_process_topologies(out); - LOG_INFO("Thread (MASTER): ", - comm_addr.thread_idx, - " finalized process topology creation:\n", - out.str()); + LOG_DEBUG("Thread (MASTER): ", + comm_addr.thread_idx, + " finalized process topology creation:\n", + out.str()); return true; } @@ -218,7 +219,7 @@ std::shared_ptr process_group_context::get_communicator( bool process_group_context::build_cluster_affinity_table( const ccl::device_indices_type& process_aggregated_device_indices) { - LOG_INFO("Node: ", my_host_name, " start build affinity table for process idx: ", process_idx); + LOG_DEBUG("Node: ", my_host_name, " start build affinity table for process idx: ", process_idx); //create cluster mask affinity //1) request hostname & device indices count @@ -289,12 +290,12 @@ bool process_group_context::build_cluster_affinity_table( cluster_device_size = std::accumulate( my_rank_mask_size_it, receive_process_indices_sizes.end(), cluster_device_rank_offset); } - LOG_INFO("Process idx: ", - ccl_communicator->rank(), - ", device rank offset: ", - cluster_device_rank_offset, - ", total device count: ", - cluster_device_size); + LOG_DEBUG("Process idx: ", + ccl_communicator->rank(), + ", device rank offset: ", + cluster_device_rank_offset, + ", total device count: ", + cluster_device_size); //TODO -S- temporary END //Serialize own devices path data @@ -354,8 +355,9 @@ bool process_group_context::build_cluster_affinity_table( } catch (std::exception& ex) { LOG_ERROR("Cannot submit requests: ", ex.what()); - LOG_INFO("Memory required for hostnames size: ", total_hostname_size, " bytes"); - LOG_INFO("Memory required for device indices size: ", total_device_indices_count, " count"); + LOG_DEBUG("Memory required for hostnames size: ", total_hostname_size, " bytes"); + LOG_DEBUG( + "Memory required for device indices size: ", total_device_indices_count, " count"); abort(); } @@ -421,7 +423,7 @@ bool process_group_context::build_cluster_affinity_table( { std::stringstream ss; process_group_context::dump_cluster_affinity_indices(cluster_gpu_indices, ss); - LOG_INFO("Cluster device affinity indices table: ", ss.str()); + LOG_DEBUG("Cluster device affinity indices table: ", ss.str()); } return true; @@ -671,10 +673,10 @@ void process_group_context::collect_cluster_colored_plain_graphs( detail::global_sorted_colored_plain_graphs& out_global_graphs) { using namespace detail::serialize; - LOG_INFO("Collect cluster colored plain graphs, process initiator: ", - process_idx, - ", graphs count: ", - detail::to_string(send_graph)); + LOG_DEBUG("Collect cluster colored plain graphs, process initiator: ", + process_idx, + ", graphs count: ", + detail::to_string(send_graph)); // serialize current process graph list into bytes device_path_serializable::raw_data_t my_serialized_graph = diff --git a/src/common/comm/l0/context/scale/base/base_session.cpp b/src/common/comm/l0/context/scale/base/base_session.cpp index 7e26961bc..f19ba40ca 100644 --- a/src/common/comm/l0/context/scale/base/base_session.cpp +++ b/src/common/comm/l0/context/scale/base/base_session.cpp @@ -63,7 +63,7 @@ void context_descr::init(size_t staged_buffer_elem_count, ze_device_mem_alloc_desc_t mem_descr = ccl_device::get_default_mem_alloc_desc(); // create total aggregated memory in device context - mem_descr.flags = ZE_DEVICE_MEM_ALLOC_FLAG_BIAS_UNCACHED; + mem_descr.flags = 0; dev_mem_consumer = device.template alloc_memory_ptr( (staged_buffer_elem_count * observer_domain_count) * ccl::get_datatype_size(kernel_params.get_datatype()), diff --git a/src/common/comm/l0/context/scale/ipc/ipc_ctx_impl.hpp b/src/common/comm/l0/context/scale/ipc/ipc_ctx_impl.hpp index 35fdafe5d..434ad7024 100644 --- a/src/common/comm/l0/context/scale/ipc/ipc_ctx_impl.hpp +++ b/src/common/comm/l0/context/scale/ipc/ipc_ctx_impl.hpp @@ -43,18 +43,18 @@ void ipc_ctx::initialize_ctx( (void)communicator; //send_stop(); stop.store(false); - LOG_INFO("IPC context Initialized for mpi rank: (", - std::to_string(communicator->rank()), - "/", - std::to_string(communicator->size()), - ")"); + LOG_DEBUG("IPC context Initialized for mpi rank: (", + std::to_string(communicator->rank()), + "/", + std::to_string(communicator->size()), + ")"); } template template void ipc_ctx::register_observer_impl(size_t rank_addr, observer_t* observer_ptr) { - LOG_INFO( + LOG_DEBUG( "device rank addr: ", std::to_string(rank_addr), ", device: ", observer_ptr->to_string()); observer::container_t>& container = scaling_ctx_base_t::template get_types_container, class_id>( @@ -110,10 +110,10 @@ template template void ipc_ctx::register_observer_impl(size_t rank_addr, ccl_ipc_gpu_comm* observer_ptr) { - LOG_INFO("DST device rank addr: ", - std::to_string(rank_addr), - ", DST device: ", - observer_ptr->to_string()); + LOG_DEBUG("DST device rank addr: ", + std::to_string(rank_addr), + ", DST device: ", + observer_ptr->to_string()); observer::container_t& container = scaling_ctx_base_t::template get_types_container(observables); auto cont_it = container.find(observer_ptr); @@ -151,7 +151,7 @@ void ipc_ctx::register_observer_impl(size_t rank_addr, //Start IPC server for each DST device for listening incoming conections from SRC devices std::string addr = create_ipc_addr_for_rank(rank_addr); - LOG_INFO("Start IPC listener for device:\n", observer_ptr->to_string(), "\nAddr: ", addr); + LOG_DEBUG("Start IPC listener for device:\n", observer_ptr->to_string(), "\nAddr: ", addr); try { observer_ptr->start(addr); auto it = listener_thread_map.find(observer_ptr); @@ -163,7 +163,7 @@ void ipc_ctx::register_observer_impl(size_t rank_addr, listener_thread_map.emplace( observer_ptr, new std::thread(&ipc_ctx::listener, this, observer_ptr)); - LOG_INFO("Listener thread started on addr: ", addr); + LOG_DEBUG("Listener thread started on addr: ", addr); } catch (const std::exception& ex) { LOG_ERROR("Cannot start IPC listener on: ", @@ -206,7 +206,7 @@ void ipc_ctx::send_stop() { template void ipc_ctx::listener(ccl_ipc_gpu_comm* listener_device) { - LOG_INFO("Start IPC context listener worker, Listener device: ", listener_device->to_string()); + LOG_DEBUG("Start IPC context listener worker, Listener device: ", listener_device->to_string()); // TODO ring only, peer-to-peer case: one SRC connects to one DST std::unique_ptr incoming_connection; @@ -214,24 +214,24 @@ void ipc_ctx::listener(ccl_ipc_gpu_comm* listener_device) { try { auto incoming = listener_device->process_connection(); if (incoming) { - LOG_INFO("Got connection on device: ", listener_device->to_string()); + LOG_DEBUG("Got connection on device: ", listener_device->to_string()); incoming_connection = std::move(incoming); } } catch (const std::exception& ex) { - LOG_INFO("Stop requested at serving connection stage"); + LOG_DEBUG("Stop requested at serving connection stage"); return; } if (stop.load()) { - LOG_INFO("Stop requested at serving connection stage"); + LOG_DEBUG("Stop requested at serving connection stage"); return; } } // processing incoming data from connected clients - LOG_INFO("Start IPC context processing worker, Listener device: ", - listener_device->to_string()); + LOG_DEBUG("Start IPC context processing worker, Listener device: ", + listener_device->to_string()); while (!stop.load()) { //TODO choose std::list decltype(processing_queue) sessions_to_execute; diff --git a/src/common/comm/l0/context/scale/numa/numa_ctx_impl.hpp b/src/common/comm/l0/context/scale/numa/numa_ctx_impl.hpp index 2e3d3021e..9c4deb598 100644 --- a/src/common/comm/l0/context/scale/numa/numa_ctx_impl.hpp +++ b/src/common/comm/l0/context/scale/numa/numa_ctx_impl.hpp @@ -27,7 +27,7 @@ template template void numa_ctx::register_observer_impl(size_t rank_addr, observer_t* observer_ptr) { - LOG_INFO( + LOG_DEBUG( "device rank addr: ", std::to_string(rank_addr), ", device: ", observer_ptr->to_string()); observer::container_t>& container = scaling_ctx_base_t::template get_types_container, class_id>( diff --git a/src/common/comm/l0/context/scale/scale_out/scale_out_ctx_impl.hpp b/src/common/comm/l0/context/scale/scale_out/scale_out_ctx_impl.hpp index 39aed531f..6cbcbb5ab 100644 --- a/src/common/comm/l0/context/scale/scale_out/scale_out_ctx_impl.hpp +++ b/src/common/comm/l0/context/scale/scale_out/scale_out_ctx_impl.hpp @@ -28,11 +28,11 @@ void scale_out_ctx::initialize_ctx( std::shared_ptr communicator) { process_communicator = communicator; - LOG_INFO("SCALE-OUT context Initialized for mpi rank: (", - std::to_string(communicator->rank()), - "/", - std::to_string(communicator->size()), - ")"); + LOG_DEBUG("SCALE-OUT context Initialized for mpi rank: (", + std::to_string(communicator->rank()), + "/", + std::to_string(communicator->size()), + ")"); } // observer_ptr interface implementations @@ -40,10 +40,10 @@ template template void scale_out_ctx::register_observer_impl(size_t rank_addr, observer_t* observer_ptr) { - LOG_INFO("scaleout device rank addr: ", - std::to_string(rank_addr), - ", device: ", - observer_ptr->to_string()); + LOG_DEBUG("scaleout device rank addr: ", + std::to_string(rank_addr), + ", device: ", + observer_ptr->to_string()); observer::container_t>& container = scaling_ctx_base_t::template get_types_container, class_id>( observables); diff --git a/src/common/comm/l0/context/thread_group_ctx.cpp b/src/common/comm/l0/context/thread_group_ctx.cpp index a5e7aa16a..24b42ee5a 100644 --- a/src/common/comm/l0/context/thread_group_ctx.cpp +++ b/src/common/comm/l0/context/thread_group_ctx.cpp @@ -64,7 +64,7 @@ bool thread_group_context::sync_barrier(const ccl::device_indices_type& device_i } //Current thread finalize communicator creation - LOG_INFO("Thread ", comm_addr.to_string(), " starts hardware topologies creation"); + LOG_DEBUG("Thread ", comm_addr.to_string(), " starts hardware topologies creation"); { std::stringstream ss; thread_group_ring_topology top(*this, devices); @@ -78,16 +78,16 @@ bool thread_group_context::sync_barrier(const ccl::device_indices_type& device_i { //TODO Create A2A topology - LOG_INFO("Thread Context Topologies A2A TODO"); + LOG_DEBUG("Thread Context Topologies A2A TODO"); } { std::stringstream out; dump_thread_topologies(out); - LOG_INFO("Thread (MASTER): ", - comm_addr.to_string(), - " finalized thread topology creation\n", - out.str()); + LOG_DEBUG("Thread (MASTER): ", + comm_addr.to_string(), + " finalized thread topology creation\n", + out.str()); } // create scheduler in final step scheduler_impl.reset(new thread_group_scheduler(comm_addr.thread_count)); diff --git a/src/common/comm/l0/device_types.hpp b/src/common/comm/l0/device_types.hpp index 31c0e7aab..6da85d8b4 100644 --- a/src/common/comm/l0/device_types.hpp +++ b/src/common/comm/l0/device_types.hpp @@ -65,7 +65,9 @@ enum class gpu_types : size_t { MAX_TYPE }; -using gpu_type_names = ::utils::enum_to_str(gpu_types::MAX_TYPE)>; +using gpu_type_names = + ::utils::enum_to_str::type>( + gpu_types::MAX_TYPE)>; inline std::string to_string(gpu_types type) { return gpu_type_names({ "REAL_GPU", "VIRTUAL_GPU", diff --git a/src/common/comm/l0/devices/ccl_gpu_base_comm.cpp b/src/common/comm/l0/devices/ccl_gpu_base_comm.cpp index f8a7e99c4..4d1d70a00 100644 --- a/src/common/comm/l0/devices/ccl_gpu_base_comm.cpp +++ b/src/common/comm/l0/devices/ccl_gpu_base_comm.cpp @@ -43,7 +43,8 @@ void cmd_list_proxy_base::close_and_execute(std::shared_ptr ctx, } auto& cmd_queue = device.get_cmd_queue(ccl_device::get_default_queue_desc(), ctx); - LOG_INFO("Execute list:", cmd_list.get(), ", queue: ", cmd_queue.get(), ", go to submit entry"); + LOG_DEBUG( + "Execute list:", cmd_list.get(), ", queue: ", cmd_queue.get(), ", go to submit entry"); res = zeCommandQueueExecuteCommandLists(cmd_queue.get(), 1, get_ptr(), fence); if (res != ZE_RESULT_SUCCESS) { throw ccl::exception(std::string("cannot execute command list, error: ") + diff --git a/src/common/comm/l0/devices/ccl_gpu_comm.cpp b/src/common/comm/l0/devices/ccl_gpu_comm.cpp index 4440efd90..9b8e94c46 100644 --- a/src/common/comm/l0/devices/ccl_gpu_comm.cpp +++ b/src/common/comm/l0/devices/ccl_gpu_comm.cpp @@ -18,7 +18,7 @@ #include #include "common/comm/l0/devices/ccl_gpu_comm.hpp" #include "sched/sched.hpp" -#include "sched/entry/l0/l0_entry.hpp" +// #include "sched/entry/l0/l0_entry.hpp" #include "common/comm/l0/modules/specific_modules_source_data.hpp" namespace native { diff --git a/src/common/comm/l0/devices/ccl_ipc_gpu_comm.cpp b/src/common/comm/l0/devices/ccl_ipc_gpu_comm.cpp index aee9c0f2f..716cd4b1d 100644 --- a/src/common/comm/l0/devices/ccl_ipc_gpu_comm.cpp +++ b/src/common/comm/l0/devices/ccl_ipc_gpu_comm.cpp @@ -18,7 +18,7 @@ #include #include "common/comm/l0/devices/ccl_ipc_gpu_comm.hpp" #include "sched/sched.hpp" -#include "sched/entry/l0/l0_entry.hpp" +// #include "sched/entry/l0/l0_entry.hpp" #include "common/comm/l0/modules/specific_modules_source_data.hpp" namespace native { @@ -68,11 +68,11 @@ ccl_ipc_gpu_comm::ccl_ipc_gpu_comm(ccl_device& assigned_device, } } - LOG_INFO("Created ", name_impl(), ", addr: ", reinterpret_cast(this)); + LOG_DEBUG("Created ", name_impl(), ", addr: ", reinterpret_cast(this)); } ccl_ipc_gpu_comm::~ccl_ipc_gpu_comm() { - LOG_INFO("Destroyed ", name_impl(), ", addr: ", reinterpret_cast(this)); + LOG_DEBUG("Destroyed ", name_impl(), ", addr: ", reinterpret_cast(this)); } ccl_ipc_gpu_comm::supported_modules& ccl_ipc_gpu_comm::get_registered_modules() { diff --git a/src/common/comm/l0/devices/communication_structs/connection.cpp b/src/common/comm/l0/devices/communication_structs/connection.cpp index 31ea06460..277bbecd9 100644 --- a/src/common/comm/l0/devices/communication_structs/connection.cpp +++ b/src/common/comm/l0/devices/communication_structs/connection.cpp @@ -82,7 +82,7 @@ ssize_t connection::send_msg_with_pid_data(const std::vector& data, } // fill regular data - struct msghdr msg = { 0 }; + struct msghdr msg = {}; struct iovec io = { .iov_base = const_cast(static_cast(data.data())), .iov_len = data.size() * sizeof(uint8_t) }; @@ -141,7 +141,7 @@ ssize_t connection::recv_msg_with_pid_data(std::vector& out_data_resize msg_buffer.iov_len = out_data_resized.size(); // prepare control data - struct msghdr msg_header = { 0 }; + struct msghdr msg_header = {}; msg_header.msg_iov = &msg_buffer; msg_header.msg_iovlen = 1; msg_header.msg_controllen = CMSG_SPACE(sizeof(fd_t) * out_pids_resized.size()); //sizeof(u.buf); diff --git a/src/common/comm/l0/devices/communication_structs/ipc_connection.cpp b/src/common/comm/l0/devices/communication_structs/ipc_connection.cpp index a48ba6424..d9dbb09a1 100644 --- a/src/common/comm/l0/devices/communication_structs/ipc_connection.cpp +++ b/src/common/comm/l0/devices/communication_structs/ipc_connection.cpp @@ -250,7 +250,8 @@ std::vector ipc_tx_connection::send_ipc_memory_ext( for (const auto& ipc_handle : handles) { serialize_offset += ipc_handle.serialize(out_raw_data, serialize_offset); pids_offset_bytes.push_back(serialize_offset - - sizeof(native::ccl_device::device_ipc_memory_handle::handle_t)); + sizeof(native::ccl_device::device_ipc_memory_handle::handle_t) - + sizeof(size_t)); LOG_DEBUG("Serialized bytes: ", serialize_offset, diff --git a/src/common/comm/l0/devices/communication_structs/ipc_server.cpp b/src/common/comm/l0/devices/communication_structs/ipc_server.cpp index 284ebb35e..4259b8b99 100644 --- a/src/common/comm/l0/devices/communication_structs/ipc_server.cpp +++ b/src/common/comm/l0/devices/communication_structs/ipc_server.cpp @@ -98,7 +98,7 @@ void ipc_server::start(const std::string& path, int expected_backlog_size) { bool ipc_server::stop() { bool ret = false; if (is_ready()) { - LOG_INFO("Gracefully stop listener: ", listen_fd); + LOG_DEBUG("Gracefully stop listener: ", listen_fd); shutdown(listen_fd, SHUT_RDWR); close(listen_fd); listen_fd = -1; diff --git a/src/common/comm/l0/gpu_comm_attr.cpp b/src/common/comm/l0/gpu_comm_attr.cpp index 1a24e56a2..7b348668d 100644 --- a/src/common/comm/l0/gpu_comm_attr.cpp +++ b/src/common/comm/l0/gpu_comm_attr.cpp @@ -73,13 +73,13 @@ bool gpu_comm_attr::sync_register_communicator(std::shared_ptr>; diff --git a/src/common/comm/l0/gpu_comm_utils.hpp b/src/common/comm/l0/gpu_comm_utils.hpp index f0cc36a6a..62177e9ca 100644 --- a/src/common/comm/l0/gpu_comm_utils.hpp +++ b/src/common/comm/l0/gpu_comm_utils.hpp @@ -84,8 +84,7 @@ struct module_loader { module_description.stype = ZE_STRUCTURE_TYPE_MODULE_DESC; module_description.pNext = nullptr; module_description.format = ZE_MODULE_FORMAT_IL_SPIRV; - module_description.inputSize = - static_cast(module_data.size()); //Ask L0: why not size_t? + module_description.inputSize = module_data.size(); module_description.pInputModule = module_data.data(); module_description.pBuildFlags = nullptr; module_description.pConstants = nullptr; diff --git a/src/common/comm/l0/modules/kernel_functions.hpp b/src/common/comm/l0/modules/kernel_functions.hpp index e3765ce91..8e9c8d446 100644 --- a/src/common/comm/l0/modules/kernel_functions.hpp +++ b/src/common/comm/l0/modules/kernel_functions.hpp @@ -154,7 +154,9 @@ struct execution_kernel : public kernel_data_storage(new_val); @@ -182,7 +184,11 @@ struct execution_kernel : public kernel_data_storage(new_val); diff --git a/src/common/comm/l0/modules/kernel_utils.hpp b/src/common/comm/l0/modules/kernel_utils.hpp index fc4b82804..6acef3623 100644 --- a/src/common/comm/l0/modules/kernel_utils.hpp +++ b/src/common/comm/l0/modules/kernel_utils.hpp @@ -22,7 +22,8 @@ namespace native { namespace detail { +std::string to_string(ccl::reduction red); std::string get_kernel_name(const std::string& kernel_name, const coll_param_gpu& params); -} +} // namespace detail } // namespace native diff --git a/src/common/comm/l0/modules/ring/allreduce_export_functions.hpp b/src/common/comm/l0/modules/ring/allreduce_export_functions.hpp index 06152f64b..5bfcf24b9 100644 --- a/src/common/comm/l0/modules/ring/allreduce_export_functions.hpp +++ b/src/common/comm/l0/modules/ring/allreduce_export_functions.hpp @@ -31,33 +31,18 @@ using send_buf_size_arg = arg; using send_buf_size_arg_type = typename send_buf_size_arg::arg_type; template -using send_buf_arg = arg; +using send_buf_arg = thread_exchangable_arg; template -using recv_buf_arg = arg; +using recv_buf_arg = thread_exchangable_arg; template -using tmp_recv_buf_arg = external_arg; +using right_send_buf_arg = + thread_exchangable_arg; -using income_data_flag_arg = external_arg; -using income_data_flag_arg_type = typename income_data_flag_arg::arg_type; - -using ready_to_recv_flag_arg = external_arg; -using ready_to_recv_flag_arg_type = typename ready_to_recv_flag_arg::arg_type; - -using local_barrier_flag_arg = arg; -using local_barrier_flag_arg_type = typename local_barrier_flag_arg::arg_type; - -// right template -using right_tmp_recv_buf_arg = - thread_exchangable_arg; - -using right_income_data_flag_arg = - thread_exchangable_arg; - -using right_ready_to_recv_flag_arg = - thread_exchangable_arg; +using right_recv_buf_arg = + thread_exchangable_arg; // IMPORTANT: the number and types of arguments must be the same in all classes, // excluding arguments specific for numa/scaleout etc. @@ -65,13 +50,8 @@ struct main_kernel : public execution_kernel, recv_buf_arg, - tmp_recv_buf_arg, - income_data_flag_arg, - ready_to_recv_flag_arg, - local_barrier_flag_arg, - right_tmp_recv_buf_arg, - right_income_data_flag_arg, - right_ready_to_recv_flag_arg> { + right_send_buf_arg, + right_recv_buf_arg> { using processing_type = void; static constexpr const char* specific_name() { @@ -85,36 +65,18 @@ struct main_kernel : public execution_kernel, recv_buf_arg, - tmp_recv_buf_arg, - income_data_flag_arg, - ready_to_recv_flag_arg, - local_barrier_flag_arg, - right_tmp_recv_buf_arg, - right_income_data_flag_arg, - right_ready_to_recv_flag_arg>; + right_send_buf_arg, + right_recv_buf_arg>; using base::base; }; -struct numa_kernel : public execution_kernel< - numa_kernel, - send_buf_size_arg, - send_buf_arg, - recv_buf_arg, - tmp_recv_buf_arg, - income_data_flag_arg, - ready_to_recv_flag_arg, - local_barrier_flag_arg, - right_tmp_recv_buf_arg, - right_income_data_flag_arg, - right_ready_to_recv_flag_arg, - - // numa-specific args - permanent_arg, - permanent_arg, - permanent_arg, - permanent_arg, - permanent_arg> { +struct numa_kernel : public execution_kernel, + recv_buf_arg, + right_send_buf_arg, + right_recv_buf_arg> { using processing_type = void; static constexpr const char* specific_name() { @@ -124,54 +86,18 @@ struct numa_kernel : public execution_kernel< using common_entry_buf_size_arg = send_buf_size_arg; using common_entry_buf_arg = send_buf_arg; - // event data - using event_prod_chunk_mem_arg = permanent_arg; - using event_prod_chunk_mem_arg_type = typename event_prod_chunk_mem_arg::arg_type; - - using event_prod_bytes_arg = permanent_arg; - using event_prod_bytes_arg_type = typename event_prod_bytes_arg::arg_type; - - using event_consumed_bytes_offset_arg = - permanent_arg; - using event_consumed_bytes_offset_arg_type = typename event_consumed_bytes_offset_arg::arg_type; - - using event_consumed_chunk_mem_arg = - permanent_arg; - using event_consumed_chunk_mem_arg_type = typename event_consumed_chunk_mem_arg::arg_type; - - using event_consumed_bytes_arg = - permanent_arg; - using event_consumed_bytes_arg_type = typename event_consumed_bytes_arg::arg_type; - using base = execution_kernel, recv_buf_arg, - tmp_recv_buf_arg, - income_data_flag_arg, - ready_to_recv_flag_arg, - local_barrier_flag_arg, - right_tmp_recv_buf_arg, - right_income_data_flag_arg, - right_ready_to_recv_flag_arg, - event_prod_chunk_mem_arg, - event_prod_bytes_arg, - event_consumed_bytes_offset_arg, - event_consumed_chunk_mem_arg, - event_consumed_bytes_arg>; + right_send_buf_arg, + right_recv_buf_arg>; template void bind_data(const ctx_params_t& out_ctx_params) { - this->template set_arg( - static_cast(out_ctx_params.host_mem_producer->get())); - this->template set_arg( - out_ctx_params.host_mem_producer_counter->get()); - this->template set_arg( - out_ctx_params.producer_aggregated_memory_offset->get()); - this->template set_arg( - static_cast(out_ctx_params.dev_mem_consumer->get())); - this->template set_arg( - out_ctx_params.dev_mem_consumer_counter->get()); + // TODO not implemented + (void)out_ctx_params; + throw ccl::exception(std::string(__FUNCTION__) + " - not implemented for that kernel type"); } using base::base; @@ -179,15 +105,10 @@ struct numa_kernel : public execution_kernel< struct ipc_kernel : public base_ipc_kernel, - stub_arg, + send_buf_arg, + recv_buf_arg, stub_arg, - tmp_recv_buf_arg, - income_data_flag_arg, - ready_to_recv_flag_arg, - stub_arg, - stub_arg, - stub_arg, - stub_arg> { + stub_arg> { using processing_type = void; using common_entry_buf_size_arg = send_buf_size_arg; @@ -199,54 +120,31 @@ struct ipc_kernel : public base_ipc_kernel, - stub_arg, + send_buf_arg, + recv_buf_arg, stub_arg, - tmp_recv_buf_arg, - income_data_flag_arg, - ready_to_recv_flag_arg, - stub_arg, - stub_arg, - stub_arg, - stub_arg>; + stub_arg>; template void bind_data(const ipc_handles_t& ipc_handles) { - auto tmp_recv_buf = reinterpret_cast::arg_type>( + auto send_buf = reinterpret_cast::arg_type>( ipc_handles.at(0).get().pointer); - this->template set_arg>(tmp_recv_buf); + this->template set_arg>(send_buf); - auto income_data_flag = - reinterpret_cast(ipc_handles.at(1).get().pointer); - this->template set_arg(income_data_flag); - - auto ready_to_recv_flag = - reinterpret_cast(ipc_handles.at(2).get().pointer); - this->template set_arg(ready_to_recv_flag); + auto recv_buf = reinterpret_cast::arg_type>( + ipc_handles.at(1).get().pointer); + this->template set_arg>(recv_buf); } using base::base; }; -struct scale_out_cpu_gw_kernel - : public execution_kernel< - scale_out_cpu_gw_kernel, - send_buf_size_arg, - send_buf_arg, - recv_buf_arg, - tmp_recv_buf_arg, - income_data_flag_arg, - ready_to_recv_flag_arg, - local_barrier_flag_arg, - right_tmp_recv_buf_arg, - right_income_data_flag_arg, - right_ready_to_recv_flag_arg, - - // scaleout-specific args - permanent_arg, - permanent_arg, - permanent_arg, - permanent_arg, - permanent_arg> { +struct scale_out_cpu_gw_kernel : public execution_kernel, + recv_buf_arg, + right_send_buf_arg, + right_recv_buf_arg> { using processing_type = void; static constexpr const char* specific_name() { @@ -256,54 +154,18 @@ struct scale_out_cpu_gw_kernel using common_entry_buf_size_arg = send_buf_size_arg; using common_entry_buf_arg = send_buf_arg; - // event data - using event_prod_chunk_mem_arg = permanent_arg; - using event_prod_chunk_mem_arg_type = typename event_prod_chunk_mem_arg::arg_type; - - using event_prod_bytes_arg = permanent_arg; - using event_prod_bytes_arg_type = typename event_prod_bytes_arg::arg_type; - - using event_consumed_bytes_offset_arg = - permanent_arg; - using event_consumed_bytes_offset_arg_type = typename event_consumed_bytes_offset_arg::arg_type; - - using event_consumed_chunk_mem_arg = - permanent_arg; - using event_consumed_chunk_mem_arg_type = typename event_consumed_chunk_mem_arg::arg_type; - - using event_consumed_bytes_arg = - permanent_arg; - using event_consumed_bytes_arg_type = typename event_consumed_bytes_arg::arg_type; - using base = execution_kernel, recv_buf_arg, - tmp_recv_buf_arg, - income_data_flag_arg, - ready_to_recv_flag_arg, - local_barrier_flag_arg, - right_tmp_recv_buf_arg, - right_income_data_flag_arg, - right_ready_to_recv_flag_arg, - event_prod_chunk_mem_arg, - event_prod_bytes_arg, - event_consumed_bytes_offset_arg, - event_consumed_chunk_mem_arg, - event_consumed_bytes_arg>; + right_send_buf_arg, + right_recv_buf_arg>; template void bind_data(const ctx_params_t& out_ctx_params) { - this->template set_arg( - static_cast(out_ctx_params.host_mem_producer->get())); - this->template set_arg( - out_ctx_params.host_mem_producer_counter->get()); - this->template set_arg( - out_ctx_params.producer_aggregated_memory_offset->get()); - this->template set_arg( - static_cast(out_ctx_params.dev_mem_consumer->get())); - this->template set_arg( - out_ctx_params.dev_mem_consumer_counter->get()); + // TODO not implemented + (void)out_ctx_params; + throw ccl::exception(std::string(__FUNCTION__) + " - not implemented for that kernel type"); } using base::base; diff --git a/src/common/comm/l0/modules/ring/bcast_export_functions.hpp b/src/common/comm/l0/modules/ring/bcast_export_functions.hpp index 1d9a610ac..67bd72301 100644 --- a/src/common/comm/l0/modules/ring/bcast_export_functions.hpp +++ b/src/common/comm/l0/modules/ring/bcast_export_functions.hpp @@ -26,85 +26,37 @@ namespace bcast { * Common args for all kernel types */ -using buf_size_arg = arg; -using buf_size_arg_type = typename buf_size_arg::arg_type; - template -using buf_arg = thread_exchangable_arg; - -using income_data_flag_arg = external_arg; -using income_data_flag_arg_type = typename income_data_flag_arg::arg_type; - -using ready_to_recv_flag_arg = external_arg; -using ready_to_recv_flag_arg_type = typename ready_to_recv_flag_arg::arg_type; - -using local_barrier_flag_arg = arg; -using local_barrier_flag_arg_type = typename local_barrier_flag_arg::arg_type; +using buf_arg = thread_exchangable_arg; template -using right_buf_arg = thread_exchangable_arg; +using right_buf_arg = thread_exchangable_arg; -using right_income_data_flag_arg = - thread_exchangable_arg; -using right_income_data_flag_arg_type = typename right_income_data_flag_arg::arg_type; - -using right_ready_to_recv_flag_arg = - thread_exchangable_arg; -using right_ready_to_recv_flag_arg_type = typename right_ready_to_recv_flag_arg::arg_type; - -using root_arg = arg; +using root_arg = arg; using root_arg_type = typename root_arg::arg_type; // IMPORTANT: the number and types of arguments must be the same in all classes, // excluding arguments specific for numa/scaleout etc. -struct main_kernel : public execution_kernel, - income_data_flag_arg, - ready_to_recv_flag_arg, - local_barrier_flag_arg, - right_buf_arg, - right_income_data_flag_arg, - right_ready_to_recv_flag_arg, - root_arg> { +struct main_kernel + : public execution_kernel, right_buf_arg, root_arg> { using processing_type = void; static constexpr const char* specific_name() { return "bcast_execution"; } - using common_entry_buf_size_arg = buf_size_arg; using common_entry_buf_arg = buf_arg; using base = execution_kernel, - income_data_flag_arg, - ready_to_recv_flag_arg, - local_barrier_flag_arg, right_buf_arg, - right_income_data_flag_arg, - right_ready_to_recv_flag_arg, root_arg>; using base::base; }; struct numa_kernel - : public execution_kernel, - income_data_flag_arg, - ready_to_recv_flag_arg, - local_barrier_flag_arg, - right_buf_arg, - right_income_data_flag_arg, - right_ready_to_recv_flag_arg, - root_arg, - - // numa-specific args - permanent_arg, - permanent_arg> { + : public execution_kernel, right_buf_arg, root_arg> { using processing_type = void; static constexpr const char* specific_name() { @@ -113,25 +65,10 @@ struct numa_kernel using common_entry_buf_arg = buf_arg; - // event data - using event_prod_chunk_mem_arg = permanent_arg; - using event_prod_chunk_mem_arg_type = typename event_prod_chunk_mem_arg::arg_type; - - using event_prod_bytes_arg = permanent_arg; - using event_prod_bytes_arg_type = typename event_prod_bytes_arg::arg_type; - using base = execution_kernel, - income_data_flag_arg, - ready_to_recv_flag_arg, - local_barrier_flag_arg, right_buf_arg, - right_income_data_flag_arg, - right_ready_to_recv_flag_arg, - root_arg, - event_prod_chunk_mem_arg, - event_prod_bytes_arg>; + root_arg>; template void bind_data(const ctx_params_t& out_ctx_params) { @@ -144,15 +81,9 @@ struct numa_kernel }; struct ipc_kernel : public base_ipc_kernel, buf_arg, - income_data_flag_arg, - ready_to_recv_flag_arg, - stub_arg, - stub_arg, - stub_arg, - stub_arg, - stub_arg> { + stub_arg, + stub_arg> { using processing_type = void; static constexpr const char* specific_name() { @@ -162,49 +93,24 @@ struct ipc_kernel : public base_ipc_kernel; using base = base_ipc_kernel, buf_arg, - income_data_flag_arg, - ready_to_recv_flag_arg, - stub_arg, - stub_arg, - stub_arg, - stub_arg, - stub_arg>; + stub_arg, + stub_arg>; template void bind_data(const ipc_handles_t& ipc_handles) { auto recv_buf = reinterpret_cast::arg_type>( ipc_handles.at(0).get().pointer); this->template set_arg>(recv_buf); - - auto income_data_flag = - reinterpret_cast(ipc_handles.at(1).get().pointer); - this->template set_arg(income_data_flag); - - auto ready_to_recv_flag = - reinterpret_cast(ipc_handles.at(2).get().pointer); - this->template set_arg(ready_to_recv_flag); } using base::base; }; -struct scale_out_cpu_gw_kernel - : public execution_kernel, - income_data_flag_arg, - ready_to_recv_flag_arg, - local_barrier_flag_arg, - right_buf_arg, - right_income_data_flag_arg, - right_ready_to_recv_flag_arg, - root_arg, - - // scaleout-specific args - permanent_arg, - permanent_arg> { +struct scale_out_cpu_gw_kernel : public execution_kernel, + right_buf_arg, + root_arg> { using processing_type = void; static constexpr const char* specific_name() { @@ -213,26 +119,10 @@ struct scale_out_cpu_gw_kernel using common_entry_buf_arg = buf_arg; - // event data - using event_prod_chunk_mem_arg = - permanent_arg; - using event_prod_chunk_mem_arg_type = typename event_prod_chunk_mem_arg::arg_type; - - using event_prod_bytes_arg = permanent_arg; - using event_prod_bytes_arg_type = typename event_prod_bytes_arg::arg_type; - using base = execution_kernel, - income_data_flag_arg, - ready_to_recv_flag_arg, - local_barrier_flag_arg, right_buf_arg, - right_income_data_flag_arg, - right_ready_to_recv_flag_arg, - root_arg, - event_prod_chunk_mem_arg, - event_prod_bytes_arg>; + root_arg>; template void bind_data(const ctx_params_t& out_ctx_params) { diff --git a/src/common/comm/l0/scheduler/allied_process_group_scheduler.hpp b/src/common/comm/l0/scheduler/allied_process_group_scheduler.hpp index 2a2b119c2..f16c9cb8b 100644 --- a/src/common/comm/l0/scheduler/allied_process_group_scheduler.hpp +++ b/src/common/comm/l0/scheduler/allied_process_group_scheduler.hpp @@ -17,8 +17,8 @@ #pragma once #include "common/utils/spinlock.hpp" #include "sched/gpu_concurrent_sched.hpp" -#include "sched/entry/l0/l0_allreduce_typed_entry.hpp" -#include "sched/entry/l0/l0_allgather_handles_entry.hpp" +// #include "sched/entry/l0/l0_allreduce_typed_entry.hpp" +// #include "sched/entry/l0/l0_allgather_handles_entry.hpp" #include "sched/entry/factory/entry_factory.hpp" #include "common/comm/l0/device_community.hpp" #include "common/comm/l0/scheduler/thread_group_scheduler.hpp" @@ -42,8 +42,9 @@ struct allied_process_group_scheduler : public thread_group_scheduler { std::shared_ptr communicator, device_storage& node_devices) : base(threads_count), - ccl_communicator(communicator), - node_total_devices(node_devices) {} + ccl_communicator(communicator) /*, + node_total_devices(node_devices)*/ + {} template (current_thread_schedule.get(), device, device_topology.get_device_storage(), @@ -194,11 +197,12 @@ struct allied_process_group_scheduler : public thread_group_scheduler { auto req = submit_entry( process_id, thread_id, device_topology, device, ctx, std::forward(args)...); return req; + */ } private: std::shared_ptr ccl_communicator; - device_storage& node_total_devices; + // device_storage& node_total_devices; }; } // namespace native diff --git a/src/common/comm/l0/scheduler/device_group_scheduler.hpp b/src/common/comm/l0/scheduler/device_group_scheduler.hpp index c63852a15..3a46fd21a 100644 --- a/src/common/comm/l0/scheduler/device_group_scheduler.hpp +++ b/src/common/comm/l0/scheduler/device_group_scheduler.hpp @@ -15,7 +15,7 @@ */ #pragma once #include "sched/gpu_sched.hpp" -#include "sched/entry/l0/l0_allreduce_typed_entry.hpp" +// #include "sched/entry/l0/l0_allreduce_typed_entry.hpp" //#include "sched/entry/l0/l0_allgather_handles_entry.hpp" #include "sched/entry/factory/entry_factory.hpp" diff --git a/src/common/comm/l0/scheduler/thread_group_scheduler.hpp b/src/common/comm/l0/scheduler/thread_group_scheduler.hpp index 86ae17a57..37c8b3870 100644 --- a/src/common/comm/l0/scheduler/thread_group_scheduler.hpp +++ b/src/common/comm/l0/scheduler/thread_group_scheduler.hpp @@ -16,13 +16,13 @@ #pragma once #include "common/utils/spinlock.hpp" #include "sched/gpu_concurrent_sched.hpp" -#include "sched/entry/l0/l0_allreduce_typed_entry.hpp" -#include "sched/entry/l0/l0_allgatherv_typed_entry.hpp" -#include "sched/entry/l0/l0_alltoallv_typed_entry.hpp" -#include "sched/entry/l0/l0_bcast_typed_entry.hpp" -#include "sched/entry/l0/l0_reduce_typed_entry.hpp" -#include "sched/entry/l0/l0_reduce_scatter_typed_entry.hpp" -#include "sched/entry/l0/l0_allgatherv_typed_entry.hpp" +// #include "sched/entry/l0/l0_allreduce_typed_entry.hpp" +// #include "sched/entry/l0/l0_allgatherv_typed_entry.hpp" +// #include "sched/entry/l0/l0_alltoallv_typed_entry.hpp" +// #include "sched/entry/l0/l0_bcast_typed_entry.hpp" +// #include "sched/entry/l0/l0_reduce_typed_entry.hpp" +// #include "sched/entry/l0/l0_reduce_scatter_typed_entry.hpp" +// #include "sched/entry/l0/l0_allgatherv_typed_entry.hpp" //#include "sched/entry/l0/l0_allgather_handles_entry.hpp" #include "sched/entry/factory/entry_factory.hpp" #include "common/comm/l0/device_community.hpp" diff --git a/src/common/comm/l0/topology/ring/process_group_ring_creator.hpp b/src/common/comm/l0/topology/ring/process_group_ring_creator.hpp index 8b2a9f92d..2c77c54e0 100644 --- a/src/common/comm/l0/topology/ring/process_group_ring_creator.hpp +++ b/src/common/comm/l0/topology/ring/process_group_ring_creator.hpp @@ -49,6 +49,7 @@ class allied_process_group_ring_topology { size_t cluster_rank_offset, size_t cluster_size, const ccl::context_comm_addr& comm_addr = {}); + virtual ~allied_process_group_ring_topology() = default; static std::pair calculate_rank_offset_with_size( size_t process_id, const std::string& host_id, diff --git a/src/common/comm/l0/topology/topology_serializer.cpp b/src/common/comm/l0/topology/topology_serializer.cpp index 309e37979..dc588075a 100644 --- a/src/common/comm/l0/topology/topology_serializer.cpp +++ b/src/common/comm/l0/topology/topology_serializer.cpp @@ -274,7 +274,6 @@ std::map device_path_deserializer::deserialize_generic_indices_map_im size_t stride) { std::map global; size_t global_size = 0; - size_t deserialized_bytes_count = 0; // preconditions if (data.size() < sizeof(global_size)) { @@ -286,7 +285,6 @@ std::map device_path_deserializer::deserialize_generic_indices_map_im auto data_it = data.begin(); std::advance(data_it, sizeof(global_size)); - deserialized_bytes_count += sizeof(global_size); size_t deserialized_processes_count = 0; for (; data_it != data.end();) { @@ -303,7 +301,6 @@ std::map device_path_deserializer::deserialize_generic_indices_map_im } memcpy(&process_id, &(*data_it), expected_count); std::advance(data_it, expected_count); - deserialized_bytes_count += expected_count; //get graph_data for process size_t process_deserialized_count = 0; @@ -311,8 +308,6 @@ std::map device_path_deserializer::deserialize_generic_indices_map_im typename T::value_type>( raw_data_t(data_it, data.end()), process_deserialized_count, 0, stride); std::advance(data_it, process_deserialized_count); - deserialized_bytes_count += process_deserialized_count; - if (!global.emplace(process_id, std::move(process_list)).second) { throw std::runtime_error( std::string(__FUNCTION__) + diff --git a/src/common/comm/single_device_communicator/single_device_base.hpp b/src/common/comm/single_device_communicator/single_device_base.hpp deleted file mode 100644 index 8062ca178..000000000 --- a/src/common/comm/single_device_communicator/single_device_base.hpp +++ /dev/null @@ -1,76 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "common/comm/l0/communicator/base_communicator.hpp" - -template -class typed_single_device_base_communicator : public base_communicator { -public: - using base_t = base_communicator; - using impl_t = comm_impl; - using self_t = typed_single_device_base_communicator; - using traits = communicator_traits; - - // Topologies - static constexpr ccl::group_split_type topology_type() { - return ccl::group_split_type::single; - } - - static constexpr ccl::device_topology_type topology_class() { - return ccl::device_topology_type::undetermined; - } - - // traits - bool is_host() const noexcept override { - return traits::is_host(); - } - - bool is_cpu() const noexcept override { - return traits::is_cpu(); - } - - bool is_gpu() const noexcept override { - return traits::is_gpu(); - } - - bool is_accelerator() const noexcept override { - return traits::is_accelerator(); - } - - typed_single_device_base_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& context, - size_t thread_idx, - size_t process_idx, - const ccl::comm_split_attr& attr); - - ccl::group_split_type get_topology_type() const override; - ccl::device_topology_type get_topology_class() const override; - - bool is_ready() const override; - - COMM_INTERFACE_COLL_METHODS(DEFINITION); -#ifdef CCL_ENABLE_SYCL - SYCL_COMM_INTERFACE_COLL_METHODS(DEFINITION); -#endif /* CCL_ENABLE_SYCL */ - - // troubleshooting - std::string to_string() const; - - impl_t* get_impl() { - return static_cast(this); - } -}; diff --git a/src/common/comm/single_device_communicator/single_device_base_impl.hpp b/src/common/comm/single_device_communicator/single_device_base_impl.hpp deleted file mode 100644 index 5f752b3a4..000000000 --- a/src/common/comm/single_device_communicator/single_device_base_impl.hpp +++ /dev/null @@ -1,76 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "oneapi/ccl/types.hpp" -#include "oneapi/ccl/type_traits.hpp" -#include "common/comm/single_device_communicator/single_device_base.hpp" - -#define TEMPLATE_DECL_ARG class comm_impl, class communicator_traits -#define TEMPLATE_DEF_ARG comm_impl, communicator_traits - -template -typed_single_device_base_communicator::typed_single_device_base_communicator( - ccl::unified_device_type&& owned_device, - ccl::unified_context_type&& context, - size_t thread_idx, - size_t process_idx, - const ccl::comm_split_attr& attr) - : base_communicator(std::move(owned_device), - std::move(context), - thread_idx, - process_idx /*, comm_attr*/, - attr) { - try { - LOG_INFO("sheduled for create, device id: ", - device.get_id(), - ", thread_id: ", - thread_idx, - ", process id:", - process_idx); - } - catch (...) { - LOG_INFO("sheduled for create single device communicator , thread_id: ", - thread_idx, - ", process id:", - process_idx); - } -} - -template -bool typed_single_device_base_communicator::is_ready() const { - return true; -} - -template -ccl::group_split_type typed_single_device_base_communicator::get_topology_type() - const { - return self_t::topology_type(); -} - -template -ccl::device_topology_type -typed_single_device_base_communicator::get_topology_class() const { - return self_t::topology_class(); -} - -template -std::string typed_single_device_base_communicator::to_string() const { - return std::string("single device communicator, rank (") + std::to_string(rank()) + "/" + - std::to_string(size()); -} - -#undef TEMPLATE_DECL_ARG -#undef TEMPLATE_DEF_ARG diff --git a/src/common/comm/single_device_communicator/single_device_communicator.cpp b/src/common/comm/single_device_communicator/single_device_communicator.cpp deleted file mode 100644 index a24f4a803..000000000 --- a/src/common/comm/single_device_communicator/single_device_communicator.cpp +++ /dev/null @@ -1,308 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" -#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) -#include "common/comm/single_device_communicator/single_device_communicator_impl.hpp" -#ifdef MULTI_GPU_SUPPORT -#include "common/comm/l0/gpu_comm_attr.hpp" -#include "common/comm/l0/context/thread_group_ctx.hpp" -#include "common/comm/l0/context/process_group_ctx.hpp" -#endif //MULTI_GPU_SUPPORT -using namespace ccl; - -single_device_communicator::single_device_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& context, - size_t thread_idx, - size_t process_idx, - const ccl::comm_split_attr& attr) - : base_t(std::move(device), - std::move(context), - thread_idx, - process_idx /*, comm_attr*/, - attr) {} - -single_device_communicator::~single_device_communicator() {} - -std::shared_ptr single_device_communicator::split( - const ccl::comm_split_attr& attr) { - // TODO - throw ccl::exception(std::string(__FUNCTION__) + " - 'is not implemented"); - return {}; -} - -void single_device_communicator::set_ccl_comm(std::shared_ptr impl) { - CCL_ASSERT(!comm_impl, "comm_impl must be nullptr before first udage"); - comm_impl = impl; - - comm_rank = comm_impl->rank(); - comm_size = comm_impl->size(); -} - -//TODO use visit() to set `context` -void single_device_communicator::set_context( - const ccl::unified_context_type::ccl_native_t& in_context) { - context = in_context; -} -void single_device_communicator::set_context(const ccl::context& in_context) { - context = in_context.get_native(); -} - -#ifdef MULTI_GPU_SUPPORT -void single_device_communicator::visit(ccl::gpu_comm_attr& comm_attr) { - auto process_ctx = comm_attr.get_process_context(); - auto thread_ctx = process_ctx->get_thread_context(process_id); - auto device_ctx = thread_ctx->get_device_group_ctx(thread_id); - - //get rank & size - - /* this->initialize_comm_addr(get_device_path(), - ctx->get_group_topology()); -*/ - this->set_comm_group_id(comm_attr.get_unique_id()); -} -#endif -ccl::event single_device_communicator::barrier(const ccl::stream::impl_value_t& op_stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) { - // TODO what exactly we need to do with 'attr' here? - - ccl_barrier_impl(comm_impl.get(), op_stream.get(), deps); - - // TODO what exactly we need to return here? ccl_barrier_impl() is void func - return std::unique_ptr(new ccl::host_event_impl(nullptr)); -} - -/* allgatherv */ -ccl::event single_device_communicator::allgatherv_base_impl( - const void* send_buf, - size_t send_count, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl_coll_attr& attr, - const ccl::vector_class& deps) { - return ccl::event(std::unique_ptr( - new ccl::host_event_impl(ccl_allgatherv_impl(send_buf, - send_count, - recv_buf, - recv_counts.data(), - dtype, - attr, - comm_impl.get(), - stream.get(), - deps)))); -} - -ccl::event single_device_communicator::allgatherv_impl(const void* send_buf, - size_t send_count, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - ccl_coll_attr internal_attr(attr); - return allgatherv_base_impl( - send_buf, send_count, recv_buf, recv_counts, dtype, stream, internal_attr, deps); -} - -ccl::event single_device_communicator::allgatherv_impl(const void* send_buf, - size_t send_count, - const ccl::vector_class& recv_bufs, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - ccl_coll_attr internal_attr(attr); - internal_attr.vector_buf = 1; - return allgatherv_base_impl(send_buf, - send_count, - (void*)(recv_bufs.data()), - recv_counts, - dtype, - stream, - internal_attr, - deps); -} - -/* allreduce */ -ccl::event single_device_communicator::allreduce_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - return ccl::event(std::unique_ptr(new ccl::host_event_impl(ccl_allreduce_impl( - send_buf, recv_buf, count, dtype, reduction, attr, comm_impl.get(), stream.get(), deps)))); -} - -/* alltoall */ -ccl::event single_device_communicator::alltoall_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - return ccl::event(std::unique_ptr(new ccl::host_event_impl(ccl_alltoall_impl( - send_buf, recv_buf, count, dtype, attr, comm_impl.get(), stream.get(), deps)))); -} - -ccl::event single_device_communicator::alltoall_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -ccl::event single_device_communicator::alltoallv_impl(const void* send_buf, - const ccl::vector_class& send_counts, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - return ccl::event(std::unique_ptr( - new ccl::host_event_impl(ccl_alltoallv_impl(send_buf, - send_counts.data(), - recv_buf, - recv_counts.data(), - dtype, - attr, - comm_impl.get(), - stream.get(), - deps)))); -} -ccl::event single_device_communicator::alltoallv_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - ccl::vector_class recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -ccl::event single_device_communicator::broadcast_impl(void* buf, - size_t count, - ccl::datatype dtype, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - return ccl::event(std::unique_ptr(new ccl::host_event_impl( - ccl_broadcast_impl(buf, count, dtype, root, attr, comm_impl.get(), stream.get(), deps)))); -} - -/* reduce */ -ccl::event single_device_communicator::reduce_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - return ccl::event( - std::unique_ptr(new ccl::host_event_impl(ccl_reduce_impl(send_buf, - recv_buf, - count, - dtype, - reduction, - root, - attr, - comm_impl.get(), - stream.get(), - deps)))); -} - -/* reduce_scatter */ -ccl::event single_device_communicator::reduce_scatter_impl( - const void* send_buf, - void* recv_buf, - size_t recv_count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - return ccl::event(std::unique_ptr( - new ccl::host_event_impl(ccl_reduce_scatter_impl(send_buf, - recv_buf, - recv_count, - dtype, - reduction, - attr, - comm_impl.get(), - stream.get(), - deps)))); -} - -/* sparse_allreduce */ -ccl::event single_device_communicator::sparse_allreduce_impl( - const void* send_ind_buf, - size_t send_ind_count, - const void* send_val_buf, - size_t send_val_count, - void* recv_ind_buf, - size_t recv_ind_count, - void* recv_val_buf, - size_t recv_val_count, - ccl::datatype index_dtype, - ccl::datatype value_dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - return ccl::event(std::unique_ptr( - new ccl::host_event_impl(ccl_sparse_allreduce_impl(send_ind_buf, - send_ind_count, - send_val_buf, - send_val_count, - recv_ind_buf, - recv_ind_count, - recv_val_buf, - recv_val_count, - index_dtype, - value_dtype, - reduction, - attr, - comm_impl.get(), - stream.get(), - deps)))); -} - -COMM_INTERFACE_COLL_INSTANTIATION(single_device_communicator); -#ifdef CCL_ENABLE_SYCL -SYCL_COMM_INTERFACE_COLL_INSTANTIATION(single_device_communicator); -#endif /* CCL_ENABLE_SYCL */ - -#endif //#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) diff --git a/src/common/comm/single_device_communicator/single_device_communicator.hpp b/src/common/comm/single_device_communicator/single_device_communicator.hpp deleted file mode 100644 index 535ad692e..000000000 --- a/src/common/comm/single_device_communicator/single_device_communicator.hpp +++ /dev/null @@ -1,66 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) - -#include "common/comm/single_device_communicator/single_device_base.hpp" -#include "common/comm/comm.hpp" -#include "common/comm/usm_visitor/usm_visitors.hpp" - -namespace native { -struct device_group_context; -} - -class single_device_communicator - : public typed_single_device_base_communicator { -public: - using base_t = typed_single_device_base_communicator; - - single_device_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& context, - size_t thread_idx, - size_t proces_idx, - const ccl::comm_split_attr& attr); - - ~single_device_communicator(); - - std::shared_ptr split(const ccl::comm_split_attr& attr) override; - -#ifdef MULTI_GPU_SUPPORT - void visit(ccl::gpu_comm_attr& comm_attr) override; -#endif - ccl::event barrier(const ccl::stream::impl_value_t& op_stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) override; - - COMM_IMPL_DECLARATION - COMM_IMPL_CLASS_DECLARATION - COMM_IMPL_SPARSE_DECLARATION - COMM_IMPL_SPARSE_CLASS_DECLARATION - - void set_ccl_comm(std::shared_ptr imp); - - //TODO use visit() to set `context` - void set_context(const ccl::unified_context_type::ccl_native_t& context); - void set_context(const ccl::context& context); - -private: - std::shared_ptr comm_impl; -}; - -#endif //#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) diff --git a/src/common/comm/single_device_communicator/single_device_communicator_impl.hpp b/src/common/comm/single_device_communicator/single_device_communicator_impl.hpp deleted file mode 100644 index a57ca84c9..000000000 --- a/src/common/comm/single_device_communicator/single_device_communicator_impl.hpp +++ /dev/null @@ -1,460 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "common/comm/single_device_communicator/single_device_communicator.hpp" -#include "common/comm/single_device_communicator/single_device_base_impl.hpp" - -#include "oneapi/ccl/native_device_api/interop_utils.hpp" -#include "common/request/request.hpp" -#include "common/event/impls/host_event.hpp" -#include "common/event/impls/scoped_event.hpp" - -#include "coll/coll.hpp" -#include "coll/coll_common_attributes.hpp" - -/* allgatherv */ - -template -ccl::event single_device_communicator::allgatherv_base_impl( - const buffer_type* send_buf, - size_t send_count, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl_coll_attr& attr, - const ccl::vector_class& deps) { - return allgatherv_base_impl(send_buf, - send_count, - recv_buf, - recv_counts, - ccl::native_type_info::dtype, - stream, - attr, - deps); -} - -template -ccl::event single_device_communicator::allgatherv_impl(const buffer_type* send_buf, - size_t send_count, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - ccl_coll_attr internal_attr(attr); - return allgatherv_base_impl( - send_buf, send_count, recv_buf, recv_counts, stream, internal_attr, deps); -} - -template -ccl::event single_device_communicator::allgatherv_impl(const buffer_type* send_buf, - size_t send_count, - ccl::vector_class& recv_bufs, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - ccl_coll_attr internal_attr(attr); - internal_attr.vector_buf = 1; - return allgatherv_base_impl(send_buf, - send_count, - (buffer_type*)(recv_bufs.data()), - recv_counts, - stream, - internal_attr, - deps); -} - -template -ccl::event single_device_communicator::allgatherv_impl(const buffer_type& send_buf, - size_t send_count, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - ccl_request* req = ccl_allgatherv_impl(reinterpret_cast(&send_buf), - send_count, - reinterpret_cast(&recv_buf), - recv_counts.data(), - ccl::native_type_info::dtype, - attr, - comm_impl.get(), - stream.get(), - deps, - true); - return std::unique_ptr(new ccl::host_event_impl(req)); -} -template -ccl::event single_device_communicator::allgatherv_impl( - const buffer_type& send_buf, - size_t send_count, - ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* allreduce */ -template -ccl::event single_device_communicator::allreduce_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - return allreduce_impl(send_buf, - recv_buf, - count, - ccl::native_type_info::dtype, - reduction, - stream, - attr, - deps); -} - -template -ccl::event single_device_communicator::allreduce_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - ccl_request* req = ccl_allreduce_impl(reinterpret_cast(&send_buf), - reinterpret_cast(&recv_buf), - count, - ccl::native_type_info::dtype, - reduction, - attr, - comm_impl.get(), - stream.get(), - deps, - true); - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -/* alltoall */ -template -ccl::event single_device_communicator::alltoall_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - return alltoall_impl( - send_buf, recv_buf, count, ccl::native_type_info::dtype, stream, attr, deps); -} - -template -ccl::event single_device_communicator::alltoall_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event single_device_communicator::alltoall_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - ccl_request* req = ccl_alltoall_impl(reinterpret_cast(&send_buf), - reinterpret_cast(&recv_buf), - count, - ccl::native_type_info::dtype, - attr, - comm_impl.get(), - stream.get(), - deps, - true); - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -template -ccl::event single_device_communicator::alltoall_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class>& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -template -ccl::event single_device_communicator::alltoallv_impl(const buffer_type* send_buf, - const ccl::vector_class& send_counts, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - return alltoallv_impl(send_buf, - send_counts, - recv_buf, - recv_counts, - ccl::native_type_info::dtype, - stream, - attr, - deps); -} - -template -ccl::event single_device_communicator::alltoallv_impl( - const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -template -ccl::event single_device_communicator::alltoallv_impl(const buffer_type& send_buf, - const ccl::vector_class& send_counts, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - ccl_request* req = ccl_alltoallv_impl(reinterpret_cast(&send_buf), - send_counts.data(), - reinterpret_cast(&recv_buf), - recv_counts.data(), - ccl::native_type_info::dtype, - attr, - comm_impl.get(), - stream.get(), - deps, - true); - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -template -ccl::event single_device_communicator::alltoallv_impl( - const ccl::vector_class>& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class>& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& dep) { - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -template -ccl::event single_device_communicator::broadcast_impl(buffer_type* buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - return broadcast_impl( - buf, count, ccl::native_type_info::dtype, root, stream, attr, deps); -} - -template -ccl::event single_device_communicator::broadcast_impl(buffer_type& buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - ccl_request* req = ccl_broadcast_impl(reinterpret_cast(&buf), - count, - ccl::native_type_info::dtype, - root, - attr, - comm_impl.get(), - stream.get(), - deps, - true); - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -/* reduce */ -template -ccl::event single_device_communicator::reduce_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - return reduce_impl(send_buf, - recv_buf, - count, - ccl::native_type_info::dtype, - reduction, - root, - stream, - attr, - deps); -} - -template -ccl::event single_device_communicator::reduce_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - const ccl_stream* stream_ptr = stream.get(); - - ccl_request* req = ccl_reduce_impl(reinterpret_cast(&send_buf), - reinterpret_cast(&recv_buf), - count, - ccl::native_type_info::dtype, - reduction, - root, - attr, - comm_impl.get(), - stream_ptr, - deps, - true); - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -/* reduce_scatter */ -template -ccl::event single_device_communicator::reduce_scatter_impl( - const buffer_type* send_buf, - buffer_type* recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - return reduce_scatter_impl(send_buf, - recv_buf, - recv_count, - ccl::native_type_info::dtype, - reduction, - stream, - attr, - deps); -} - -template -ccl::event single_device_communicator::reduce_scatter_impl( - const buffer_type& send_buf, - buffer_type& recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - const ccl_stream* stream_ptr = stream.get(); - ccl_request* req = ccl_reduce_scatter_impl(reinterpret_cast(&send_buf), - reinterpret_cast(&recv_buf), - recv_count, - ccl::native_type_info::dtype, - reduction, - attr, - comm_impl.get(), - stream_ptr, - deps, - true); - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -/* sparse_allreduce */ -template -ccl::event single_device_communicator::sparse_allreduce_impl( - const index_buffer_type* send_ind_buf, - size_t send_ind_count, - const value_buffer_type* send_val_buf, - size_t send_val_count, - index_buffer_type* recv_ind_buf, - size_t recv_ind_count, - value_buffer_type* recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - return sparse_allreduce_impl(send_ind_buf, - send_ind_count, - send_val_buf, - send_val_count, - recv_ind_buf, - recv_ind_count, - recv_val_buf, - recv_val_count, - ccl::native_type_info::dtype, - ccl::native_type_info::dtype, - reduction, - stream, - attr, - deps); -} - -template -ccl::event single_device_communicator::sparse_allreduce_impl( - const index_buffer_container_type& send_ind_buf, - size_t send_ind_count, - const value_buffer_container_type& send_val_buf, - size_t send_val_count, - index_buffer_container_type& recv_ind_buf, - size_t recv_ind_count, - value_buffer_container_type& recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - const ccl_stream* stream_ptr = stream.get(); - - ccl_request* req = - ccl_sparse_allreduce_impl(reinterpret_cast(&send_ind_buf), - send_ind_count, - reinterpret_cast(&send_val_buf), - send_val_count, - reinterpret_cast(&recv_ind_buf), - recv_ind_count, - reinterpret_cast(&recv_val_buf), - recv_val_count, - ccl::native_type_info::dtype, - ccl::native_type_info::dtype, - reduction, - attr, - comm_impl.get(), - stream_ptr, - deps, - true); - return std::unique_ptr(new ccl::host_event_impl(req)); -} diff --git a/src/common/comm/usm_visitor/allgather_usm_visitor.hpp b/src/common/comm/usm_visitor/allgather_usm_visitor.hpp deleted file mode 100644 index a638dda18..000000000 --- a/src/common/comm/usm_visitor/allgather_usm_visitor.hpp +++ /dev/null @@ -1,171 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" - -template -struct allgather_usm_visitor { - using self_t = communicator_impl; - - self_t* get_self() { - return static_cast(this); - } - - const self_t* get_self() const { - return static_cast(const_cast(this)->get_self()); - } - - template - bool visit(ccl::event& req, - ccl::datatype dtype, - const void* send_buf, - size_t send_count, - void* recv_buf, - Args&&... args) { - bool processed = false; - LOG_TRACE("comm: ", - /*get_self()->to_string(),*/ - " - starting to find visitor for datatype: ", - ccl::to_string(dtype), - " , handle: ", - utils::enum_to_underlying(dtype)); - - switch (dtype) { - case ccl::datatype::int8: { - using type = int8_t; - req = get_self()->template allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint8: { - using type = uint8_t; - req = get_self()->template allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int16: { - using type = int16_t; - req = get_self()->template allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint16: { - using type = uint16_t; - req = get_self()->template allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int32: { - using type = int32_t; - req = get_self()->template allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint32: { - using type = uint32_t; - req = get_self()->template allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int64: { - using type = int64_t; - req = get_self()->template allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint64: { - using type = uint64_t; - req = get_self()->template allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float16: { - using type = ccl::float16; - req = get_self()->template allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float32: { - using type = float; - req = get_self()->template allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float64: { - using type = double; - req = get_self()->template allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::bfloat16: { - using type = ccl::bfloat16; - req = get_self()->template allgatherv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - std::forward(args)...); - processed = true; - break; - } - default: { - CCL_THROW("unknown datatype ", dtype); - LOG_DEBUG("comm: ", - /*get_self()->to_string(),*/ - "-no found visitor for datatype: ", - ccl::to_string(dtype), - " , handle: ", - utils::enum_to_underlying(dtype), - ", use RAW types"); - break; - } - } - return processed; - } -}; diff --git a/src/common/comm/usm_visitor/allreduce_usm_visitor.hpp b/src/common/comm/usm_visitor/allreduce_usm_visitor.hpp deleted file mode 100644 index 523abb41b..000000000 --- a/src/common/comm/usm_visitor/allreduce_usm_visitor.hpp +++ /dev/null @@ -1,54 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" - -template -struct allreduce_usm_visitor { - using self_t = communicator_impl; - - self_t* get_self() { - return static_cast(this); - } - - const self_t* get_self() const { - return static_cast(const_cast(this)->get_self()); - } - - template - bool visit(ccl::event& req, - ccl::datatype dtype, - const void* send_buf, - void* recv_buf, - size_t count, - Args&&... args) { - bool processed = false; - LOG_TRACE("comm: ", - /*get_self()->to_string(),*/ - " - starting to find visitor for datatype: ", - ccl::to_string(dtype), - " , handle: ", - utils::enum_to_underlying(dtype)); - req = get_self()->template allreduce_impl((const uint8_t*)(const void*)send_buf, - (uint8_t*)(void*)recv_buf, - count, - std::forward(args)...); - - return processed; - } -}; diff --git a/src/common/comm/usm_visitor/alltoall_usm_visitor.hpp b/src/common/comm/usm_visitor/alltoall_usm_visitor.hpp deleted file mode 100644 index 498291c71..000000000 --- a/src/common/comm/usm_visitor/alltoall_usm_visitor.hpp +++ /dev/null @@ -1,171 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" - -template -struct alltoall_usm_visitor { - using self_t = communicator_impl; - - self_t* get_self() { - return static_cast(this); - } - - const self_t* get_self() const { - return static_cast(const_cast(this)->get_self()); - } - - template - bool visit(ccl::event& req, - ccl::datatype dtype, - const void* send_buf, - void* recv_buf, - size_t count, - Args&&... args) { - bool processed = false; - LOG_TRACE("comm: ", - /*get_self()->to_string(),*/ - " - starting to find visitor for datatype: ", - ccl::to_string(dtype), - " , handle: ", - utils::enum_to_underlying(dtype)); - - switch (dtype) { - case ccl::datatype::int8: { - using type = int8_t; - req = get_self()->template alltoall_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint8: { - using type = uint8_t; - req = get_self()->template alltoall_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int16: { - using type = uint16_t; - req = get_self()->template alltoall_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint16: { - using type = uint16_t; - req = get_self()->template alltoall_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int32: { - using type = int32_t; - req = get_self()->template alltoall_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint32: { - using type = uint32_t; - req = get_self()->template alltoall_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int64: { - using type = int64_t; - req = get_self()->template alltoall_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint64: { - using type = uint64_t; - req = get_self()->template alltoall_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float16: { - using type = ccl::float16; - req = get_self()->template alltoall_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float32: { - using type = float; - req = get_self()->template alltoall_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float64: { - using type = double; - req = get_self()->template alltoall_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::bfloat16: { - using type = ccl::bfloat16; - req = get_self()->template alltoall_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - default: { - CCL_THROW("unknown datatype ", dtype); - LOG_DEBUG("comm: ", - /*get_self()->to_string(),*/ - " - no found visitor for datatype: ", - ccl::to_string(dtype), - " , handle: ", - utils::enum_to_underlying(dtype), - ", use RAW types"); - break; - } - } - return processed; - } -}; diff --git a/src/common/comm/usm_visitor/alltoallv_usm_visitor.hpp b/src/common/comm/usm_visitor/alltoallv_usm_visitor.hpp deleted file mode 100644 index 4a4535678..000000000 --- a/src/common/comm/usm_visitor/alltoallv_usm_visitor.hpp +++ /dev/null @@ -1,184 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" - -template -struct alltoallv_usm_visitor { - using self_t = communicator_impl; - - self_t* get_self() { - return static_cast(this); - } - - const self_t* get_self() const { - return static_cast(const_cast(this)->get_self()); - } - - template - bool visit(ccl::event& req, - ccl::datatype dtype, - const void* send_buf, - const ccl::vector_class& send_count, - void* recv_buf, - const ccl::vector_class& recv_counts, - Args&&... args) { - bool processed = false; - LOG_TRACE("comm: ", - /*get_self()->to_string(),*/ - " - starting to find visitor for datatype: ", - ccl::to_string(dtype), - " , handle: ", - utils::enum_to_underlying(dtype)); - - switch (dtype) { - case ccl::datatype::int8: { - using type = int8_t; - req = get_self()->template alltoallv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint8: { - using type = uint8_t; - req = get_self()->template alltoallv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int16: { - using type = int16_t; - req = get_self()->template alltoallv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint16: { - using type = uint16_t; - req = get_self()->template alltoallv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int32: { - using type = int32_t; - req = get_self()->template alltoallv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint32: { - using type = uint32_t; - req = get_self()->template alltoallv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int64: { - using type = int64_t; - req = get_self()->template alltoallv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint64: { - using type = uint64_t; - req = get_self()->template alltoallv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float16: { - using type = ccl::float16; - req = get_self()->template alltoallv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float32: { - using type = float; - req = get_self()->template alltoallv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float64: { - using type = double; - req = get_self()->template alltoallv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::bfloat16: { - using type = ccl::bfloat16; - req = get_self()->template alltoallv_impl(static_cast(send_buf), - send_count, - static_cast(recv_buf), - recv_counts, - std::forward(args)...); - processed = true; - break; - } - default: { - CCL_THROW("unknown datatype ", dtype); - LOG_DEBUG("comm: ", - /*get_self()->to_string(),*/ - " - no found visitor for datatype: ", - ccl::to_string(dtype), - " , handle: ", - utils::enum_to_underlying(dtype), - ", use RAW types"); - break; - } - } - return processed; - } -}; diff --git a/src/common/comm/usm_visitor/broadcast_usm_visitor.hpp b/src/common/comm/usm_visitor/broadcast_usm_visitor.hpp deleted file mode 100644 index e8bfba0f0..000000000 --- a/src/common/comm/usm_visitor/broadcast_usm_visitor.hpp +++ /dev/null @@ -1,142 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" - -template -struct broadcast_usm_visitor { - using self_t = communicator_impl; - - self_t* get_self() { - return static_cast(this); - } - - const self_t* get_self() const { - return static_cast(const_cast(this)->get_self()); - } - - template - bool visit(ccl::event& req, ccl::datatype dtype, void* buf, size_t count, Args&&... args) { - bool processed = false; - LOG_TRACE("comm: ", - /*get_self()->to_string(),*/ - " - starting to find visitor for datatype: ", - ccl::to_string(dtype), - " , handle: ", - utils::enum_to_underlying(dtype)); - - switch (dtype) { - case ccl::datatype::int8: { - using type = int8_t; - req = get_self()->template broadcast_impl( - static_cast(buf), count, std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint8: { - using type = uint8_t; - req = get_self()->template broadcast_impl( - static_cast(buf), count, std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int16: { - using type = int16_t; - req = get_self()->template broadcast_impl( - static_cast(buf), count, std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint16: { - using type = uint16_t; - req = get_self()->template broadcast_impl( - static_cast(buf), count, std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int32: { - using type = int32_t; - req = get_self()->template broadcast_impl( - static_cast(buf), count, std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint32: { - using type = uint32_t; - req = get_self()->template broadcast_impl( - static_cast(buf), count, std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int64: { - using type = int64_t; - req = get_self()->template broadcast_impl( - static_cast(buf), count, std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint64: { - using type = uint64_t; - req = get_self()->template broadcast_impl( - static_cast(buf), count, std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float16: { - using type = ccl::float16; - req = get_self()->template broadcast_impl( - static_cast(buf), count, std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float32: { - using type = float; - req = get_self()->template broadcast_impl( - static_cast(buf), count, std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float64: { - using type = double; - req = get_self()->template broadcast_impl( - static_cast(buf), count, std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::bfloat16: { - using type = ccl::bfloat16; - req = get_self()->template broadcast_impl( - static_cast(buf), count, std::forward(args)...); - processed = true; - break; - } - default: { - CCL_THROW("unknown datatype ", dtype); - LOG_DEBUG("comm: ", - /*get_self()->to_string(),*/ - " - no found visitor for datatype: ", - ccl::to_string(dtype), - " , handle: ", - utils::enum_to_underlying(dtype), - ", use RAW types"); - break; - } - } - return processed; - } -}; diff --git a/src/common/comm/usm_visitor/reduce_scatter_usm_visitor.hpp b/src/common/comm/usm_visitor/reduce_scatter_usm_visitor.hpp deleted file mode 100644 index c01f8c48a..000000000 --- a/src/common/comm/usm_visitor/reduce_scatter_usm_visitor.hpp +++ /dev/null @@ -1,184 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" - -template -struct reduce_scatter_usm_visitor { - using self_t = communicator_impl; - - self_t* get_self() { - return static_cast(this); - } - - const self_t* get_self() const { - return static_cast( - const_cast(this)->get_self()); - } - - template - bool visit(ccl::event& req, - ccl::datatype dtype, - const void* send_buf, - void* recv_buf, - size_t count, - Args&&... args) { - bool processed = false; - LOG_TRACE("comm: ", - /*get_self()->to_string(),*/ - " - starting to find visitor for datatype: ", - ccl::to_string(dtype), - " , handle: ", - utils::enum_to_underlying(dtype)); - - switch (dtype) { - case ccl::datatype::int8: { - using type = int8_t; - req = get_self()->template reduce_scatter_impl( - static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint8: { - using type = uint8_t; - req = get_self()->template reduce_scatter_impl( - static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int16: { - using type = int16_t; - req = get_self()->template reduce_scatter_impl( - static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint16: { - using type = uint16_t; - req = get_self()->template reduce_scatter_impl( - static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int32: { - using type = int32_t; - req = get_self()->template reduce_scatter_impl( - static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint32: { - using type = uint32_t; - req = get_self()->template reduce_scatter_impl( - static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int64: { - using type = int64_t; - req = get_self()->template reduce_scatter_impl( - static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint64: { - using type = uint64_t; - req = get_self()->template reduce_scatter_impl( - static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float16: { - using type = ccl::float16; - req = get_self()->template reduce_scatter_impl( - static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float32: { - using type = float; - req = get_self()->template reduce_scatter_impl( - static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float64: { - using type = double; - req = get_self()->template reduce_scatter_impl( - static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::bfloat16: { - using type = ccl::bfloat16; - req = get_self()->template reduce_scatter_impl( - static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - default: { - CCL_THROW("unknown datatype ", dtype); - LOG_DEBUG("comm: ", - /*get_self()->to_string(),*/ - " - no found visitor for datatype: ", - ccl::to_string(dtype), - " , handle: ", - utils::enum_to_underlying(dtype), - ", use RAW types"); - break; - } - } - return processed; - } -}; diff --git a/src/common/comm/usm_visitor/reduce_usm_visitor.hpp b/src/common/comm/usm_visitor/reduce_usm_visitor.hpp deleted file mode 100644 index 0eb226bf8..000000000 --- a/src/common/comm/usm_visitor/reduce_usm_visitor.hpp +++ /dev/null @@ -1,171 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" - -template -struct reduce_usm_visitor { - using self_t = communicator_impl; - - self_t* get_self() { - return static_cast(this); - } - - const self_t* get_self() const { - return static_cast(const_cast(this)->get_self()); - } - - template - bool visit(ccl::event& req, - ccl::datatype dtype, - const void* send_buf, - void* recv_buf, - size_t count, - Args&&... args) { - bool processed = false; - LOG_TRACE("comm: ", - /*get_self()->to_string(),*/ - " - starting to find visitor for datatype: ", - ccl::to_string(dtype), - " , handle: ", - utils::enum_to_underlying(dtype)); - - switch (dtype) { - case ccl::datatype::int8: { - using type = int8_t; - req = get_self()->template reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint8: { - using type = uint8_t; - req = get_self()->template reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int16: { - using type = int16_t; - req = get_self()->template reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint16: { - using type = uint16_t; - req = get_self()->template reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int32: { - using type = int32_t; - req = get_self()->template reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint32: { - using type = uint32_t; - req = get_self()->template reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::int64: { - using type = int64_t; - req = get_self()->template reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint64: { - using type = uint64_t; - req = get_self()->template reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float16: { - using type = ccl::float16; - req = get_self()->template reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float32: { - using type = float; - req = get_self()->template reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float64: { - using type = double; - req = get_self()->template reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::bfloat16: { - using type = ccl::bfloat16; - req = get_self()->template reduce_impl(static_cast(send_buf), - static_cast(recv_buf), - count, - std::forward(args)...); - processed = true; - break; - } - default: { - CCL_THROW("unknown datatype ", dtype); - LOG_DEBUG("comm: ", - /*get_self()->to_string(),*/ - " - no found visitor for datatype: ", - ccl::to_string(dtype), - " , handle: ", - utils::enum_to_underlying(dtype), - ", use RAW types"); - break; - } - } - return processed; - } -}; diff --git a/src/common/comm/usm_visitor/sparse_allreduce_usm_visitor.hpp b/src/common/comm/usm_visitor/sparse_allreduce_usm_visitor.hpp deleted file mode 100644 index b95705651..000000000 --- a/src/common/comm/usm_visitor/sparse_allreduce_usm_visitor.hpp +++ /dev/null @@ -1,199 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/type_traits.hpp" - -template -struct sparse_allreduce_usm_visitor { - using self_t = communicator_impl; - - self_t* get_self() { - return static_cast(this); - } - - const self_t* get_self() const { - return static_cast( - const_cast(this)->get_self()); - } - - template - bool visit(ccl::event& req, - ccl::datatype index_dtype, - ccl::datatype value_dtype, - const void* send_ind_buf, - size_t send_ind_count, - const void* send_val_buf, - size_t send_val_count, - void* recv_ind_buf, - size_t recv_ind_count, - void* recv_val_buf, - size_t recv_val_count, - Args&&... args) { - bool processed = false; - LOG_TRACE("comm: ", - /*get_self()->to_string(),*/ - " - starting to find visitor for datatype: ", - ccl::to_string(value_dtype), - " , handle: ", - utils::enum_to_underlying(value_dtype)); - - CCL_THROW("unexpected path"); - - switch (value_dtype) //TODO -S- value only - { - case ccl::datatype::int8: { - using type = int8_t; - req = get_self()->template sparse_allreduce_impl( - static_cast(send_ind_buf), - send_ind_count, - static_cast(send_val_buf), - send_val_count, - static_cast(recv_ind_buf), - recv_ind_count, - static_cast(recv_val_buf), - recv_val_count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint8: { - throw ccl::exception( - std::string(__PRETTY_FUNCTION__) + " - USM convertation of value_dtype: " + - ccl::to_string(value_dtype) + " is not supported for such configuration"); - break; - } - case ccl::datatype::int16: { - throw ccl::exception( - std::string(__PRETTY_FUNCTION__) + " - USM convertation of value_dtype: " + - ccl::to_string(value_dtype) + " is not supported for such configuration"); - break; - } - case ccl::datatype::uint16: { - throw ccl::exception( - std::string(__PRETTY_FUNCTION__) + " - USM convertation of value_dtype: " + - ccl::to_string(value_dtype) + " is not supported for such configuration"); - break; - } - case ccl::datatype::int32: { - using type = int32_t; - req = get_self()->template sparse_allreduce_impl( - static_cast(send_ind_buf), - send_ind_count, - static_cast(send_val_buf), - send_val_count, - static_cast(recv_ind_buf), - recv_ind_count, - static_cast(recv_val_buf), - recv_val_count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint32: { - throw ccl::exception( - std::string(__PRETTY_FUNCTION__) + " - USM convertation of value_dtype: " + - ccl::to_string(value_dtype) + " is not supported for such configuration"); - break; - } - case ccl::datatype::int64: { - using type = int64_t; - req = get_self()->template sparse_allreduce_impl( - static_cast(send_ind_buf), - send_ind_count, - static_cast(send_val_buf), - send_val_count, - static_cast(recv_ind_buf), - recv_ind_count, - static_cast(recv_val_buf), - recv_val_count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::uint64: { - using type = uint64_t; - req = get_self()->template sparse_allreduce_impl( - static_cast(send_ind_buf), - send_ind_count, - static_cast(send_val_buf), - send_val_count, - static_cast(recv_ind_buf), - recv_ind_count, - static_cast(recv_val_buf), - recv_val_count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float16: { - throw ccl::exception( - std::string(__PRETTY_FUNCTION__) + " - USM convertation of value_dtype: " + - ccl::to_string(value_dtype) + " is not supported for such configuration"); - break; - } - case ccl::datatype::float32: { - using type = float; - req = get_self()->template sparse_allreduce_impl( - static_cast(send_ind_buf), - send_ind_count, - static_cast(send_val_buf), - send_val_count, - static_cast(recv_ind_buf), - recv_ind_count, - static_cast(recv_val_buf), - recv_val_count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::float64: { - using type = double; - req = get_self()->template sparse_allreduce_impl( - static_cast(send_ind_buf), - send_ind_count, - static_cast(send_val_buf), - send_val_count, - static_cast(recv_ind_buf), - recv_ind_count, - static_cast(recv_val_buf), - recv_val_count, - std::forward(args)...); - processed = true; - break; - } - case ccl::datatype::bfloat16: { - throw ccl::exception( - std::string(__PRETTY_FUNCTION__) + - " - USM convertationf bloat16 is not supported for such configuration"); - break; - } - default: { - CCL_THROW("unknown datatype ", dtype); - LOG_DEBUG("comm: ", - /*get_self()->to_string(),*/ - " - no found visitor for datatype: ", - ccl::to_string(value_dtype), - " , handle: ", - utils::enum_to_underlying(value_dtype), - ", use RAW types"); - break; - } - } - return processed; - } -}; diff --git a/src/common/comm/usm_visitor/usm_visitors.hpp b/src/common/comm/usm_visitor/usm_visitors.hpp deleted file mode 100644 index 4a557efb1..000000000 --- a/src/common/comm/usm_visitor/usm_visitors.hpp +++ /dev/null @@ -1,25 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "common/comm/usm_visitor/allgather_usm_visitor.hpp" -#include "common/comm/usm_visitor/allreduce_usm_visitor.hpp" -#include "common/comm/usm_visitor/alltoall_usm_visitor.hpp" -#include "common/comm/usm_visitor/alltoallv_usm_visitor.hpp" -#include "common/comm/usm_visitor/broadcast_usm_visitor.hpp" -#include "common/comm/usm_visitor/reduce_usm_visitor.hpp" -#include "common/comm/usm_visitor/reduce_scatter_usm_visitor.hpp" -//#include "common/comm/usm_visitor/sparse_allreduce_usm_visitor.hpp" diff --git a/src/common/datatype/datatype.hpp b/src/common/datatype/datatype.hpp index 555fedd6e..974cef398 100644 --- a/src/common/datatype/datatype.hpp +++ b/src/common/datatype/datatype.hpp @@ -29,14 +29,12 @@ class ccl_datatype { public: - ccl_datatype(ccl::datatype idx, size_t size); ccl_datatype() = default; - ~ccl_datatype() = default; - ccl_datatype& operator=(const ccl_datatype& other) = default; - + ccl_datatype(ccl::datatype idx, size_t size); ccl_datatype(const ccl_datatype& other) = default; + ccl_datatype& operator=(const ccl_datatype& other) = default; - ccl::datatype idx() const { + ccl::datatype idx() const noexcept { return m_idx; } @@ -46,8 +44,8 @@ class ccl_datatype { } private: - ccl::datatype m_idx; - size_t m_size; + ccl::datatype m_idx = ccl::datatype::int8; + size_t m_size = sizeof(int8_t); }; /* frequently used in multiple places */ diff --git a/src/common/env/env.cpp b/src/common/env/env.cpp index e9ffa8f88..97bd1cae0 100644 --- a/src/common/env/env.cpp +++ b/src/common/env/env.cpp @@ -37,8 +37,17 @@ std::map env_data::priority_mode_names = { }; std::map env_data::atl_transport_names = { - std::make_pair(ccl_atl_ofi, "ofi"), + std::make_pair(ccl_atl_ofi, "ofi") +#ifdef CCL_ENABLE_MPI + , std::make_pair(ccl_atl_mpi, "mpi") +#endif // CCL_ENABLE_MPI +}; + +std::map env_data::atl_send_proxy_names = { + std::make_pair(ccl_atl_send_proxy_none, "none"), + std::make_pair(ccl_atl_send_proxy_regular, "regular"), + std::make_pair(ccl_atl_send_proxy_usm, "usm") }; std::map env_data::staging_buffer_names = { @@ -52,22 +61,35 @@ std::map env_data::mnic_type_names = { std::make_pair(ATL_MNIC_GLOBAL, "global") }; +std::map env_data::ze_copy_engine_names = { + std::make_pair(ccl_ze_copy_engine_none, "none"), + std::make_pair(ccl_ze_copy_engine_main, "main"), + std::make_pair(ccl_ze_copy_engine_link, "link") +}; + env_data::env_data() : was_printed(false), log_level(ccl_log_level::warn), + queue_dump(0), sched_dump(0), + sched_profile(0), fw_type(ccl_framework_none), worker_count(1), worker_offload(1), worker_wait(1), - +#ifdef CCL_ENABLE_MPI atl_transport(ccl_atl_mpi), +#else // CCL_ENABLE_MPI + atl_transport(ccl_atl_ofi), +#endif // CCL_ENABLE_MPI enable_shm(0), enable_rma(0), - enable_device_buf(0), + enable_hmem(0), + atl_send_proxy(ccl_atl_send_proxy_none), + enable_atl_cache(1), enable_sync_coll(0), enable_extra_ep(0), @@ -88,9 +110,15 @@ env_data::env_data() max_short_size(0), bcast_part_count(CCL_ENV_SIZET_NOT_SPECIFIED), cache_key_type(ccl_cache_key_match_id), +#ifdef CCL_ENABLE_SYCL + enable_cache_flush(1), +#else // CCL_ENABLE_SYCL enable_cache_flush(0), +#endif // CCL_ENABLE_SYCL + enable_buffer_cache(1), enable_strict_order(0), staging_buffer(ccl_staging_usm), + enable_op_sync(0), chunk_count(1), min_chunk_size(65536), @@ -105,18 +133,29 @@ env_data::env_data() alltoall_scatter_max_ops(CCL_ENV_SIZET_NOT_SPECIFIED), alltoall_scatter_plain(0), - enable_comm_kernels(0), - comm_kernels_path(), - comm_kernels_debug(0), - gpu_thread_count(CCL_ENV_SIZET_NOT_SPECIFIED), + kernel_path(), + kernel_debug(0), + enable_kernel_cache(1), + kernel_group_size(CCL_ENV_SIZET_NOT_SPECIFIED), + kernel_group_count(CCL_ENV_SIZET_NOT_SPECIFIED), + enable_kernel_sync(1), + kernel_1s_lead(0), + enable_kernel_1s_copy_ops(0), + enable_kernel_1s_ipc_wa(0), + enable_kernel_output_event(0), + ze_serialize_mode(0), + ze_copy_engine(ccl_ze_copy_engine_none), bf16_impl_type(ccl_bf16_no_compiler_support), - fp16_impl_type(ccl_fp16_no_compiler_support) {} + fp16_impl_type(ccl_fp16_no_compiler_support) { +} void env_data::parse() { env_2_enum(CCL_LOG_LEVEL, ccl_logger::level_names, log_level); ccl_logger::set_log_level(log_level); + env_2_type(CCL_QUEUE_DUMP, queue_dump); env_2_type(CCL_SCHED_DUMP, sched_dump); + env_2_type(CCL_SCHED_PROFILE, sched_profile); if (fw_type == ccl_framework_none) { /* try to automatically detect framework */ @@ -149,11 +188,17 @@ void env_data::parse() { env_2_atl_transport(); env_2_type(CCL_ATL_SHM, enable_shm); env_2_type(CCL_ATL_RMA, enable_rma); - env_2_type(CCL_ATL_DEVICE_BUF, enable_device_buf); + env_2_type(CCL_ATL_HMEM, enable_hmem); + if (atl_transport == ccl_atl_mpi && enable_hmem) { + worker_count = 1; + } + env_2_enum(CCL_ATL_SEND_PROXY, atl_send_proxy_names, atl_send_proxy); + env_2_type(CCL_ATL_CACHE, enable_atl_cache); env_2_type(CCL_ATL_SYNC_COLL, enable_sync_coll); env_2_type(CCL_ATL_EXTRA_EP, enable_extra_ep); env_2_enum(CCL_MNIC, mnic_type_names, mnic_type); + env_2_type(CCL_MNIC_NAME, mnic_name_raw); env_2_type(CCL_MNIC_COUNT, mnic_count); if (mnic_count == CCL_ENV_SIZET_NOT_SPECIFIED) { mnic_count = worker_count; @@ -204,8 +249,14 @@ void env_data::parse() { env_2_type(CCL_BCAST_PART_COUNT, (size_t&)bcast_part_count); env_2_enum(CCL_CACHE_KEY, ccl_sched_key::key_type_names, cache_key_type); env_2_type(CCL_CACHE_FLUSH, enable_cache_flush); + env_2_type(CCL_BUFFER_CACHE, enable_buffer_cache); env_2_type(CCL_STRICT_ORDER, enable_strict_order); + if (enable_unordered_coll && enable_strict_order) { + LOG_INFO("unordered collectives are requested, disable strict order"); + enable_strict_order = 0; + } env_2_enum(CCL_STAGING_BUFFER, staging_buffer_names, staging_buffer); + env_2_type(CCL_OP_SYNC, enable_op_sync); env_2_type(CCL_CHUNK_COUNT, chunk_count); CCL_THROW_IF_NOT(chunk_count >= 1, "incorrect ", CCL_CHUNK_COUNT, " ", chunk_count); @@ -231,17 +282,24 @@ void env_data::parse() { env_2_type(CCL_ALLTOALL_SCATTER_MAX_OPS, (size_t&)alltoall_scatter_max_ops); env_2_type(CCL_ALLTOALL_SCATTER_PLAIN, alltoall_scatter_plain); - env_2_type(CCL_COMM_KERNELS, enable_comm_kernels); - if (enable_comm_kernels) { - env_2_type(CCL_COMM_KERNELS_PATH, comm_kernels_path); - if (comm_kernels_path.empty()) { - std::string ccl_root = getenv("CCL_ROOT"); - CCL_THROW_IF_NOT(!ccl_root.empty(), "incorrect comm kernels path, CCL_ROOT not found!"); - comm_kernels_path = ccl_root + "/lib/kernels/"; - } - env_2_type(CCL_COMM_KERNELS_DEBUG, comm_kernels_debug); + env_2_type(CCL_KERNEL_PATH, kernel_path); + if (kernel_path.empty()) { + std::string ccl_root = getenv("CCL_ROOT"); + CCL_THROW_IF_NOT(!ccl_root.empty(), "incorrect comm kernels path, CCL_ROOT not found!"); + kernel_path = ccl_root + "/lib/kernels/"; } - env_2_type(CCL_GPU_THREAD_COUNT, gpu_thread_count); + + env_2_type(CCL_KERNEL_DEBUG, kernel_debug); + env_2_type(CCL_KERNEL_CACHE, enable_kernel_cache); + env_2_type(CCL_KERNEL_GROUP_SIZE, kernel_group_size); + env_2_type(CCL_KERNEL_GROUP_COUNT, kernel_group_count); + env_2_type(CCL_KERNEL_SYNC, enable_kernel_sync); + env_2_type(CCL_KERNEL_1S_LEAD, kernel_1s_lead); + env_2_type(CCL_KERNEL_1S_USE_COPY_OPS, enable_kernel_1s_copy_ops); + env_2_type(CCL_KERNEL_1S_IPC_WA, enable_kernel_1s_ipc_wa); + env_2_type(CCL_KERNEL_OUTPUT_EVENT, enable_kernel_output_event); + env_2_type(CCL_ZE_SERIALIZE, ze_serialize_mode); + env_2_enum(CCL_ZE_COPY_ENGINE, ze_copy_engine_names, ze_copy_engine); auto bf16_impl_types = ccl_bf16_get_impl_types(); ccl_bf16_impl_type bf16_env_impl_type; @@ -276,39 +334,42 @@ void env_data::print(int rank) { else was_printed = true; + auto& global_data = ccl::global_data::get(); + if (rank == 0) { auto version = utils::get_library_version(); LOG_INFO("library version: ", version.full); LOG_INFO("specification version: ", ONECCL_SPEC_VERSION); #ifdef CCL_ENABLE_SYCL LOG_INFO("compute backend: ", version.cl_backend_name); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL #ifdef ENABLE_DEBUG const char* build_mode = "debug"; -#else /* ENABLE_DEBUG */ +#else // ENABLE_DEBUG const char* build_mode = "release"; -#endif /* ENABLE_DEBUG */ +#endif // ENABLE_DEBUG LOG_INFO("build mode: ", build_mode); LOG_INFO("C compiler: ", CCL_C_COMPILER); LOG_INFO("C++ compiler: ", CCL_CXX_COMPILER); + LOG_INFO(global_data.hwloc_wrapper->to_string()); } - auto& global_data = ccl::global_data::get(); auto local_proc_idx = global_data.executor->get_local_proc_idx(); auto local_proc_count = global_data.executor->get_local_proc_count(); - if (rank < (int)local_proc_count) { + if (rank < local_proc_count) { for (size_t w_idx = 0; w_idx < worker_count; w_idx++) { - LOG_INFO(CCL_WORKER_AFFINITY, - ": local process [", + LOG_INFO("local process [", local_proc_idx, ":", local_proc_count, "]: worker: ", w_idx, - ", core: ", - worker_affinity[local_proc_idx * worker_count + w_idx]); + ", cpu: ", + worker_affinity[local_proc_idx * worker_count + w_idx], + ", numa: ", + worker_mem_affinity[local_proc_idx * worker_count + w_idx]); } } @@ -320,18 +381,24 @@ void env_data::print(int rank) { LOG_INFO(CCL_WORKER_WAIT, ": ", worker_wait); LOG_INFO(CCL_LOG_LEVEL, ": ", str_by_enum(ccl_logger::level_names, log_level)); + LOG_INFO(CCL_QUEUE_DUMP, ": ", queue_dump); LOG_INFO(CCL_SCHED_DUMP, ": ", sched_dump); + LOG_INFO(CCL_SCHED_PROFILE, ": ", sched_profile); LOG_INFO(CCL_FRAMEWORK, ": ", str_by_enum(ccl_framework_type_names, fw_type)); LOG_INFO(CCL_ATL_TRANSPORT, ": ", str_by_enum(atl_transport_names, atl_transport)); LOG_INFO(CCL_ATL_SHM, ": ", enable_shm); LOG_INFO(CCL_ATL_RMA, ": ", enable_rma); - LOG_INFO(CCL_ATL_DEVICE_BUF, ": ", enable_device_buf); + LOG_INFO(CCL_ATL_HMEM, ": ", enable_hmem); + LOG_INFO(CCL_ATL_SEND_PROXY, ": ", str_by_enum(atl_send_proxy_names, atl_send_proxy)); + LOG_INFO(CCL_ATL_CACHE, ": ", enable_atl_cache); LOG_DEBUG(CCL_ATL_SYNC_COLL, ": ", enable_sync_coll); LOG_DEBUG(CCL_ATL_EXTRA_EP, ": ", enable_extra_ep); LOG_INFO(CCL_MNIC, ": ", str_by_enum(mnic_type_names, mnic_type)); + LOG_INFO( + CCL_MNIC_NAME, ": ", (mnic_name_raw.length()) ? mnic_name_raw : CCL_ENV_STR_NOT_SPECIFIED); LOG_INFO(CCL_MNIC_COUNT, ": ", mnic_count); LOG_INFO(CCL_ALLGATHERV, @@ -379,8 +446,10 @@ void env_data::print(int rank) { : CCL_ENV_STR_NOT_SPECIFIED); LOG_INFO(CCL_CACHE_KEY, ": ", str_by_enum(ccl_sched_key::key_type_names, cache_key_type)); LOG_INFO(CCL_CACHE_FLUSH, ": ", enable_cache_flush); + LOG_INFO(CCL_BUFFER_CACHE, ": ", enable_buffer_cache); LOG_INFO(CCL_STRICT_ORDER, ": ", enable_strict_order); LOG_INFO(CCL_STAGING_BUFFER, ": ", str_by_enum(staging_buffer_names, staging_buffer)); + LOG_INFO(CCL_OP_SYNC, ": ", enable_op_sync); LOG_INFO(CCL_CHUNK_COUNT, ": ", chunk_count); LOG_INFO(CCL_MIN_CHUNK_SIZE, ": ", min_chunk_size); @@ -404,16 +473,27 @@ void env_data::print(int rank) { LOG_INFO(CCL_ALLTOALL_SCATTER_PLAIN, ": ", alltoall_scatter_plain); #ifdef CCL_ENABLE_SYCL - LOG_INFO(CCL_COMM_KERNELS, ": ", enable_comm_kernels); - LOG_INFO(CCL_COMM_KERNELS_PATH, + LOG_INFO( + CCL_KERNEL_PATH, ": ", (!kernel_path.empty()) ? kernel_path : CCL_ENV_STR_NOT_SPECIFIED); + LOG_INFO(CCL_KERNEL_DEBUG, ": ", kernel_debug); + LOG_INFO(CCL_KERNEL_CACHE, ": ", enable_kernel_cache); + LOG_INFO(CCL_KERNEL_GROUP_SIZE, ": ", - (!comm_kernels_path.empty()) ? comm_kernels_path : CCL_ENV_STR_NOT_SPECIFIED); - LOG_INFO(CCL_COMM_KERNELS_DEBUG, ": ", comm_kernels_debug); - LOG_INFO(CCL_GPU_THREAD_COUNT, + (kernel_group_size != CCL_ENV_SIZET_NOT_SPECIFIED) ? std::to_string(kernel_group_size) + : CCL_ENV_STR_NOT_SPECIFIED); + LOG_INFO(CCL_KERNEL_GROUP_COUNT, ": ", - (gpu_thread_count != CCL_ENV_SIZET_NOT_SPECIFIED) ? std::to_string(gpu_thread_count) - : CCL_ENV_STR_NOT_SPECIFIED); -#endif /* CCL_ENABLE_SYCL */ + (kernel_group_count != CCL_ENV_SIZET_NOT_SPECIFIED) + ? std::to_string(kernel_group_count) + : CCL_ENV_STR_NOT_SPECIFIED); + LOG_INFO(CCL_KERNEL_SYNC, ": ", enable_kernel_sync); + LOG_INFO(CCL_KERNEL_1S_LEAD, ": ", kernel_1s_lead); + LOG_INFO(CCL_KERNEL_1S_USE_COPY_OPS, ": ", enable_kernel_1s_copy_ops); + LOG_INFO(CCL_KERNEL_1S_IPC_WA, ": ", enable_kernel_1s_ipc_wa); + LOG_INFO(CCL_KERNEL_OUTPUT_EVENT, ": ", enable_kernel_output_event); + LOG_INFO(CCL_ZE_SERIALIZE, ": ", ze_serialize_mode); + LOG_INFO(CCL_ZE_COPY_ENGINE, ": ", str_by_enum(ze_copy_engine_names, ze_copy_engine)); +#endif // CCL_ENABLE_SYCL LOG_INFO(CCL_BF16, ": ", str_by_enum(bf16_impl_names, bf16_impl_type)); LOG_INFO(CCL_FP16, ": ", str_by_enum(fp16_impl_names, fp16_impl_type)); @@ -442,7 +522,7 @@ void env_data::set_internal_env() { } } -int env_data::env_2_worker_affinity_auto(size_t local_proc_idx, size_t workers_per_process) { +int env_data::env_2_worker_affinity_auto(int local_proc_idx, size_t workers_per_process) { char* available_cores = std::getenv(I_MPI_AVAILABLE_CORES_ENV); CCL_THROW_IF_NOT(available_cores && strlen(available_cores) != 0, "auto pinning requires ", @@ -497,46 +577,104 @@ int env_data::env_2_worker_affinity_auto(size_t local_proc_idx, size_t workers_p return 1; } -int env_data::parse_core_id(const std::string& core_id_str, size_t& result) { +int env_data::parse_number(const std::string& number_str, size_t& result) { char* end_ptr; - const char* core_id_str_ptr = core_id_str.c_str(); + const char* number_str_ptr = number_str.c_str(); errno = 0; - auto core_id = std::strtol(core_id_str_ptr, &end_ptr, 10); + auto core_id = std::strtol(number_str_ptr, &end_ptr, 10); if ((errno == ERANGE && (core_id == LONG_MAX || core_id == LONG_MIN)) || (errno != 0 && core_id == 0)) { - LOG_ERROR("core id value is invalid in string: ", core_id_str); + LOG_ERROR("core id value is invalid in string: ", number_str); return 0; } - if (end_ptr == core_id_str_ptr) { - LOG_ERROR("no digits were found in string: ", core_id_str); + if (end_ptr == number_str_ptr) { + LOG_ERROR("no digits were found in string: ", number_str); return 0; } if (core_id < 0) { - LOG_ERROR( - "core id cannot be less than zero but got ", core_id, " in string: ", core_id_str); + LOG_ERROR("core id cannot be less than zero but got ", core_id, " in string: ", number_str); return 0; } result = core_id; return 1; } -int env_data::env_2_worker_affinity(size_t local_proc_idx, size_t local_proc_count) { +int env_data::parse_affinity(const std::string& input, + std::vector& output, + size_t expected_output_size) { + size_t idx; + char* range_str; + + /* create copy of input string because it will be modified in strsep */ + std::string input_copy(input.c_str()); + char* input_str = (char*)input_copy.c_str(); + + output.clear(); + + while (input_str) { + range_str = strsep(&input_str, ","); + if (!range_str) { + break; + } + + auto range = tokenize>(std::string(range_str), '-'); + + if ((range.size() != 2) && (range.size() != 1)) { + LOG_ERROR( + "unexpected format in input: ", + input, + ", specify range values using - or single value using "); + return 0; + } + + if (range.size() == 1) { + /* to unify logic below */ + range.push_back(*range.begin()); + } + + CCL_ASSERT(range.size() == 2, "unexpected number of values in range"); + + size_t first_value, last_value; + if (!parse_number(range[0], first_value) || !parse_number(range[1], last_value)) { + return 0; + } + + if (first_value > last_value) { + LOG_ERROR("unexpected first and last values in range: ", + range_str, + ", first value should be less or equal to last value"); + return 0; + } + + for (idx = first_value; idx <= last_value; idx++) { + output.push_back(idx); + } + } + + if (output.size() < expected_output_size) { + LOG_ERROR("unexpected number of values in input: ", + input, + ", expected at least ", + expected_output_size, + " values"); + return 0; + } + + return 1; +} + +int env_data::env_2_worker_affinity(int local_proc_idx, int local_proc_count) { CCL_THROW_IF_NOT(local_proc_count > 0); size_t idx; - std::unique_ptr affinity_copy; - char* affinity_to_parse = getenv(CCL_WORKER_AFFINITY); - char* core_range_str; - char* tmp; + char* env_to_parse = getenv(CCL_WORKER_AFFINITY); size_t system_core_count; - size_t affinity_size = local_proc_count * worker_count; - if (!affinity_to_parse || (strlen(affinity_to_parse) == 0) || - (strcmp(affinity_to_parse, "auto") == 0)) { - worker_affinity.assign(affinity_size, 0); + if (!env_to_parse || (strlen(env_to_parse) == 0) || (strcmp(env_to_parse, "auto") == 0)) { + worker_affinity.assign(affinity_size, CCL_UNDEFINED_CPU_ID); if (std::getenv(I_MPI_AVAILABLE_CORES_ENV)) { /* generate auto affinity based on IMPI process pinning */ return env_2_worker_affinity_auto(local_proc_idx, worker_count); @@ -556,63 +694,37 @@ int env_data::env_2_worker_affinity(size_t local_proc_idx, size_t local_proc_cou } } - /* create copy of original buffer because it will be modified in strsep */ - size_t affinity_len = strlen(affinity_to_parse); - affinity_copy = - std::unique_ptr(static_cast(CCL_CALLOC(affinity_len + 1, "affinity_copy"))); - CCL_MEMCPY(affinity_copy.get(), affinity_to_parse, affinity_len); - tmp = affinity_copy.get(); - - while (tmp) { - core_range_str = strsep(&tmp, ","); - if (!core_range_str) { - break; - } - - auto core_range = tokenize>(std::string(core_range_str), '-'); + CCL_THROW_IF_NOT(parse_affinity(env_to_parse, worker_affinity, affinity_size), + "failed to parse worker affinity"); - if ((core_range.size() != 2) && (core_range.size() != 1)) { - LOG_ERROR( - "unexpected format in affinity: ", - affinity_to_parse, - ", specify core range using - or single core using "); - return 0; - } - - if (core_range.size() == 1) { - /* to unify logic below */ - core_range.push_back(*core_range.begin()); - } - - CCL_ASSERT(core_range.size() == 2, "unexpected number of cores in range"); - - size_t first_core, last_core; - if (!parse_core_id(core_range[0], first_core) || !parse_core_id(core_range[1], last_core)) { - return 0; - } + return 1; +} - if (first_core > last_core) { - LOG_ERROR("unexpected first and last cores in range: ", - core_range_str, - ", first core should be less or equal to last core"); - return 0; - } +int env_data::env_2_worker_mem_affinity() { + CCL_THROW_IF_NOT(worker_affinity.size() > 0); - for (idx = first_core; idx <= last_core; idx++) { - worker_affinity.push_back(idx); + size_t idx; + char* env_to_parse = getenv(CCL_WORKER_MEM_AFFINITY); + size_t affinity_size = worker_affinity.size(); + + if (!env_to_parse || (strlen(env_to_parse) == 0) || (strcmp(env_to_parse, "auto") == 0)) { + worker_mem_affinity.assign(affinity_size, CCL_UNDEFINED_NUMA_NODE); + /* generate list of default numa nodes, local wrt worker cores */ + for (idx = 0; idx < affinity_size; idx++) { + worker_mem_affinity[idx] = + ccl::global_data::get().hwloc_wrapper->get_numa_node_by_cpu(worker_affinity[idx]); } + return 1; } - if (worker_affinity.size() < affinity_size) { - LOG_ERROR("unexpected number of cores in affinity: ", - affinity_to_parse, - ", specify 1 core per 1 worker thread"); - return 0; - } + CCL_THROW_IF_NOT(parse_affinity(env_to_parse, worker_mem_affinity, affinity_size), + "failed to parse worker memory affinity"); + return 1; } void env_data::env_2_atl_transport() { +#ifdef CCL_ENABLE_MPI if (!getenv(CCL_ATL_TRANSPORT) && !with_mpirun()) { LOG_WARN("did not find MPI-launcher specific variables, switch to ATL/OFI, " "to force enable ATL/MPI set CCL_ATL_TRANSPORT=mpi"); @@ -620,6 +732,7 @@ void env_data::env_2_atl_transport() { atl_transport = ccl_atl_ofi; } else +#endif // CCL_ENABLE_MPI env_2_enum(CCL_ATL_TRANSPORT, atl_transport_names, atl_transport); } @@ -630,4 +743,4 @@ bool env_data::with_mpirun() { : false; } -} /* namespace ccl */ +} // namespace ccl diff --git a/src/common/env/env.hpp b/src/common/env/env.hpp index 00e708de2..f69ed88ee 100644 --- a/src/common/env/env.hpp +++ b/src/common/env/env.hpp @@ -35,7 +35,9 @@ constexpr const char* CCL_ENV_STR_NOT_SPECIFIED = ""; constexpr const ssize_t CCL_ENV_SIZET_NOT_SPECIFIED = -1; constexpr const char* CCL_LOG_LEVEL = "CCL_LOG_LEVEL"; +constexpr const char* CCL_QUEUE_DUMP = "CCL_QUEUE_DUMP"; constexpr const char* CCL_SCHED_DUMP = "CCL_SCHED_DUMP"; +constexpr const char* CCL_SCHED_PROFILE = "CCL_SCHED_PROFILE"; constexpr const char* CCL_FRAMEWORK = "CCL_FRAMEWORK"; @@ -43,6 +45,7 @@ constexpr const char* CCL_WORKER_COUNT = "CCL_WORKER_COUNT"; constexpr const char* CCL_WORKER_OFFLOAD = "CCL_WORKER_OFFLOAD"; constexpr const char* CCL_WORKER_WAIT = "CCL_WORKER_WAIT"; constexpr const char* CCL_WORKER_AFFINITY = "CCL_WORKER_AFFINITY"; +constexpr const char* CCL_WORKER_MEM_AFFINITY = "CCL_WORKER_MEM_AFFINITY"; constexpr const char* I_MPI_AVAILABLE_CORES_ENV = "I_MPI_PIN_INFO"; constexpr const char* I_MPI_AVAILABLE_CORES_DELIMS = ",x"; @@ -50,11 +53,14 @@ constexpr const char* I_MPI_AVAILABLE_CORES_DELIMS = ",x"; constexpr const char* CCL_ATL_TRANSPORT = "CCL_ATL_TRANSPORT"; constexpr const char* CCL_ATL_SHM = "CCL_ATL_SHM"; constexpr const char* CCL_ATL_RMA = "CCL_ATL_RMA"; -constexpr const char* CCL_ATL_DEVICE_BUF = "CCL_ATL_DEVICE_BUF"; +constexpr const char* CCL_ATL_HMEM = "CCL_ATL_HMEM"; +constexpr const char* CCL_ATL_SEND_PROXY = "CCL_ATL_SEND_PROXY"; constexpr const char* CCL_ATL_SYNC_COLL = "CCL_ATL_SYNC_COLL"; constexpr const char* CCL_ATL_EXTRA_EP = "CCL_ATL_EXTRA_EP"; +constexpr const char* CCL_ATL_CACHE = "CCL_ATL_CACHE"; constexpr const char* CCL_MNIC = "CCL_MNIC"; +constexpr const char* CCL_MNIC_NAME = "CCL_MNIC_NAME"; constexpr const char* CCL_MNIC_COUNT = "CCL_MNIC_COUNT"; constexpr const char* CCL_ALLGATHERV = "CCL_ALLGATHERV"; @@ -81,8 +87,10 @@ constexpr const char* CCL_MAX_SHORT_SIZE = "CCL_MAX_SHORT_SIZE"; constexpr const char* CCL_BCAST_PART_COUNT = "CCL_BCAST_PART_COUNT"; constexpr const char* CCL_CACHE_KEY = "CCL_CACHE_KEY"; constexpr const char* CCL_CACHE_FLUSH = "CCL_CACHE_FLUSH"; +constexpr const char* CCL_BUFFER_CACHE = "CCL_BUFFER_CACHE"; constexpr const char* CCL_STRICT_ORDER = "CCL_STRICT_ORDER"; constexpr const char* CCL_STAGING_BUFFER = "CCL_STAGING_BUFFER"; +constexpr const char* CCL_OP_SYNC = "CCL_OP_SYNC"; constexpr const char* CCL_CHUNK_COUNT = "CCL_CHUNK_COUNT"; constexpr const char* CCL_MIN_CHUNK_SIZE = "CCL_MIN_CHUNK_SIZE"; @@ -98,9 +106,18 @@ constexpr const char* CCL_ALLTOALL_SCATTER_MAX_OPS = "CCL_ALLTOALL_SCATTER_MAX_O constexpr const char* CCL_ALLTOALL_SCATTER_PLAIN = "CCL_ALLTOALL_SCATTER_PLAIN"; constexpr const char* CCL_COMM_KERNELS = "CCL_COMM_KERNELS"; -constexpr const char* CCL_COMM_KERNELS_PATH = "CCL_COMM_KERNELS_PATH"; -constexpr const char* CCL_COMM_KERNELS_DEBUG = "CCL_COMM_KERNELS_DEBUG"; -constexpr const char* CCL_GPU_THREAD_COUNT = "CCL_GPU_THREAD_COUNT"; +constexpr const char* CCL_KERNEL_PATH = "CCL_KERNEL_PATH"; +constexpr const char* CCL_KERNEL_DEBUG = "CCL_KERNEL_DEBUG"; +constexpr const char* CCL_KERNEL_CACHE = "CCL_KERNEL_CACHE"; +constexpr const char* CCL_KERNEL_GROUP_SIZE = "CCL_KERNEL_GROUP_SIZE"; +constexpr const char* CCL_KERNEL_GROUP_COUNT = "CCL_KERNEL_GROUP_COUNT"; +constexpr const char* CCL_KERNEL_SYNC = "CCL_KERNEL_SYNC"; +constexpr const char* CCL_KERNEL_1S_LEAD = "CCL_KERNEL_1S_LEAD"; +constexpr const char* CCL_KERNEL_1S_USE_COPY_OPS = "CCL_KERNEL_1S_USE_COPY_OPS"; +constexpr const char* CCL_KERNEL_1S_IPC_WA = "CCL_KERNEL_1S_IPC_WA"; +constexpr const char* CCL_KERNEL_OUTPUT_EVENT = "CCL_KERNEL_OUTPUT_EVENT"; +constexpr const char* CCL_ZE_SERIALIZE = "CCL_ZE_SERIALIZE"; +constexpr const char* CCL_ZE_COPY_ENGINE = "CCL_ZE_COPY_ENGINE"; constexpr const char* CCL_BF16 = "CCL_BF16"; constexpr const char* CCL_FP16 = "CCL_FP16"; @@ -109,8 +126,20 @@ enum ccl_priority_mode { ccl_priority_none, ccl_priority_direct, ccl_priority_li enum ccl_atl_transport { ccl_atl_ofi, ccl_atl_mpi }; +enum ccl_atl_send_proxy { + ccl_atl_send_proxy_none, + ccl_atl_send_proxy_regular, + ccl_atl_send_proxy_usm +}; + enum ccl_staging_buffer { ccl_staging_regular, ccl_staging_usm }; +enum ccl_ze_copy_engine_mode { + ccl_ze_copy_engine_none, + ccl_ze_copy_engine_main, + ccl_ze_copy_engine_link +}; + namespace ccl { class env_data { @@ -132,23 +161,29 @@ class env_data { ccl_spinlock print_guard{}; ccl_log_level log_level; + int queue_dump; int sched_dump; + int sched_profile; ccl_framework_type fw_type; size_t worker_count; int worker_offload; int worker_wait; - std::vector worker_affinity; + std::vector worker_affinity; + std::vector worker_mem_affinity; ccl_atl_transport atl_transport; int enable_shm; int enable_rma; - int enable_device_buf; + int enable_hmem; + ccl_atl_send_proxy atl_send_proxy; + int enable_atl_cache; int enable_sync_coll; int enable_extra_ep; atl_mnic_t mnic_type; + std::string mnic_name_raw; ssize_t mnic_count; /* @@ -180,8 +215,10 @@ class env_data { ssize_t bcast_part_count; ccl_cache_key_type cache_key_type; int enable_cache_flush; + int enable_buffer_cache; int enable_strict_order; ccl_staging_buffer staging_buffer; + int enable_op_sync; size_t chunk_count; size_t min_chunk_size; @@ -196,10 +233,18 @@ class env_data { ssize_t alltoall_scatter_max_ops; int alltoall_scatter_plain; - int enable_comm_kernels; - std::string comm_kernels_path; - int comm_kernels_debug; - ssize_t gpu_thread_count; + std::string kernel_path; + int kernel_debug; + int enable_kernel_cache; + ssize_t kernel_group_size; + ssize_t kernel_group_count; + int enable_kernel_sync; + int kernel_1s_lead; + int enable_kernel_1s_copy_ops; + int enable_kernel_1s_ipc_wa; + int enable_kernel_output_event; + int ze_serialize_mode; + ccl_ze_copy_engine_mode ze_copy_engine; ccl_bf16_impl_type bf16_impl_type; ccl_fp16_impl_type fp16_impl_type; @@ -273,15 +318,22 @@ class env_data { static std::map priority_mode_names; static std::map atl_transport_names; + static std::map atl_send_proxy_names; static std::map staging_buffer_names; + static std::map ze_copy_engine_names; static std::map mnic_type_names; - int env_2_worker_affinity(size_t local_proc_idx, size_t local_proc_count); + int env_2_worker_affinity(int local_proc_idx, int local_proc_count); + int env_2_worker_mem_affinity(); void env_2_atl_transport(); private: - int env_2_worker_affinity_auto(size_t local_proc_idx, size_t workers_per_process); - int parse_core_id(const std::string& core_id_str, size_t& result); + int env_2_worker_affinity_auto(int local_proc_idx, size_t workers_per_process); + + int parse_affinity(const std::string& input, + std::vector& output, + size_t expected_output_size); + int parse_number(const std::string& number_str, size_t& result); }; -} /* namespace ccl */ +} // namespace ccl diff --git a/src/common/event/impls/gpu_event.cpp b/src/common/event/impls/gpu_event.cpp deleted file mode 100644 index 64d6f03b6..000000000 --- a/src/common/event/impls/gpu_event.cpp +++ /dev/null @@ -1,116 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "common/request/request.hpp" -#include "common/event/impls/gpu_event.hpp" -#include "sched/gpu_sched.hpp" - -namespace ccl { -gpu_event_impl::gpu_event_impl(std::unique_ptr&& sched) - : gpu_sched(std::move(sched)) { - if (!gpu_sched) { - completed = true; - } -} - -gpu_event_impl::~gpu_event_impl() { - if (!completed) { - LOG_ERROR("not completed gpu event is destroyed"); - } -} - -void gpu_event_impl::wait() { - if (!completed && gpu_sched) { - do { - gpu_sched->do_progress(); - completed = gpu_sched->wait(0); - } while (!completed); - } -} - -bool gpu_event_impl::test() { - if (!completed && gpu_sched) { - completed = gpu_sched->wait(0); - gpu_sched->do_progress(); - } - return completed; -} - -bool gpu_event_impl::cancel() { - throw ccl::exception(std::string(__FUNCTION__) + " - is not implemented"); -} - -event::native_t& gpu_event_impl::get_native() { - throw ccl::exception(std::string(__FUNCTION__) + " - is not implemented"); -} - -gpu_shared_event_impl::gpu_shared_event_impl(std::shared_ptr&& sched) - : gpu_sched(std::move(sched)) { - if (!gpu_sched) { - completed = true; - } -} - -gpu_shared_event_impl::~gpu_shared_event_impl() { - if (!completed) { - LOG_ERROR("not completed shared gpu event is destroyed"); - } -} - -void gpu_shared_event_impl::wait() { - if (!completed && gpu_sched) { - do { - gpu_sched->do_progress(); - completed = gpu_sched->wait(0); - } while (!completed); - } -} - -bool gpu_shared_event_impl::test() { - if (!completed && gpu_sched) { - completed = gpu_sched->wait(0); - gpu_sched->do_progress(); - } - return completed; -} - -bool gpu_shared_event_impl::cancel() { - throw ccl::exception(std::string(__FUNCTION__) + " - is not implemented"); -} - -event::native_t& gpu_shared_event_impl::get_native() { - throw ccl::exception(std::string(__FUNCTION__) + " - is not implemented"); -} - -gpu_shared_process_event_impl::gpu_shared_process_event_impl( - std::shared_ptr&& sched) {} - -gpu_shared_process_event_impl::~gpu_shared_process_event_impl() {} - -void gpu_shared_process_event_impl::wait() {} - -bool gpu_shared_process_event_impl::test() { - return false; -} - -bool gpu_shared_process_event_impl::cancel() { - throw ccl::exception(std::string(__FUNCTION__) + " - is not implemented"); -} - -event::native_t& gpu_shared_process_event_impl::get_native() { - throw ccl::exception(std::string(__FUNCTION__) + " - is not implemented"); -} - -} // namespace ccl diff --git a/src/common/event/impls/gpu_event.hpp b/src/common/event/impls/gpu_event.hpp deleted file mode 100644 index f1113e437..000000000 --- a/src/common/event/impls/gpu_event.hpp +++ /dev/null @@ -1,70 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include - -#include "oneapi/ccl.hpp" -#include "common/event/impls/event_impl.hpp" - -class ccl_gpu_sched; - -namespace ccl { - -class gpu_event_impl final : public event_impl { -public: - explicit gpu_event_impl(std::unique_ptr&& sched); - ~gpu_event_impl(); - - void wait() override; - bool test() override; - bool cancel() override; - event::native_t& get_native() override; - -private: - std::unique_ptr gpu_sched; - bool completed = false; -}; - -class gpu_shared_event_impl final : public event_impl { -public: - explicit gpu_shared_event_impl(std::shared_ptr&& sched); - ~gpu_shared_event_impl(); - - void wait() override; - bool test() override; - bool cancel() override; - event::native_t& get_native() override; - -private: - std::shared_ptr gpu_sched; - bool completed = false; -}; - -class gpu_shared_process_event_impl final : public event_impl { -public: - explicit gpu_shared_process_event_impl(std::shared_ptr&& sched); - ~gpu_shared_process_event_impl(); - - void wait() override; - bool test() override; - bool cancel() override; - event::native_t& get_native() override; - -private: - std::shared_ptr gpu_sched; -}; - -} // namespace ccl diff --git a/src/common/event/impls/host_event.cpp b/src/common/event/impls/host_event.cpp index 35f8baf3b..80aff3603 100644 --- a/src/common/event/impls/host_event.cpp +++ b/src/common/event/impls/host_event.cpp @@ -15,13 +15,14 @@ */ #include "common/request/request.hpp" #include "common/event/impls/host_event.hpp" +#include "common/utils/sycl_utils.hpp" #include "exec/exec.hpp" namespace ccl { host_event_impl::host_event_impl(ccl_request* r) : req(r) { if (!req) { - // If the user calls collective with coll_attr->synchronous=1 then it will be progressed + // if the user calls collective with coll_attr->synchronous=1 then it will be progressed // in place and API will return null event. In this case mark cpp wrapper as completed, // all calls to wait() or test() will do nothing completed = true; @@ -29,7 +30,14 @@ host_event_impl::host_event_impl(ccl_request* r) : req(r) { } host_event_impl::~host_event_impl() { - if (!completed) { + // TODO: need to find a way to syncronize these 2 statuses, right now there are + // some issues, e.g. in case of pure host event get_native() is an empty sycl + // event which always complete, this way LOG_ERROR is never called + if (!completed +#ifdef CCL_ENABLE_SYCL + && (!utils::is_sycl_event_completed(get_native())) +#endif + ) { LOG_ERROR("not completed event is destroyed"); } } @@ -53,7 +61,16 @@ bool host_event_impl::cancel() { } event::native_t& host_event_impl::get_native() { +#ifdef CCL_ENABLE_SYCL + if (ccl::global_data::env().enable_kernel_output_event) { + return req->get_native_event(); + } + else { + CCL_THROW("get_native() is not available without CCL_KERNEL_OUTPUT_EVENT=1 env variable"); + } +#else throw ccl::exception(std::string(__FUNCTION__) + " - is not implemented"); +#endif } } // namespace ccl diff --git a/src/common/event/impls/native_event.cpp b/src/common/event/impls/native_event.cpp index f5effe2e2..3cab35de2 100644 --- a/src/common/event/impls/native_event.cpp +++ b/src/common/event/impls/native_event.cpp @@ -23,20 +23,24 @@ native_event_impl::native_event_impl(std::unique_ptr ev) : ev(std::mo void native_event_impl::wait() { if (!completed) { #ifdef CCL_ENABLE_SYCL - auto native_event = ev->get_attribute_value( - detail::ccl_api_type_attr_traits{}); + auto native_event = get_native(); native_event.wait(); -#else +#else // CCL_ENABLE_SYCL throw ccl::exception(std::string(__FUNCTION__) + " - is not implemented"); -#endif +#endif // CCL_ENABLE_SYCL completed = true; } } bool native_event_impl::test() { if (!completed) { +#ifdef CCL_ENABLE_SYCL + auto native_event = get_native(); + completed = native_event.get_info() == + sycl::info::event_command_status::complete; +#else // CCL_ENABLE_SYCL throw ccl::exception(std::string(__FUNCTION__) + " - is not implemented"); +#endif // CCL_ENABLE_SYCL } return completed; } diff --git a/src/common/global/global.cpp b/src/common/global/global.cpp index f1f04ef66..516aac8ee 100644 --- a/src/common/global/global.cpp +++ b/src/common/global/global.cpp @@ -23,8 +23,14 @@ #include "exec/exec.hpp" #include "fusion/fusion.hpp" #include "parallelizer/parallelizer.hpp" +#include "sched/buffer_cache.hpp" #include "sched/cache/cache.hpp" +#ifdef MULTI_GPU_SUPPORT +#include "sched/entry/gpu/ze_cache.hpp" +#include "sched/entry/gpu/ze_primitives.hpp" +#endif // MULTI_GPU_SUPPORT + namespace ccl { thread_local bool global_data::is_worker_thread = false; @@ -34,8 +40,7 @@ global_data::global_data() { to ensure static objects construction/destruction rule */ LOG_INFO("create global_data object"); - //TODO new_api configure thread wait timeout - thread_barrier_wait_timeout_sec = 5; + kernel_counter = 0; } global_data::~global_data() { @@ -60,6 +65,10 @@ ccl::status global_data::reset() { reset_resize_dependent_objects(); reset_resize_independent_objects(); +#ifdef MULTI_GPU_SUPPORT + finalize_gpu(); +#endif // MULTI_GPU_SUPPORT + return ccl::status::success; } @@ -67,6 +76,10 @@ ccl::status global_data::init() { env_object.parse(); env_object.set_internal_env(); +#ifdef MULTI_GPU_SUPPORT + init_gpu(); +#endif // MULTI_GPU_SUPPORT + init_resize_dependent_objects(); init_resize_independent_objects(); @@ -77,6 +90,8 @@ void global_data::init_resize_dependent_objects() { dtypes = std::unique_ptr(new ccl_datatype_storage()); sched_cache = std::unique_ptr(new ccl_sched_cache()); + buffer_cache = + std::unique_ptr(new ccl::buffer_cache(env_object.worker_count)); if (env_object.enable_fusion) { /* create fusion_manager before executor because service_worker uses fusion_manager */ @@ -94,20 +109,41 @@ void global_data::init_resize_independent_objects() { algorithm_selector = std::unique_ptr>( new ccl_algorithm_selector_wrapper()); - algorithm_selector->init(); + + hwloc_wrapper = std::unique_ptr(new ccl_hwloc_wrapper()); } void global_data::reset_resize_dependent_objects() { comm_ids.reset(); fusion_manager.reset(); sched_cache.reset(); + buffer_cache.reset(); dtypes.reset(); } void global_data::reset_resize_independent_objects() { parallelizer.reset(); algorithm_selector.reset(); + hwloc_wrapper.reset(); +} + +#ifdef MULTI_GPU_SUPPORT +void global_data::init_gpu() { + LOG_INFO("initializing level-zero"); + ze_result_t res = zeInit(ZE_INIT_FLAG_GPU_ONLY); + if (res != ZE_RESULT_SUCCESS) { + CCL_THROW("error at zeInit, code: ", ccl::ze::to_string(res)); + } + ze_cache = std::unique_ptr(new ccl::ze::cache(env_object.worker_count)); + LOG_INFO("initialized level-zero"); +} + +void global_data::finalize_gpu() { + LOG_INFO("finalizing level-zero"); + ze_cache.reset(); + LOG_INFO("finalized level-zero"); } +#endif // MULTI_GPU_SUPPORT -} /* namespace ccl */ +} // namespace ccl diff --git a/src/common/global/global.hpp b/src/common/global/global.hpp index 3d4d813f9..a2e29b2e3 100644 --- a/src/common/global/global.hpp +++ b/src/common/global/global.hpp @@ -19,6 +19,7 @@ #include "common/env/env.hpp" #include "common/utils/utils.hpp" #include "common/comm/l0/comm_context_storage.hpp" +#include "hwloc/hwloc_wrapper.hpp" #include "internal_types.hpp" #include @@ -46,30 +47,17 @@ class ccl_executor; class ccl_sched_cache; class ccl_parallelizer; class ccl_fusion_manager; -struct ccl_group_context; template class ccl_algorithm_selector_wrapper; namespace ccl { -// class comm_group; -// using comm_group_t = std::shared_ptr; +class buffer_cache; -// struct ccl_group_context { -// TODO -// * In multithreading scenario we use different comm_group_t objects in different threads. -// * But we need to match different groups created for the same world in different threads -// * The assumption is done: if different groups created from the same communicator color, than they -// * should be interpreted as the same groups in the same world. -// * -// * -// * In the final solution the 'group_unique_key' should be equal to unique KVS idenditifier - -// using group_unique_key = typename ccl::ccl_host_attributes_traits::type; -// std::map communicator_group_map; -// ccl_spinlock mutex; -// }; +namespace ze { +class cache; +} // namespace ze class global_data { public: @@ -96,23 +84,31 @@ class global_data { std::unique_ptr dtypes; std::unique_ptr executor; std::unique_ptr sched_cache; + std::unique_ptr buffer_cache; std::unique_ptr parallelizer; std::unique_ptr fusion_manager; std::unique_ptr> algorithm_selector; - std::unique_ptr global_ctx; + std::unique_ptr hwloc_wrapper; + std::atomic kernel_counter; + +#ifdef MULTI_GPU_SUPPORT + std::unique_ptr ze_cache; +#endif // MULTI_GPU_SUPPORT static thread_local bool is_worker_thread; bool is_ft_enabled; - //TODO new_api configure thread wait timeout - size_t thread_barrier_wait_timeout_sec = 5; - private: global_data(); void init_resize_independent_objects(); void reset_resize_independent_objects(); +#ifdef MULTI_GPU_SUPPORT + void init_gpu(); + void finalize_gpu(); +#endif // MULTI_GPU_SUPPORT + env_data env_object; }; @@ -125,4 +121,4 @@ class global_data { } while (0); \ } -} /* namespace ccl */ +} // namespace ccl diff --git a/src/common/log/log.cpp b/src/common/log/log.cpp index 48119e30d..a92adb861 100644 --- a/src/common/log/log.cpp +++ b/src/common/log/log.cpp @@ -39,6 +39,7 @@ std::ostream& operator<<(std::ostream& os, ccl_streambuf& buf) { void ccl_logger::write_prefix(std::ostream& str) { constexpr size_t time_buf_size = 20; + constexpr size_t tid_width = 5; time_t timer; char time_buf[time_buf_size]{}; struct tm time_info {}; @@ -47,7 +48,7 @@ void ccl_logger::write_prefix(std::ostream& str) { strftime(time_buf, time_buf_size, "%Y:%m:%d-%H:%M:%S", &time_info); str << time_buf; } - str << ":(" << gettid() << ") "; + str << ":(" << std::setw(tid_width) << gettid() << ") "; } void ccl_logger::write_backtrace(std::ostream& str) { diff --git a/src/common/log/log.hpp b/src/common/log/log.hpp index 4ffd009a3..c5d7877bf 100644 --- a/src/common/log/log.hpp +++ b/src/common/log/log.hpp @@ -23,6 +23,7 @@ #include #include +#include "oneapi/ccl/exception.hpp" #include "oneapi/ccl/types.hpp" #include "common/utils/spinlock.hpp" #include "common/utils/utils.hpp" @@ -247,9 +248,9 @@ extern ccl_logger logger; basedir_static(__FILE__), \ ":", \ __LINE__, \ - " ", \ - __FUNCTION__, \ " ", \ + __FUNCTION__, \ + ": ", \ ##__VA_ARGS__); \ } \ } @@ -275,9 +276,9 @@ extern ccl_logger logger; basedir_static(__FILE__), \ ":", \ __LINE__, \ - " ", \ - __FUNCTION__, \ " ", \ + __FUNCTION__, \ + ": ", \ ##__VA_ARGS__); \ } \ } @@ -289,9 +290,9 @@ extern ccl_logger logger; basedir_static(__FILE__), \ ":", \ __LINE__, \ - " ", \ - __FUNCTION__, \ " ", \ + __FUNCTION__, \ + ": ", \ ##__VA_ARGS__); \ } \ } @@ -312,11 +313,11 @@ extern ccl_logger logger; do { \ std::stringstream throw_msg_ss; \ ccl_logger::format(throw_msg_ss, \ - __FILENAME__, \ - ":", \ - __FUNCTION__, \ + basedir_static(__FILE__), \ ":", \ __LINE__, \ + " ", \ + __FUNCTION__, \ ": EXCEPTION: ", \ ##__VA_ARGS__); \ throw ccl::exception(throw_msg_ss.str()); \ @@ -329,11 +330,11 @@ extern ccl_logger logger; do { \ std::stringstream throw_msg_ss; \ ccl_logger::format(throw_msg_ss, \ - __FILENAME__, \ - ":", \ - __FUNCTION__, \ + basedir_static(__FILE__), \ ":", \ __LINE__, \ + " ", \ + __FUNCTION__, \ ": EXCEPTION: ", \ ##__VA_ARGS__); \ LOG_ERROR("Error - ", ##__VA_ARGS__); \ diff --git a/src/common/request/request.hpp b/src/common/request/request.hpp index 4df670f1b..62881c34b 100644 --- a/src/common/request/request.hpp +++ b/src/common/request/request.hpp @@ -20,11 +20,15 @@ #include "common/utils/utils.hpp" +#ifdef CCL_ENABLE_SYCL +#include +#endif + class alignas(CACHELINE_SIZE) ccl_request { public: - using dump_func = std::function; + using dump_func = std::function; #ifdef ENABLE_DEBUG - void set_dump_callback(dump_func &&callback); + void set_dump_callback(dump_func&& callback); #endif virtual ~ccl_request(); @@ -39,9 +43,36 @@ class alignas(CACHELINE_SIZE) ccl_request { mutable bool urgent = false; +#ifdef CCL_ENABLE_SYCL + void set_native_event(sycl::event new_event) { + native_event = new_event; + } + + sycl::event& get_native_event() { + return native_event; + } + + void set_sync_event(sycl::event new_event) { + sync_event = new_event; + } + + sycl::event& get_sync_event() { + return sync_event; + } + +#endif private: std::atomic_int completion_counter{ 0 }; +#ifdef CCL_ENABLE_SYCL + // The actual event from submit_barrier. It's returned to the user via ccl::event.get_native() + sycl::event native_event; + // This is basically a wrapped l0 event from sched_base, we need to keep as sycl object because its destructor + // implies wait on the event, but in our case it's not yet completed(right after we created it from l0 event). + // So just keep it here until we signal the corresponding l0 event. + sycl::event sync_event; +#endif + #ifdef ENABLE_DEBUG dump_func dump_callback; mutable size_t complete_checks_count = 0; diff --git a/src/common/stream/stream.cpp b/src/common/stream/stream.cpp index 8aae8b8b2..2776d9aa9 100644 --- a/src/common/stream/stream.cpp +++ b/src/common/stream/stream.cpp @@ -35,138 +35,17 @@ ccl_stream::ccl_stream(stream_type type, for (size_t idx = 0; idx < native_streams.size(); idx++) { native_streams[idx] = stream_native_t(stream.get_context(), stream.get_device()); } -#endif /* CCL_ENABLE_SYCL */ -} - -ccl_stream::ccl_stream(stream_type type, - stream_native_handle_t handle, - const ccl::library_version& version) - : type(type), - version(version) { - creation_is_postponed = true; - (void)handle; - throw std::runtime_error(std::string(__PRETTY_FUNCTION__) + " - unsupported "); -} -ccl_stream::ccl_stream(stream_type type, const ccl::library_version& version) - : type(type), - version(version) { - creation_is_postponed = true; - LOG_DEBUG("Scheduled postponed stream creation"); + backend = stream.get_device().get_backend(); +#endif // CCL_ENABLE_SYCL } -void ccl_stream::build_from_params() { - if (!creation_is_postponed) { - throw ccl::exception(std::string(__FUNCTION__) + - " - incorrect usage, stream is not sheduled for postponed creation"); - } - - type = stream_type::host; - try { -#ifdef CCL_ENABLE_SYCL - if (native_context.first) { - if (!native_device.first) { - throw ccl::exception( - std::string(__FUNCTION__) + - " - incorrect usage, not enough parameters for stream creation: " - " context is available, but device is not. Both required"); - } - - LOG_DEBUG("create stream from device & context"); - stream_native_t stream_candidate{ native_context.second, native_device.second }; - std::swap(stream_candidate, - native_stream); //TODO USE attributes from sycl queue construction - } - else if (native_device.first) { - LOG_DEBUG("create stream from device only"); - stream_native_t stream_candidate{ native_device.second }; - std::swap(stream_candidate, - native_stream); //TODO USE attributes from sycl queue construction - - native_context.second = native_stream.get_context(); - native_context.first = true; - } - else { - throw ccl::exception(std::string(__FUNCTION__) + - " - incorrect usage, not enough parameters for stream creation: " - " context is empty and device is empty too."); - } - - //override type - if (native_stream.get_device().is_host()) { - type = stream_type::host; - } - else if (native_stream.get_device().is_cpu()) { - type = stream_type::cpu; - } - else if (native_stream.get_device().is_gpu()) { - type = stream_type::gpu; - } - else { - throw ccl::invalid_argument( - "CORE", - "create_stream", - std::string("Unsupported SYCL queue's device type for postponed creation:\n") + - native_stream.get_device().template get_info() + - std::string("Supported types: host, cpu, gpu")); - } - LOG_INFO("SYCL queue type from postponed creation: ", - ::to_string(type), - ", device: ", - native_stream.get_device().template get_info()); -#else -#ifdef MULTI_GPU_SUPPORT - ze_command_queue_desc_t descr = - stream_native_device_t::element_type::get_default_queue_desc(); - - //TODO use attributes.... - //Create from device & context - if (native_context.first) { - if (!native_device.first) { - throw ccl::exception( - std::string(__FUNCTION__) + - " - incorrect usage, not enough parameters for stream creation: " - " context is available, but device is not. Both required"); - } - - LOG_DEBUG("create stream from device & context"); - auto stream_candidate = - native_device.second->create_cmd_queue(native_context.second, descr); - native_stream = std::make_shared( - std::move(stream_candidate)); - } - else if (native_device.first) { - LOG_DEBUG("create stream from device only"); - - auto stream_candidate = native_device.second->create_cmd_queue({}, descr); - native_stream = std::make_shared( - std::move(stream_candidate)); - - native_context.second = native_stream->get_ctx().lock(); - native_context.first = true; - } - else { - throw ccl::exception(std::string(__FUNCTION__) + - " - incorrect usage, not enough parameters for stream creation: " - " context is empty and device is empty too."); - } - - type = stream_type::gpu; -#endif -#endif - } - catch (const std::exception& ex) { - throw ccl::exception(std::string("Cannot build ccl_stream from params: ") + ex.what()); - } - creation_is_postponed = false; -} - -//Export Attributes +// export attributes typename ccl_stream::version_traits_t::type ccl_stream::set_attribute_value( typename version_traits_t::type val, const version_traits_t& t) { (void)t; - throw ccl::exception("Set value for 'ccl::stream_attr_id::library_version' is not allowed"); + throw ccl::exception("set value for 'ccl::stream_attr_id::library_version' is not allowed"); return version; } @@ -177,127 +56,18 @@ const typename ccl_stream::version_traits_t::return_type& ccl_stream::get_attrib typename ccl_stream::native_handle_traits_t::return_type& ccl_stream::get_attribute_value( const native_handle_traits_t& id) { - if (creation_is_postponed) { - throw ccl::exception(std::string(__FUNCTION__) + " - stream is not properly created yet"); - } - return native_stream; } -typename ccl_stream::device_traits_t::return_type& ccl_stream::get_attribute_value( - const device_traits_t& id) { - if (!native_device.first) { - throw ccl::exception(std::string(__FUNCTION__) + " - stream has no native device"); - } - return native_device.second; -} - -typename ccl_stream::context_traits_t::return_type& ccl_stream::get_attribute_value( - const context_traits_t& id) { - if (!native_context.first) { - throw ccl::exception(std::string(__FUNCTION__) + " - stream has no native context"); - } - return native_context.second; -} - -typename ccl_stream::context_traits_t::return_type& ccl_stream::set_attribute_value( - typename context_traits_t::type val, - const context_traits_t& t) { - if (!creation_is_postponed) { - throw ccl::exception("Cannot set 'ccl::stream_attr_id::context'`for constructed stream"); - } - std::swap(native_context.second, val); - native_context.first = true; - return native_context.second; -} -/* -typename ccl_stream::context_traits_t::return_type& ccl_stream::set_attribute_value( - typename context_traits_t::handle_t val, - const context_traits_t& t) { - if (!creation_is_postponed) { - throw ccl::exception("Cannot set 'ccl::stream_attr_id::context'`for constructed stream"); - } - native_context.second = ccl::unified_context_type{ val }.get(); //context_traits_t::type - native_context.first = true; - return native_context.second; -}*/ - -typename ccl_stream::ordinal_traits_t::type ccl_stream::set_attribute_value( - typename ordinal_traits_t::type val, - const ordinal_traits_t& t) { - if (!creation_is_postponed) { - throw ccl::exception("Cannot set 'ccl::stream_attr_id::ordinal'`for constructed stream"); - } - auto old = ordinal_val; - std::swap(ordinal_val, val); - return old; -} - -const typename ccl_stream::ordinal_traits_t::return_type& ccl_stream::get_attribute_value( - const ordinal_traits_t& id) const { - return ordinal_val; -} - -typename ccl_stream::index_traits_t::type ccl_stream::set_attribute_value( - typename index_traits_t::type val, - const index_traits_t& t) { - if (!creation_is_postponed) { - throw ccl::exception("Cannot set 'ccl::stream_attr_id::index'`for constructed stream"); - } - auto old = index_val; - std::swap(index_val, val); - return old; -} - -const typename ccl_stream::index_traits_t::return_type& ccl_stream::get_attribute_value( - const index_traits_t& id) const { - return index_val; -} - -typename ccl_stream::flags_traits_t::type ccl_stream::set_attribute_value( - typename flags_traits_t::type val, - const flags_traits_t& t) { - if (!creation_is_postponed) { - throw ccl::exception("Cannot set 'ccl::stream_attr_id::flags'`for constructed stream"); - } - auto old = flags_val; - std::swap(flags_val, val); - return old; -} - -const typename ccl_stream::flags_traits_t::return_type& ccl_stream::get_attribute_value( - const flags_traits_t& id) const { - return flags_val; -} - -typename ccl_stream::mode_traits_t::type ccl_stream::set_attribute_value( - typename mode_traits_t::type val, - const mode_traits_t& t) { - if (!creation_is_postponed) { - throw ccl::exception("Cannot set 'ccl::stream_attr_id::mode'`for constructed stream"); - } - auto old = mode_val; - std::swap(mode_val, val); - return old; -} - -const typename ccl_stream::mode_traits_t::return_type& ccl_stream::get_attribute_value( - const mode_traits_t& id) const { - return mode_val; -} - -typename ccl_stream::priority_traits_t::type ccl_stream::set_attribute_value( - typename priority_traits_t::type val, - const priority_traits_t& t) { - if (!creation_is_postponed) { - throw ccl::exception("Cannot set 'ccl::stream_attr_id::priority'`for constructed stream"); - } - auto old = priority_val; - std::swap(priority_val, val); - return old; -} - -const typename ccl_stream::priority_traits_t::return_type& ccl_stream::get_attribute_value( - const priority_traits_t& id) const { - return priority_val; +std::string ccl_stream::to_string() const { + std::stringstream ss; +#ifdef CCL_ENABLE_SYCL + ss << "{ " + << "type: " << ::to_string(type) << ", in_order: " << native_stream.is_in_order() + << ", device: " << native_stream.get_device().get_info() + << " }"; +#else // CCL_ENABLE_SYCL + ss << reinterpret_cast(native_stream.get()); +#endif // CCL_ENABLE_SYCL + return ss.str(); } diff --git a/src/common/stream/stream.hpp b/src/common/stream/stream.hpp index 3e7210d58..7f192715b 100644 --- a/src/common/stream/stream.hpp +++ b/src/common/stream/stream.hpp @@ -15,17 +15,16 @@ */ #pragma once -#include "oneapi/ccl/types_policy.hpp" -#include "oneapi/ccl/types.hpp" -#include "oneapi/ccl/type_traits.hpp" -#include "oneapi/ccl/stream_attr_ids.hpp" -#include "oneapi/ccl/stream_attr_ids_traits.hpp" +#include "coll/coll_common_attributes.hpp" +#include "common/stream/stream_provider_dispatcher.hpp" #include "common/utils/enums.hpp" #include "common/utils/utils.hpp" -#include "common/stream/stream_provider_dispatcher.hpp" - -#include "coll/coll_common_attributes.hpp" #include "internal_types.hpp" +#include "oneapi/ccl/stream_attr_ids.hpp" +#include "oneapi/ccl/stream_attr_ids_traits.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/type_traits.hpp" namespace ccl { namespace detail { @@ -36,21 +35,12 @@ class environment; using stream_str_enum = utils::enum_to_str; std::string to_string(const stream_type& type); -/* -ccl::status CCL_API ccl_stream_create(stream_type type, - void* native_stream, - ccl_stream_t* stream); -*/ class alignas(CACHELINE_SIZE) ccl_stream : public stream_provider_dispatcher { public: friend class stream_provider_dispatcher; friend class ccl::detail::environment; - /* - friend ccl::status CCL_API ccl_stream_create(stream_type type, - void* native_stream, - ccl_stream_t* stream);*/ + using stream_native_t = stream_provider_dispatcher::stream_native_t; - using stream_native_handle_t = stream_provider_dispatcher::stream_native_handle_t; ccl_stream() = delete; ccl_stream(const ccl_stream& other) = delete; @@ -60,6 +50,8 @@ class alignas(CACHELINE_SIZE) ccl_stream : public stream_provider_dispatcher { using stream_provider_dispatcher::get_native_stream; + std::string to_string() const; + stream_type get_type() const { return type; } @@ -68,10 +60,20 @@ class alignas(CACHELINE_SIZE) ccl_stream : public stream_provider_dispatcher { return (type == stream_type::cpu || type == stream_type::gpu); } + bool is_gpu() const { + return type == stream_type::gpu; + } + +#ifdef CCL_ENABLE_SYCL + cl::sycl::backend get_backend() const noexcept { + return backend; + } +#endif // CCL_ENBALE_SYCL + static std::unique_ptr create(stream_native_t& native_stream, const ccl::library_version& version); - //Export Attributes + // export attributes using version_traits_t = ccl::detail::ccl_api_type_attr_traits; typename version_traits_t::return_type set_attribute_value(typename version_traits_t::type val, @@ -86,79 +88,14 @@ class alignas(CACHELINE_SIZE) ccl_stream : public stream_provider_dispatcher { typename native_handle_traits_t::return_type& get_attribute_value( const native_handle_traits_t& id); - using device_traits_t = - ccl::detail::ccl_api_type_attr_traits; - typename device_traits_t::return_type& get_attribute_value(const device_traits_t& id); - - using context_traits_t = - ccl::detail::ccl_api_type_attr_traits; - typename context_traits_t::return_type& get_attribute_value(const context_traits_t& id); - - typename context_traits_t::return_type& set_attribute_value(typename context_traits_t::type val, - const context_traits_t& t); - /* - typename context_traits_t::return_type& set_attribute_value( - typename context_traits_t::handle_t val, - const context_traits_t& t); -*/ - using ordinal_traits_t = - ccl::detail::ccl_api_type_attr_traits; - typename ordinal_traits_t::return_type set_attribute_value(typename ordinal_traits_t::type val, - const ordinal_traits_t& t); - - const typename ordinal_traits_t::return_type& get_attribute_value( - const ordinal_traits_t& id) const; - - using index_traits_t = - ccl::detail::ccl_api_type_attr_traits; - typename index_traits_t::return_type set_attribute_value(typename index_traits_t::type val, - const index_traits_t& t); - - const typename index_traits_t::return_type& get_attribute_value(const index_traits_t& id) const; - - using flags_traits_t = - ccl::detail::ccl_api_type_attr_traits; - typename flags_traits_t::return_type set_attribute_value(typename flags_traits_t::type val, - const flags_traits_t& t); - - const typename flags_traits_t::return_type& get_attribute_value(const flags_traits_t& id) const; - - using mode_traits_t = - ccl::detail::ccl_api_type_attr_traits; - typename mode_traits_t::return_type set_attribute_value(typename mode_traits_t::type val, - const mode_traits_t& t); - - const typename mode_traits_t::return_type& get_attribute_value(const mode_traits_t& id) const; - - using priority_traits_t = - ccl::detail::ccl_api_type_attr_traits; - typename priority_traits_t::return_type set_attribute_value( - typename priority_traits_t::type val, - const priority_traits_t& t); - - const typename priority_traits_t::return_type& get_attribute_value( - const priority_traits_t& id) const; - - void build_from_params(); - private: ccl_stream(stream_type type, stream_native_t& native_stream, const ccl::library_version& version); - ccl_stream(stream_type type, - stream_native_handle_t native_stream, - const ccl::library_version& version); - - ccl_stream(stream_type type, const ccl::library_version& version); - stream_type type; +#ifdef CCL_ENABLE_SYCL + cl::sycl::backend backend; +#endif // CCL_ENBALE_SYCL const ccl::library_version version; - typename ordinal_traits_t::return_type ordinal_val; - typename index_traits_t::return_type index_val; - typename flags_traits_t::return_type flags_val; - typename mode_traits_t::return_type mode_val; - typename priority_traits_t::return_type priority_val; - - bool is_context_enabled{ false }; }; diff --git a/src/common/stream/stream_provider_dispatcher.hpp b/src/common/stream/stream_provider_dispatcher.hpp index bdbea6b9b..cd5b7a028 100644 --- a/src/common/stream/stream_provider_dispatcher.hpp +++ b/src/common/stream/stream_provider_dispatcher.hpp @@ -38,33 +38,22 @@ enum class stream_type : int { class ccl_stream; class stream_provider_dispatcher { public: - using stream_native_handle_t = typename ccl::unified_stream_type::handle_t; using stream_native_t = typename ccl::unified_stream_type::ccl_native_t; + using stream_native_device_t = typename ccl::unified_device_type::ccl_native_t; - ; using stream_native_context_t = typename ccl::unified_context_type::ccl_native_t; stream_native_t get_native_stream() const; #ifdef CCL_ENABLE_SYCL stream_native_t* get_native_stream(size_t idx); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL const stream_native_device_t& get_native_device() const; stream_native_device_t& get_native_device(); - std::string to_string() const; - - // available admissions to create stream static std::unique_ptr create(stream_native_t& native_stream, const ccl::library_version& version); - static std::unique_ptr create(stream_native_handle_t native_handle, - const ccl::library_version& version); - static std::unique_ptr create(stream_native_device_t device, - const ccl::library_version& version); - static std::unique_ptr create(stream_native_device_t device, - stream_native_context_t context, - const ccl::library_version& version); template using optional = std::pair; @@ -72,12 +61,10 @@ class stream_provider_dispatcher { optional native_device; optional native_context; - bool creation_is_postponed{ false }; - stream_native_t native_stream; #ifdef CCL_ENABLE_SYCL /* FIXME: tmp w/a for MT support in queue */ std::vector native_streams; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL }; diff --git a/src/common/stream/stream_provider_dispatcher_impl.hpp b/src/common/stream/stream_provider_dispatcher_impl.hpp index 1f44700bf..bff3c91eb 100644 --- a/src/common/stream/stream_provider_dispatcher_impl.hpp +++ b/src/common/stream/stream_provider_dispatcher_impl.hpp @@ -19,13 +19,14 @@ #ifdef CCL_ENABLE_SYCL #include -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL -// Creation from class-type: cl::sycl::queue or native::ccl_device::devie_queue +// creation from sycl::queue std::unique_ptr stream_provider_dispatcher::create( stream_native_t& native_stream, const ccl::library_version& version) { stream_type type = stream_type::host; + #ifdef CCL_ENABLE_SYCL if (native_stream.get_device().is_host()) { type = stream_type::host; @@ -38,11 +39,11 @@ std::unique_ptr stream_provider_dispatcher::create( } else { throw ccl::invalid_argument( - "CORE", + "core", "create_stream", - std::string("Unsupported SYCL queue's device type:\n") + + std::string("unsupported SYCL queue's device type:\n") + native_stream.get_device().template get_info() + - std::string("Supported types: host, cpu, gpu")); + std::string("supported types: host, cpu, gpu")); } std::unique_ptr ret(new ccl_stream(type, native_stream, version)); @@ -50,123 +51,31 @@ std::unique_ptr stream_provider_dispatcher::create( ret->native_device.first = true; ret->native_context.second = native_stream.get_context(); ret->native_context.first = true; + LOG_INFO("SYCL queue type: ", ::to_string(type), + ", in_order: ", + native_stream.is_in_order(), ", device: ", native_stream.get_device().template get_info()); -#else -#ifdef MULTI_GPU_SUPPORT - LOG_INFO("L0 queue type: gpu - supported only"); - type = stream_type::gpu; - std::unique_ptr ret(new ccl_stream(type, native_stream, version)); - ret->native_device.second = native_stream->get_owner().lock(); - ret->native_device.first = true; - ret->native_context.second = native_stream->get_ctx().lock(); - ret->native_context.first = true; -#else +#else // CCL_ENABLE_SYCL std::unique_ptr ret(new ccl_stream(type, native_stream, version)); -#endif -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL return ret; } -// Creation from handles: cl_queue or ze_device_queue_handle_t -std::unique_ptr stream_provider_dispatcher::create( - stream_native_handle_t native_stream, - const ccl::library_version& version) { - return std::unique_ptr(new ccl_stream(stream_type::gpu, native_stream, version)); -} - -// Postponed creation from device -std::unique_ptr stream_provider_dispatcher::create( - stream_native_device_t device, - const ccl::library_version& version) { - auto ret = std::unique_ptr(new ccl_stream(stream_type::gpu, version)); - ret->native_device.second = device; - ret->native_device.first = true; - return ret; -} - -// Postponed creation from device & context -std::unique_ptr stream_provider_dispatcher::create( - stream_native_device_t device, - stream_native_context_t context, - const ccl::library_version& version) { - auto ret = stream_provider_dispatcher::create(device, version); - ret->native_context.second = context; - ret->native_context.first = true; - return ret; -} - stream_provider_dispatcher::stream_native_t stream_provider_dispatcher::get_native_stream() const { - if (creation_is_postponed) { - throw ccl::exception("native stream is not set"); - } - return native_stream; } #ifdef CCL_ENABLE_SYCL stream_provider_dispatcher::stream_native_t* stream_provider_dispatcher::get_native_stream( size_t idx) { - if (creation_is_postponed) { - throw ccl::exception("native stream is not set"); - } - if (idx >= native_streams.size()) { throw ccl::exception("unexpected stream idx"); } - return &(native_streams[idx]); } -#endif /* CCL_ENABLE_SYCL */ - -const stream_provider_dispatcher::stream_native_device_t& -stream_provider_dispatcher::get_native_device() const { - if (!native_device.first) { - throw ccl::exception(std::string(__FUNCTION__) + " - stream has no native device"); - } - return native_device.second; -} - -stream_provider_dispatcher::stream_native_device_t& -stream_provider_dispatcher::get_native_device() { - return const_cast( - static_cast(this)->get_native_device()); -} - -std::string stream_provider_dispatcher::to_string() const { - if (creation_is_postponed) { - throw ccl::exception("stream is not properly created yet"); - } - std::stringstream ss; -#ifdef CCL_ENABLE_SYCL - ss << "sycl: " - << native_stream.get_info() - .get_info(); -#else - ss << reinterpret_cast(native_stream.get()); //TODO -#endif - return ss.str(); -} - -/* -stream_provider_dispatcher::stream_native_handle_t -stream_provider_dispatcher::get_native_stream_handle_impl(stream_native_t &handle) -{ -#ifdef CCL_ENABLE_SYCL - if (!handle.get_device().is_host()) - { - return *reinterpret_cast(handle.get()); - } - else - { - return *reinterpret_cast(&handle); - } -#else - return handle; -#endif -} -*/ +#endif // CCL_ENABLE_SYCL diff --git a/src/common/utils/buffer.hpp b/src/common/utils/buffer.hpp index ef2dd3531..d35e9aada 100644 --- a/src/common/utils/buffer.hpp +++ b/src/common/utils/buffer.hpp @@ -23,7 +23,7 @@ enum class ccl_buffer_type { DIRECT, INDIRECT }; inline std::ostream& operator<<(std::ostream& os, const ccl_buffer_type& type) { - os << static_cast::type>(type); + os << static_cast::type>(type); return os; } @@ -68,7 +68,16 @@ class ccl_buffer { size(size), offset(offset), type(type) { - LOG_DEBUG("create: src ", src, ", size ", size, ", offset ", offset, ", type ", type); + LOG_DEBUG("create: src ", + src, + ", size ", + size, + ", offset ", + offset, + ", type ", + type, + ", ptr ", + get_ptr()); CCL_ASSERT(check_offset()); } @@ -212,7 +221,8 @@ class ccl_buffer { friend std::ostream& operator<<(std::ostream& out, const ccl_buffer& buf) { out << "(src: " << buf.get_src() << ", size " << buf.get_size() << ", off " - << buf.get_offset() << ", type: " << buf.get_type() << ")"; + << buf.get_offset() << ", type: " << buf.get_type() << ", ptr: " << buf.get_ptr() + << ")"; return out; } }; diff --git a/src/common/utils/hash.hpp b/src/common/utils/hash.hpp new file mode 100644 index 000000000..9e98b4f60 --- /dev/null +++ b/src/common/utils/hash.hpp @@ -0,0 +1,56 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include + +namespace ccl { +namespace utils { + +inline size_t calculate_hash(size_t left, size_t right) { + left = ((left >> 2) + right + (left << 6) + 0x9e3779b9) ^ left; + return left; +} + +template +struct hash_impl { + size_t operator()(size_t left, const std::tuple& tuple) const { + using next_t = typename std::tuple_element>::type; + hash_impl next; + size_t right = std::hash()(std::get(tuple)); + return next(calculate_hash(left, right), tuple); + } +}; + +template +struct hash_impl<0, types...> { + size_t operator()(size_t left, const std::tuple& tuple) const { + using next_t = typename std::tuple_element<0, std::tuple>::type; + size_t right = std::hash()(std::get<0>(tuple)); + return calculate_hash(left, right); + } +}; + +struct tuple_hash { + template + size_t operator()(const std::tuple& tuple) const { + const size_t start = std::tuple_size>::value - 1; + return hash_impl()(0, tuple); + } +}; + +} // namespace utils +} // namespace ccl diff --git a/src/common/utils/sycl_utils.hpp b/src/common/utils/sycl_utils.hpp new file mode 100644 index 000000000..fba47ea42 --- /dev/null +++ b/src/common/utils/sycl_utils.hpp @@ -0,0 +1,69 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#ifdef CCL_ENABLE_SYCL + +#include + +#include "common/stream/stream.hpp" +#include "common/global/global.hpp" + +namespace ccl { +namespace utils { + +static inline bool is_sycl_event_completed(sycl::event e) { + return e.template get_info() == + sycl::info::event_command_status::complete; +} + +static inline bool should_use_sycl_output_event(ccl_stream* stream) { + return (stream && stream->is_sycl_device_stream() && stream->is_gpu() && + ccl::global_data::env().enable_kernel_output_event); +} + +static inline std::string usm_type_to_str(sycl::usm::alloc type) { + switch (type) { + case sycl::usm::alloc::host: return "host"; + case sycl::usm::alloc::device: return "device"; + case sycl::usm::alloc::shared: return "shared"; + case sycl::usm::alloc::unknown: return "unknown"; + default: CCL_THROW("unexpected USM type: ", static_cast(type)); + } +} + +static inline std::string sycl_device_to_str(const sycl::device& dev) { + if (dev.is_host()) { + return "host"; + } + else if (dev.is_cpu()) { + return "cpu"; + } + else if (dev.is_gpu()) { + return "gpu"; + } + else if (dev.is_accelerator()) { + return "accel"; + } + else { + CCL_THROW("unexpected device type"); + } +} + +} // namespace utils +} // namespace ccl + +#endif // CCL_ENABLE_SYCL diff --git a/src/common/utils/sync_object.hpp b/src/common/utils/sync_object.hpp index 51b252cd8..f910f0138 100644 --- a/src/common/utils/sync_object.hpp +++ b/src/common/utils/sync_object.hpp @@ -27,7 +27,7 @@ class sync_object { void visit() { auto cnt = sync.fetch_sub(1, std::memory_order_release); - CCL_ASSERT(cnt >= 0 && cnt <= initial_cnt, "invalid count ", cnt); + CCL_ASSERT(cnt <= initial_cnt, "invalid count ", cnt); } void reset() { diff --git a/src/common/utils/utils.hpp b/src/common/utils/utils.hpp index ab3412570..2d37ce136 100644 --- a/src/common/utils/utils.hpp +++ b/src/common/utils/utils.hpp @@ -56,6 +56,9 @@ #define container_of(ptr, type, field) ((type*)((char*)ptr - offsetof(type, field))) #endif +#define CCL_UNDEFINED_CPU_ID (-1) +#define CCL_UNDEFINED_NUMA_NODE (-1) + #define CACHELINE_SIZE 64 #define CCL_REG_MSG_ALIGNMENT (4096) @@ -143,6 +146,10 @@ /* other */ +static inline size_t ccl_get_ptr_diff(const void* ptr1, const void* ptr2) { + return static_cast(ptr2) - static_cast(ptr1); +} + static inline size_t ccl_pof2(size_t number) { size_t last_bit_mask = ((size_t)1 << (8 * sizeof(size_t) - 1)); if (number & last_bit_mask) { @@ -161,6 +168,7 @@ static inline size_t ccl_aligned_sz(size_t size, size_t alignment) { return ((size % alignment) == 0) ? size : ((size / alignment) + 1) * alignment; } +#if 0 static inline timespec ccl_from_time_point( const std::chrono::time_point point) { auto sec = std::chrono::time_point_cast(point); @@ -169,6 +177,7 @@ static inline timespec ccl_from_time_point( return timespec{ .tv_sec = sec.time_since_epoch().count(), .tv_nsec = ns.count() }; } +#endif template container tokenize(const std::string& input, char delimeter) { @@ -202,6 +211,7 @@ void ccl_str_to_array(const char* input, std::set delims, std::vector& } while (can_parse); } +#if 0 //TODO naite implementation, use TBB template 1 && !ccl::global_data::env().enable_comm_kernels) { - throw ccl::unimplemented("API", "create_communicators", "for multiple devices"); - } + // if (table_size > 1 && !ccl::global_data::env().enable_comm_kernels) { + // throw ccl::unimplemented("API", "create_communicators", "for multiple devices"); + // } } template @@ -300,9 +299,6 @@ struct comm_impl_dispatch_selector atl, ccl::group_split_type::single); - //TODO use gpu_comm_attr to automatically visit() - //auto single_dev_comm = std::dynamic_pointer_cast(impl); - //single_dev_comm->set_context(context); ccl::vector_class ret; ret.push_back(ccl::communicator(std::move(impl))); return ret; diff --git a/src/comp/bf16/bf16.cpp b/src/comp/bf16/bf16.cpp index 116e8ec14..b79915e63 100644 --- a/src/comp/bf16/bf16.cpp +++ b/src/comp/bf16/bf16.cpp @@ -113,7 +113,7 @@ void ccl_convert_bf16_to_fp32_arrays(void* bf16_buf, float* fp32_buf, size_t cou } } -#else /* CCL_BF16_COMPILER */ +#else // CCL_BF16_COMPILER void ccl_bf16_reduce(const void* in_buf, size_t in_cnt, @@ -131,4 +131,4 @@ void ccl_convert_bf16_to_fp32_arrays(void* bf16_buf, float* fp32_buf, size_t cou CCL_FATAL("BF16->FP32 conversion was requested but CCL was compiled w/o BF16 support"); } -#endif /* CCL_BF16_COMPILER */ +#endif // CCL_BF16_COMPILER diff --git a/src/comp/bf16/bf16.hpp b/src/comp/bf16/bf16.hpp index 18d6c797b..2c6e73d3b 100644 --- a/src/comp/bf16/bf16.hpp +++ b/src/comp/bf16/bf16.hpp @@ -26,13 +26,13 @@ __attribute__((target("avx512bw,avx512vl"))) void ccl_bf16_reduce(const void* in_buf, size_t in_cnt, void* inout_buf, size_t* out_cnt, ccl::reduction reduction_op); -#else /* CCL_BF16_TARGET_ATTRIBUTES */ +#else // CCL_BF16_TARGET_ATTRIBUTES void ccl_bf16_reduce(const void* in_buf, size_t in_cnt, void* inout_buf, size_t* out_cnt, ccl::reduction reduction_op); -#endif /* CCL_BF16_TARGET_ATTRIBUTES */ +#endif // CCL_BF16_TARGET_ATTRIBUTES void ccl_convert_fp32_to_bf16_arrays(void*, void*, size_t); void ccl_convert_bf16_to_fp32_arrays(void*, float*, size_t); @@ -46,7 +46,7 @@ void ccl_convert_fp32_to_bf16(const void* src, void* dst) #else void ccl_convert_fp32_to_bf16(const void* src, void* dst) __attribute__((target("avx512bw"))); #endif -#endif /* CCL_BF16_TARGET_ATTRIBUTES */ +#endif // CCL_BF16_TARGET_ATTRIBUTES #ifdef CCL_BF16_TARGET_ATTRIBUTES #ifdef CCL_BF16_AVX512BF_COMPILER @@ -55,6 +55,6 @@ void ccl_convert_bf16_to_fp32(const void* src, void* dst) #else void ccl_convert_bf16_to_fp32(const void* src, void* dst) __attribute__((target("avx512bw"))); #endif -#endif /* CCL_BF16_TARGET_ATTRIBUTES */ +#endif // CCL_BF16_TARGET_ATTRIBUTES -#endif /* CCL_BF16_COMPILER */ +#endif // CCL_BF16_COMPILER diff --git a/src/comp/bf16/bf16_intrisics.cpp b/src/comp/bf16/bf16_intrisics.cpp index 1bef254db..3f38ce927 100644 --- a/src/comp/bf16/bf16_intrisics.cpp +++ b/src/comp/bf16/bf16_intrisics.cpp @@ -37,4 +37,4 @@ BF16_TARGET_ATTRIBUTE_BWF __m512 bf16_reduce(__m512 a, __m512 b, ccl_bf16_reduct return (*op)(a, b); } -#endif /* CCL_BF16_COMPILER */ +#endif // CCL_BF16_COMPILER diff --git a/src/comp/bf16/bf16_intrisics.hpp b/src/comp/bf16/bf16_intrisics.hpp index e452aab9c..de8670895 100644 --- a/src/comp/bf16/bf16_intrisics.hpp +++ b/src/comp/bf16/bf16_intrisics.hpp @@ -42,7 +42,7 @@ #define BF16_INLINE_TARGET_ATTRIBUTE_ALL \ __attribute__((__always_inline__, target(BF16_ALL_ATTRS))) inline -#else /* CCL_BF16_TARGET_ATTRIBUTES */ +#else // CCL_BF16_TARGET_ATTRIBUTES #define BF16_TARGET_ATTRIBUTE_BWF #define BF16_TARGET_ATTRIBUTE_ALL @@ -50,7 +50,7 @@ #define BF16_INLINE_TARGET_ATTRIBUTE __attribute__((__always_inline__)) inline #define BF16_INLINE_TARGET_ATTRIBUTE_ALL __attribute__((__always_inline__)) inline -#endif /* CCL_BF16_TARGET_ATTRIBUTES */ +#endif // CCL_BF16_TARGET_ATTRIBUTES typedef __m512 (*ccl_bf16_reduction_func_ptr)(__m512 a, __m512 b); BF16_TARGET_ATTRIBUTE_BWF __m512 bf16_sum_wrap(__m512 a, __m512 b); @@ -141,4 +141,4 @@ BF16_INLINE_TARGET_ATTRIBUTE_ALL void ccl_bf16_reduce_impl(const void* in_buf, #endif } -#endif /* CCL_BF16_COMPILER */ +#endif // CCL_BF16_COMPILER diff --git a/src/comp/bf16/bf16_utils.hpp b/src/comp/bf16/bf16_utils.hpp index 0e0e36db7..212b74491 100644 --- a/src/comp/bf16/bf16_utils.hpp +++ b/src/comp/bf16/bf16_utils.hpp @@ -59,7 +59,7 @@ __attribute__((__always_inline__)) inline std::set ccl_bf16_ : "=a"(reg[0]), "=b"(reg[1]), "=c"(reg[2]), "=d"(reg[3]) : "a"(7), "c"(1)); is_avx512bf_enabled = (reg[0] & (1 << 5)) >> 5; -#endif /* CCL_BF16_AVX512BF_COMPILER */ +#endif // CCL_BF16_AVX512BF_COMPILER if (is_avx512f_enabled) result.insert(ccl_bf16_avx512f); diff --git a/src/comp/comp.cpp b/src/comp/comp.cpp index 8bee9f353..1d9f8645e 100644 --- a/src/comp/comp.cpp +++ b/src/comp/comp.cpp @@ -13,18 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/types.hpp" #include "comp/bf16/bf16.hpp" #include "comp/comp.hpp" #include "comp/fp16/fp16.hpp" #include "common/log/log.hpp" #include "common/global/global.hpp" #include "common/utils/enums.hpp" +#include "common/utils/sycl_utils.hpp" +#include "oneapi/ccl/types.hpp" #include "sched/queue/queue.hpp" #ifdef CCL_ENABLE_SYCL #include -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL #define CCL_REDUCE(type) \ do { \ @@ -111,6 +112,10 @@ ccl::status ccl_comp_reduce(ccl_sched* sched, ccl::reduction reduction, ccl::reduction_fn reduction_fn, const ccl::fn_context* context) { + if (!in_count) { + return ccl::status::success; + } + #ifdef CCL_ENABLE_SYCL ccl_stream* stream = (ccl_stream*)sched->coll_param.stream; @@ -125,9 +130,9 @@ ccl::status ccl_comp_reduce(ccl_sched* sched, auto inout_ptr_type = sycl::get_pointer_type(inout_buf, q->get_context()); LOG_DEBUG("in_ptr_type: ", - native::detail::usm_to_string(in_ptr_type), + ccl::utils::usm_type_to_str(in_ptr_type), ", inout_ptr_type: ", - native::detail::usm_to_string(inout_ptr_type), + ccl::utils::usm_type_to_str(inout_ptr_type), ", native_stream: ", stream->to_string(), ", in_count: ", @@ -143,12 +148,12 @@ ccl::status ccl_comp_reduce(ccl_sched* sched, size_t bytes = in_count * dtype.size(); if (in_ptr_type == sycl::usm::alloc::device) { - host_in_buf = CCL_MALLOC(bytes, "host_in_buf"); + host_in_buf = sched->alloc_buffer_unmanaged(bytes, ccl_sched_buf_runtime); q->memcpy(host_in_buf, in_buf, bytes).wait(); } if (inout_ptr_type == sycl::usm::alloc::device) { - host_inout_buf = CCL_MALLOC(bytes, "host_inout_buf"); + host_inout_buf = sched->alloc_buffer_unmanaged(bytes, ccl_sched_buf_runtime); q->memcpy(host_inout_buf, inout_buf, bytes).wait(); } @@ -156,20 +161,20 @@ ccl::status ccl_comp_reduce(ccl_sched* sched, host_in_buf, in_count, host_inout_buf, out_count, dtype, reduction, reduction_fn, context); if (host_in_buf != in_buf) { - CCL_FREE(host_in_buf); + sched->free_buffer_unmanaged(host_in_buf, bytes, ccl_sched_buf_runtime); } if (host_inout_buf != inout_buf) { q->memcpy(inout_buf, host_inout_buf, bytes).wait(); - CCL_FREE(host_inout_buf); + sched->free_buffer_unmanaged(host_inout_buf, bytes, ccl_sched_buf_runtime); } return ccl::status::success; -#else /* CCL_ENABLE_SYCL */ +#else // CCL_ENABLE_SYCL return ccl_comp_reduce_regular( in_buf, in_count, inout_buf, out_count, dtype, reduction, reduction_fn, context); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL } ccl::status ccl_comp_batch_reduce(const void* in_buf, @@ -223,11 +228,11 @@ ccl::status ccl_comp_batch_reduce(const void* in_buf, const char* ccl_reduction_to_str(ccl::reduction type) { switch (type) { - case ccl::reduction::sum: return "SUM"; - case ccl::reduction::prod: return "PROD"; - case ccl::reduction::min: return "MIN"; - case ccl::reduction::max: return "MAX"; - case ccl::reduction::custom: return "CUSTOM"; - default: return "UNKNOWN"; + case ccl::reduction::sum: return "sum"; + case ccl::reduction::prod: return "prod"; + case ccl::reduction::min: return "min"; + case ccl::reduction::max: return "max"; + case ccl::reduction::custom: return "custom"; + default: return "unknown"; } } diff --git a/src/comp/fp16/fp16.cpp b/src/comp/fp16/fp16.cpp index f9ac7641f..99e33f325 100644 --- a/src/comp/fp16/fp16.cpp +++ b/src/comp/fp16/fp16.cpp @@ -58,7 +58,7 @@ void ccl_convert_fp16_to_fp32(const void* src, void* dst) { _mm256_storeu_si256((__m256i*)dst, (__m256i)_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src))); } -#else /* CCL_FP16_COMPILER */ +#else // CCL_FP16_COMPILER void ccl_fp16_reduce(const void* in_buf, size_t in_cnt, @@ -76,4 +76,4 @@ void ccl_convert_fp16_to_fp32(const void* src, void* dst) { CCL_FATAL("FP16->FP32 conversion was requested but CCL was compiled w/o FP16 support"); } -#endif /* CCL_FP16_COMPILER */ +#endif // CCL_FP16_COMPILER diff --git a/src/comp/fp16/fp16.hpp b/src/comp/fp16/fp16.hpp index 62e66ac3d..7f239b35e 100644 --- a/src/comp/fp16/fp16.hpp +++ b/src/comp/fp16/fp16.hpp @@ -25,7 +25,7 @@ __attribute__((target("avx512bw,avx512vl,f16c"))) void ccl_fp16_reduce(const voi ccl::reduction reduction_op); __attribute__((target("f16c"))) void ccl_convert_fp32_to_fp16(const void* src, void* dst); __attribute__((target("f16c"))) void ccl_convert_fp16_to_fp32(const void* src, void* dst); -#else /* CCL_FP16_TARGET_ATTRIBUTES */ +#else // CCL_FP16_TARGET_ATTRIBUTES void ccl_fp16_reduce(const void* in_buf, size_t in_cnt, void* inout_buf, @@ -33,4 +33,4 @@ void ccl_fp16_reduce(const void* in_buf, ccl::reduction reduction_op); void ccl_convert_fp32_to_fp16(const void* src, void* dst); void ccl_convert_fp16_to_fp32(const void* src, void* dst); -#endif /* CCL_FP16_TARGET_ATTRIBUTES */ +#endif // CCL_FP16_TARGET_ATTRIBUTES diff --git a/src/comp/fp16/fp16_intrisics.cpp b/src/comp/fp16/fp16_intrisics.cpp index 203f2e7f9..6913c1990 100644 --- a/src/comp/fp16/fp16_intrisics.cpp +++ b/src/comp/fp16/fp16_intrisics.cpp @@ -20,4 +20,4 @@ CCL_FP16_DEFINE_ELEM_FUNCS(256); CCL_FP16_DEFINE_ELEM_FUNCS(512); -#endif /* CCL_FP16_COMPILER */ +#endif // CCL_FP16_COMPILER diff --git a/src/comp/fp16/fp16_intrisics.hpp b/src/comp/fp16/fp16_intrisics.hpp index 4b2f4b28e..ca88b0dba 100644 --- a/src/comp/fp16/fp16_intrisics.hpp +++ b/src/comp/fp16/fp16_intrisics.hpp @@ -38,14 +38,14 @@ __attribute__((__always_inline__, target("avx512f"))) inline #define FP16_INLINE_TARGET_ATTRIBUTE_ALL \ __attribute__((__always_inline__, target(FP16_ALL_ATTRS))) inline -#else /* CCL_FP16_TARGET_ATTRIBUTES */ +#else // CCL_FP16_TARGET_ATTRIBUTES #define FP16_TARGET_ATTRIBUTE_F16C #define FP16_TARGET_ATTRIBUTE_AVX512 #define FP16_TARGET_ATTRIBUTE_ALL #define FP16_INLINE_TARGET_ATTRIBUTE_F16C __attribute__((__always_inline__)) inline #define FP16_INLINE_TARGET_ATTRIBUTE_AVX512F __attribute__((__always_inline__)) inline #define FP16_INLINE_TARGET_ATTRIBUTE_ALL __attribute__((__always_inline__)) inline -#endif /* CCL_FP16_TARGET_ATTRIBUTES */ +#endif // CCL_FP16_TARGET_ATTRIBUTES #define FP16_TARGET_ATTRIBUTE_256 FP16_TARGET_ATTRIBUTE_F16C #define FP16_TARGET_ATTRIBUTE_512 FP16_TARGET_ATTRIBUTE_AVX512 @@ -187,4 +187,4 @@ FP16_INLINE_TARGET_ATTRIBUTE_ALL void ccl_fp16_reduce_impl(const void* in_buf, } } -#endif /* CCL_FP16_COMPILER */ +#endif // CCL_FP16_COMPILER diff --git a/src/environment_impl.hpp b/src/environment_impl.hpp index 000178da2..2d117c714 100644 --- a/src/environment_impl.hpp +++ b/src/environment_impl.hpp @@ -76,12 +76,6 @@ stream CCL_API environment::create_stream(native_stream_type& native_stream) { return stream::create_stream(native_stream); } -template -stream CCL_API environment::create_stream(native_stream_type& native_stream, - native_context_type& native_ctx) { - return stream::create_stream(native_stream, native_ctx); -} - /******************** COMMUNICATOR ********************/ template @@ -102,16 +96,6 @@ vector_class CCL_API environment::create_communicators( shared_ptr_class kvs, const comm_attr& attr) const { return communicator::create_communicators(comm_size, local_rank_device_map, context, kvs); - /* - (void)context; - vector_class ret; - ret.push_back(create_single_device_communicator(comm_size, - local_rank_device_map.begin()->first, - local_rank_device_map.begin()->second, - context, - kvs)); - return ret; -*/ } template @@ -122,16 +106,6 @@ environment::create_communicators(const int comm_size, shared_ptr_class kvs, const comm_attr& attr) const { return communicator::create_communicators(comm_size, local_rank_device_map, context, kvs); - /* - (void)context; - vector_class ret; - ret.push_back(create_single_device_communicator(comm_size, - local_rank_device_map.begin()->first, - local_rank_device_map.begin()->second, - context, - kvs)); - return ret; -*/ } } // namespace detail @@ -140,7 +114,7 @@ environment::create_communicators(const int comm_size, /******************** TypeGenerations ********************/ -#define CREATE_DEV_COMM_INSTANTIATION(DeviceType, ContextType) \ +#define CREATE_COMM_INSTANTIATION(DeviceType, ContextType) \ template ccl::vector_class CCL_API \ ccl::detail::environment::create_communicators( \ const int comm_size, \ diff --git a/src/exec/exec.cpp b/src/exec/exec.cpp index 2d86bfe6e..2dc64ed0c 100644 --- a/src/exec/exec.cpp +++ b/src/exec/exec.cpp @@ -50,11 +50,12 @@ atl_attr_t ccl_executor::generate_atl_attr(const ccl::env_data& env) { don't use ring_rma till that */ attr.in.enable_rma = 0; // env.enable_rma; - attr.in.enable_device_buf = env.enable_device_buf; + attr.in.enable_hmem = env.enable_hmem; attr.in.enable_sync_coll = env.enable_sync_coll; attr.in.enable_extra_ep = env.enable_extra_ep; attr.in.ep_count = calculate_atl_ep_count(env.worker_count); attr.in.mnic_type = env.mnic_type; + attr.in.mnic_name = env.mnic_name_raw; attr.in.mnic_count = env.mnic_count; memset(&attr.out, 0, sizeof(attr.out)); @@ -82,10 +83,11 @@ ccl_executor::ccl_executor(const char* main_addr) { atl_wrapper::set_exec(this); } -void ccl_executor::start_workers(size_t proc_idx, size_t proc_count) { +void ccl_executor::start_workers(int proc_idx, int proc_count) { set_local_coord(proc_idx, proc_count); auto& env = ccl::global_data::env(); CCL_THROW_IF_NOT(env.env_2_worker_affinity(get_local_proc_idx(), get_local_proc_count())); + CCL_THROW_IF_NOT(env.env_2_worker_mem_affinity()); start_workers(); } @@ -116,18 +118,23 @@ void ccl_executor::start_workers() { } if (env.worker_offload) { - size_t affinity = env.worker_affinity[get_local_proc_idx() * worker_count + idx]; + size_t cpu_affinity = env.worker_affinity[get_local_proc_idx() * worker_count + idx]; + size_t mem_affinity = + env.worker_mem_affinity[get_local_proc_idx() * worker_count + idx]; - CCL_THROW_IF_NOT(workers.back()->start(affinity) == ccl::status::success, - "failed to start worker # ", - idx); + CCL_THROW_IF_NOT( + workers.back()->start(cpu_affinity, mem_affinity) == ccl::status::success, + "failed to start worker # ", + idx); LOG_DEBUG("started worker: local_proc_idx ", get_local_proc_idx(), ", worker_idx ", idx, - ", affinity ", - affinity); + ", cpu: ", + cpu_affinity, + ", numa: ", + mem_affinity); } } workers_started = true; @@ -238,6 +245,8 @@ void ccl_executor::start(ccl_master_sched* sched) { size_t worker_idx; for (size_t idx = 0; idx < sched->partial_scheds.size(); idx++) { worker_idx = (this->*get_worker_idx_fn)(sched->partial_scheds[idx].get()); + LOG_DEBUG( + "worker idx: ", worker_idx, ", coll: ", ccl_coll_type_to_str(sched->coll_param.ctype)); workers[worker_idx]->add(sched->partial_scheds[idx].get()); } } @@ -277,7 +286,7 @@ void ccl_executor::do_work() { } } -void ccl_executor::set_local_coord(size_t proc_idx, size_t proc_count) { +void ccl_executor::set_local_coord(int proc_idx, int proc_count) { local_proc_idx = proc_idx; local_proc_count = proc_count; auto& env = ccl::global_data::env(); @@ -288,8 +297,8 @@ void ccl_executor::set_local_coord(size_t proc_idx, size_t proc_count) { char* local_idx_env = getenv(local_idx_env_name); char* local_count_env = getenv(local_count_env_name); if (local_idx_env && local_count_env) { - size_t local_idx = std::atoi(local_idx_env); - size_t local_count = std::atoi(local_count_env); + int local_idx = std::atoi(local_idx_env); + int local_count = std::atoi(local_count_env); CCL_THROW_IF_NOT(local_idx == local_proc_idx, "unexpected local_proc_idx ", local_proc_idx, diff --git a/src/exec/exec.hpp b/src/exec/exec.hpp index fb25f11a5..43dbe9c61 100644 --- a/src/exec/exec.hpp +++ b/src/exec/exec.hpp @@ -70,7 +70,7 @@ class alignas(CACHELINE_SIZE) ccl_executor { bool test(const ccl_request* req); void start_workers(); - void start_workers(size_t local_proc_idx, size_t local_proc_count); + void start_workers(int local_proc_idx, int local_proc_count); bool are_workers_started() { return workers_started; }; @@ -86,10 +86,10 @@ class alignas(CACHELINE_SIZE) ccl_executor { void unlock_workers(); bool is_locked = false; - size_t get_local_proc_idx() const { + int get_local_proc_idx() const { return local_proc_idx; } - size_t get_local_proc_count() const { + int get_local_proc_count() const { return local_proc_count; } @@ -102,7 +102,7 @@ class alignas(CACHELINE_SIZE) ccl_executor { std::unique_ptr create_sched_queue(size_t idx, size_t ep_per_worker); void do_work(); - void set_local_coord(size_t proc_idx, size_t proc_count); + void set_local_coord(int proc_idx, int proc_count); std::vector> workers; // TODO: Rework to support listener @@ -111,8 +111,8 @@ class alignas(CACHELINE_SIZE) ccl_executor { typedef size_t (ccl_executor::*get_worker_idx_fn_t)(ccl_sched* sched); get_worker_idx_fn_t get_worker_idx_fn; size_t rr_worker_idx = 0; /* to distribute work in round-robin */ - size_t local_proc_idx; - size_t local_proc_count; + int local_proc_idx; + int local_proc_count; bool workers_started = false; }; diff --git a/src/exec/thread/base_thread.cpp b/src/exec/thread/base_thread.cpp index d3fc960fe..247c1789a 100644 --- a/src/exec/thread/base_thread.cpp +++ b/src/exec/thread/base_thread.cpp @@ -17,17 +17,18 @@ #include "common/utils/yield.hpp" #include "exec/thread/base_thread.hpp" -ccl::status ccl_base_thread::start(int affinity) { +ccl::status ccl_base_thread::start(int cpu_affinity, int mem_affinity) { LOG_DEBUG(name(), " ", idx); - start_affinity = affinity; + start_cpu_affinity = cpu_affinity; + start_mem_affinity = mem_affinity; - /* start thread with initial affinity */ + /* start thread with initial CPU affinity */ pthread_attr_t attr; pthread_attr_init(&attr); cpu_set_t cpuset; __CPU_ZERO_S(sizeof(cpu_set_t), &cpuset); - __CPU_SET_S(affinity, sizeof(cpu_set_t), &cpuset); + __CPU_SET_S(cpu_affinity, sizeof(cpu_set_t), &cpuset); pthread_attr_setaffinity_np(&attr, sizeof(cpu_set_t), &cpuset); int err = pthread_create(&thread, &attr, thread_function, get_this()); @@ -77,29 +78,31 @@ ccl::status ccl_base_thread::stop() { return ccl::status::success; } -ccl::status ccl_base_thread::set_affinity(int affinity) { - LOG_DEBUG(name(), " # ", idx, ", affinity ", affinity); +ccl::status ccl_base_thread::set_cpu_affinity(int cpu_affinity) { + /* unused, cpu affinity is set on thread start */ + + LOG_DEBUG(name(), " # ", idx, ", CPU affinity ", cpu_affinity); int pthread_err; cpu_set_t cpuset; __CPU_ZERO_S(sizeof(cpu_set_t), &cpuset); - __CPU_SET_S(affinity, sizeof(cpu_set_t), &cpuset); + __CPU_SET_S(cpu_affinity, sizeof(cpu_set_t), &cpuset); if ((pthread_err = pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset)) != 0) { LOG_ERROR("pthread_setaffinity_np failed, err ", pthread_err); return ccl::status::runtime_error; } - if (get_affinity() != affinity) { - LOG_ERROR(name(), " ", idx, " is not pinned ", affinity); + if (get_real_cpu_affinity() != cpu_affinity) { + LOG_ERROR(name(), " ", idx, " is not pinned to CPU ", cpu_affinity); return ccl::status::runtime_error; } return ccl::status::success; } -int ccl_base_thread::get_affinity() { +int ccl_base_thread::get_real_cpu_affinity() { int pthread_err; int result = CCL_UNDEFINED_CPU_ID; cpu_set_t cpuset; @@ -121,7 +124,7 @@ int ccl_base_thread::get_affinity() { } } - CCL_THROW_IF_NOT(result != CCL_UNDEFINED_CPU_ID, "can't retrieve affinity"); + CCL_THROW_IF_NOT(result != CCL_UNDEFINED_CPU_ID, "can't retrieve CPU affinity"); return result; } diff --git a/src/exec/thread/base_thread.hpp b/src/exec/thread/base_thread.hpp index 85e4c7f65..4088921c5 100644 --- a/src/exec/thread/base_thread.hpp +++ b/src/exec/thread/base_thread.hpp @@ -23,8 +23,6 @@ #include "common/log/log.hpp" #include "internal_types.hpp" -#define CCL_UNDEFINED_CPU_ID (-1) - class ccl_base_thread { public: ccl_base_thread(size_t idx, void* (*thread_function)(void*)) @@ -32,11 +30,12 @@ class ccl_base_thread { started(false), wait(0), idx(idx), - start_affinity(CCL_UNDEFINED_CPU_ID), + start_cpu_affinity(CCL_UNDEFINED_CPU_ID), + start_mem_affinity(CCL_UNDEFINED_NUMA_NODE), thread_function(thread_function) {} ccl_base_thread() = delete; - ~ccl_base_thread() = default; + virtual ~ccl_base_thread() = default; ccl_base_thread(const ccl_base_thread&) = delete; ccl_base_thread(ccl_base_thread&&) = delete; @@ -44,7 +43,7 @@ class ccl_base_thread { ccl_base_thread& operator=(const ccl_base_thread&) = delete; ccl_base_thread& operator=(ccl_base_thread&&) = delete; - ccl::status start(int affinity); + ccl::status start(int cpu_affinity, int mem_affinity); ccl::status stop(); virtual bool can_reset() { @@ -58,10 +57,15 @@ class ccl_base_thread { return static_cast(this); }; - int get_start_affinity() { - return start_affinity; + int get_start_cpu_affinity() { + return start_cpu_affinity; + } + + int get_real_cpu_affinity(); + + int get_start_mem_affinity() { + return start_mem_affinity; } - int get_affinity(); virtual const std::string& name() const { static const std::string name("base_thread"); @@ -90,11 +94,13 @@ class ccl_base_thread { wait_data wait; private: - ccl::status set_affinity(int affinity); + ccl::status set_cpu_affinity(int cpu_affinity); const size_t idx; - int start_affinity; + int start_cpu_affinity; + int start_mem_affinity; + void* (*thread_function)(void*); pthread_t thread{}; }; diff --git a/src/exec/thread/worker.cpp b/src/exec/thread/worker.cpp index 25eea9ea1..e39bfcee4 100644 --- a/src/exec/thread/worker.cpp +++ b/src/exec/thread/worker.cpp @@ -63,11 +63,10 @@ ccl::status ccl_worker::do_work(size_t& processed_count) { if (ret != ccl::status::success) return ret; -#ifdef ENABLE_DEBUG - if (processed_count == 0 && (do_work_counter % CCL_WORKER_PROCESS_ALL_ITERS * 1024) == 0) { - //sched_queue->dump(std::cout); + if ((do_work_counter % (4 * CCL_WORKER_PROCESS_ALL_ITERS) == 0) && + ccl::global_data::env().queue_dump) { + sched_queue->dump(std::cout); } -#endif return ccl::status::success; } @@ -84,11 +83,12 @@ ccl::status ccl_worker::process_strict_sched_queue() { ccl_sched* sched = *sched_it; if (sched->get_in_bin_status() == ccl_sched_in_bin_erased) { - CCL_ASSERT(!sched->bin); + CCL_THROW_IF_NOT(!sched->bin, "erased sched should be without bin"); erased_scheds++; - /* only single sched in active strict queue can be erased since previous call */ - CCL_ASSERT(erased_scheds == 1); + CCL_THROW_IF_NOT( + erased_scheds == 1, + "only single sched in active strict queue can be erased since previous call"); /* now it is safe to release this sched */ sched->req->complete(); @@ -96,17 +96,17 @@ ccl::status ccl_worker::process_strict_sched_queue() { } if (sched->get_in_bin_status() == ccl_sched_in_bin_none) { - CCL_ASSERT(!sched->bin, "unexpected bin ", sched->bin); + CCL_THROW_IF_NOT(!sched->bin, "unexpected bin ", sched->bin); /* here we add sched from strict_queue to regular queue for real execution */ LOG_DEBUG("add sched ", sched, " from strict_queue to exec_queue, req ", sched->req); sched_queue->add(sched); } - CCL_ASSERT(sched->get_in_bin_status() == ccl_sched_in_bin_added, - "sched ", - sched, - " unexpected in_bin_status ", - sched->get_in_bin_status()); + CCL_THROW_IF_NOT(sched->get_in_bin_status() == ccl_sched_in_bin_added, + "sched ", + sched, + " unexpected in_bin_status ", + sched->get_in_bin_status()); sched->do_progress(); @@ -162,7 +162,9 @@ ccl::status ccl_worker::process_sched_bin(ccl_sched_bin* bin, size_t& completed_ completed_sched_count = 0; size_t bin_size = bin->size(); - CCL_ASSERT(bin_size > 0); + + if (bin_size == 0) + return ccl::status::success; LOG_TRACE("bin ", bin, ", sched_count ", bin_size); @@ -279,15 +281,15 @@ bool ccl_worker::check_stop_condition(size_t iter) { bool ccl_worker::check_affinity_condition(size_t iter) { if ((iter % CCL_WORKER_CHECK_AFFINITY_ITERS) == 0) { - int start_affinity = get_start_affinity(); - int affinity = get_affinity(); - if (start_affinity != affinity) { + int start_cpu_affinity = get_start_cpu_affinity(); + int real_cpu_affinity = get_real_cpu_affinity(); + if (start_cpu_affinity != real_cpu_affinity) { LOG_ERROR("worker ", get_idx(), - " unexpectedly changed affinity from ", - start_affinity, + " unexpectedly changed CPU affinity from ", + start_cpu_affinity, " to ", - affinity); + real_cpu_affinity); } } @@ -299,7 +301,18 @@ static void* ccl_worker_func(void* args) { auto worker_idx = worker->get_idx(); - LOG_DEBUG("worker_idx ", worker_idx); + int cpu_core = worker->get_start_cpu_affinity(); + int numa_node = worker->get_start_mem_affinity(); + + LOG_DEBUG("worker: ", + "idx: ", + worker_idx, + ", cpu: ", + cpu_core, + ", numa: ", + ccl::global_data::get().hwloc_wrapper->get_numa_node(numa_node).to_string()); + + ccl::global_data::get().hwloc_wrapper->membind_thread(numa_node); size_t iter = 0; size_t processed_count = 0; diff --git a/src/exec/thread/worker.hpp b/src/exec/thread/worker.hpp index 168f691ae..59749027f 100644 --- a/src/exec/thread/worker.hpp +++ b/src/exec/thread/worker.hpp @@ -32,7 +32,12 @@ class ccl_worker : public ccl_base_thread { ccl_worker(const ccl_worker& other) = delete; ccl_worker& operator=(const ccl_worker& other) = delete; ccl_worker(size_t idx, std::unique_ptr queue); - virtual ~ccl_worker() = default; + + virtual ~ccl_worker() { + strict_sched_queue.reset(); + sched_queue.reset(); + } + virtual void* get_this() override { return static_cast(this); }; diff --git a/src/fusion/fusion.cpp b/src/fusion/fusion.cpp index 7385f9982..d555b90bf 100644 --- a/src/fusion/fusion.cpp +++ b/src/fusion/fusion.cpp @@ -15,6 +15,7 @@ */ #include "exec/exec.hpp" #include "fusion/fusion.hpp" +#include "sched/buffer_cache.hpp" #include "sched/cache/cache.hpp" #include "sched/entry/factory/entry_factory.hpp" @@ -40,66 +41,13 @@ ccl::status release_fusion_buf_for_cached_sched(ccl_sched* sched, const void* ct return release_fusion_buf(ctx); } -ccl_fusion_buffer_cache::ccl_fusion_buffer_cache(size_t buf_size) : buf_size(buf_size) { - void* buf; - for (size_t idx = 0; idx < CCL_FUSION_BUFFER_CACHE_PREALLOC; idx++) { - buf = CCL_MALLOC(buf_size, "buffer"); - free_buffers.push_back(buf); - all_buffers.push_back(buf); - } - LOG_INFO("created buffer_cache: buf_size ", buf_size); -} - -ccl_fusion_buffer_cache::~ccl_fusion_buffer_cache() { - std::lock_guard lock{ guard }; - - if (all_buffers.size() != free_buffers.size()) { - LOG_INFO("fusion buffers may be still in use" - ", free_buffers: ", - free_buffers.size(), - ", all_buffers: ", - all_buffers.size()); - } - - for (size_t idx = 0; idx < all_buffers.size(); idx++) { - CCL_FREE(all_buffers[idx]); - } - - all_buffers.clear(); - free_buffers.clear(); -} - -void* ccl_fusion_buffer_cache::get() { - std::lock_guard lock{ guard }; - - void* buf; - if (!free_buffers.empty()) { - buf = free_buffers.front(); - free_buffers.pop_front(); - } - else { - buf = CCL_MALLOC(buf_size, "buffer"); - LOG_DEBUG("get buf from extra allocation ", buf); - all_buffers.push_back(buf); - } - CCL_THROW_IF_NOT(buf, "empty buf"); - - return buf; -} - -void ccl_fusion_buffer_cache::release(void* buf) { - std::lock_guard lock{ guard }; - CCL_THROW_IF_NOT(buf, "empty buf"); - free_buffers.push_back(buf); -} - ccl_fusion_manager::ccl_fusion_manager() : bytes_threshold(ccl::global_data::env().fusion_bytes_threshold), count_threshold(ccl::global_data::env().fusion_count_threshold), - buf_cache(ccl::global_data::env().fusion_bytes_threshold * - ccl::global_data::env().fusion_count_threshold) { - CCL_ASSERT(bytes_threshold >= 1, "unexpected fusion_bytes_threshold ", bytes_threshold); - CCL_ASSERT(count_threshold >= 1, "unexpected fusion_count_threshold ", count_threshold); + buffer_size(bytes_threshold * count_threshold) { + CCL_THROW_IF_NOT(bytes_threshold >= 1, "unexpected fusion_bytes_threshold ", bytes_threshold); + CCL_THROW_IF_NOT(count_threshold >= 1, "unexpected fusion_count_threshold ", count_threshold); + CCL_THROW_IF_NOT(buffer_size >= 1, "unexpected fusion_buffer_size ", buffer_size); long cycle_usec = long(ccl::global_data::env().fusion_cycle_ms * 1000.0); cycle = std::chrono::microseconds(cycle_usec); @@ -110,7 +58,9 @@ ccl_fusion_manager::ccl_fusion_manager() ", bytes_threshold ", bytes_threshold, ", count_threshold ", - count_threshold); + count_threshold, + ", buffer_size ", + buffer_size); } ccl_fusion_manager::~ccl_fusion_manager() { @@ -145,20 +95,20 @@ void ccl_fusion_manager::reset() { } bool ccl_fusion_manager::can_fuse(ccl_master_sched* sched) { - if (atl_wrapper::attr.out.enable_device_buf) { + if (atl_wrapper::attr.out.enable_hmem) { /* TODO: implement fusion with D2D copies */ return false; } - size_t bytes = sched->coll_param.count * sched->coll_param.dtype.size(); - - if (bytes >= bytes_threshold) { - LOG_DEBUG("can't fuse due to size ", bytes, ", max ", bytes_threshold); + if (sched->coll_param.ctype != ccl_coll_allreduce) { + LOG_DEBUG("can't fuse due to coll_type ", ccl_coll_type_to_str(sched->coll_param.ctype)); return false; } - if (sched->coll_param.ctype != ccl_coll_allreduce) { - LOG_DEBUG("can't fuse due to coll_type ", ccl_coll_type_to_str(sched->coll_param.ctype)); + size_t bytes = sched->coll_param.get_send_count() * sched->coll_param.dtype.size(); + + if (bytes >= bytes_threshold) { + LOG_DEBUG("can't fuse due to size ", bytes, ", max ", bytes_threshold); return false; } @@ -186,7 +136,7 @@ bool ccl_fusion_manager::add(ccl_master_sched* sched) { sched->set_counter(1); { - std::lock_guard lock{ guard }; + std::lock_guard lock{ guard }; postponed_queue.push_back(sched); } @@ -217,7 +167,7 @@ ccl_master_sched* ccl_fusion_manager::build_sched() { max_priority = first_sched->coll_attr.priority; for (const auto& s : exec_queue) { - sum_count += s->coll_param.count; + sum_count += s->coll_param.get_send_count(); if (!s->coll_attr.to_cache) { use_cache = false; } @@ -233,20 +183,20 @@ ccl_master_sched* ccl_fusion_manager::build_sched() { exec_queue.size()); ccl_master_sched* sched = nullptr; - auto create_fn = [this, ctype, &fusion_buf, sum_count, dtype, reduction, comm]() { + auto create_fn = [this, ctype, &fusion_buf, sum_count, dtype, reduction, comm, stream]() { ccl_master_sched* sched = nullptr; switch (ctype) { case ccl_coll_allreduce: { - ccl_coll_param coll_param{}; - fusion_buf = this->buf_cache.get(); - coll_param.ctype = ctype; - coll_param.send_buf = fusion_buf; - coll_param.recv_buf = fusion_buf; - coll_param.count = sum_count; - coll_param.dtype = dtype; - coll_param.reduction = reduction; - coll_param.comm = comm; - coll_param.stream = nullptr; + ccl::global_data::get().buffer_cache->get(0, buffer_size, &fusion_buf); + ccl_coll_attr coll_attr; + ccl_coll_param coll_param = ccl_coll_param::create_allreduce_param(fusion_buf, + fusion_buf, + sum_count, + dtype.idx(), + reduction, + coll_attr, + comm, + stream); sched = new ccl_master_sched(coll_param); sched->internal_type = ccl_sched_internal_fusion; } break; @@ -294,7 +244,7 @@ ccl_master_sched* ccl_fusion_manager::build_sched() { CCL_THROW_IF_NOT(sched); { - std::lock_guard lock{ guard }; + std::lock_guard lock{ guard }; tracked_scheds.push_back(sched); } @@ -340,28 +290,30 @@ ccl_master_sched* ccl_fusion_manager::build_sched() { size_t global_copy_idx = idx * copies_per_part + copy_idx; #ifdef CCL_ENABLE_SYCL if (stream && stream->is_sycl_device_stream()) - entry_factory::make_entry( + entry_factory::make_entry( part_scheds[idx].get(), - copy_direction::d2h, - ccl_buffer(&(exec_queue[global_copy_idx]->coll_param.device_send_buf), - exec_queue[global_copy_idx]->coll_param.count * dtype_size, - ccl_buffer_type::INDIRECT), - ccl_buffer(fusion_buf, buf_cache.get_buf_size(), offset), - exec_queue[global_copy_idx]->coll_param.count, + ccl_buffer( + exec_queue[global_copy_idx]->coll_param.get_send_buf_ptr( + 0, ccl_coll_param::buf_type::device), + exec_queue[global_copy_idx]->coll_param.get_send_count() * dtype_size, + ccl_buffer_type::INDIRECT), + ccl_buffer(fusion_buf, buffer_size, offset), + exec_queue[global_copy_idx]->coll_param.get_send_count(), dtype, - stream); + copy_attr(copy_direction::d2h)); else -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL entry_factory::make_entry( part_scheds[idx].get(), - ccl_buffer(&(exec_queue[global_copy_idx]->coll_param.send_buf), - exec_queue[global_copy_idx]->coll_param.count * dtype_size, - ccl_buffer_type::INDIRECT), - ccl_buffer(fusion_buf, buf_cache.get_buf_size(), offset), - exec_queue[global_copy_idx]->coll_param.count, + ccl_buffer( + exec_queue[global_copy_idx]->coll_param.get_send_buf_ptr(), + exec_queue[global_copy_idx]->coll_param.get_send_count() * dtype_size, + ccl_buffer_type::INDIRECT), + ccl_buffer(fusion_buf, buffer_size, offset), + exec_queue[global_copy_idx]->coll_param.get_send_count(), dtype); - offset += exec_queue[global_copy_idx]->coll_param.count * dtype_size; + offset += exec_queue[global_copy_idx]->coll_param.get_send_count() * dtype_size; } } @@ -378,30 +330,32 @@ ccl_master_sched* ccl_fusion_manager::build_sched() { size_t global_copy_idx = idx * copies_per_part + copy_idx; #ifdef CCL_ENABLE_SYCL if (stream && stream->is_sycl_device_stream()) - entry_factory::make_entry( + entry_factory::make_entry( part_scheds[idx].get(), - copy_direction::h2d, - ccl_buffer(fusion_buf, buf_cache.get_buf_size(), offset), - ccl_buffer(&(exec_queue[global_copy_idx]->coll_param.device_recv_buf), - exec_queue[global_copy_idx]->coll_param.count * dtype_size, - ccl_buffer_type::INDIRECT), - exec_queue[global_copy_idx]->coll_param.count, + ccl_buffer(fusion_buf, buffer_size, offset), + ccl_buffer( + exec_queue[global_copy_idx]->coll_param.get_recv_buf_ptr( + 0, ccl_coll_param::buf_type::device), + exec_queue[global_copy_idx]->coll_param.get_recv_count() * dtype_size, + ccl_buffer_type::INDIRECT), + exec_queue[global_copy_idx]->coll_param.get_recv_count(), dtype, - stream); + copy_attr(copy_direction::h2d)); else -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL entry_factory::make_entry( part_scheds[idx].get(), - ccl_buffer(fusion_buf, buf_cache.get_buf_size(), offset), - ccl_buffer(&(exec_queue[global_copy_idx]->coll_param.recv_buf), - exec_queue[global_copy_idx]->coll_param.count * dtype_size, - ccl_buffer_type::INDIRECT), - exec_queue[global_copy_idx]->coll_param.count, + ccl_buffer(fusion_buf, buffer_size, offset), + ccl_buffer( + exec_queue[global_copy_idx]->coll_param.get_recv_buf_ptr(), + exec_queue[global_copy_idx]->coll_param.get_recv_count() * dtype_size, + ccl_buffer_type::INDIRECT), + exec_queue[global_copy_idx]->coll_param.get_recv_count(), dtype); part_scheds[idx]->add_barrier(); - offset += exec_queue[global_copy_idx]->coll_param.count * dtype_size; + offset += exec_queue[global_copy_idx]->coll_param.get_recv_count() * dtype_size; entry_factory::make_entry( part_scheds[idx].get(), complete_user_request, exec_queue[global_copy_idx]); CCL_THROW_IF_NOT(!exec_queue[global_copy_idx]->is_completed(), @@ -448,7 +402,7 @@ void ccl_fusion_manager::execute() { /* separate block to reduce lock scope */ { - std::lock_guard lock{ guard }; + std::lock_guard lock{ guard }; if (!postponed_queue.empty()) { LOG_DEBUG("postponed_queue size ", postponed_queue.size()); @@ -461,7 +415,7 @@ void ccl_fusion_manager::execute() { exec_queue.push_back(first_sched); postponed_queue.pop_front(); exec_queue_sum_bytes = - first_sched->coll_param.count * first_sched->coll_param.dtype.size(); + first_sched->coll_param.get_send_count() * first_sched->coll_param.dtype.size(); } for (auto it = postponed_queue.begin(); it != postponed_queue.end();) { @@ -471,8 +425,8 @@ void ccl_fusion_manager::execute() { s->coll_param.ctype == first_sched->coll_param.ctype && s->coll_param.reduction == first_sched->coll_param.reduction && s->coll_param.stream == first_sched->coll_param.stream) { - size_t size = s->coll_param.count * s->coll_param.dtype.size(); - if (exec_queue_sum_bytes + size > CCL_FUSION_BUFFER_SIZE) { + size_t size = s->coll_param.get_send_count() * s->coll_param.dtype.size(); + if (exec_queue_sum_bytes + size > buffer_size) { LOG_DEBUG("too much bytes in buffer, flush exec_queue"); flush_exec_queue = true; break; @@ -515,7 +469,7 @@ void ccl_fusion_manager::execute() { } void ccl_fusion_manager::release_buffer(void* buf) { - buf_cache.release(buf); + ccl::global_data::get().buffer_cache->push(0, buffer_size, buf); } void ccl_fusion_manager::clear_exec_queue() { @@ -524,7 +478,7 @@ void ccl_fusion_manager::clear_exec_queue() { } void ccl_fusion_manager::check_tracked_scheds(bool force_release) { - std::lock_guard lock{ guard }; + std::lock_guard lock{ guard }; for (auto it = tracked_scheds.begin(); it != tracked_scheds.end();) { ccl_master_sched* sched = *it; if (sched->is_completed() && (!sched->coll_attr.to_cache || force_release)) { diff --git a/src/fusion/fusion.hpp b/src/fusion/fusion.hpp index ff102ee5b..c9e7a5ab9 100644 --- a/src/fusion/fusion.hpp +++ b/src/fusion/fusion.hpp @@ -22,36 +22,6 @@ #include #include -#define CCL_FUSION_BYTES_THRESHOLD (8 * 8192) -#define CCL_FUSION_COUNT_THRESHOLD (256) -#define CCL_FUSION_BUFFER_SIZE (CCL_FUSION_BYTES_THRESHOLD * CCL_FUSION_COUNT_THRESHOLD) -#define CCL_FUSION_BUFFER_CACHE_PREALLOC (4) - -using ccl_fusion_lock_t = ccl_spinlock; - -class ccl_fusion_buffer_cache { -public: - ccl_fusion_buffer_cache(size_t buf_size); - ~ccl_fusion_buffer_cache(); - - ccl_fusion_buffer_cache(const ccl_fusion_buffer_cache& other) = delete; - ccl_fusion_buffer_cache& operator=(const ccl_fusion_buffer_cache& other) = delete; - - void* get(); - void release(void* buf); - void clear(); - - size_t get_buf_size() { - return buf_size; - } - -private: - size_t buf_size; - ccl_fusion_lock_t guard{}; - std::deque free_buffers; - std::deque all_buffers; -}; - class ccl_fusion_manager { public: ccl_fusion_manager(); @@ -74,13 +44,17 @@ class ccl_fusion_manager { const size_t bytes_threshold; const size_t count_threshold; + const size_t buffer_size; + + using lock_t = ccl_spinlock; + lock_t guard{}; - ccl_fusion_lock_t guard{}; using sched_queue_t = std::deque; sched_queue_t postponed_queue{}; sched_queue_t exec_queue{}; + size_t exec_queue_sum_bytes = 0; - ccl_fusion_buffer_cache buf_cache; + std::list tracked_scheds{}; std::chrono::steady_clock::duration cycle; diff --git a/src/hwloc/hwloc_wrapper.c b/src/hwloc/hwloc_wrapper.c deleted file mode 100644 index 5b0601743..000000000 --- a/src/hwloc/hwloc_wrapper.c +++ /dev/null @@ -1,93 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "hwloc_wrapper.h" - -static hwloc_info_t hwloc_info = { .initialized = 0 }; - -hwloc_status_t hwloc_init() { - hwloc_status_t ret = HWLOC_SUCCESS; - - hwloc_info.initialized = 0; - hwloc_info.bindset = hwloc_bitmap_alloc(); - - if (hwloc_topology_init(&hwloc_info.topology) < 0) { - printf("hwloc_topology_init failed (%s)\n", strerror(errno)); - goto err; - } - - hwloc_topology_set_io_types_filter(hwloc_info.topology, HWLOC_TYPE_FILTER_KEEP_ALL); - - if (hwloc_topology_load(hwloc_info.topology) < 0) { - printf("hwloc_topology_load failed (%s)\n", strerror(errno)); - goto err; - } - - if (hwloc_get_proc_cpubind( - hwloc_info.topology, getpid(), hwloc_info.bindset, HWLOC_CPUBIND_PROCESS) < 0) { - printf("hwloc_get_proc_cpubind failed (%s)\n", strerror(errno)); - goto err; - } - - hwloc_info.initialized = 1; - - return ret; - -err: - return HWLOC_FAILURE; -} - -hwloc_status_t hwloc_finalize() { - hwloc_status_t ret = HWLOC_SUCCESS; - - hwloc_topology_destroy(hwloc_info.topology); - hwloc_bitmap_free(hwloc_info.bindset); - hwloc_info.initialized = 0; - - return ret; -} - -int hwloc_is_initialized() { - return hwloc_info.initialized; -} - -static hwloc_obj_t hwloc_get_first_non_io_obj_by_pci(int domain, int bus, int dev, int func) { - hwloc_obj_t io_device = hwloc_get_pcidev_by_busid(hwloc_info.topology, domain, bus, dev, func); - HWLOC_ASSERT(io_device, - "failed to get PCI device with domain %d, bus %d, dev %d, func %d", - domain, - bus, - dev, - func); - hwloc_obj_t first_non_io = hwloc_get_non_io_ancestor_obj(hwloc_info.topology, io_device); - HWLOC_ASSERT(first_non_io, "failed to get ancestor of PCI device"); - return first_non_io; -} - -int hwloc_is_dev_close_by_pci(int domain, int bus, int dev, int func) { - int is_close = 0; - - if (!hwloc_is_initialized()) - return is_close; - - hwloc_obj_t first_non_io = hwloc_get_first_non_io_obj_by_pci(domain, bus, dev, func); - - /* determine if PCI device is "close" to process by checking if process's affinity is included - * in PCI device's affinity or if PCI device's affinity is included in process's affinity */ - is_close = (hwloc_bitmap_isincluded(hwloc_info.bindset, first_non_io->cpuset) || - hwloc_bitmap_isincluded(first_non_io->cpuset, hwloc_info.bindset)); - - return is_close; -} diff --git a/src/hwloc/hwloc_wrapper.cpp b/src/hwloc/hwloc_wrapper.cpp new file mode 100644 index 000000000..514d5bc94 --- /dev/null +++ b/src/hwloc/hwloc_wrapper.cpp @@ -0,0 +1,360 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "common/global/global.hpp" +#include "hwloc/hwloc_wrapper.hpp" + +ccl_numa_node::ccl_numa_node() + : idx(CCL_UNDEFINED_NUMA_NODE), + os_idx(CCL_UNDEFINED_NUMA_NODE), + mem_in_mb(0), + core_count(0), + membind_support(0) {} + +ccl_numa_node::ccl_numa_node(int idx, + int os_idx, + size_t mem_in_mb, + int core_count, + const std::vector& cpus, + int membind_support) + : idx(idx), + os_idx(os_idx), + mem_in_mb(mem_in_mb), + core_count(core_count), + cpus(cpus), + membind_support(membind_support) {} + +std::string ccl_numa_node::to_string() { + std::stringstream ss; + + ss << "{" + << "idx: " << idx << ", memory: " << mem_in_mb << " MB" + << ", cores: " << core_count << ", cpus: " << cpus.size() << ", membind: " << membind_support + << "}"; + + return ss.str(); +} + +ccl_hwloc_wrapper::ccl_hwloc_wrapper() + : membind_thread_supported(false), + bindset(nullptr), + topology(nullptr) { + /* mandatory checks */ + + if (hwloc_topology_init(&topology) < 0) { + LOG_WARN("hwloc_topology_init failed (", strerror(errno), ")"); + return; + } + + hwloc_topology_set_io_types_filter(topology, HWLOC_TYPE_FILTER_KEEP_ALL); + + if (hwloc_topology_load(topology) < 0) { + LOG_WARN("hwloc_topology_load failed (", strerror(errno), ")"); + return; + } + + hwloc_obj_t root_obj = hwloc_get_root_obj(topology); + LOG_DEBUG("hwloc root object: ", obj_to_string(root_obj)); + + bindset = hwloc_bitmap_alloc(); + if (hwloc_get_proc_cpubind(topology, getpid(), bindset, HWLOC_CPUBIND_PROCESS) < 0) { + LOG_WARN("hwloc_get_proc_cpubind failed (", strerror(errno), ")"); + return; + } + + CCL_THROW_IF_NOT(topology && bindset); + + /* optional checks */ + + const struct hwloc_topology_support* topo_support = hwloc_topology_get_support(topology); + membind_thread_supported = topo_support->membind->set_thisthread_membind; + if (!membind_thread_supported) { + LOG_WARN("no support for memory binding of current thread"); + } + + hwloc_const_bitmap_t nodeset = hwloc_topology_get_topology_nodeset(topology); + int numa_node_count = hwloc_bitmap_weight(nodeset); + + for (int idx = 0; idx < numa_node_count; idx++) { + hwloc_obj_t numa_node = hwloc_get_numanode_obj_by_os_index(topology, idx); + int os_idx = numa_node->logical_index; + int mem_in_mb = + (numa_node->attr) ? numa_node->attr->numanode.local_memory / (1024 * 1024) : 0; + int core_count = + hwloc_get_nbobjs_inside_cpuset_by_type(topology, numa_node->cpuset, HWLOC_OBJ_CORE); + std::vector cpus; + for (int core_idx = 0; core_idx < core_count; core_idx++) { + hwloc_obj_t core_obj = hwloc_get_obj_inside_cpuset_by_type( + topology, numa_node->cpuset, HWLOC_OBJ_CORE, core_idx); + int cpus_per_core = + hwloc_get_nbobjs_inside_cpuset_by_type(topology, core_obj->cpuset, HWLOC_OBJ_PU); + for (int cpu_idx = 0; cpu_idx < cpus_per_core; cpu_idx++) { + hwloc_obj_t cpu_obj = hwloc_get_obj_inside_cpuset_by_type( + topology, core_obj->cpuset, HWLOC_OBJ_PU, cpu_idx); + cpus.push_back(cpu_obj->os_index); + } + } + numa_nodes.push_back( + ccl_numa_node(idx, os_idx, mem_in_mb, core_count, cpus, check_membind(idx))); + } +} + +ccl_hwloc_wrapper::~ccl_hwloc_wrapper() { + hwloc_bitmap_free(bindset); + hwloc_topology_destroy(topology); +} + +bool ccl_hwloc_wrapper::is_initialized() { + return (topology && bindset) ? true : false; +} + +std::string ccl_hwloc_wrapper::to_string() { + std::stringstream ss; + bool initialized = is_initialized(); + ss << "hwloc initialized: " << initialized << "\n"; + if (initialized) { + ss << "{\n"; + ss << " membind_thread_supported: " << membind_thread_supported << "\n"; + for (auto& node : numa_nodes) { + ss << " numa: {" + << "idx: " << node.idx << ", os idx: " << node.os_idx + << ", memory: " << node.mem_in_mb << " MB" + << ", cores: " << node.core_count << ", cpus: " << node.cpus.size() + << ", membind: " << node.membind_support << "}\n"; + } + ss << "}"; + } + return ss.str(); +} + +bool ccl_hwloc_wrapper::is_dev_close_by_pci(int domain, int bus, int dev, int func) { + bool is_close = false; + + if (!is_initialized()) { + LOG_WARN("hwloc is not initialized, skip checking of locality for device: [", + domain, + ":", + bus, + ":", + dev, + ":", + func, + "]"); + return is_close; + } + + hwloc_obj_t first_non_io = get_first_non_io_obj_by_pci(domain, bus, dev, func); + CCL_THROW_IF_NOT(first_non_io); + + LOG_DEBUG("first_non_io object: ", obj_to_string(first_non_io)); + + /* determine if PCI device is "close" to process by checking if process's affinity is included + * in PCI device's affinity or if PCI device's affinity is included in process's affinity */ + is_close = (hwloc_bitmap_isincluded(bindset, first_non_io->cpuset) || + hwloc_bitmap_isincluded(first_non_io->cpuset, bindset)); + + return is_close; +} + +void ccl_hwloc_wrapper::membind_thread(int numa_node) { + if (!is_initialized()) { + LOG_WARN("hwloc is not initialized, skip thread membind for NUMA node ", numa_node); + return; + } + + if (!membind_thread_supported) { + LOG_WARN( + "no support for memory binding of current thread, skip thread membind for NUMA node ", + numa_node); + return; + } + + if (!is_valid_numa_node(numa_node)) { + LOG_WARN("invalid NUMA node ", + numa_node, + ", NUMA node count ", + get_numa_node_count(), + ", skip thread membind"); + return; + } + + if (!get_numa_node(numa_node).membind_support) { + LOG_WARN("no membind support for NUMA node ", numa_node, ", skip thread membind"); + return; + } + + hwloc_nodeset_t nodeset = hwloc_bitmap_alloc(); + hwloc_bitmap_only(nodeset, unsigned(numa_node)); + CCL_THROW_IF_NOT(hwloc_bitmap_isset(nodeset, numa_node) == 1, "hwloc_bitmap_isset failed"); + + if (hwloc_set_membind(topology, + nodeset, + HWLOC_MEMBIND_BIND, + HWLOC_MEMBIND_THREAD | HWLOC_MEMBIND_STRICT | HWLOC_MEMBIND_BYNODESET) < + 0) { + LOG_WARN("failed to bind thread to NUMA node ", numa_node, " (", strerror(errno), ")"); + } + else { + LOG_DEBUG("bound thread to NUMA node ", numa_node); + } + + hwloc_bitmap_free(nodeset); +} + +int ccl_hwloc_wrapper::get_numa_node_by_cpu(int cpu) { + if (!is_initialized()) { + LOG_WARN("hwloc is not initialized, can't get numa NUMA for CPU ", cpu); + return CCL_UNDEFINED_NUMA_NODE; + } + + if (cpu == CCL_UNDEFINED_CPU_ID) { + return CCL_UNDEFINED_NUMA_NODE; + } + + for (auto& node : numa_nodes) { + for (auto cpu_idx : node.cpus) { + if (cpu_idx == cpu) { + return node.idx; + } + } + } + + return CCL_UNDEFINED_NUMA_NODE; +} + +ccl_numa_node ccl_hwloc_wrapper::get_numa_node(int numa_node) { + if (!is_initialized()) { + LOG_WARN("hwloc is not initialized, can't get info for NUMA node ", numa_node); + return {}; + } + + if (!is_valid_numa_node(numa_node)) { + LOG_WARN("invalid NUMA node ", numa_node, ", NUMA node count ", get_numa_node_count()); + return {}; + } + + return numa_nodes[numa_node]; +} + +bool ccl_hwloc_wrapper::is_valid_numa_node(int numa_node) { + if ((numa_node == CCL_UNDEFINED_NUMA_NODE) || (numa_node < 0) || + (numa_node >= static_cast(get_numa_node_count()))) { + return false; + } + return true; +} + +bool ccl_hwloc_wrapper::check_membind(int numa_node) { + hwloc_obj_t numa_node_obj = hwloc_get_numanode_obj_by_os_index(topology, numa_node); + size_t check_buf_len = 8192; + void* buffer = hwloc_alloc_membind(topology, + check_buf_len, + numa_node_obj->nodeset, + HWLOC_MEMBIND_BIND, + HWLOC_MEMBIND_STRICT | HWLOC_MEMBIND_BYNODESET); + + if (!buffer) { + return false; + } + + bool membind_ok = true; + + hwloc_bitmap_t nodeset = hwloc_bitmap_alloc(); + hwloc_bitmap_zero(nodeset); + hwloc_membind_policy_t policy = HWLOC_MEMBIND_DEFAULT; + + if (hwloc_get_area_membind( + topology, buffer, check_buf_len, nodeset, &policy, HWLOC_MEMBIND_BYNODESET) < 0) { + LOG_WARN("NUMA node ", numa_node, ", failed to get nodeset and policy for buffer ", buffer); + membind_ok = false; + } + + if (policy != HWLOC_MEMBIND_BIND) { + LOG_WARN("NUMA node ", + numa_node, + ", unxpected membind policy ", + policy, + ", expected ", + HWLOC_MEMBIND_BIND); + membind_ok = false; + } + + int i = 0, bind_count = 0; + hwloc_bitmap_foreach_begin(i, nodeset) { + hwloc_obj_t obj = hwloc_get_numanode_obj_by_os_index(topology, i); + if (obj) { + bind_count++; + } + } + hwloc_bitmap_foreach_end(); + + if (bind_count != 1) { + LOG_WARN("buffer should be bound to single NUMA node but actual bind_count", bind_count); + membind_ok = false; + } + + if (!hwloc_bitmap_isset(nodeset, numa_node)) { + LOG_WARN("nodeset doesn't have expected index ", numa_node); + membind_ok = false; + } + + if (hwloc_bitmap_first(nodeset) != numa_node) { + LOG_WARN("nodeset has unexpected first index ", + hwloc_bitmap_first(nodeset), + ", expected ", + numa_node); + membind_ok = false; + } + + hwloc_bitmap_free(nodeset); + hwloc_free(topology, buffer, check_buf_len); + + return membind_ok; +} + +size_t ccl_hwloc_wrapper::get_numa_node_count() { + return numa_nodes.size(); +} + +hwloc_obj_t ccl_hwloc_wrapper::get_first_non_io_obj_by_pci(int domain, int bus, int dev, int func) { + hwloc_obj_t io_device = hwloc_get_pcidev_by_busid(topology, domain, bus, dev, func); + CCL_THROW_IF_NOT(io_device, + "failed to get PCI device with domain %d, bus %d, dev %d, func %d", + domain, + bus, + dev, + func); + + hwloc_obj_t first_non_io = hwloc_get_non_io_ancestor_obj(topology, io_device); + CCL_THROW_IF_NOT(first_non_io, "failed to get ancestor of PCI device"); + return first_non_io; +} + +std::string ccl_hwloc_wrapper::obj_to_string(hwloc_obj_t obj) { + std::stringstream ss; + const size_t obj_str_len = 4096; + char str[obj_str_len]; + + hwloc_obj_type_snprintf(str, obj_str_len, obj, 1); + ss << "type: " << str << "\n"; + hwloc_obj_attr_snprintf(str, obj_str_len, obj, " :: ", 1); + ss << "attr: " << str << "\n"; + hwloc_bitmap_taskset_snprintf(str, obj_str_len, obj->cpuset); + ss << "cpuset: " << str << "\n"; + hwloc_bitmap_taskset_snprintf(str, obj_str_len, obj->nodeset); + ss << "nodeset: " << str << "\n"; + + return ss.str(); +} diff --git a/src/hwloc/hwloc_wrapper.h b/src/hwloc/hwloc_wrapper.h deleted file mode 100644 index 7b7ff7b9d..000000000 --- a/src/hwloc/hwloc_wrapper.h +++ /dev/null @@ -1,73 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#ifndef HWLOC_WRAPPER_H -#define HWLOC_WRAPPER_H - -#ifdef __cplusplus -extern "C" { -#endif - -#include "hwloc.h" -#include - -#define GETTID() syscall(SYS_gettid) - -#define HWLOC_ASSERT(cond, fmt, ...) \ - do { \ - if (!(cond)) { \ - fprintf(stderr, \ - "(%ld): %s:%s:%d: ASSERT '%s' FAILED: " fmt "\n", \ - GETTID(), \ - __FILE__, \ - __FUNCTION__, \ - __LINE__, \ - #cond, \ - ##__VA_ARGS__); \ - fflush(stderr); \ - } \ - } while (0) - -typedef enum { HWLOC_SUCCESS, HWLOC_FAILURE, HWLOC_UNSUPPORTED } hwloc_status_t; - -inline const char* hwloc_status_to_str(hwloc_status_t status) { - switch (status) { - case HWLOC_SUCCESS: return "SUCCESS"; - case HWLOC_FAILURE: return "FAILURE"; - case HWLOC_UNSUPPORTED: return "UNSUPPORTED"; - default: return "UNKNOWN"; - } -} - -typedef struct { - hwloc_topology_t topology; - hwloc_cpuset_t bindset; - int initialized; -} hwloc_info_t; - -hwloc_status_t hwloc_init(); -hwloc_status_t hwloc_finalize(); -int hwloc_is_initialized(); - -/* - * return true if pci device is close to this process - */ -int hwloc_is_dev_close_by_pci(int domain, int bus, int dev, int func); - -#ifdef __cplusplus -} -#endif - -#endif /* HWLOC_WRAPPER_H */ diff --git a/src/hwloc/hwloc_wrapper.hpp b/src/hwloc/hwloc_wrapper.hpp new file mode 100644 index 000000000..2919c2857 --- /dev/null +++ b/src/hwloc/hwloc_wrapper.hpp @@ -0,0 +1,68 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "hwloc.h" + +#define CCL_HWLOC_INVALID_NUMA_NODE (-1) + +struct ccl_numa_node { + int idx; + int os_idx; + size_t mem_in_mb; + int core_count; + std::vector cpus; + int membind_support; + + ccl_numa_node(); + ccl_numa_node(int idx, + int os_idx, + size_t mem_in_mb, + int core_count, + const std::vector& cpus, + int membind_support); + + std::string to_string(); +}; + +class ccl_hwloc_wrapper { +public: + ccl_hwloc_wrapper(); + ~ccl_hwloc_wrapper(); + + bool is_initialized(); + + std::string to_string(); + + bool is_dev_close_by_pci(int domain, int bus, int dev, int func); + + void membind_thread(int numa_node); + int get_numa_node_by_cpu(int cpu); + ccl_numa_node get_numa_node(int numa_node); + +private: + bool is_valid_numa_node(int numa_node); + bool check_membind(int numa_node); + size_t get_numa_node_count(); + hwloc_obj_t get_first_non_io_obj_by_pci(int domain, int bus, int dev, int func); + std::string obj_to_string(hwloc_obj_t obj); + + std::vector numa_nodes; + + bool membind_thread_supported; + hwloc_cpuset_t bindset; + hwloc_topology_t topology; +}; diff --git a/src/kernels/a2a_helpers.h b/src/kernels/a2a_helpers.h deleted file mode 100644 index 10c44e398..000000000 --- a/src/kernels/a2a_helpers.h +++ /dev/null @@ -1,38 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "common.h" - -#define DEFINE_A2A_COMM_DATA(NAME, T) \ - typedef struct __attribute__((packed)) a2a_gpu_comm_data_##NAME { \ - __global T* recv_buf; \ - __global sync_flag_type* ready_to_receive_flag; \ - __global sync_flag_type* data_sent_flag; \ - } a2a_gpu_comm_data_##NAME; - -DEFINE_A2A_COMM_DATA(int8, int8_t) -DEFINE_A2A_COMM_DATA(uint8, uint8_t) -DEFINE_A2A_COMM_DATA(int16, int16_t) -DEFINE_A2A_COMM_DATA(uint16, uint16_t) -DEFINE_A2A_COMM_DATA(int32, int32_t) -DEFINE_A2A_COMM_DATA(uint32, uint32_t) -DEFINE_A2A_COMM_DATA(int64, int64_t) -DEFINE_A2A_COMM_DATA(uint64, uint64_t) -//DEFINE_A2A_COMM_DATA(float16, half) -DEFINE_A2A_COMM_DATA(float32, float) -DEFINE_A2A_COMM_DATA(float64, double) -DEFINE_A2A_COMM_DATA(bfloat16, uint16_t) diff --git a/src/kernels/common.h b/src/kernels/common.h index 493e2c5c1..132c6f6a3 100644 --- a/src/kernels/common.h +++ b/src/kernels/common.h @@ -1,12 +1,12 @@ /* Copyright 2016-2020 Intel Corporation - + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -27,9 +27,9 @@ typedef atomic_int sync_flag_type; #else // default type for sync flags typedef volatile int sync_flag_type; -#endif /* ENABLE_KERNEL_ATOMICS */ +#endif // ENABLE_KERNEL_ATOMICS -#else /* HOST_CTX */ +#else // HOST_CTX #pragma OPENCL EXTENSION cl_intel_subgroups : enable #pragma OPENCL EXTENSION cl_khr_subgroups : enable @@ -118,14 +118,14 @@ typedef ushort bfloat16; printf("kernel %d.%d barrier passed\n", rank, thread_id); #define LOG_IN_BARRIER(rank, thread_id, flag, desired) \ printf("kernel %d.%d barrier %d/%d\n", rank, thread_id, flag, desired); -#else /* ENABLE_KERNEL_DEBUG */ +#else // ENABLE_KERNEL_DEBUG #define LOG_INPUT_DATA_START(rank) #define LOG_INPUT_DATA_END(rank) #define LOG_OUTGOING_DATA_START(rank) #define LOG_OUTGOING_DATA_END(rank) #define LOG_BARRIER_PASSED(rank, thread_id) #define LOG_IN_BARRIER(rank, thread_id, flag, desired) -#endif /* ENABLE_KERNEL_DEBUG */ +#endif // ENABLE_KERNEL_DEBUG #define SWAP_VARIABLES(var1, var2, type) \ do { \ @@ -193,7 +193,7 @@ typedef atomic_int sync_flag_type; #define GET_PROXY_SIZE(_sync_flag, size) \ size = atomic_load_explicit(_sync_flag, memory_order_seq_cst, memory_scope_all_svm_devices); -#else /* ENABLE_KERNEL_ATOMICS */ +#else // ENABLE_KERNEL_ATOMICS // default type for sync flags typedef volatile int sync_flag_type; @@ -236,7 +236,7 @@ typedef volatile int sync_flag_type; #define GET_PROXY_SIZE(_sync_flag, size) size = *_sync_flag; -#endif /* ENABLE_KERNEL_ATOMICS */ +#endif // ENABLE_KERNEL_ATOMICS /* #define KERNEL_BARRIER(_barrier_flag, _desired, _increment) \ @@ -284,4 +284,4 @@ typedef volatile int sync_flag_type; _desired += comm_size; \ }*/ -#endif /* HOST_CTX */ +#endif // HOST_CTX diff --git a/src/kernels/kernels.cl b/src/kernels/kernels.cl new file mode 100644 index 000000000..ad96eba79 --- /dev/null +++ b/src/kernels/kernels.cl @@ -0,0 +1,157 @@ +#include "common.h" +#include "shared.h" + +__kernel void empty_kernel(int my_rank, + int comm_size, + ulong count, + const __global void* input_buffer, + __global void* output_buffer, + const __global void* right_input_buffer, + __global void* right_output_buffer) { + return; +} + +// Name - unique name suffix for the kernel +// T - type parameter(e.g. float, int4, etc) +// VecSize - vector size of the type. E.g. if float4 is used, VecSize is 4. Note: if just float is used, +// the value must be one as it's used for division inside the kernel. +// Op - A operation parameter(e.g. add(x, y)) +// OpName - Operator name which goes to the kernel name, e.g. OpName = add, Op = __add_int(actual function) +#define DEFINE_ALLREDUCE_KERNEL(Name, T, VecSize, Op, OpName) \ + __kernel void allreduce_kernel_##Name##_##OpName(int my_rank, \ + int comm_size, \ + ulong count, \ + const __global T* input_buffer, \ + __global T* output_buffer, \ + const __global T* right_input_buffer, \ + __global T* right_output_buffer) { \ + DEBUG_BLOCK(printf("rank: %d, comm size: %d, count: %zu\n", my_rank, comm_size, count)); \ + size_t work_group_size = get_global_size(0); \ + size_t thread_id = get_global_id(0); \ +\ + for (size_t i = 0; thread_id + i < count; i += work_group_size) { \ + const size_t idx = thread_id + i; \ + output_buffer[idx] = Op(input_buffer[idx], right_input_buffer[idx]); \ + right_output_buffer[idx] = output_buffer[idx]; \ + } \ + } + +#define DEFINE_REDUCE_LOCAL_OUTOFPLACE_KERNEL(Name, T, VecSize, Op, OpName) \ + __kernel void reduce_local_outofplace_kernel_##Name##_##OpName( \ + int my_rank, \ + int comm_size, \ + ulong count, \ + const __global T* input_buffer_1, \ + const __global T* input_buffer_2, \ + __global T* output_buffer) { \ + DEBUG_BLOCK(printf("rank: %d, comm size: %d, count: %zu\n", my_rank, comm_size, count)); \ + size_t work_group_size = get_global_size(0); \ + size_t thread_id = get_global_id(0); \ +\ + for (size_t i = 0; thread_id + i < count; i += work_group_size) { \ + const size_t idx = thread_id + i; \ + output_buffer[idx] = Op(input_buffer_1[idx], input_buffer_2[idx]); \ + } \ + } + +#define DEFINE_REDUCE_LOCAL_INPLACE_KERNEL(Name, T, VecSize, Op, OpName) \ + __kernel void reduce_local_inplace_kernel_##Name##_##OpName( \ + ulong count, const __global T* input_buffer, __global T* inoutput_buffer) { \ + DEBUG_BLOCK(/* int sg_id = get_sub_group_id(); */ \ + printf("in reduce_local_inplace_kernel_\n")); \ + size_t work_group_size = get_global_size(0); \ + size_t thread_id = get_global_id(0); \ +\ + for (size_t i = 0; thread_id + i < count; i += work_group_size) { \ + const size_t idx = thread_id + i; \ + inoutput_buffer[idx] = Op(input_buffer[idx], inoutput_buffer[idx]); \ + } \ + } + +// Define kernels for a specific operation for all the supported types. +// Note: for op function we use convention ___, where type is the actual type(e.g. int4, float) +// FIXME: Temporary use scalar types instead of vector ones. This is a workaround for issues in case when +// elems_count % VecSize != 0. Need to find a proper fix with a good performance. +#define VEC_SIZE RING_ALLREDUCE_VEC_SIZE + +#define DEFINE_KERNELS_WITH_OP(KernelName, OpName) \ + DEFINE_##KernelName##_KERNEL(int8, char, VEC_SIZE, __##OpName##_##char, OpName) \ + DEFINE_##KernelName##_KERNEL(uint8, uchar, VEC_SIZE, __##OpName##_##uchar, OpName) \ +\ + DEFINE_##KernelName##_KERNEL(int16, short, VEC_SIZE, __##OpName##_##short, OpName) \ + DEFINE_##KernelName##_KERNEL( \ + uint16, ushort, VEC_SIZE, __##OpName##_##ushort, OpName) \ +\ + DEFINE_##KernelName##_KERNEL(int32, int, VEC_SIZE, __##OpName##_##int, OpName) \ + DEFINE_##KernelName##_KERNEL( \ + uint32, uint, VEC_SIZE, __##OpName##_##uint, OpName) \ +\ + DEFINE_##KernelName##_KERNEL( \ + int64, long, VEC_SIZE, __##OpName##_##long, OpName) \ + DEFINE_##KernelName##_KERNEL( \ + uint64, ulong, VEC_SIZE, __##OpName##_##ulong, OpName) \ +\ + DEFINE_##KernelName##_KERNEL( \ + float32, float, VEC_SIZE, __##OpName##_##float, OpName) \ + DEFINE_##KernelName##_KERNEL(float64, \ + double, \ + VEC_SIZE, \ + __##OpName##_##double, \ + OpName) + +#define DEFINE_KERNELS_WITH_LP_OP(KernelName, OpName) \ + DEFINE_##KernelName##_KERNEL(bfloat16, ushort, VEC_SIZE, __bf16_##OpName##_##ushort, OpName) \ + DEFINE_##KernelName##_KERNEL(float16, half, VEC_SIZE, __##OpName##_##half, OpName) + +#define DEFINE_OPS(T) \ + DEFINE_SUM_OP(T) \ + DEFINE_PROD_OP(T) \ + DEFINE_MIN_OP(T) \ + DEFINE_MAX_OP(T) + +#define DEFINE_BF16OPS(T) \ + DEFINE_BF16SUM_OP(T) \ + DEFINE_BF16PROD_OP(T) \ + DEFINE_BF16MIN_OP(T) \ + DEFINE_BF16MAX_OP(T) + +#define DEFINE_FP16OPS(T) \ + DEFINE_FP16SUM_OP(T) \ + DEFINE_FP16PROD_OP(T) \ + DEFINE_FP16MIN_OP(T) \ + DEFINE_FP16MAX_OP(T) + +// Define Op function for each supported type(use vector types for some of them as required by the kernel) +DEFINE_OPS(char) +DEFINE_OPS(uchar) + +DEFINE_OPS(short) +DEFINE_OPS(ushort) + +DEFINE_OPS(int) +DEFINE_OPS(uint) + +DEFINE_OPS(long) +DEFINE_OPS(ulong) + +DEFINE_OPS(float) +DEFINE_OPS(double) + +DEFINE_BF16OPS(ushort) +DEFINE_FP16OPS(half) + +// Define the actual kernels +#define DEFINE_ALL_KERNELS(KernelName) \ + DEFINE_KERNELS_WITH_OP(KernelName, sum) \ + DEFINE_KERNELS_WITH_OP(KernelName, prod) \ + DEFINE_KERNELS_WITH_OP(KernelName, min) \ + DEFINE_KERNELS_WITH_OP(KernelName, max) \ +\ + DEFINE_KERNELS_WITH_LP_OP(KernelName, sum) \ + DEFINE_KERNELS_WITH_LP_OP(KernelName, prod) \ + DEFINE_KERNELS_WITH_LP_OP(KernelName, min) \ + DEFINE_KERNELS_WITH_LP_OP(KernelName, max) + +DEFINE_ALL_KERNELS(ALLREDUCE) +DEFINE_ALL_KERNELS(REDUCE_LOCAL_OUTOFPLACE) +DEFINE_ALL_KERNELS(REDUCE_LOCAL_INPLACE) diff --git a/src/kernels/kernels.spv b/src/kernels/kernels.spv new file mode 100644 index 000000000..d386d0d80 Binary files /dev/null and b/src/kernels/kernels.spv differ diff --git a/src/kernels/lp.h b/src/kernels/lp.h index e28ea1d13..7e159e061 100644 --- a/src/kernels/lp.h +++ b/src/kernels/lp.h @@ -1,12 +1,12 @@ /* Copyright 2016-2020 Intel Corporation - + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -27,9 +27,9 @@ ushort __fp32_to_bf16(float V) { ushort2 temp = as_ushort2(V); return temp.s1; } -#else /* CCL_BF16_GPU_TRUNCATE */ -#include "rne.h" -#endif /* CCL_BF16_GPU_TRUNCATE */ +#else // CCL_BF16_GPU_TRUNCATE +#include "legacy/rne.h" +#endif // CCL_BF16_GPU_TRUNCATE #define DEFINE_BF16SUM_OP(T) \ T __bf16_sum_##T(T lhs, T rhs) { \ @@ -139,7 +139,8 @@ half __fp32_to_fp16(float V) { T __max_##T(T lhs, T rhs) { \ return __fp32_to_fp16(max(__fp16_to_fp32(lhs), __fp16_to_fp32(rhs))); \ } -#else /* CCL_FP16_GPU_TRUNCATE */ + +#else // CCL_FP16_GPU_TRUNCATE #define DEFINE_FP16SUM_OP(T) \ T __sum_##T(T lhs, T rhs) { \ return lhs + rhs; \ @@ -159,4 +160,5 @@ half __fp32_to_fp16(float V) { T __max_##T(T lhs, T rhs) { \ return max(lhs, rhs); \ } -#endif /* CCL_FP16_GPU_TRUNCATE */ + +#endif // CCL_FP16_GPU_TRUNCATE diff --git a/src/kernels/rne.h b/src/kernels/rne.h deleted file mode 100644 index 47ca9bf78..000000000 --- a/src/kernels/rne.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#ifndef RNE_H -#define RNE_H - -// bf <--> float conversion -// bf : no igc type for bf yet. Use short as *opaque* type for it. -// -// float -> bf conversion builtins (rte rounding mode) -short __builtin_IB_ftobf_1(float a) __attribute__((const)); -short2 __builtin_IB_ftobf_2(float2 a) __attribute__((const)); -short4 __builtin_IB_ftobf_4(float4 a) __attribute__((const)); -short8 __builtin_IB_ftobf_8(float8 a) __attribute__((const)); -short16 __builtin_IB_ftobf_16(float16 a) __attribute__((const)); - -// bf -> float conversion builtins (precise conversion) -float __builtin_IB_bftof_1(short a) __attribute__((const)); -float2 __builtin_IB_bftof_2(short2 a) __attribute__((const)); -float4 __builtin_IB_bftof_4(short4 a) __attribute__((const)); -float8 __builtin_IB_bftof_8(short8 a) __attribute__((const)); -float16 __builtin_IB_bftof_16(short16 a) __attribute__((const)); - -// 2 floats --> packed 2 bf (rte rounding mode) -int __builtin_IB_2fto2bf_1(float a, float b) __attribute__((const)); -int2 __builtin_IB_2fto2bf_2(float2 a, float2 b) __attribute__((const)); -int4 __builtin_IB_2fto2bf_4(float4 a, float4 b) __attribute__((const)); -int8 __builtin_IB_2fto2bf_8(float8 a, float8 b) __attribute__((const)); -int16 __builtin_IB_2fto2bf_16(float16 a, float16 b) __attribute__((const)); - -float __bf16_to_fp32(ushort V) { - return __builtin_IB_bftof_1(as_short(V)); -} - -ushort __fp32_to_bf16(float V) { - return as_ushort(__builtin_IB_ftobf_1(V)); -} - -#endif /* RNE_H */ diff --git a/src/kernels/shared.h b/src/kernels/shared.h index 3dce51e12..3b4f6c17e 100644 --- a/src/kernels/shared.h +++ b/src/kernels/shared.h @@ -1,12 +1,12 @@ /* Copyright 2016-2020 Intel Corporation - + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -68,4 +68,4 @@ static inline size_t ring_reduce_scatter_tmp_buffer_size(size_t elems_count, siz return 2 * ring_reduce_scatter_get_segment_size(elems_count, comm_size); } -#endif /* SHARED_H */ +#endif // SHARED_H diff --git a/src/native_device_api/interop_utils.cpp b/src/native_device_api/interop_utils.cpp index f76230dee..cba21295b 100644 --- a/src/native_device_api/interop_utils.cpp +++ b/src/native_device_api/interop_utils.cpp @@ -83,145 +83,5 @@ size_t get_sycl_subdevice_id(const cl::sycl::device& device) { } #endif -#ifdef CCL_ENABLE_SYCL -using usm_type_str = utils::enum_to_str; -std::string usm_to_string(cl::sycl::usm::alloc val) { - return usm_type_str({ - "HOST", - "DEVICE", - "SHARED", - }) - .choose(val, "UNKNOWN"); -} -#endif - -using usm_move_str = utils::enum_to_str; -std::string to_string(usm_support_mode val) { - return usm_move_str({ - "prohibited", - "direct", - "shared", - "convertation", - }) - .choose(val, "UNKNOWN"); -} - -size_t get_platform_type_index(const ccl::unified_device_type::ccl_native_t& device) { - size_t index = 2; //`gpu` for default L0 backend - -#ifdef CCL_ENABLE_SYCL - if (device.is_host()) { - index = 0; - } - else if (device.is_cpu()) { - index = 1; - } - else if (device.is_gpu()) { - index = 2; - } - else if (device.is_accelerator()) { - index = 3; - } - else { - CCL_THROW("invalid device type"); - } -#endif - - return index; -} - -#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) -assoc_result check_assoc_device_memory(const void* mem, - const ccl::unified_device_type::ccl_native_t& device, - const ccl::unified_context_type::ccl_native_t& ctx) { - assoc_result ret{ usm_support_mode::direct, mem, "" }; - -#ifdef CCL_ENABLE_SYCL - - sycl::usm::alloc pointer_type = sycl::get_pointer_type(mem, ctx); - - using usm_truth_table = - std::array; - - constexpr int platform_config_count = 4; /*host, cpu, gpu, accel*/ - constexpr std::array usm_target_table{ { - { { usm_support_mode::direct, - usm_support_mode::prohibited, - usm_support_mode::shared, - usm_support_mode::direct } }, //host conf: host, device, shared, unknown - { { usm_support_mode::direct, - usm_support_mode::prohibited, - usm_support_mode::shared, - usm_support_mode::direct } }, //cpu conf: host, device, shared, unknown - { { usm_support_mode::prohibited, - usm_support_mode::need_conversion, - usm_support_mode::shared, - usm_support_mode::prohibited } }, //gpu conf: host, device, shared, unknown - { { usm_support_mode::prohibited, - usm_support_mode::prohibited, - usm_support_mode::shared, - usm_support_mode::prohibited } } //accel conf: host, device, shared, unknown - } }; - - auto platform_type_index = get_platform_type_index(device); - - auto pointer_type_idx = utils::enum_to_underlying(pointer_type); - CCL_THROW_IF_NOT(pointer_type_idx < usm_target_table[platform_type_index].size(), - "usm_type index ", - pointer_type_idx, - " is larger that array size ", - usm_target_table[platform_type_index].size()); - - std::get(ret) = - usm_target_table[platform_type_index][pointer_type_idx]; - - if (std::get(ret) == usm_support_mode::prohibited) { - std::stringstream ss; - ss << "incompatible usm type requested: " << usm_to_string(pointer_type) - << " for device: " << std::to_string(platform_type_index); - std::get(ret) = ss.str(); - } -#else - //TODO calls method `assoc` for ccl_device -#endif - return ret; -} - -usm_support_mode check_assoc_device_memory(const std::vector& mems, - const ccl::unified_device_type::ccl_native_t& device, - const ccl::unified_context_type::ccl_native_t& ctx) { - usm_support_mode ret = usm_support_mode::direct; - std::string err_msg; - - for (size_t idx = 0; idx < mems.size(); idx++) { - usm_support_mode mode; - std::tie(mode, std::ignore, err_msg) = check_assoc_device_memory(mems[idx], device, ctx); - - if (idx > 0) - CCL_THROW_IF_NOT(mode == ret, "different USM modes between buffers: ", err_msg); - - ret = mode; - - CCL_THROW_IF_NOT((mode == usm_support_mode::direct) || (mode == usm_support_mode::shared) || - (mode == usm_support_mode::need_conversion), - "unsupported USM configuration: ", - err_msg); - } - - return ret; -} - -#endif //defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) - -std::string to_string(const assoc_result& res) { - std::stringstream ss; - ss << "Mem: " << std::get(res) - << ", is: " << to_string(std::get(res)); - const std::string& err_cause = std::get(res); - if (!err_cause.empty()) { - ss << ", error cause: " << err_cause; - } - return ss.str(); -} } // namespace detail } // namespace native diff --git a/src/native_device_api/l0/base.cpp b/src/native_device_api/l0/base.cpp index a91a5bc1b..792aa0bee 100644 --- a/src/native_device_api/l0/base.cpp +++ b/src/native_device_api/l0/base.cpp @@ -176,15 +176,12 @@ std::string CCL_BE_API to_string(const ze_memory_allocation_properties_t& prop) std::string CCL_BE_API to_string(const ze_device_mem_alloc_desc_t& mem_descr) { std::stringstream ss; - std::string flag; + std::string flag = "0"; if (mem_descr.flags & ZE_DEVICE_MEM_ALLOC_FLAG_BIAS_CACHED) { flag = "ZE_DEVICE_MEM_ALLOC_FLAG_BIAS_CACHED"; } - if (mem_descr.flags & ZE_DEVICE_MEM_ALLOC_FLAG_BIAS_UNCACHED) { - flag += flag.empty() ? "" : "|"; - flag = flag + "ZE_DEVICE_MEM_ALLOC_FLAG_BIAS_UNCACHED"; - } + if (flag.empty()) { CCL_THROW("unknown ze_device_mem_alloc_flags_t flag: " + std::to_string(static_cast(mem_descr.flags))); diff --git a/src/native_device_api/l0/device.cpp b/src/native_device_api/l0/device.cpp index 201ce63e9..bf8a9de2e 100644 --- a/src/native_device_api/l0/device.cpp +++ b/src/native_device_api/l0/device.cpp @@ -471,7 +471,24 @@ CCL_BE_API ccl_device::device_ipc_memory_handle ccl_device::create_ipc_memory_ha if (ret != ZE_RESULT_SUCCESS) { CCL_THROW("cannot get ipc mem handle, error: " + native::to_string(ret)); } - return device_ipc_memory_handle(ipc_handle, get_ptr(), ctx); + + void* base_ptr = nullptr; + size_t alloc_size = 0; + ret = zeMemGetAddressRange(ctx->get(), device_mem_ptr, &base_ptr, &alloc_size); + if (ret != ZE_RESULT_SUCCESS) { + CCL_THROW("zeMemGetAddressRange failed, error: " + native::to_string(ret)); + } + + LOG_DEBUG("Retrieved memory info for ", + device_mem_ptr, + ", base ptr: ", + base_ptr, + ", size ", + alloc_size, + ", offset ", + ((char*)device_mem_ptr) - ((char*)base_ptr)); + return device_ipc_memory_handle( + ipc_handle, get_ptr(), ctx, ((char*)device_mem_ptr) - ((char*)base_ptr)); } CCL_BE_API std::shared_ptr @@ -496,7 +513,7 @@ ccl_device::create_shared_ipc_memory_handle(void* device_mem_ptr, return ipc_storage .insert({ device_mem_ptr, std::shared_ptr( - new device_ipc_memory_handle(ipc_handle, get_ptr(), ctx)) }) + new device_ipc_memory_handle(ipc_handle, get_ptr(), ctx, 0)) }) .first->second; } @@ -528,10 +545,8 @@ CCL_BE_API ccl_device::device_ipc_memory ccl_device::get_ipc_memory( std::shared_ptr&& ipc_handle, std::shared_ptr ctx) { assert(ipc_handle->get_owner().lock().get() == this && "IPC handle doesn't belong to device: "); - //, this, - // ", expected device: ", ipc_handle.get_owner()); - ze_ipc_memory_flags_t flag = ZE_DEVICE_MEM_ALLOC_FLAG_BIAS_UNCACHED; + ze_ipc_memory_flags_t flag = 0; ip_memory_elem_t ipc_memory{}; if (!ctx) { @@ -552,6 +567,13 @@ CCL_BE_API ccl_device::device_ipc_memory ccl_device::get_ipc_memory( //ipc_handle.handle = nullptr; ipc_handle->owner.reset(); + LOG_DEBUG("Open ipc handle, got ptr(w/o offset): ", + ipc_memory.pointer, + " and offset: ", + ipc_handle->get_offset()); + + ipc_memory.pointer = (void*)((char*)ipc_memory.pointer + ipc_handle->get_offset()); + ipc_memory.offset = ipc_handle->get_offset(); return device_ipc_memory(ipc_memory, get_ptr(), ctx); } @@ -559,7 +581,7 @@ CCL_BE_API std::shared_ptr ccl_device::restore_sh std::shared_ptr&& ipc_handle, std::shared_ptr ctx) { assert(ipc_handle->get_owner().lock().get() == this && "IPC handle doesn't belong to device: "); - ze_ipc_memory_flags_t flag = ZE_DEVICE_MEM_ALLOC_FLAG_BIAS_UNCACHED; + ze_ipc_memory_flags_t flag = 0; ip_memory_elem_t ipc_memory{}; if (!ctx) { @@ -583,9 +605,14 @@ CCL_BE_API std::shared_ptr ccl_device::restore_sh } void CCL_BE_API ccl_device::on_delete(ip_memory_elem_t& ipc_mem, ze_context_handle_t& ctx) { - ze_result_t ret = zeMemCloseIpcHandle(ctx, ipc_mem.pointer); + // There are cases when we call this function on the same pointers(e.g. there are 2 ipc handles for + // the same L0 allocation, so subtracting their offsets results in the same ptr), and the function + // can return an error on the second run. Technically, this function is like free() meaning we + // can skip the error without throwing an exception without any affect on the execution. + // And just in case report the error in the debug log + ze_result_t ret = zeMemCloseIpcHandle(ctx, (void*)((char*)ipc_mem.pointer - ipc_mem.offset)); if (ret != ZE_RESULT_SUCCESS) { - CCL_THROW("cannot close ipc mem handle, error: " + native::to_string(ret)); + LOG_DEBUG("Cannot close ipc mem handle, ignoring error: " + native::to_string(ret)); } } diff --git a/src/native_device_api/sycl/export.cpp b/src/native_device_api/sycl/export.cpp index da40e3595..7c32c47e7 100644 --- a/src/native_device_api/sycl/export.cpp +++ b/src/native_device_api/sycl/export.cpp @@ -49,6 +49,12 @@ CCL_BE_API generic_device_type::generic_device_type device_index_type id, cl::sycl::info::device_type type) : device() { + if ((std::get<0>(id) == ccl::unused_index_value) && + (std::get<1>(id) == ccl::unused_index_value) && + (std::get<2>(id) == ccl::unused_index_value)) { + return; + } + LOG_DEBUG("Try to find SYCL device by index: ", id, ", type: ", diff --git a/src/native_device_api/sycl_l0/export.cpp b/src/native_device_api/sycl_l0/export.cpp index 7a49797ed..f27a87dbc 100644 --- a/src/native_device_api/sycl_l0/export.cpp +++ b/src/native_device_api/sycl_l0/export.cpp @@ -49,6 +49,12 @@ generic_device_type::generic_device_type( device_index_type id, cl::sycl::info::device_type type /* = info::device_type::gpu*/) : device() { + if ((std::get<0>(id) == ccl::unused_index_value) && + (std::get<1>(id) == ccl::unused_index_value) && + (std::get<2>(id) == ccl::unused_index_value)) { + return; + } + LOG_DEBUG("Try to find SYCL device by index: ", id, ", type: ", diff --git a/src/parallelizer/parallelizer.cpp b/src/parallelizer/parallelizer.cpp index afb0b5bd9..6c1483a2b 100644 --- a/src/parallelizer/parallelizer.cpp +++ b/src/parallelizer/parallelizer.cpp @@ -18,6 +18,7 @@ #include "coll/selection/selection.hpp" #include "common/global/global.hpp" +#include "common/utils/sycl_utils.hpp" #include "parallelizer/parallelizer.hpp" #include "sched/entry/coll/coll_entry_helper.hpp" #include "sched/entry/factory/entry_factory.hpp" @@ -86,12 +87,13 @@ ccl::status ccl_parallelizer::process(ccl_master_sched* sched) { process_base(sched); #ifdef CCL_ENABLE_SYCL - ccl_coll_param& coll_param = sched->coll_param; - if (coll_param.stream && coll_param.stream->is_sycl_device_stream() && - (coll_param.device_send_buf || coll_param.device_recv_buf)) { + ccl_coll_param& param = sched->coll_param; + if (param.stream && param.stream->is_sycl_device_stream() && + (!param.device_send_bufs.empty() || !param.device_recv_bufs.empty())) { process_pre_post_copies(sched); } -#endif /* CCL_ENABLE_SYCL */ + process_output_event(sched); +#endif // CCL_ENABLE_SYCL /* should be the last call in the sequence of process_* calls because it sets dependencies for all partial schedules @@ -104,9 +106,9 @@ ccl::status ccl_parallelizer::process(ccl_master_sched* sched) { ccl::status ccl_parallelizer::process_deps(ccl_master_sched* sched) { auto& part_scheds = sched->partial_scheds; ccl_sched* deps_sched = part_scheds[0].get(); - size_t part_count = part_scheds.size(); + size_t sched_count = part_scheds.size(); - for (size_t idx = 0; idx < part_count; idx++) { + for (size_t idx = 0; idx < sched_count; idx++) { part_scheds[idx]->set_add_mode(ccl_sched_add_front); } sched->sync_partial_scheds(); @@ -120,120 +122,102 @@ ccl::status ccl_parallelizer::process_deps(ccl_master_sched* sched) { #ifdef CCL_ENABLE_SYCL ccl::status ccl_parallelizer::process_pre_post_copies(ccl_master_sched* sched) { auto& part_scheds = sched->partial_scheds; - ccl_sched* copy_sched = part_scheds[0].get(); - size_t part_count = part_scheds.size(); - + size_t sched_count = part_scheds.size(); ccl_coll_param& coll_param = sched->coll_param; ccl_comm* comm = coll_param.comm; - int comm_size = comm->size(); int my_rank = comm->rank(); - const ccl_datatype& dtype = coll_param.dtype; size_t dtype_size = dtype.size(); - ccl_coll_type coll_type = coll_param.ctype; - - size_t d2h_bytes = 0, h2d_bytes = 0; - size_t d2h_count = 0, h2d_count = 0; - - void* device_in_buf = nullptr; - void* device_out_buf = nullptr; - void* host_in_buf = nullptr; - void* host_out_buf = nullptr; - size_t device_in_buf_offset = 0; - switch (coll_type) { - case ccl_coll_bcast: - if (comm->rank() == coll_param.root) - d2h_count = coll_param.count; - else - d2h_count = 0; - h2d_count = coll_param.count; - break; - - case ccl_coll_reduce: - d2h_count = coll_param.count; - if (my_rank == coll_param.root) - h2d_count = coll_param.count; - else - h2d_count = 0; - break; - - case ccl_coll_reduce_scatter: - d2h_count = coll_param.count * comm_size; - h2d_count = coll_param.count; - break; - - case ccl_coll_allreduce: d2h_count = h2d_count = coll_param.count; break; - - case ccl_coll_allgatherv: - if (coll_param.device_send_buf == coll_param.device_recv_buf) { - device_in_buf_offset = - std::accumulate(coll_param.recv_counts, coll_param.recv_counts + my_rank, 0); - LOG_TRACE("device_in_buf_offset = ", device_in_buf_offset); - } - d2h_count = coll_param.send_count; - h2d_count = - std::accumulate(coll_param.recv_counts, coll_param.recv_counts + comm_size, 0); - break; - - case ccl_coll_alltoall: d2h_count = h2d_count = coll_param.count * comm_size; break; - case ccl_coll_alltoallv: - d2h_count = - std::accumulate(coll_param.send_counts, coll_param.send_counts + comm_size, 0); - h2d_count = - std::accumulate(coll_param.recv_counts, coll_param.recv_counts + comm_size, 0); - break; - - default: CCL_FATAL("unexpected coll_type ", coll_type); break; + std::vector d2h_counts; + std::vector h2d_counts; + bool reuse_buffers; + sched->get_pre_post_copy_counts(d2h_counts, h2d_counts, reuse_buffers); + + if ((coll_type == ccl_coll_allgatherv) && + coll_param.is_inplace(ccl_coll_param::buf_type::device)) { + CCL_THROW_IF_NOT(coll_param.device_recv_bufs.size() == 1, + "unexpected device_recv_bufs.size ", + coll_param.device_recv_bufs.size()); + device_in_buf_offset = std::accumulate( + coll_param.recv_counts.begin(), coll_param.recv_counts.begin() + my_rank, 0); + LOG_TRACE("device_in_buf_offset = ", device_in_buf_offset); } - device_in_buf = &(coll_param.device_send_buf); - host_in_buf = (void*)coll_param.send_buf; - d2h_bytes = d2h_count * dtype_size; - - host_out_buf = coll_param.recv_buf; - device_out_buf = &(coll_param.device_recv_buf); - h2d_bytes = h2d_count * dtype_size; + size_t total_d2h_count = std::accumulate(d2h_counts.begin(), d2h_counts.end(), 0); + size_t total_h2d_count = std::accumulate(h2d_counts.begin(), h2d_counts.end(), 0); - if (d2h_bytes) { - for (size_t idx = 0; idx < part_count; idx++) { + if (total_d2h_count) { + for (size_t idx = 0; idx < sched_count; idx++) { part_scheds[idx]->set_add_mode(ccl_sched_add_front); } sched->sync_partial_scheds(); - entry_factory::make_entry( - copy_sched, - copy_direction::d2h, - ccl_buffer(device_in_buf, d2h_bytes, ccl_buffer_type::INDIRECT), - ccl_buffer(host_in_buf, d2h_bytes), - d2h_count, - dtype, - coll_param.stream, - device_in_buf_offset); + for (size_t idx = 0; idx < d2h_counts.size(); idx++) { + size_t sched_idx = idx % sched_count; + size_t count = d2h_counts[idx]; + size_t bytes = count * dtype_size; + + entry_factory::make_entry( + part_scheds[sched_idx].get(), + ccl_buffer(coll_param.get_send_buf_ptr(idx, ccl_coll_param::buf_type::device), + bytes, + ccl_buffer_type::INDIRECT), + ccl_buffer(coll_param.get_send_buf(idx), bytes), + count, + dtype, + copy_attr(copy_direction::d2h, device_in_buf_offset)); + } } - if (h2d_bytes) { - for (size_t idx = 0; idx < part_count; idx++) { + if (total_h2d_count) { + for (size_t idx = 0; idx < sched_count; idx++) { part_scheds[idx]->set_add_mode(ccl_sched_add_back); } sched->sync_partial_scheds(); - entry_factory::make_entry( - copy_sched, - copy_direction::h2d, - ccl_buffer(host_out_buf, h2d_bytes), - ccl_buffer(device_out_buf, h2d_bytes, ccl_buffer_type::INDIRECT), - h2d_count, - dtype, - coll_param.stream); - part_scheds[0]->add_barrier(); + for (size_t idx = 0; idx < h2d_counts.size(); idx++) { + size_t sched_idx = idx % sched_count; + size_t count = h2d_counts[idx]; + size_t bytes = count * dtype_size; + + entry_factory::make_entry( + part_scheds[sched_idx].get(), + ccl_buffer(coll_param.get_recv_buf(idx), bytes), + ccl_buffer(coll_param.get_recv_buf_ptr(idx, ccl_coll_param::buf_type::device), + bytes, + ccl_buffer_type::INDIRECT), + count, + dtype, + copy_attr(copy_direction::h2d, 0)); + } + + sched->sync_partial_scheds(); } return ccl::status::success; } -#endif /* CCL_ENABLE_SYCL */ + +ccl::status ccl_parallelizer::process_output_event(ccl_master_sched* sched) { + if (!ccl::utils::should_use_sycl_output_event(sched->coll_param.stream)) { + return ccl::status::success; + } + + auto& part_scheds = sched->partial_scheds; + size_t sched_count = part_scheds.size(); + + for (size_t idx = 0; idx < sched_count; idx++) { + part_scheds[idx]->set_add_mode(ccl_sched_add_back); + } + sched->sync_partial_scheds(); + + entry_factory::make_entry(part_scheds[0].get(), sched); + + return ccl::status::success; +} +#endif // CCL_ENABLE_SYCL ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { /* TODO: split on per-collective classes */ @@ -243,17 +227,16 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { ccl::global_data& data = ccl::global_data::get(); ccl::status status = ccl::status::success; - size_t part_count = 1, idx, base_count, dtype_size, comm_size, my_rank; + size_t part_count = 1, idx, base_count, dtype_size, comm_size; ccl_coll_param& coll_param = sched->coll_param; - ccl_coll_param_copy* coll_param_copy = &(sched->coll_param_copy); - ccl_coll_attr* coll_attr = &(sched->coll_attr); + ccl_coll_attr& coll_attr = sched->coll_attr; + ccl_comm* comm = coll_param.comm; const ccl_datatype& dtype = coll_param.dtype; dtype_size = dtype.size(); comm_size = comm->size(); - my_rank = comm->rank(); ccl_coll_type coll_type = coll_param.ctype; @@ -262,26 +245,26 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { auto& part_scheds = sched->partial_scheds; std::vector part_scheds_vector; - const size_t* recv_counts = nullptr; - std::vector ag_recv_bufs; size_t ag_recv_bytes = 0, ag_recv_count = 0; size_t a2av_send_bytes = 0, a2av_recv_bytes = 0; size_t a2av_send_count = 0, a2av_recv_count = 0; - ccl_coll_allgatherv_algo ag_algo = ccl_coll_allgatherv_naive; - ccl_coll_alltoall_algo a2a_algo = ccl_coll_alltoall_direct; - ccl_coll_alltoallv_algo a2av_algo = ccl_coll_alltoallv_direct; - ccl_coll_bcast_algo ag_mbcast_algo = ccl_coll_bcast_naive; + ccl_coll_algo algo; + ccl_coll_algo internal_algo; std::vector part_ctxs; ccl_selector_param selector_param; selector_param.ctype = coll_type; - selector_param.count = coll_param.count; - selector_param.recv_counts = coll_param.recv_counts; + selector_param.count = coll_param.get_send_count(); selector_param.dtype = dtype; selector_param.comm = comm; + selector_param.stream = coll_param.stream; + selector_param.is_vector_buf = coll_attr.is_vector_buf; +#ifdef CCL_ENABLE_SYCL + selector_param.is_sycl_buf = coll_attr.is_sycl_buf; +#endif // CCL_ENABLE_SYCL switch (coll_type) { case ccl_coll_barrier: part_count = max_data_partition_count; break; @@ -292,20 +275,24 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { } case ccl_coll_reduce: case ccl_coll_allreduce: - if ((coll_param.count * dtype_size <= ccl::global_data::env().max_short_size) || - (coll_param.count < max_data_partition_count)) { + if ((coll_param.get_send_count() * dtype_size <= + ccl::global_data::env().max_short_size) || + (coll_param.get_send_count() < max_data_partition_count)) { part_count = 1; } else { /* to workaround lack of large msg protocol on ATL level */ - part_count = (coll_param.count * dtype_size) / CCL_ATL_LARGE_MSG_SIZE; + part_count = (coll_param.get_send_count() * dtype_size) / CCL_ATL_LARGE_MSG_SIZE; if (part_count < max_data_partition_count) part_count = max_data_partition_count; } + if (ccl_is_topo_ring_algo(selector_param)) { + part_count = 1; + } break; case ccl_coll_alltoall: - a2a_algo = data.algorithm_selector->get(selector_param); - if (a2a_algo == ccl_coll_alltoall_direct) { + algo.alltoall = data.algorithm_selector->get(selector_param); + if (algo.alltoall == ccl_coll_alltoall_direct) { part_count = 1; } else { @@ -313,13 +300,8 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { } break; case ccl_coll_alltoallv: - a2av_algo = data.algorithm_selector->get(selector_param); - coll_param_copy->a2av_send_counts.assign((size_t*)coll_param.send_counts, - (size_t*)coll_param.send_counts + comm_size); - coll_param_copy->a2av_recv_counts.assign((size_t*)coll_param.recv_counts, - (size_t*)coll_param.recv_counts + comm_size); - - if (a2av_algo == ccl_coll_alltoallv_direct) { + algo.alltoallv = data.algorithm_selector->get(selector_param); + if (algo.alltoallv == ccl_coll_alltoallv_direct) { part_count = 1; } else { @@ -327,31 +309,24 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { } break; case ccl_coll_allgatherv: - selector_param.vector_buf = coll_attr->vector_buf; - ag_algo = data.algorithm_selector->get(selector_param); - coll_param_copy->ag_recv_counts.assign((size_t*)coll_param.recv_counts, - (size_t*)coll_param.recv_counts + comm_size); - - if (ag_algo == ccl_coll_allgatherv_direct || ag_algo == ccl_coll_allgatherv_naive || - ag_algo == ccl_coll_allgatherv_ring) { + selector_param.recv_counts = coll_param.recv_counts.data(); + algo.allgatherv = data.algorithm_selector->get(selector_param); + if (algo.allgatherv == ccl_coll_allgatherv_direct || + algo.allgatherv == ccl_coll_allgatherv_naive || + algo.allgatherv == ccl_coll_allgatherv_ring) { part_count = 1; } - else if (ag_algo == ccl_coll_allgatherv_multi_bcast || - ag_algo == ccl_coll_allgatherv_flat) { + else if (algo.allgatherv == ccl_coll_allgatherv_multi_bcast || + algo.allgatherv == ccl_coll_allgatherv_flat) { part_count = comm_size; ag_recv_bufs.resize(comm_size); - - if (coll_attr->vector_buf) { - coll_param_copy->ag_recv_bufs.assign((void**)coll_param.recv_buf, - (void**)coll_param.recv_buf + comm_size); - } - - if (ag_algo == ccl_coll_allgatherv_multi_bcast) { + if (algo.allgatherv == ccl_coll_allgatherv_multi_bcast) { selector_param.ctype = ccl_coll_bcast; - selector_param.count = sched->coll_param.send_count; + selector_param.count = sched->coll_param.get_send_count(); selector_param.dtype = dtype; - ag_mbcast_algo = data.algorithm_selector->get(selector_param); - if (ag_mbcast_algo == ccl_coll_bcast_direct) { + internal_algo.bcast = + data.algorithm_selector->get(selector_param); + if (internal_algo.bcast == ccl_coll_bcast_direct) { /* group all direct bcasts for specific worker together into single schedule w/o any barriers to get all required MPI level tags by single do_progress call @@ -363,7 +338,7 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { } } else { - CCL_FATAL("unexpected allgatherv_algo ", ag_algo); + CCL_FATAL("unexpected allgatherv_algo ", algo.allgatherv); } break; case ccl_coll_reduce_scatter: part_count = 1; break; @@ -378,8 +353,8 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { ", part_count ", part_count); - if (coll_type == ccl_coll_allgatherv && ag_algo == ccl_coll_allgatherv_multi_bcast && - ag_mbcast_algo == ccl_coll_bcast_direct) { + if (coll_type == ccl_coll_allgatherv && algo.allgatherv == ccl_coll_allgatherv_multi_bcast && + internal_algo.bcast == ccl_coll_bcast_direct) { counts.resize(comm_size, 0); offsets.resize(comm_size, 0); } @@ -390,8 +365,7 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { for (idx = 0; idx < part_count; idx++) { ccl_coll_param part_coll_param{}; - part_coll_param.ctype = sched->coll_param.ctype; - part_coll_param.dtype = sched->coll_param.dtype; + part_coll_param.ctype = ccl_coll_partial; part_coll_param.stream = sched->coll_param.stream; part_coll_param.comm = comm; sched->add_partial_sched(part_coll_param); @@ -402,7 +376,7 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { for (idx = 0; idx < part_scheds.size(); idx++) { /* in this place all coll attributes for partial schedules * are taken from master schedule, including priority */ - part_scheds[idx]->coll_attr = *coll_attr; + part_scheds[idx]->coll_attr = coll_attr; part_scheds_vector[idx] = part_scheds[idx].get(); } @@ -413,49 +387,46 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { case ccl_coll_allreduce: case ccl_coll_reduce_scatter: case ccl_coll_sparse_allreduce: - base_count = coll_param.count / part_count; + base_count = coll_param.get_recv_count() / part_count; for (idx = 0; idx < counts.size(); idx++) { counts[idx] = base_count; offsets[idx] = idx * counts[idx] * dtype_size; } - counts[counts.size() - 1] += coll_param.count % counts.size(); - for (idx = 0; idx < part_scheds.size(); idx++) { - part_scheds[idx]->coll_param.count = counts[idx]; - } + counts[counts.size() - 1] += coll_param.get_recv_count() % counts.size(); + // for (idx = 0; idx < part_scheds.size(); idx++) { + // part_scheds[idx]->coll_param.recv_counts.push_back(counts[idx]); + // } break; case ccl_coll_alltoall: case ccl_coll_alltoallv: if (coll_type == ccl_coll_alltoallv) { - CCL_ASSERT(coll_param.send_counts); - CCL_ASSERT(coll_param.recv_counts); - a2av_send_count = - std::accumulate(coll_param.send_counts, coll_param.send_counts + comm_size, 0); - a2av_recv_count = - std::accumulate(coll_param.recv_counts, coll_param.recv_counts + comm_size, 0); + a2av_send_count = std::accumulate( + coll_param.send_counts.begin(), coll_param.send_counts.end(), 0); + a2av_recv_count = std::accumulate( + coll_param.recv_counts.begin(), coll_param.recv_counts.end(), 0); } else { - a2av_send_count = coll_param.count * comm_size; - a2av_recv_count = coll_param.count * comm_size; + a2av_send_count = coll_param.get_send_count() * comm_size; + a2av_recv_count = coll_param.get_recv_count() * comm_size; } a2av_send_bytes = a2av_send_count * dtype_size; a2av_recv_bytes = a2av_recv_count * dtype_size; break; case ccl_coll_allgatherv: - recv_counts = coll_param.recv_counts; - CCL_ASSERT(recv_counts); - counts[0] = recv_counts[0]; + counts[0] = coll_param.get_recv_count(0); offsets[0] = 0; - if (ag_algo == ccl_coll_allgatherv_direct || ag_algo == ccl_coll_allgatherv_naive || - ag_algo == ccl_coll_allgatherv_ring) { + if (algo.allgatherv == ccl_coll_allgatherv_direct || + algo.allgatherv == ccl_coll_allgatherv_naive || + algo.allgatherv == ccl_coll_allgatherv_ring) { } else { for (idx = 1; idx < comm_size; idx++) { - counts[idx] = recv_counts[idx]; + counts[idx] = coll_param.get_recv_count(idx); offsets[idx] = offsets[idx - 1] + counts[idx - 1] * dtype_size; } } ag_recv_count = - std::accumulate(coll_param.recv_counts, coll_param.recv_counts + comm_size, 0); + std::accumulate(coll_param.recv_counts.begin(), coll_param.recv_counts.end(), 0); ag_recv_bytes = ag_recv_count * dtype_size; break; default: CCL_FATAL("unexpected coll_type ", coll_type); break; @@ -477,8 +448,8 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { for (idx = 0; idx < part_count; idx++) { ccl_coll_entry_param param{}; param.ctype = ccl_coll_bcast; - param.recv_buf = ccl_buffer(&(coll_param.recv_buf), - coll_param.count * dtype_size, + param.recv_buf = ccl_buffer(coll_param.get_recv_buf_ptr(), + coll_param.get_recv_count() * dtype_size, offsets[idx], ccl_buffer_type::INDIRECT); param.count = counts[idx]; @@ -493,12 +464,12 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { for (idx = 0; idx < part_count; idx++) { ccl_coll_entry_param param{}; param.ctype = ccl_coll_reduce; - param.send_buf = ccl_buffer(&(coll_param.send_buf), - coll_param.count * dtype_size, + param.send_buf = ccl_buffer(coll_param.get_send_buf_ptr(), + coll_param.get_send_count() * dtype_size, offsets[idx], ccl_buffer_type::INDIRECT); - param.recv_buf = ccl_buffer(&(coll_param.recv_buf), - coll_param.count * dtype_size, + param.recv_buf = ccl_buffer(coll_param.get_recv_buf_ptr(), + coll_param.get_recv_count() * dtype_size, offsets[idx], ccl_buffer_type::INDIRECT); param.count = counts[idx]; @@ -506,6 +477,7 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { param.reduction = coll_param.reduction; param.root = coll_param.root; param.comm = comm; + param.stream = coll_param.stream; coll_entry_helper::add_coll_entry(part_scheds[idx].get(), param); } break; @@ -515,17 +487,19 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { ccl_coll_entry_param param{}; param.ctype = ccl_coll_reduce_scatter; - bool inplace = (coll_param.send_buf == coll_param.recv_buf) ? true : false; - size_t recv_buf_size = coll_param.count * dtype_size; + bool inplace = coll_param.is_inplace(); + size_t recv_buf_size = coll_param.get_recv_count() * dtype_size; if (inplace) recv_buf_size *= comm_size; - param.send_buf = ccl_buffer(&(coll_param.send_buf), - coll_param.count * comm_size * dtype_size, + param.send_buf = ccl_buffer(coll_param.get_send_buf_ptr(), + coll_param.get_send_count() * dtype_size, + offsets[idx], + ccl_buffer_type::INDIRECT); + param.recv_buf = ccl_buffer(coll_param.get_recv_buf_ptr(), + recv_buf_size, offsets[idx], ccl_buffer_type::INDIRECT); - param.recv_buf = ccl_buffer( - &(coll_param.recv_buf), recv_buf_size, offsets[idx], ccl_buffer_type::INDIRECT); param.count = counts[idx]; param.dtype = dtype; param.reduction = coll_param.reduction; @@ -537,7 +511,7 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { case ccl_coll_allreduce: { ccl_parallelizer_prologue_ctx* main_ctx = nullptr; - if (coll_attr->prologue_fn) { + if (coll_attr.prologue_fn) { part_ctxs.reserve(part_count); main_ctx = (ccl_parallelizer_prologue_ctx*)part_scheds[0] @@ -545,16 +519,17 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { .get_ptr(); main_ctx->part_idx = 0; main_ctx->part_count = 1; - entry_factory::make_entry(part_scheds[0].get(), - coll_attr->prologue_fn, - ccl_buffer(&(coll_param.send_buf), - coll_param.count * dtype_size, - ccl_buffer_type::INDIRECT), - coll_param.count, - dtype, - &(main_ctx->buf), - &(main_ctx->count), - &(main_ctx->dt_idx)); + entry_factory::make_entry( + part_scheds[0].get(), + coll_attr.prologue_fn, + ccl_buffer(coll_param.get_send_buf_ptr(), + coll_param.get_send_count() * dtype_size, + ccl_buffer_type::INDIRECT), + coll_param.get_send_count(), + dtype, + &(main_ctx->buf), + &(main_ctx->count), + &(main_ctx->dt_idx)); sched->sync_partial_scheds(); @@ -580,13 +555,13 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { for (idx = 0; idx < part_count; idx++) { ccl_coll_entry_param param{}; param.ctype = ccl_coll_allreduce; - if (!coll_attr->prologue_fn) { - param.send_buf = ccl_buffer(&(coll_param.send_buf), - coll_param.count * dtype_size, + if (!coll_attr.prologue_fn) { + param.send_buf = ccl_buffer(coll_param.get_send_buf_ptr(), + coll_param.get_send_count() * dtype_size, offsets[idx], ccl_buffer_type::INDIRECT); - param.recv_buf = ccl_buffer(&(coll_param.recv_buf), - coll_param.count * dtype_size, + param.recv_buf = ccl_buffer(coll_param.get_recv_buf_ptr(), + coll_param.get_recv_count() * dtype_size, offsets[idx], ccl_buffer_type::INDIRECT); param.count = counts[idx]; @@ -600,11 +575,12 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { } param.reduction = coll_param.reduction; param.comm = comm; + param.stream = coll_param.stream; auto entry = coll_entry_helper::add_coll_entry( part_scheds[idx].get(), param); - if (coll_attr->prologue_fn) { + if (coll_attr.prologue_fn) { auto part_ctx = part_ctxs[idx]; entry->set_field_fn( ccl_parallelizer_prologue_get_buf, part_ctx, false); @@ -617,17 +593,17 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { } } - if (coll_attr->prologue_fn && !coll_attr->epilogue_fn) { + if (coll_attr.prologue_fn && !coll_attr.epilogue_fn) { sched->sync_partial_scheds(); - auto entry = - entry_factory::make_entry(part_scheds[0].get(), - ccl_buffer(), /* in_buf */ - ccl_buffer(&(coll_param.recv_buf), - coll_param.count * dtype_size, - ccl_buffer_type::INDIRECT), - 0, /* count */ - ccl_datatype_int8); + auto entry = entry_factory::make_entry( + part_scheds[0].get(), + ccl_buffer(), /* in_buf */ + ccl_buffer(coll_param.get_recv_buf_ptr(), + coll_param.get_recv_count() * dtype_size, + ccl_buffer_type::INDIRECT), + 0, /* count */ + ccl_datatype_int8); entry->set_field_fn( ccl_parallelizer_prologue_get_buf, main_ctx, false); entry->set_field_fn( @@ -636,22 +612,22 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { ccl_parallelizer_prologue_get_dtype, main_ctx, false); } - if (coll_attr->epilogue_fn) { + if (coll_attr.epilogue_fn) { sched->sync_partial_scheds(); auto entry = entry_factory::make_entry( part_scheds[0].get(), - coll_attr->epilogue_fn, - ccl_buffer(&(coll_param.recv_buf), - coll_param.count * dtype_size, + coll_attr.epilogue_fn, + ccl_buffer(coll_param.get_recv_buf_ptr(), + coll_param.get_recv_count() * dtype_size, ccl_buffer_type::INDIRECT), - coll_param.count, + coll_param.get_recv_count(), dtype, - ccl_buffer(&(coll_param.recv_buf), - coll_param.count * dtype_size, + ccl_buffer(coll_param.get_recv_buf_ptr(), + coll_param.get_recv_count() * dtype_size, ccl_buffer_type::INDIRECT), - coll_param.count, + coll_param.get_recv_count(), dtype); - if (coll_attr->prologue_fn) { + if (coll_attr.prologue_fn) { entry->set_field_fn( ccl_parallelizer_prologue_get_buf, main_ctx, false); entry->set_field_fn( @@ -664,125 +640,33 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { } case ccl_coll_allgatherv: { - if (ag_algo == ccl_coll_allgatherv_direct || ag_algo == ccl_coll_allgatherv_naive || - ag_algo == ccl_coll_allgatherv_ring) { + if (algo.allgatherv == ccl_coll_allgatherv_direct || + algo.allgatherv == ccl_coll_allgatherv_naive || + algo.allgatherv == ccl_coll_allgatherv_ring) { ccl_coll_entry_param param{}; param.ctype = ccl_coll_allgatherv; - param.send_buf = ccl_buffer(&(coll_param.send_buf), - coll_param.send_count * dtype_size, + param.send_buf = ccl_buffer(coll_param.get_send_buf_ptr(), + coll_param.get_send_count() * dtype_size, ccl_buffer_type::INDIRECT); - param.recv_buf = - ccl_buffer(&(coll_param.recv_buf), ag_recv_bytes, ccl_buffer_type::INDIRECT); - param.send_count = coll_param.send_count; - param.recv_counts = coll_param_copy->ag_recv_counts.data(); + param.recv_buf = ccl_buffer( + coll_param.get_recv_buf_ptr(), ag_recv_bytes, ccl_buffer_type::INDIRECT); + param.send_count = coll_param.get_send_count(); + param.recv_counts = coll_param.recv_counts.data(); param.dtype = dtype; param.comm = comm; coll_entry_helper::add_coll_entry(part_scheds[0].get(), param); } else { - CCL_ASSERT(ag_algo == ccl_coll_allgatherv_flat || - ag_algo == ccl_coll_allgatherv_multi_bcast, + CCL_ASSERT(algo.allgatherv == ccl_coll_allgatherv_flat || + algo.allgatherv == ccl_coll_allgatherv_multi_bcast, "unexpected allgatherv algorithm"); - for (idx = 0; idx < comm_size; idx++) { - if (coll_attr->vector_buf) { - ag_recv_bufs[idx].set(&(coll_param_copy->ag_recv_bufs[idx]), - counts[idx] * dtype_size, - ccl_buffer_type::INDIRECT); - } - else { - ag_recv_bufs[idx].set(&(coll_param.recv_buf), - ag_recv_bytes, - offsets[idx], - ccl_buffer_type::INDIRECT); - } - } - - if (ag_algo == ccl_coll_allgatherv_flat) { - auto send_seg = ccl_buffer(&(coll_param.send_buf), - coll_param.send_count * dtype_size, - ccl_buffer_type::INDIRECT); - - if (coll_param.send_buf != coll_param.recv_buf) { - entry_factory::make_entry( - part_scheds[2 * my_rank % part_count].get(), - ccl_buffer(&(coll_param.send_buf), - coll_param.send_count * dtype_size, - ccl_buffer_type::INDIRECT), - ag_recv_bufs[my_rank], - counts[my_rank], - dtype); - } - else { - send_seg = ccl_buffer(&(coll_param.send_buf), - ag_recv_bytes, - offsets[my_rank], - ccl_buffer_type::INDIRECT); - } - - CCL_ASSERT(part_count == comm_size); - - for (idx = 0; idx < part_count; idx++) { - if (idx == my_rank) - continue; - - entry_factory::make_entry( - part_scheds[(my_rank + idx) % part_count].get(), - ag_recv_bufs[idx], - counts[idx], - dtype, - idx, - comm); - entry_factory::make_entry( - part_scheds[(my_rank + idx) % part_count].get(), - send_seg, - counts[my_rank], - dtype, - idx, - comm); - } - sched->sync_partial_scheds(); + if (algo.allgatherv == ccl_coll_allgatherv_flat) { + ccl_coll_build_flat_allgatherv(sched, part_scheds_vector, coll_param); } else { - CCL_ASSERT(ag_algo == ccl_coll_allgatherv_multi_bcast); - - if (coll_param.send_buf != coll_param.recv_buf) { - std::vector copy_counts(max_data_partition_count); - std::vector copy_offsets(max_data_partition_count); - for (idx = 0; idx < max_data_partition_count; idx++) { - copy_counts[idx] = counts[comm->rank()] / max_data_partition_count; - copy_offsets[idx] = idx * copy_counts[idx] * dtype_size; - } - copy_counts[max_data_partition_count - 1] += - counts[comm->rank()] % max_data_partition_count; - - CCL_ASSERT(part_scheds.size() >= max_data_partition_count); - - for (idx = 0; idx < max_data_partition_count; idx++) { - entry_factory::make_entry( - part_scheds[idx].get(), - ccl_buffer(&(coll_param.send_buf), - coll_param.send_count * dtype_size, - copy_offsets[idx], - ccl_buffer_type::INDIRECT), - ag_recv_bufs[comm->rank()] + copy_offsets[idx], - copy_counts[idx], - dtype); - } - sched->sync_partial_scheds(); - } - - for (idx = 0; idx < comm_size; idx++) { - ccl_coll_entry_param param{}; - param.ctype = ccl_coll_bcast; - param.recv_buf = ag_recv_bufs[idx]; - param.count = counts[idx]; - param.dtype = dtype; - param.root = idx; - param.comm = comm; - coll_entry_helper::add_coll_entry( - part_scheds[idx % part_count].get(), param); - } + ccl_coll_build_multi_bcast_allgatherv( + sched, part_scheds_vector, coll_param, max_data_partition_count); } } break; @@ -790,35 +674,36 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { case ccl_coll_alltoall: case ccl_coll_alltoallv: { - if (a2a_algo == ccl_coll_alltoall_naive || a2av_algo == ccl_coll_alltoallv_naive) { + if (algo.alltoall == ccl_coll_alltoall_naive || + algo.alltoallv == ccl_coll_alltoallv_naive) { ccl_coll_build_naive_alltoallv(sched, part_scheds_vector, coll_param); } - else if (a2a_algo == ccl_coll_alltoall_scatter || - a2av_algo == ccl_coll_alltoallv_scatter) { + else if (algo.alltoall == ccl_coll_alltoall_scatter || + algo.alltoallv == ccl_coll_alltoallv_scatter) { ccl_coll_build_scatter_alltoallv(sched, part_scheds_vector, coll_param); } - else if (a2a_algo == ccl_coll_alltoall_scatter_barrier || - a2av_algo == ccl_coll_alltoallv_scatter_barrier) { + else if (algo.alltoall == ccl_coll_alltoall_scatter_barrier || + algo.alltoallv == ccl_coll_alltoallv_scatter_barrier) { ccl_coll_build_scatter_barrier_alltoallv(sched, part_scheds_vector, coll_param); } else { ccl_coll_entry_param param{}; param.ctype = coll_type; - param.send_buf = - ccl_buffer(&(coll_param.send_buf), a2av_send_bytes, ccl_buffer_type::INDIRECT); - param.recv_buf = - ccl_buffer(&(coll_param.recv_buf), a2av_recv_bytes, ccl_buffer_type::INDIRECT); + param.send_buf = ccl_buffer( + coll_param.get_send_buf_ptr(), a2av_send_bytes, ccl_buffer_type::INDIRECT); + param.recv_buf = ccl_buffer( + coll_param.get_recv_buf_ptr(), a2av_recv_bytes, ccl_buffer_type::INDIRECT); param.dtype = dtype; param.comm = comm; if (coll_type == ccl_coll_alltoall) { - param.count = coll_param.count; + param.count = coll_param.get_send_count(); coll_entry_helper::add_coll_entry(part_scheds[0].get(), param); } else { - param.send_counts = coll_param_copy->a2av_send_counts.data(); - param.recv_counts = coll_param_copy->a2av_recv_counts.data(); + param.send_counts = coll_param.send_counts.data(); + param.recv_counts = coll_param.recv_counts.data(); coll_entry_helper::add_coll_entry(part_scheds[0].get(), param); } @@ -867,15 +752,15 @@ ccl::status ccl_parallelizer::process_base(ccl_master_sched* sched) { comm)); } - if (coll_attr->sparse_allreduce_completion_fn) { - CCL_THROW_IF_NOT(!coll_attr->sparse_allreduce_alloc_fn); + if (coll_attr.sparse_allreduce_completion_fn) { + CCL_THROW_IF_NOT(!coll_attr.sparse_allreduce_alloc_fn); sched->sync_partial_scheds(); auto entry = entry_factory::make_entry( part_scheds[0].get(), - coll_attr->sparse_allreduce_completion_fn, - coll_attr->sparse_allreduce_fn_ctx, + coll_attr.sparse_allreduce_completion_fn, + coll_attr.sparse_allreduce_fn_ctx, ccl_buffer(), 0, coll_param.sparse_param.itype, diff --git a/src/parallelizer/parallelizer.hpp b/src/parallelizer/parallelizer.hpp index 6c3d30e44..5b7a3500a 100644 --- a/src/parallelizer/parallelizer.hpp +++ b/src/parallelizer/parallelizer.hpp @@ -38,7 +38,8 @@ class ccl_parallelizer { #ifdef CCL_ENABLE_SYCL ccl::status process_pre_post_copies(ccl_master_sched* sched); -#endif /* CCL_ENABLE_SYCL */ + ccl::status process_output_event(ccl_master_sched* sched); +#endif // CCL_ENABLE_SYCL ccl::status process_base(ccl_master_sched* sched); diff --git a/src/sched/buffer_cache.cpp b/src/sched/buffer_cache.cpp new file mode 100644 index 000000000..f53a77c57 --- /dev/null +++ b/src/sched/buffer_cache.cpp @@ -0,0 +1,151 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "common/global/global.hpp" +#include "sched/buffer_cache.hpp" + +namespace ccl { + +buffer_cache::~buffer_cache() { + for (auto& instance : reg_buffers) { + instance.clear(); + } + +#ifdef CCL_ENABLE_SYCL + for (auto& instance : sycl_buffers) { + instance.clear(); + } +#endif // CCL_ENABLE_SYCL +} + +void buffer_cache::get(size_t idx, size_t bytes, void** pptr) { + reg_buffers.at(idx % reg_buffers.size()).get(bytes, pptr); +} + +void buffer_cache::push(size_t idx, size_t bytes, void* ptr) { + reg_buffers.at(idx % reg_buffers.size()).push(bytes, ptr); +} + +#ifdef CCL_ENABLE_SYCL +void buffer_cache::get(size_t idx, size_t bytes, const sycl::context& ctx, void** pptr) { + sycl_buffers.at(idx % sycl_buffers.size()).get(bytes, ctx, pptr); +} + +void buffer_cache::push(size_t idx, size_t bytes, const sycl::context& ctx, void* ptr) { + sycl_buffers.at(idx % sycl_buffers.size()).push(bytes, ctx, ptr); +} +#endif // CCL_ENABLE_SYCL + +regular_buffer_cache::~regular_buffer_cache() { + if (!cache.empty()) { + LOG_WARN("buffer cache is not empty, size: ", cache.size()); + clear(); + } +} + +void regular_buffer_cache::clear() { + std::lock_guard lock{ guard }; + LOG_DEBUG("clear buffer cache: size: ", cache.size()); + for (auto& key_value : cache) { + CCL_FREE(key_value.second); + } + cache.clear(); +} + +void regular_buffer_cache::get(size_t bytes, void** pptr) { + if (global_data::env().enable_buffer_cache) { + std::lock_guard lock{ guard }; + key_t key(bytes); + auto key_value = cache.find(key); + if (key_value != cache.end()) { + *pptr = key_value->second; + cache.erase(key_value); + LOG_DEBUG("loaded from buffer cache: bytes: ", bytes, ", ptr: ", *pptr); + return; + } + } + *pptr = CCL_MALLOC(bytes, "buffer"); +} + +void regular_buffer_cache::push(size_t bytes, void* ptr) { + if (global_data::env().enable_buffer_cache) { + std::lock_guard lock{ guard }; + key_t key(bytes); + cache.insert({ std::move(key), ptr }); + LOG_DEBUG("inserted to buffer cache: bytes: ", bytes, ", ptr: ", ptr); + return; + } + CCL_FREE(ptr); +} + +#ifdef CCL_ENABLE_SYCL +sycl_buffer_cache::~sycl_buffer_cache() { + if (!cache.empty()) { + LOG_WARN("sycl buffer cache is not empty, size: ", cache.size()); + clear(); + } +} + +void sycl_buffer_cache::clear() { + std::lock_guard lock{ guard }; + LOG_DEBUG("clear sycl buffer cache: size: ", cache.size()); + for (auto& key_value : cache) { + const sycl::context& ctx = std::get<1>(key_value.first); + if (ctx.get_backend() == sycl::backend::opencl) { + continue; + } + try { + sycl::free(key_value.second, ctx); + } + catch (sycl::exception& e) { + LOG_INFO("clear: got exception during sycl::free, ptr: ", key_value.second); + } + } + cache.clear(); +} + +void sycl_buffer_cache::get(size_t bytes, const sycl::context& ctx, void** pptr) { + if (global_data::env().enable_buffer_cache) { + std::lock_guard lock{ guard }; + key_t key(bytes, ctx); + auto key_value = cache.find(key); + if (key_value != cache.end()) { + *pptr = key_value->second; + cache.erase(key_value); + LOG_DEBUG("loaded from sycl buffer cache: bytes: ", bytes, ", ptr: ", *pptr); + return; + } + } + *pptr = sycl::aligned_alloc_host(64, bytes, ctx); +} + +void sycl_buffer_cache::push(size_t bytes, const sycl::context& ctx, void* ptr) { + if (global_data::env().enable_buffer_cache) { + std::lock_guard lock{ guard }; + key_t key(bytes, ctx); + cache.insert({ std::move(key), ptr }); + LOG_DEBUG("inserted to sycl buffer cache: bytes: ", bytes, ", ptr: ", ptr); + return; + } + try { + sycl::free(ptr, ctx); + } + catch (sycl::exception& e) { + LOG_INFO("push: got exception during sycl::free, ptr: ", ptr); + } +} +#endif // CCL_ENABLE_SYCL + +} // namespace ccl diff --git a/src/sched/buffer_cache.hpp b/src/sched/buffer_cache.hpp new file mode 100644 index 000000000..50b28e034 --- /dev/null +++ b/src/sched/buffer_cache.hpp @@ -0,0 +1,104 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include +#include + +#ifdef CCL_ENABLE_SYCL +#include +#include "common/utils/hash.hpp" +#endif // CCL_ENABLE_SYCL + +#include "common/utils/spinlock.hpp" + +namespace ccl { + +class regular_buffer_cache; +#ifdef CCL_ENABLE_SYCL +class sycl_buffer_cache; +#endif // CCL_ENABLE_SYCL + +class buffer_cache { +public: + buffer_cache(size_t instance_count) + : reg_buffers(instance_count) +#ifdef CCL_ENABLE_SYCL + , + sycl_buffers(instance_count) +#endif // CCL_ENABLE_SYCL + { + } + buffer_cache(const buffer_cache&) = delete; + buffer_cache& operator=(const buffer_cache&) = delete; + ~buffer_cache(); + + void get(size_t idx, size_t bytes, void** pptr); + + void push(size_t idx, size_t bytes, void* ptr); + +#ifdef CCL_ENABLE_SYCL + void get(size_t idx, size_t bytes, const sycl::context& ctx, void** pptr); + + void push(size_t idx, size_t bytes, const sycl::context& ctx, void* ptr); +#endif // CCL_ENABLE_SYCL + + using lock_t = ccl_spinlock; + +private: + std::vector reg_buffers; +#ifdef CCL_ENABLE_SYCL + std::vector sycl_buffers; +#endif // CCL_ENABLE_SYCL +}; + +class regular_buffer_cache { +public: + regular_buffer_cache() = default; + ~regular_buffer_cache(); + + void clear(); + void get(size_t bytes, void** pptr); + void push(size_t bytes, void* ptr); + +private: + buffer_cache::lock_t guard{}; + + using key_t = size_t; + using value_t = void*; + std::unordered_multimap cache; +}; + +#ifdef CCL_ENABLE_SYCL +class sycl_buffer_cache { +public: + sycl_buffer_cache() = default; + ~sycl_buffer_cache(); + + void clear(); + void get(size_t bytes, const sycl::context& ctx, void** pptr); + void push(size_t bytes, const sycl::context& ctx, void* ptr); + +private: + buffer_cache::lock_t guard{}; + + using key_t = typename std::tuple; + using value_t = void*; + std::unordered_multimap cache; +}; +#endif // CCL_ENABLE_SYCL + +} // namespace ccl diff --git a/src/sched/cache/key.cpp b/src/sched/cache/key.cpp index 7196e78ac..6f232d42f 100644 --- a/src/sched/cache/key.cpp +++ b/src/sched/cache/key.cpp @@ -19,6 +19,7 @@ #include "common/utils/enums.hpp" #include +#include std::map ccl_sched_key::key_type_names = { std::make_pair(ccl_cache_key_full, "full"), @@ -45,28 +46,31 @@ void ccl_sched_key::set(const ccl_coll_param& param, const ccl_coll_attr& attr) f.comm = param.comm; switch (f.ctype) { - case ccl_coll_allgatherv: f.count1 = param.send_count; break; + case ccl_coll_allgatherv: + f.count1 = param.get_send_count(); + vec1 = param.recv_counts; + break; case ccl_coll_allreduce: - f.count1 = param.count; + f.count1 = param.get_send_count(); f.reduction = param.reduction; break; - case ccl_coll_alltoall: f.count1 = param.count; break; + case ccl_coll_alltoall: f.count1 = param.get_send_count(); break; case ccl_coll_alltoallv: - f.buf1 = (void*)param.send_counts; - f.buf2 = (void*)param.recv_counts; + vec1 = param.send_counts; + vec2 = param.recv_counts; break; case ccl_coll_barrier: break; case ccl_coll_bcast: - f.count1 = param.count; + f.count1 = param.get_send_count(); f.root = param.root; break; case ccl_coll_reduce: - f.count1 = param.count; + f.count1 = param.get_send_count(); f.reduction = param.reduction; f.root = param.root; break; case ccl_coll_reduce_scatter: - f.count1 = param.count; + f.count1 = param.get_send_count(); f.reduction = param.reduction; break; case ccl_coll_sparse_allreduce: @@ -89,22 +93,26 @@ bool ccl_sched_key::check(const ccl_coll_param& param, const ccl_coll_attr& attr param.dtype.idx() == f.dtype || param.comm == f.comm); switch (f.ctype) { - case ccl_coll_allgatherv: result &= (param.send_count == f.count1); break; + case ccl_coll_allgatherv: + result &= (param.get_send_count() == f.count1 && param.recv_counts == vec1); + break; case ccl_coll_allreduce: - result &= (param.count == f.count1 && param.reduction == f.reduction); + result &= (param.get_send_count() == f.count1 && param.reduction == f.reduction); break; - case ccl_coll_alltoall: result &= (param.count == f.count1); break; + case ccl_coll_alltoall: result &= (param.get_send_count() == f.count1); break; case ccl_coll_alltoallv: - result &= (param.send_counts == f.buf1 && param.recv_counts == f.buf2); + result &= (param.send_counts == vec1 && param.recv_counts == vec2); break; case ccl_coll_barrier: break; - case ccl_coll_bcast: result &= (param.count == f.count1 && param.root == f.root); break; + case ccl_coll_bcast: + result &= (param.get_send_count() == f.count1 && param.root == f.root); + break; case ccl_coll_reduce: - result &= - (param.count == f.count1 && param.reduction == f.reduction && param.root == f.root); + result &= (param.get_send_count() == f.count1 && param.reduction == f.reduction && + param.root == f.root); break; case ccl_coll_reduce_scatter: - result &= (param.count == f.count1 && param.reduction == f.reduction); + result &= (param.get_send_count() == f.count1 && param.reduction == f.reduction); break; case ccl_coll_sparse_allreduce: result &= (param.sparse_param.send_ind_count == f.count1 && @@ -120,9 +128,13 @@ bool ccl_sched_key::check(const ccl_coll_param& param, const ccl_coll_attr& attr } bool ccl_sched_key::operator==(const ccl_sched_key& k) const { - bool are_fields_equal = (ccl::global_data::env().cache_key_type == ccl_cache_key_full) - ? !memcmp(&f, &(k.f), sizeof(ccl_sched_key_inner_fields)) - : 1; + bool are_fields_equal = 1; + if (ccl::global_data::env().cache_key_type == ccl_cache_key_full) { + are_fields_equal = !memcmp(&f, &(k.f), sizeof(ccl_sched_key_inner_fields)); + are_fields_equal &= (vec1 == k.vec1) ? 1 : 0; + are_fields_equal &= (vec2 == k.vec2) ? 1 : 0; + } + bool are_keys_equal = are_fields_equal && !match_id.compare(k.match_id); LOG_DEBUG("are_keys_equal ", are_keys_equal); @@ -133,7 +145,7 @@ bool ccl_sched_key::operator==(const ccl_sched_key& k) const { } void ccl_sched_key::print() const { - LOG_DEBUG("ctype ", + LOG_DEBUG("coll ", ccl_coll_type_to_str(f.ctype), ", dtype ", ccl::global_data::get().dtypes->name(f.dtype), @@ -163,6 +175,10 @@ void ccl_sched_key::print() const { (void*)f.epilogue_fn, ", reduction_fn ", (void*)f.reduction_fn, + ", vec1.size ", + vec1.size(), + ", vec2.size ", + vec2.size(), ", match_id ", match_id); } @@ -173,12 +189,16 @@ size_t ccl_sched_key_hasher::operator()(const ccl_sched_key& k) const { size_t hash_value = string_hasher(k.match_id); if (ccl::global_data::env().cache_key_type == ccl_cache_key_full) { + /* TODO: improve hashing for vec fields to reduce probability of collisions + e.g. sum(a[idx]*(idx+1)) */ + size_t vec1_sum = std::accumulate(k.vec1.begin(), k.vec1.end(), 0); + size_t vec2_sum = std::accumulate(k.vec2.begin(), k.vec2.end(), 0); hash_value += k.f.ctype + utils::enum_to_underlying(k.f.dtype) + utils::enum_to_underlying(k.f.itype) + utils::enum_to_underlying(k.f.reduction) + k.f.count1 + k.f.count2 + k.f.root + (size_t)k.f.buf1 + (size_t)k.f.buf2 + (size_t)k.f.count3 + (size_t)k.f.count4 + (size_t)k.f.comm + (size_t)k.f.prologue_fn + - (size_t)k.f.epilogue_fn + (size_t)k.f.reduction_fn; + (size_t)k.f.epilogue_fn + (size_t)k.f.reduction_fn + vec1_sum + vec2_sum; } const_cast(k).set_hasher_result(hash_value); diff --git a/src/sched/cache/key.hpp b/src/sched/cache/key.hpp index 71776d4c2..e8dae9a40 100644 --- a/src/sched/cache/key.hpp +++ b/src/sched/cache/key.hpp @@ -20,6 +20,7 @@ #include #include +#include enum ccl_cache_key_type { ccl_cache_key_full, @@ -77,6 +78,9 @@ class ccl_sched_key { /* inner structure for bit comparison */ ccl_sched_key_inner_fields f; + std::vector vec1; + std::vector vec2; + std::string match_id{}; bool operator==(const ccl_sched_key& k) const; diff --git a/src/sched/entry/coll/coll_entry.cpp b/src/sched/entry/coll/coll_entry.cpp index 7a97b75c0..14467a06c 100644 --- a/src/sched/entry/coll/coll_entry.cpp +++ b/src/sched/entry/coll/coll_entry.cpp @@ -24,7 +24,7 @@ void coll_entry::start() { if (!coll_sched) { ccl_coll_param coll_param{}; - coll_param.ctype = param.ctype; + coll_param.ctype = sched->coll_param.ctype; coll_param.comm = sched->coll_param.comm; coll_param.stream = sched->coll_param.stream; coll_sched.reset(new ccl_extra_sched(coll_param, sched->sched_id)); @@ -36,7 +36,7 @@ void coll_entry::start() { LOG_DEBUG("starting COLL entry: ", this, ", subsched: ", coll_sched.get()); auto req = sched->start_subsched(coll_sched.get()); - LOG_DEBUG(" started COLL entry: ", this, ", subsched ", coll_sched.get(), ", req ", req); + LOG_DEBUG("started COLL entry: ", this, ", subsched ", coll_sched.get(), ", req ", req); status = ccl_sched_entry_status_started; } diff --git a/src/sched/entry/coll/coll_entry.hpp b/src/sched/entry/coll/coll_entry.hpp index d3fa6bdb5..ef9a5edff 100644 --- a/src/sched/entry/coll/coll_entry.hpp +++ b/src/sched/entry/coll/coll_entry.hpp @@ -47,7 +47,12 @@ class coll_entry : public sched_entry, void update() override; bool is_strict_order_satisfied() override { +#ifdef CCL_ENABLE_SYCL + /* use more strict condition for SYCL build to handle async execution */ + return (coll_sched) ? coll_sched->is_completed() : false; +#else // CCL_ENABLE_SYCL return (coll_sched) ? coll_sched->is_strict_order_satisfied() : false; +#endif // CCL_ENABLE_SYCL } const char* name() const override { diff --git a/src/sched/entry/coll/coll_entry_helper.cpp b/src/sched/entry/coll/coll_entry_helper.cpp index 344f73b92..b878d73f8 100644 --- a/src/sched/entry/coll/coll_entry_helper.cpp +++ b/src/sched/entry/coll/coll_entry_helper.cpp @@ -28,6 +28,13 @@ ccl::status coll_entry_helper::build_schedule(ccl_sched* sched, sched->coll_attr.match_id = parent_sched->coll_attr.match_id; } } + sched->coll_attr.to_cache = parent_sched->coll_attr.to_cache; + +#ifdef CCL_ENABLE_SYCL + sched->coll_attr.is_sycl_buf = parent_sched->coll_attr.is_sycl_buf; +#endif // CCL_ENABLE_SYCL + + sched->hint_algo = param.hint_algo; switch (param.ctype) { case ccl_coll_allgatherv: { diff --git a/src/sched/entry/coll/coll_entry_helper.hpp b/src/sched/entry/coll/coll_entry_helper.hpp index ce222d855..f1d8860f4 100644 --- a/src/sched/entry/coll/coll_entry_helper.hpp +++ b/src/sched/entry/coll/coll_entry_helper.hpp @@ -26,36 +26,41 @@ class coll_entry_helper { template static coll_entry* add_coll_entry(ccl_sched* sched, const ccl_coll_entry_param& param) { CCL_THROW_IF_NOT(coll_id == param.ctype); - if (ccl::global_data::env().atl_transport == ccl_atl_mpi) { - ccl_selector_param selector_param; - selector_param.ctype = param.ctype; - selector_param.count = param.count; - selector_param.recv_counts = param.recv_counts; - selector_param.dtype = param.dtype; - selector_param.comm = param.comm; - if (param.ctype == ccl_coll_allgatherv) { - selector_param.count = param.send_count; - selector_param.vector_buf = sched->coll_attr.vector_buf; - } - bool is_direct_algo = - ccl::global_data::get().algorithm_selector->is_direct(selector_param); + ccl_selector_param selector_param; + selector_param.ctype = param.ctype; + selector_param.count = param.count; + if (param.ctype == ccl_coll_allgatherv) { + selector_param.count = param.send_count; + } + selector_param.recv_counts = param.recv_counts; + selector_param.dtype = param.dtype; + selector_param.comm = param.comm; + selector_param.stream = param.stream; + selector_param.is_vector_buf = sched->coll_attr.is_vector_buf; +#ifdef CCL_ENABLE_SYCL + selector_param.is_sycl_buf = sched->coll_attr.is_sycl_buf; +#endif // CCL_ENABLE_SYCL + selector_param.hint_algo = param.hint_algo; - if (is_direct_algo) { - if (sched->coll_attr.prologue_fn) { - /* - for direct MPI algo with prologue will use regular coll_entry - to simplify work with postponed fields - */ - sched->strict_order = true; - } - else { - /* otherwise will place entry directly into schedule due to performance reasons */ - auto res = coll_entry_helper::build_schedule(sched, sched, param); - CCL_ASSERT( - res == ccl::status::success, "error during build_schedule, res ", res); - return nullptr; /* coll_entry ptr is required for prologue case only */ - } + if (ccl_is_topo_ring_algo(selector_param)) { + sched->strict_order = true; + } + + if ((ccl::global_data::env().atl_transport == ccl_atl_mpi) && + ccl_is_direct_algo(selector_param)) { + if (sched->coll_attr.prologue_fn) { + /* + for direct MPI algo with prologue will use regular coll_entry + to simplify work with postponed fields + */ + sched->strict_order = true; + } + else { + /* otherwise will place entry directly into schedule due to performance reasons */ + auto res = coll_entry_helper::build_schedule(sched, sched, param); + CCL_ASSERT(res == ccl::status::success, "error during build_schedule, res ", res); + return nullptr; /* coll_entry ptr is required for prologue case only */ } } diff --git a/src/sched/entry/coll/coll_entry_param.hpp b/src/sched/entry/coll/coll_entry_param.hpp index 431bf0997..88db2aea5 100644 --- a/src/sched/entry/coll/coll_entry_param.hpp +++ b/src/sched/entry/coll/coll_entry_param.hpp @@ -30,21 +30,5 @@ struct ccl_coll_entry_param { int root; ccl_comm* comm; ccl_stream* stream; - - ccl_coll_param to_coll_param() const { - ccl_coll_param param; - param.ctype = ctype; - param.send_buf = send_buf.get_ptr(); - param.recv_buf = recv_buf.get_ptr(); - param.count = count; - param.send_count = send_count; - param.send_counts = send_counts; - param.recv_counts = recv_counts; - param.dtype = dtype; - param.reduction = reduction; - param.root = root; - param.comm = comm; - param.stream = stream; - return param; - } + ccl_coll_algo hint_algo; }; diff --git a/src/sched/entry/copy/copy_entry.cpp b/src/sched/entry/copy/copy_entry.cpp new file mode 100644 index 000000000..6d350ad19 --- /dev/null +++ b/src/sched/entry/copy/copy_entry.cpp @@ -0,0 +1,165 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "sched/entry/copy/copy_entry.hpp" +#include "sched/queue/queue.hpp" + +#ifdef CCL_ENABLE_SYCL +#include +#include +#endif // CCL_ENABLE_SYCL + +copy_entry::copy_entry(ccl_sched* sched, + ccl_buffer in_buf, + ccl_buffer out_buf, + size_t count, + const ccl_datatype& dtype, + copy_attr attr) + : +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + ze_copy_entry(sched, in_buf, out_buf, count, dtype, attr), +#else + sched_entry(sched), +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT + sched(sched), + in_buf(in_buf), + out_buf(out_buf), + count(count), + dtype(dtype), + attr(attr) { + CCL_THROW_IF_NOT(sched, "no sched"); +} + +void copy_entry::start() { + //update_fields(); + + LOG_DEBUG(class_name(), ": in_buf ", in_buf, ", out_buf ", out_buf, ", count ", count); + +#ifdef CCL_ENABLE_SYCL + int is_sycl_buf = sched->coll_attr.is_sycl_buf; + sycl::queue* q = nullptr; + + sycl::usm::alloc in_ptr_type = sycl::usm::alloc::unknown; + sycl::usm::alloc out_ptr_type = sycl::usm::alloc::unknown; + + if (sched->coll_param.stream) { + q = sched->coll_param.stream->get_native_stream(sched->queue->get_idx()); + CCL_THROW_IF_NOT(q, "null sycl queue"); + in_ptr_type = sycl::get_pointer_type(in_buf.get_ptr(), q->get_context()); + out_ptr_type = sycl::get_pointer_type(out_buf.get_ptr(), q->get_context()); + + LOG_DEBUG("in_ptr_type: ", + ccl::utils::usm_type_to_str(in_ptr_type), + ", out_ptr_type: ", + ccl::utils::usm_type_to_str(out_ptr_type)); + + if (attr.direction == copy_direction::undefined) { + if (in_ptr_type == sycl::usm::alloc::device && + out_ptr_type == sycl::usm::alloc::device) { + attr.direction = copy_direction::d2d; + } + + if ((in_ptr_type != sycl::usm::alloc::device) && + (out_ptr_type != sycl::usm::alloc::device)) { + attr.direction = copy_direction::h2h; + } + + if ((in_ptr_type == sycl::usm::alloc::device) && + (out_ptr_type != sycl::usm::alloc::device)) { + attr.direction = copy_direction::d2h; + } + + if ((in_ptr_type != sycl::usm::alloc::device) && + (out_ptr_type == sycl::usm::alloc::device)) { + attr.direction = copy_direction::h2d; + } + + CCL_THROW_IF_NOT(attr.direction != copy_direction::undefined); + } + } +#endif // CCL_ENABLE_SYCL + + LOG_DEBUG("count: ", count, ", direction: ", to_string(attr.direction)); + + if (!sched->coll_param.stream || (attr.direction == copy_direction::h2h)) { +#ifdef CCL_ENABLE_SYCL + CCL_THROW_IF_NOT(in_ptr_type != sycl::usm::alloc::device, + "unexpected device usm type for input buffer"); + CCL_THROW_IF_NOT(out_ptr_type != sycl::usm::alloc::device, + "unexpected device usm type for output buffer"); +#endif // CCL_ENABLE_SYCL + do_regular_copy(); + return; + } + +#ifdef CCL_ENABLE_SYCL + if (q->get_backend() != cl::sycl::backend::level_zero || is_sycl_buf) { + ctype = copy_type::sycl; + if (!is_sycl_buf) { + if ((in_ptr_type != sycl::usm::alloc::device) && + (out_ptr_type != sycl::usm::alloc::device)) { + do_regular_copy(); + return; + } + } + + copier = sycl_copier( + attr.direction, in_buf, out_buf, count, dtype, is_sycl_buf, attr.in_buf_offset); + copier.set_queue(q); + ccl_tuple_for_each_indexed(copier); + status = ccl_sched_entry_status_started; + } +#ifdef MULTI_GPU_SUPPORT + else { + ctype = copy_type::ze; + ze_copy_entry::start(); // status + } +#endif // MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL +} + +void copy_entry::update() { +#ifdef CCL_ENABLE_SYCL + if (ctype == copy_type::sycl) { + if (copier.is_completed()) { + status = ccl_sched_entry_status_complete; + } + } +#ifdef MULTI_GPU_SUPPORT + else { + ze_copy_entry::update(); + } +#endif // MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL +} + +void copy_entry::do_regular_copy() { + size_t bytes = dtype.size() * count; + auto comp_status = ccl_comp_copy(in_buf.get_ptr(bytes), out_buf.get_ptr(bytes), count, dtype); + CCL_ASSERT(comp_status == ccl::status::success, "bad status ", comp_status); + status = ccl_sched_entry_status_complete; +} + +const ccl_buffer& copy_entry::get_field_ref(field_id_t id) { + return in_buf; +} + +const size_t& copy_entry::get_field_ref(field_id_t id) { + return count; +} + +const ccl_datatype& copy_entry::get_field_ref(field_id_t id) { + return dtype; +} diff --git a/src/sched/entry/copy/copy_entry.hpp b/src/sched/entry/copy/copy_entry.hpp index d7a1e95e2..4d0ce48bb 100644 --- a/src/sched/entry/copy/copy_entry.hpp +++ b/src/sched/entry/copy/copy_entry.hpp @@ -18,11 +18,17 @@ #include "sched/entry/copy/copy_helper.hpp" #include "sched/entry/entry.hpp" -#ifdef CCL_ENABLE_SYCL -#include -#endif /* CCL_ENABLE_SYCL */ +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) +#include "sched/entry/gpu/ze_copy_entry.hpp" +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT + +enum class copy_type : int { regular, sycl, ze }; +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) +class copy_entry : public ze_copy_entry, +#else class copy_entry : public sched_entry, +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT public postponed_fieldscoll_param.stream; + copy_attr attr = {}); - if (!stream) { - do_regular_copy(); - return; - } + void start() override; + void update() override; - sycl::queue* q = stream->get_native_stream(sched->queue->get_idx()); - CCL_THROW_IF_NOT(q, "null sycl queue"); - auto in_ptr_type = sycl::get_pointer_type(in_buf.get_ptr(), q->get_context()); - auto out_ptr_type = sycl::get_pointer_type(out_buf.get_ptr(), q->get_context()); - - LOG_DEBUG("in_ptr_type: ", - native::detail::usm_to_string(in_ptr_type), - ", out_ptr_type: ", - native::detail::usm_to_string(out_ptr_type), - ", native_stream: ", - stream->to_string(), - ", count: ", - count) - - if ((in_ptr_type != sycl::usm::alloc::device) && - (out_ptr_type != sycl::usm::alloc::device)) { - do_regular_copy(); - return; - } - - copy_direction direction; - - if ((in_ptr_type == sycl::usm::alloc::device) && - (out_ptr_type == sycl::usm::alloc::device)) { - direction = copy_direction::d2d; - } - - if ((in_ptr_type == sycl::usm::alloc::host) && (out_ptr_type == sycl::usm::alloc::device)) { - direction = copy_direction::h2d; - } - - if ((in_ptr_type == sycl::usm::alloc::device) && (out_ptr_type == sycl::usm::alloc::host)) { - direction = copy_direction::d2h; - } - - copier = sycl_copier(direction, in_buf, out_buf, count, dtype, 0); - copier.set_queue(q); - ccl_tuple_for_each_indexed(copier); - status = ccl_sched_entry_status_started; -#else /* CCL_ENABLE_SYCL */ - do_regular_copy(); -#endif /* CCL_ENABLE_SYCL */ - } - - void update() override { -#ifdef CCL_ENABLE_SYCL - if (copier.is_completed()) { - status = ccl_sched_entry_status_complete; - } -#endif /* CCL_ENABLE_SYCL */ - } - - void do_regular_copy() { - size_t bytes = count * dtype.size(); - auto comp_status = - ccl_comp_copy(in_buf.get_ptr(bytes), out_buf.get_ptr(bytes), count, dtype); - CCL_ASSERT(comp_status == ccl::status::success, "bad status ", comp_status); - status = ccl_sched_entry_status_complete; - } - - const char* name() const override { - return class_name(); - } - - ccl_buffer& get_field_ref(field_id_t id) { - return in_buf; - } - - size_t& get_field_ref(field_id_t id) { - return count; - } - - ccl_datatype& get_field_ref(field_id_t id) { - return dtype; - } + const ccl_buffer& get_field_ref(field_id_t id); + const size_t& get_field_ref(field_id_t id); + const ccl_datatype& get_field_ref(field_id_t id); protected: void dump_detail(std::stringstream& str) const override { @@ -145,18 +69,22 @@ class copy_entry : public sched_entry, ", out_buf ", out_buf, ", in_buf_offset ", - in_buf_offset, + attr.in_buf_offset, "\n"); } private: - ccl_buffer in_buf; - ccl_buffer out_buf; - size_t count; - ccl_datatype dtype; - size_t in_buf_offset; + ccl_sched* const sched; + ccl_buffer in_buf{}; + ccl_buffer out_buf{}; + const size_t count; + const ccl_datatype dtype; + copy_attr attr; + copy_type ctype{ copy_type::regular }; #ifdef CCL_ENABLE_SYCL - sycl_copier copier; -#endif /* CCL_ENABLE_SYCL */ + sycl_copier copier{}; +#endif // CCL_ENABLE_SYCL + + void do_regular_copy(); }; diff --git a/src/sched/entry/copy/copy_helper.cpp b/src/sched/entry/copy/copy_helper.cpp index 8854d22de..f5d2db4e3 100644 --- a/src/sched/entry/copy/copy_helper.cpp +++ b/src/sched/entry/copy/copy_helper.cpp @@ -15,8 +15,27 @@ */ #include "sched/entry/copy/copy_helper.hpp" +copy_attr::copy_attr(int peer_rank, + size_t peer_buf_idx, + copy_direction direction, + ccl_comm* map_comm, + size_t in_buf_offset) + : peer_rank(peer_rank), + peer_buf_idx(peer_buf_idx), + direction(direction), + map_comm(map_comm), + in_buf_offset(in_buf_offset) {} + +copy_attr::copy_attr(copy_direction direction, size_t in_buf_offset) + : peer_rank(ccl_comm::invalid_rank), + peer_buf_idx(0), + direction(direction), + map_comm(nullptr), + in_buf_offset(in_buf_offset) {} + using copy_direction_str_enum = utils::enum_to_str; std::string to_string(copy_direction val) { - return copy_direction_str_enum({ "D2H", "H2D", "D2D" }).choose(val, "UNKNOWN"); + return copy_direction_str_enum({ "UNDEFINED", "H2H", "D2H", "H2D", "D2D" }) + .choose(val, "UNKNOWN"); } diff --git a/src/sched/entry/copy/copy_helper.hpp b/src/sched/entry/copy/copy_helper.hpp index ec8a44e33..1e3666885 100644 --- a/src/sched/entry/copy/copy_helper.hpp +++ b/src/sched/entry/copy/copy_helper.hpp @@ -20,11 +20,28 @@ #include "common/utils/buffer.hpp" #include "common/utils/enums.hpp" #include "common/utils/tuple.hpp" +#include "common/utils/sycl_utils.hpp" #include "oneapi/ccl/native_device_api/interop_utils.hpp" -enum class copy_direction { d2h, h2d, d2d }; +enum class copy_direction { undefined, h2h, d2h, h2d, d2d }; std::string to_string(copy_direction val); +struct copy_attr { + int peer_rank; + size_t peer_buf_idx; + copy_direction direction; + ccl_comm* map_comm; + size_t in_buf_offset; + + copy_attr(int peer_rank = ccl_comm::invalid_rank, + size_t peer_buf_idx = 0, + copy_direction direction = copy_direction::undefined, + ccl_comm* map_comm = nullptr, + size_t in_buf_offset = 0); + + copy_attr(copy_direction direction, size_t in_buf_offset = 0); +}; + #ifdef CCL_ENABLE_SYCL struct sycl_copier { @@ -34,12 +51,14 @@ struct sycl_copier { ccl_buffer out_buf, size_t count, const ccl_datatype& dtype, - size_t in_buf_offset) + bool is_sycl_buf = false, + size_t in_buf_offset = 0) : direction(direction), in_buf(in_buf), out_buf(out_buf), count(count), dtype(dtype), + is_sycl_buf(is_sycl_buf), in_buf_offset(in_buf_offset) {} bool is_completed() { @@ -69,13 +88,12 @@ struct sycl_copier { void* in_buf_ptr = in_buf.get_ptr(bytes); void* out_buf_ptr = out_buf.get_ptr(bytes); - size_t offset = in_buf_offset; - if (direction == copy_direction::d2d) { + CCL_THROW_IF_NOT(!is_sycl_buf, "D2D + SYCL buffer"); e = q->submit([&](sycl::handler& h) { h.memcpy(out_buf_ptr, static_cast(in_buf_ptr) + - offset, + in_buf_offset, bytes); }); return; @@ -83,33 +101,13 @@ struct sycl_copier { void* void_device_ptr = (direction == copy_direction::h2d) ? out_buf_ptr : in_buf_ptr; - /* - don't print this pointer through CCL logger - as in case of char/int8_t it will be interpreted as string - and logger will try access device memory - use void_device_ptr instead - */ - typename specific_sycl_buffer::value_type* device_ptr = - static_cast(void_device_ptr); - - auto device_ptr_type = sycl::get_pointer_type(device_ptr, q->get_context()); - - CCL_THROW_IF_NOT((device_ptr_type == sycl::usm::alloc::device || - device_ptr_type == sycl::usm::alloc::shared || - device_ptr_type == sycl::usm::alloc::unknown), - "unexpected USM type ", - native::detail::usm_to_string(device_ptr_type), - " for device_ptr ", - device_ptr); - - specific_sycl_buffer* device_buf_ptr = nullptr; - - if (device_ptr_type == sycl::usm::alloc::unknown) { - /* cast pointer into SYCL buffer */ - device_buf_ptr = static_cast(void_device_ptr); - } - else { - /* do nothing, provided USM pointer can be used as is in copy kernel */ + if (!is_sycl_buf) { + auto device_ptr_type = sycl::get_pointer_type(void_device_ptr, q->get_context()); + CCL_THROW_IF_NOT(device_ptr_type == sycl::usm::alloc::device, + "unexpected USM type ", + ccl::utils::usm_type_to_str(device_ptr_type), + " for device_ptr ", + void_device_ptr); } LOG_DEBUG("count: ", @@ -128,12 +126,14 @@ struct sycl_copier { out_buf_ptr, ", device_ptr: ", void_device_ptr, - ", is_device_usm: ", - (device_buf_ptr) ? "no" : "yes", - ", device_ptr usm_type: ", - native::detail::usm_to_string(device_ptr_type)); + ", is_sycl_buf: ", + (is_sycl_buf) ? "yes" : "no"); + + if (is_sycl_buf) { + /* cast device pointer into SYCL buffer */ + specific_sycl_buffer* device_buf_ptr = + static_cast(void_device_ptr); - if (device_buf_ptr) { specific_sycl_buffer host_buf( static_cast( (direction == copy_direction::h2d) ? in_buf_ptr : out_buf_ptr), @@ -143,20 +143,22 @@ struct sycl_copier { e = q->submit([&](sycl::handler& h) { auto& src_buf = (direction == copy_direction::h2d) ? host_buf : *device_buf_ptr; auto& dst_buf = (direction == copy_direction::h2d) ? *device_buf_ptr : host_buf; - auto src_buf_acc = - src_buf.template get_access(h, count, offset); + auto src_buf_acc = src_buf.template get_access( + h, count, in_buf_offset); auto dst_buf_acc = dst_buf.template get_access(h); h.copy(src_buf_acc, dst_buf_acc); }); } else { - e = q->submit([&](sycl::handler& h) { - h.memcpy(out_buf_ptr, - static_cast(in_buf_ptr) + - offset, - bytes); - }); + /* don't do special cast, provided USM pointer can be used as is in copy kernel */ + e = q->memcpy(out_buf_ptr, + static_cast(in_buf_ptr) + + in_buf_offset, + bytes); } + + /* TODO: fix parallel copies */ + e.wait(); } else { LOG_TRACE("visitor skipped index: ", @@ -173,9 +175,10 @@ struct sycl_copier { ccl_buffer out_buf; size_t count; ccl_datatype dtype; + bool is_sycl_buf; sycl::queue* q; size_t in_buf_offset; sycl::event e; }; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL diff --git a/src/sched/entry/copy/sycl_copy_entry.hpp b/src/sched/entry/copy/sycl_copy_entry.hpp deleted file mode 100644 index 2ae7a2729..000000000 --- a/src/sched/entry/copy/sycl_copy_entry.hpp +++ /dev/null @@ -1,99 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#ifdef CCL_ENABLE_SYCL - -#include "sched/entry/copy/copy_helper.hpp" -#include "sched/entry/entry.hpp" - -#include - -class sycl_copy_entry : public sched_entry { -public: - static constexpr const char* class_name() noexcept { - return "SYCL_COPY"; - } - - sycl_copy_entry() = delete; - sycl_copy_entry(ccl_sched* sched, - copy_direction direction, - ccl_buffer in_buf, - ccl_buffer out_buf, - size_t count, - const ccl_datatype& dtype, - const ccl_stream* stream, - size_t offset = 0) - : sched_entry(sched), - direction(direction), - in_buf(in_buf), - out_buf(out_buf), - count(count), - dtype(dtype), - stream(stream), - offset(offset), - copier(sycl_copier(direction, in_buf, out_buf, count, dtype, offset)) {} - - void start() override { - LOG_DEBUG(class_name(), ": in_buf ", in_buf, ", out_buf ", out_buf, ", count ", count); - - copier.set_queue(((ccl_stream*)stream)->get_native_stream(sched->queue->get_idx())); - ccl_tuple_for_each_indexed(copier); - status = ccl_sched_entry_status_started; - } - - void update() override { - if (copier.is_completed()) { - status = ccl_sched_entry_status_complete; - } - } - - const char* name() const override { - return class_name(); - } - -protected: - void dump_detail(std::stringstream& str) const override { - ccl_logger::format(str, - "direction ", - to_string(direction), - ", dtype ", - ccl::global_data::get().dtypes->name(dtype), - ", count ", - count, - ", in_buf ", - in_buf, - ", out_buf ", - out_buf, - ", native_stream ", - stream->to_string(), - ", offset ", - offset, - "\n"); - } - -private: - copy_direction direction; - ccl_buffer in_buf; - ccl_buffer out_buf; - size_t count; - ccl_datatype dtype; - const ccl_stream* stream; - size_t offset; - sycl_copier copier; -}; - -#endif /* CCL_ENABLE_SYCL */ diff --git a/src/sched/entry/deps_entry.hpp b/src/sched/entry/deps_entry.hpp index 81464cdac..8e23ebb69 100644 --- a/src/sched/entry/deps_entry.hpp +++ b/src/sched/entry/deps_entry.hpp @@ -12,16 +12,24 @@ class deps_entry : public sched_entry { deps_entry(ccl_sched* sched) : sched_entry(sched) {} void start() override { + status = ccl_sched_entry_status_started; + } + + void update() override { + bool all_completed = true; std::vector& deps = sched->get_deps(); + + // Note: ccl event caches the true result of test() method, so we can just iterate over whole + // array of deps each update() call without any overhead. for (size_t idx = 0; idx < deps.size(); idx++) { -#ifdef CCL_ENABLE_SYCL - /* TODO: detect pure sycl::event and ccl::event for device op */ - deps[idx].get_native().wait(); -#else /* CCL_ENABLE_SYCL */ - deps[idx].wait(); -#endif /* CCL_ENABLE_SYCL */ + bool completed = deps[idx].test(); + + all_completed = all_completed && completed; + } + + if (all_completed) { + status = ccl_sched_entry_status_complete; } - status = ccl_sched_entry_status_complete; } const char* name() const override { diff --git a/src/sched/entry/entry.cpp b/src/sched/entry/entry.cpp index 320838dc4..4cbfa18ee 100644 --- a/src/sched/entry/entry.cpp +++ b/src/sched/entry/entry.cpp @@ -13,53 +13,77 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "common/global/global.hpp" +#include "common/log/log.hpp" #include "sched/entry/entry.hpp" #include "sched/sched.hpp" -#include "common/log/log.hpp" void sched_entry::do_progress() { if (is_completed()) return; - // TODO: fix this tempropary workaround - // For l0 entry take_credit & return_credit isn't needed - // That's why we'd skip it - bool is_l0_entry = false; - const char* name_entry = this->name(); - - // in case if entry is empty name or its length = 1 - if (strlen(name_entry) >= 2) - is_l0_entry = name_entry[0] == 'L' && name_entry[1] == '0'; - if (status < ccl_sched_entry_status_started) { - CCL_ASSERT( + CCL_THROW_IF_NOT( status == ccl_sched_entry_status_not_started || status == ccl_sched_entry_status_again, "bad status ", - status); - - if (is_l0_entry || sched->flow_control.take_credit()) { - start(); - CCL_ASSERT(status >= ccl_sched_entry_status_again, "bad status ", status); + status, + "(", + status_to_str(status), + ")"); + + bool took_credits = false; + if (status == ccl_sched_entry_status_not_started) { + took_credits = sched->flow_control.take_credit(); + if (took_credits && ccl::global_data::env().sched_profile) { + timer.start(); + } } - else { - status = ccl_sched_entry_status_again; + else if (status == ccl_sched_entry_status_again) { + took_credits = true; + } + + if (!took_credits) { + return; } + + start(); + CCL_THROW_IF_NOT(status >= ccl_sched_entry_status_again, + "bad status ", + status, + "(", + status_to_str(status), + ")"); } else if (status == ccl_sched_entry_status_started) { LOG_TRACE("update entry ", name()); update(); - CCL_ASSERT(status >= ccl_sched_entry_status_started, "bad status ", status); + CCL_THROW_IF_NOT(status >= ccl_sched_entry_status_started, + "bad status ", + status, + "(", + status_to_str(status), + ")"); } - if (status == ccl_sched_entry_status_complete && !is_l0_entry) { - sched->flow_control.return_credit(); - } + if (status == ccl_sched_entry_status_complete) { + if (ccl::global_data::env().sched_profile) { + timer.stop(); + } - if (status == ccl_sched_entry_status_complete && exec_mode == ccl_sched_entry_exec_once) { - status = ccl_sched_entry_status_complete_once; + if (exec_mode == ccl_sched_entry_exec_once) { + status = ccl_sched_entry_status_complete_once; + } + + sched->flow_control.return_credit(); } - // TODO: what if status is ccl_sched_entry_status_failed or ccl_sched_entry_status_invalid? + CCL_THROW_IF_NOT( + status != ccl_sched_entry_status_failed && status != ccl_sched_entry_status_invalid, + "bad status ", + status, + "(", + status_to_str(status), + ")"); } bool sched_entry::is_completed() { @@ -69,12 +93,16 @@ bool sched_entry::is_completed() { void sched_entry::update() { /* - update is required for communication/synchronization/wait_value ops + update is required for async ops (atl, ze, sync) for other ops it is empty method */ } void sched_entry::reset(size_t idx) { + if (ccl::global_data::env().sched_profile) { + timer.reset(); + } + if (status == ccl_sched_entry_status_complete_once) { return; } @@ -84,7 +112,7 @@ void sched_entry::reset(size_t idx) { } bool sched_entry::is_strict_order_satisfied() { - return (status > ccl_sched_entry_status_not_started); + return (status >= ccl_sched_entry_status_started); } void sched_entry::dump(std::stringstream& str, size_t idx) const { diff --git a/src/sched/entry/entry.hpp b/src/sched/entry/entry.hpp index f816bd59a..5ca1b0bcb 100644 --- a/src/sched/entry/entry.hpp +++ b/src/sched/entry/entry.hpp @@ -18,10 +18,10 @@ #include "atl/atl.h" #include "common/datatype/datatype.hpp" #include "common/utils/utils.hpp" +#include "sched/sched_timer.hpp" #include "sched/entry/postponed_fields.hpp" #include "internal_types.hpp" -#include #include typedef ccl::status (*ccl_sched_entry_function_t)(const void*); @@ -61,7 +61,7 @@ class alignas(CACHELINE_SIZE) sched_entry { void do_progress(); bool is_completed(); - virtual void reset(size_t start_idx); + virtual void reset(size_t idx); virtual bool is_strict_order_satisfied(); @@ -77,6 +77,8 @@ class alignas(CACHELINE_SIZE) sched_entry { static const char* status_to_str(ccl_sched_entry_status status); + ccl::sched_timer timer; + protected: virtual void start() = 0; virtual void update(); diff --git a/src/sched/entry/factory/entry_factory.hpp b/src/sched/entry/factory/entry_factory.hpp index 83a3ab33e..90f8c51af 100644 --- a/src/sched/entry/factory/entry_factory.hpp +++ b/src/sched/entry/factory/entry_factory.hpp @@ -28,9 +28,6 @@ #include "sched/entry/factory/entry_factory.h" #include "sched/entry/copy/copy_entry.hpp" -#ifdef CCL_ENABLE_SYCL -#include "sched/entry/copy/sycl_copy_entry.hpp" -#endif /* CCL_ENABLE_SYCL */ #include "sched/entry/deps_entry.hpp" #include "sched/entry/deregister_entry.hpp" #include "sched/entry/epilogue_entry.hpp" @@ -48,6 +45,15 @@ #include "sched/entry/wait_value_entry.hpp" #include "sched/entry/write_entry.hpp" +#if defined(MULTI_GPU_SUPPORT) && defined(CCL_ENABLE_SYCL) +#include "sched/entry/gpu/ze_allreduce_entry.hpp" +#include "sched/entry/gpu/ze_copy_entry.hpp" +#include "sched/entry/gpu/ze_handle_exchange_entry.hpp" +#include "sched/entry/gpu/ze_event_signal_entry.hpp" +#include "sched/entry/gpu/ze_event_wait_entry.hpp" +#include "sched/entry/gpu/ze_reduce_entry.hpp" +#endif // MULTI_GPU_SUPPORT && CCL_ENABLE_SYCL + #include "sched/sched.hpp" namespace entry_factory { diff --git a/src/sched/entry/gpu/ze_allreduce_entry.cpp b/src/sched/entry/gpu/ze_allreduce_entry.cpp new file mode 100644 index 000000000..822547c27 --- /dev/null +++ b/src/sched/entry/gpu/ze_allreduce_entry.cpp @@ -0,0 +1,299 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "common/comm/l0/modules/kernel_utils.hpp" +#include "common/stream/stream.hpp" +#include "sched/entry/gpu/ze_primitives.hpp" +#include "sched/entry/gpu/ze_cache.hpp" +#include "sched/entry/gpu/ze_allreduce_entry.hpp" +#include "sched/queue/queue.hpp" + +#include + +using namespace ccl; +using namespace ccl::ze; + +ze_allreduce_entry::ze_allreduce_entry(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t cnt, + const ccl_datatype& dtype, + reduction op, + ccl_comm* comm) + : ze_base_entry(sched, comm, local_events_count /* request additional events */), + send_buf(send_buf), + recv_buf(recv_buf), + cnt(cnt), + dtype(dtype), + op(op), + buf_size_bytes(dtype.size() * cnt) {} + +ze_allreduce_entry::~ze_allreduce_entry() { + finalize(); +} + +void ze_allreduce_entry::init() { + if (ze_base_entry::is_initialized) { + return; + } + + LOG_DEBUG("initialization"); + + init_mode init_mode_type; + if (global_data::env().enable_kernel_1s_copy_ops) { + init_mode_type = (init_mode::copy | init_mode::compute); + } + else { + init_mode_type = init_mode::compute; + } + + ze_base_entry::init(init_mode_type); + + /* create kernels */ + ccl_buffer right_send_buf; + ccl_buffer right_recv_buf; + int peer_rank = (comm_rank + 1) % comm_size; + + send_buf_ptr = send_buf.get_ptr(); + recv_buf_ptr = recv_buf.get_ptr(); + if (send_buf_ptr == recv_buf_ptr) { + sched->get_memory().handle_manager.get(peer_rank, 1, right_send_buf, comm); + sched->get_memory().handle_manager.get(peer_rank, 1, right_recv_buf, comm); + } + else { + sched->get_memory().handle_manager.get(peer_rank, 0, right_send_buf, comm); + sched->get_memory().handle_manager.get(peer_rank, 1, right_recv_buf, comm); + } + right_send_buf_ptr = right_send_buf.get_ptr(); + right_recv_buf_ptr = right_recv_buf.get_ptr(); + + ze_kernel_args_t allreduce_kernel_args = { { sizeof(comm_rank), &comm_rank }, + { sizeof(comm_size), &comm_size }, + { sizeof(cnt), &cnt }, + { sizeof(send_buf_ptr), &send_buf_ptr }, + { sizeof(recv_buf_ptr), &recv_buf_ptr }, + { sizeof(right_send_buf_ptr), &right_send_buf_ptr }, + { sizeof(right_recv_buf_ptr), + &right_recv_buf_ptr } }; + + ze_kernel_args_t reduce_local_kernel_args = { { sizeof(comm_rank), &comm_rank }, + { sizeof(comm_size), &comm_size }, + { sizeof(cnt), &cnt }, + { sizeof(send_buf_ptr), &send_buf_ptr }, + { sizeof(tmp_buf_ptr), &tmp_buf_ptr }, + { sizeof(recv_buf_ptr), &recv_buf_ptr } }; + + global_data::get().ze_cache->get(context, device, "kernels.spv", &module); + + if (global_data::env().enable_kernel_1s_copy_ops) { + main_kernel_name = "reduce_local_outofplace_kernel_"; + device_mem_alloc_desc = default_device_mem_alloc_desc; + global_data::get().ze_cache->get(worker_idx, + context, + device, + device_mem_alloc_desc, + buf_size_bytes, + 0, /*alignment*/ + &tmp_buf_ptr); + } + else { + main_kernel_name = "allreduce_kernel_"; + } + main_kernel_name += to_string(dtype.idx()) + "_" + ccl_reduction_to_str(op); + LOG_DEBUG("get kernel: name: ", main_kernel_name); + global_data::get().ze_cache->get(worker_idx, module, main_kernel_name, &main_kernel); + + auto& main_kernel_args = (global_data::env().enable_kernel_1s_copy_ops) + ? reduce_local_kernel_args + : allreduce_kernel_args; + LOG_DEBUG("kernel ", main_kernel, " args:\n", to_string(main_kernel_args)); + set_kernel_args(main_kernel, main_kernel_args); + + ze_group_size_t group_size; + get_suggested_group_size(main_kernel, cnt, &group_size); + LOG_DEBUG("suggested group size: ", to_string(group_size)); + + get_suggested_group_count(group_size, cnt, &group_count); + LOG_DEBUG("suggested group count: ", to_string(group_count)); + + ZE_CALL(zeKernelSetGroupSize, + (main_kernel, group_size.groupSizeX, group_size.groupSizeY, group_size.groupSizeZ)); + + if (global_data::env().enable_kernel_1s_ipc_wa) { + LOG_DEBUG("get kernel: name: ", empty_kernel_name); + global_data::get().ze_cache->get(worker_idx, module, empty_kernel_name, &empty_kernel); + CCL_THROW_IF_NOT(empty_kernel, "null empty_kernel"); + /* use allreduce_kernel_args since they have pointers to peer mem */ + set_kernel_args(empty_kernel, allreduce_kernel_args); + } + + ze_event_desc_t event_desc = default_event_desc; + event_desc.signal = ZE_EVENT_SCOPE_FLAG_SUBDEVICE; + event_desc.wait = ZE_EVENT_SCOPE_FLAG_SUBDEVICE; + + uint32_t last_event_idx = 1; // 0 is used to track entry progress + + if (empty_kernel) { + LOG_DEBUG("create event for empty kernel"); + event_desc.index = last_event_idx++; + ZE_CALL(zeEventCreate, (event_pool, &event_desc, &empty_kernel_event)); + } + + if (global_data::env().enable_kernel_1s_copy_ops) { + event_desc.index = last_event_idx++; + ZE_CALL(zeEventCreate, (event_pool, &event_desc, ©_from_peer_event)); + event_desc.index = last_event_idx++; + ZE_CALL(zeEventCreate, (event_pool, &event_desc, &reduce_local_kernel_event)); + } + + LOG_DEBUG("real event count: ", last_event_idx); + + /* do appends */ + if (empty_kernel) { + LOG_DEBUG("append empty kernel"); + ze_group_count_t empty_group_count = { 1, 1, 1 }; + ZE_CALL(zeCommandListAppendLaunchKernel, + (ze_base_entry::comp_primitives.list, + empty_kernel, + &empty_group_count, + empty_kernel_event, + 0, + nullptr)); + } + + if (global_data::env().enable_kernel_1s_copy_ops) { + LOG_DEBUG("one-sided multi-phase algorithm"); + + ZE_CALL(zeCommandListAppendMemoryCopy, + (ze_base_entry::get_copy_list(), + tmp_buf_ptr, + right_send_buf_ptr, + buf_size_bytes, + copy_from_peer_event, + (empty_kernel_event) ? 1 : 0, + &empty_kernel_event)); + + ZE_CALL(zeCommandListAppendLaunchKernel, + (ze_base_entry::comp_primitives.list, + main_kernel, + &group_count, + reduce_local_kernel_event, + 1, + ©_from_peer_event)); + + ZE_CALL(zeCommandListAppendMemoryCopy, + (ze_base_entry::get_copy_list(), + right_recv_buf_ptr, + recv_buf_ptr, + buf_size_bytes, + ze_base_entry::entry_event, + 1, + &reduce_local_kernel_event)); + } + else { + LOG_DEBUG("one-sided monolithic algorithm"); + ZE_CALL(zeCommandListAppendLaunchKernel, + (ze_base_entry::comp_primitives.list, + main_kernel, + &group_count, + ze_base_entry::entry_event, + (empty_kernel_event) ? 1 : 0, + &empty_kernel_event)); + } + + ZE_CALL(zeCommandListClose, (ze_base_entry::comp_primitives.list)); + if (global_data::env().enable_kernel_1s_copy_ops) { + ZE_CALL(zeCommandListClose, (ze_base_entry::copy_primitives.list)); + } + LOG_DEBUG("initialization complete"); +} + +void ze_allreduce_entry::start() { + init(); + + if (ze_base_entry::is_initialized && status == ccl_sched_entry_status_not_started) { + reset_sync_objects(); + } + + size_t kernel_counter = 0; + if (global_data::env().enable_kernel_sync) { + kernel_counter = global_data::get().kernel_counter++; + } + + if (kernel_counter == 0) { + ze_base_entry::start(); + status = ccl_sched_entry_status_started; + } + else { + global_data::get().kernel_counter--; + status = ccl_sched_entry_status_again; + } +} + +void ze_allreduce_entry::update() { + ze_base_entry::update(); + if (status == ccl_sched_entry_status_complete && !sched->coll_attr.to_cache) { + finalize(); + } + + if (global_data::env().enable_kernel_sync && global_data::get().kernel_counter > 0) { + global_data::get().kernel_counter--; + } +} + +void ze_allreduce_entry::finalize() { + if (!ze_base_entry::is_initialized) { + return; + } + + LOG_DEBUG("finalization"); + + /* events */ + if (global_data::env().enable_kernel_1s_copy_ops) { + LOG_DEBUG("copy ops finalization"); + ZE_CALL(zeEventDestroy, (copy_from_peer_event)); + ZE_CALL(zeEventDestroy, (reduce_local_kernel_event)); + /* device mem */ + global_data::get().ze_cache->push(worker_idx, + context, + device, + device_mem_alloc_desc, + buf_size_bytes, + 0, /*alignment*/ + tmp_buf_ptr); + } + + /* kernels */ + if (empty_kernel_event) { + ZE_CALL(zeEventDestroy, (empty_kernel_event)); + global_data::get().ze_cache->push(worker_idx, module, empty_kernel_name, empty_kernel); + } + global_data::get().ze_cache->push(worker_idx, module, main_kernel_name, main_kernel); + + ze_base_entry::finalize(); + + LOG_DEBUG("finalization complete"); +} + +void ze_allreduce_entry::reset_sync_objects() { + if (empty_kernel_event) { + ZE_CALL(zeEventHostReset, (empty_kernel_event)); + } + + if (global_data::env().enable_kernel_1s_copy_ops) { + ZE_CALL(zeEventHostReset, (copy_from_peer_event)); + ZE_CALL(zeEventHostReset, (reduce_local_kernel_event)); + } +} diff --git a/src/sched/entry/gpu/ze_allreduce_entry.hpp b/src/sched/entry/gpu/ze_allreduce_entry.hpp new file mode 100644 index 000000000..6231e44c0 --- /dev/null +++ b/src/sched/entry/gpu/ze_allreduce_entry.hpp @@ -0,0 +1,106 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "common/utils/buffer.hpp" +#include "comp/comp.hpp" +#include "sched/entry/gpu/ze_base_entry.hpp" + +#include +#include + +class ze_allreduce_entry : public ze_base_entry { +public: + static constexpr const char* class_name() noexcept { + return "ZE_ALLREDUCE"; + } + + const char* name() const noexcept override { + return class_name(); + } + + ze_allreduce_entry() = delete; + explicit ze_allreduce_entry(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t cnt, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm); + ~ze_allreduce_entry(); + + void init(); + void start() override; + void update() override; + void finalize(); + + void reset_sync_objects(); + + bool is_strict_order_satisfied() override { + return (status >= ccl_sched_entry_status_complete); + } + +protected: + void dump_detail(std::stringstream& str) const override { + ccl_logger::format(str, + "dt ", + ccl::global_data::get().dtypes->name(dtype), + ", cnt ", + cnt, + ", send_buf ", + send_buf, + ", recv_buf ", + recv_buf, + ", op ", + ccl_reduction_to_str(op), + ", comm_id ", + sched->get_comm_id(), + ", context ", + context, + "\n"); + } + +private: + static constexpr uint32_t local_events_count{ 3 }; + + const ccl_buffer send_buf; + const ccl_buffer recv_buf; + void* send_buf_ptr{}; + void* recv_buf_ptr{}; + void* right_send_buf_ptr{}; + void* right_recv_buf_ptr{}; + void* tmp_buf_ptr{}; + const unsigned long cnt; + const ccl_datatype dtype; + const ccl::reduction op; + const size_t buf_size_bytes; + + ze_event_handle_t empty_kernel_event{}; + ze_event_handle_t copy_from_peer_event{}; + ze_event_handle_t reduce_local_kernel_event{}; + + ze_module_handle_t module{}; + + ze_group_count_t group_count{}; + + ze_kernel_handle_t main_kernel{}; + std::string main_kernel_name{}; + + ze_kernel_handle_t empty_kernel{}; + std::string empty_kernel_name{ "empty_kernel" }; + + ze_device_mem_alloc_desc_t device_mem_alloc_desc; +}; diff --git a/src/sched/entry/gpu/ze_base_entry.cpp b/src/sched/entry/gpu/ze_base_entry.cpp new file mode 100644 index 000000000..a8998634b --- /dev/null +++ b/src/sched/entry/gpu/ze_base_entry.cpp @@ -0,0 +1,245 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "common/stream/stream.hpp" +#include "sched/queue/queue.hpp" + +#include "sched/entry/gpu/ze_base_entry.hpp" +#include "sched/entry/gpu/ze_cache.hpp" +#include "sched/entry/gpu/ze_call.hpp" +#include "ze_primitives.hpp" + +#include + +using namespace ccl; +using namespace ccl::ze; + +ze_base_entry::ze_base_entry(ccl_sched *sched, ccl_comm *comm, uint32_t add_event_count) + : sched_entry(sched), + sched(sched), + comm(comm), + add_event_count(add_event_count) { + CCL_THROW_IF_NOT(sched, "no sched"); + if (!comm) { + comm = sched->coll_param.comm; + } + CCL_THROW_IF_NOT(comm, "no comm"); + comm_rank = comm->rank(); + comm_size = comm->size(); +} + +void ze_base_entry::init(init_mode ze_init_mode) { + if (is_initialized) { + return; + } + worker_idx = sched->queue->get_idx(); + + CCL_THROW_IF_NOT(sched->coll_param.stream, "null stream"); + + LOG_DEBUG("getting a native stream"); + auto native_stream = sched->coll_param.stream->get_native_stream(worker_idx); + if (native_stream->get_backend() != sycl::backend::level_zero) { + CCL_THROW("unsupported sycl backend"); + } + + auto sycl_device = native_stream->get_device(); + device = sycl_device.template get_native(); + + auto sycl_context = native_stream->get_context(); + context = sycl_context.template get_native(); + + /* get queue properties */ + uint32_t num_queue_groups; + get_num_queue_groups(device, &num_queue_groups); + + ze_queue_properties_t queue_props; + get_queues_properties(device, num_queue_groups, &queue_props); + + /* init compute queue, list */ + if (init_mode::compute & ze_init_mode) { + LOG_DEBUG("compute init mode is enabled"); + get_comp_primitives(queue_props, comp_primitives); + init_primitives(comp_primitives); + } + + /* init copy queue, list */ + if (init_mode::copy & ze_init_mode) { + LOG_DEBUG("copy init mode is enabled"); + get_copy_primitives(queue_props, copy_primitives, ze_init_mode); + init_primitives(copy_primitives); + } + + /* create event pool */ + event_pool_desc = default_event_pool_desc; + event_pool_desc.count = 1 + add_event_count; // at least one event to track progress + global_data::get().ze_cache->get(worker_idx, context, event_pool_desc, &event_pool); + LOG_DEBUG("get event pool: { max event count: ", event_pool_desc.count, " }"); + + /* create event */ + ze_event_desc_t event_desc = default_event_desc; + event_desc.signal = ZE_EVENT_SCOPE_FLAG_SUBDEVICE; + event_desc.wait = ZE_EVENT_SCOPE_FLAG_SUBDEVICE; + event_desc.index = 0; + ZE_CALL(zeEventCreate, (event_pool, &event_desc, &entry_event)); + + is_initialized = true; +} + +void ze_base_entry::finalize() { + if (!is_initialized) { + return; + } + ZE_CALL(zeEventDestroy, (entry_event)); + + /* event pool */ + global_data::get().ze_cache->push(worker_idx, context, event_pool_desc, event_pool); + + if (comp_primitives.list && comp_primitives.queue) { + LOG_DEBUG("push from cache for compute list and queue"); + /* list */ + global_data::get().ze_cache->push( + worker_idx, context, device, comp_primitives.list_desc, comp_primitives.list); + + /* queue */ + global_data::get().ze_cache->push( + worker_idx, context, device, comp_primitives.queue_desc, comp_primitives.queue); + } + + if (copy_primitives.list && copy_primitives.queue) { + LOG_DEBUG("push from cache for copy list and queue"); + /* copy list */ + global_data::get().ze_cache->push( + worker_idx, context, device, copy_primitives.list_desc, copy_primitives.list); + + /* copy queue */ + global_data::get().ze_cache->push( + worker_idx, context, device, copy_primitives.queue_desc, copy_primitives.queue); + } + + is_initialized = false; +} + +void ze_base_entry::start() { + CCL_THROW_IF_NOT(entry_event, "no entry event"); + ZE_CALL(zeEventHostReset, (entry_event)); + + if (comp_primitives.list && comp_primitives.queue) { + LOG_DEBUG("execute compute command list"); + ZE_CALL(zeCommandQueueExecuteCommandLists, + (comp_primitives.queue, 1, &comp_primitives.list, nullptr)); + } + + if (copy_primitives.list && copy_primitives.queue) { + LOG_DEBUG("execute copy command list"); + ZE_CALL(zeCommandQueueExecuteCommandLists, + (copy_primitives.queue, 1, ©_primitives.list, nullptr)); + } + + if (((global_data::env().ze_serialize_mode & ze_call::serialize_mode::block)) != 0) { + LOG_DEBUG("wait until command lists are executed"); + if (copy_primitives.queue) + ZE_CALL(zeHostSynchronize, (copy_primitives.queue)); + if (comp_primitives.queue) + ZE_CALL(zeHostSynchronize, (comp_primitives.queue)); + } +} + +void ze_base_entry::update() { + ze_result_t query_status; + + if (global_data::env().kernel_debug == 0) { + query_status = zeEventQueryStatus(entry_event); + } + else { + if (copy_primitives.queue) + query_status = zeHostSynchronize(copy_primitives.queue); + if (comp_primitives.queue) + query_status = zeHostSynchronize(comp_primitives.queue); + } + + if (query_status == ZE_RESULT_SUCCESS) { + LOG_DEBUG("command list complete"); + status = ccl_sched_entry_status_complete; + } + else if (query_status == ZE_RESULT_NOT_READY) { + // just return in case if the kernel is not ready yet, will check again on the next iteration + return; + } + else { + CCL_THROW("error at zeEventQueryStatus"); + } +} + +ze_command_list_handle_t ze_base_entry::get_copy_list() { + ze_command_list_handle_t list = nullptr; + if (copy_primitives.list) { + list = copy_primitives.list; + LOG_DEBUG("copy list is returned"); + } + else { + list = comp_primitives.list; + LOG_DEBUG("compute list is returned"); + } + CCL_THROW_IF_NOT(list, "command list is invalid"); + return list; +} + +void ze_base_entry::get_comp_primitives(const ze_queue_properties_t &queue_props, + cmd_primitives &comp_primitives) { + uint32_t ordinal, queue_index; + get_comp_queue_ordinal(device, queue_props, &ordinal); + get_queue_index(queue_props, ordinal, comm_rank, &queue_index); + + comp_primitives.queue_desc.ordinal = ordinal; + comp_primitives.queue_desc.index = queue_index; + comp_primitives.list_desc.commandQueueGroupOrdinal = ordinal; +} + +void ze_base_entry::get_copy_primitives(const ze_queue_properties_t &queue_props, + cmd_primitives ©_primitives, + init_mode ze_init_mode) { + uint32_t ordinal, queue_index; + get_copy_queue_ordinal(device, queue_props, &ordinal); + + // TODO: index depends on rank's changing, when > 1 queues are created, + // the index is still the same for different queues, that's the issue. + // WA is adding optional counter, which says the order number of a queue. + // Need to think, how we'd calculate the index for every queue. + // Hang in case of CCL_KERNEL_1S_USE_COPY_OPS=1 CCL_ZE_COPY_ENGINE=none + if (ze_init_mode == (init_mode::copy | init_mode::compute)) { + get_queue_index(queue_props, ordinal, comm_rank + 1, &queue_index); + } + else { + get_queue_index(queue_props, ordinal, comm_rank, &queue_index); + } + + copy_primitives.queue_desc.ordinal = ordinal; + copy_primitives.queue_desc.index = queue_index; + copy_primitives.list_desc.commandQueueGroupOrdinal = ordinal; +} + +void ze_base_entry::init_primitives(cmd_primitives &cmd_primitives) { + global_data::get().ze_cache->get( + worker_idx, context, device, cmd_primitives.queue_desc, &cmd_primitives.queue); + LOG_DEBUG("get queue: { ordinal: ", + cmd_primitives.queue_desc.ordinal, + ", index: ", + cmd_primitives.queue_desc.index, + " }"); + + global_data::get().ze_cache->get( + worker_idx, context, device, cmd_primitives.list_desc, &cmd_primitives.list); + LOG_DEBUG("get list: { ordinal: ", cmd_primitives.list_desc.commandQueueGroupOrdinal, " }"); +} diff --git a/src/sched/entry/gpu/ze_base_entry.hpp b/src/sched/entry/gpu/ze_base_entry.hpp new file mode 100644 index 000000000..b04d3f4df --- /dev/null +++ b/src/sched/entry/gpu/ze_base_entry.hpp @@ -0,0 +1,79 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "common/comm/comm.hpp" +#include "common/global/global.hpp" +#include "sched/sched.hpp" +#include "sched/entry/entry.hpp" + +#include + +using namespace ccl::ze; + +struct cmd_primitives { + ze_command_queue_handle_t queue{}; + ze_command_queue_desc_t queue_desc{ default_cmd_queue_desc }; + ze_command_list_handle_t list{}; + ze_command_list_desc_t list_desc{ default_cmd_list_desc }; +}; + +class ze_base_entry : public sched_entry { +public: + ze_base_entry() = delete; + ze_base_entry(const ze_base_entry &) = delete; + virtual ~ze_base_entry(){}; + +protected: + explicit ze_base_entry(ccl_sched *sched, + ccl_comm *comm = nullptr, + uint32_t add_event_count = 0); + + void init(init_mode ze_init_mode); + virtual void start() override; + virtual void update() override; + void finalize(); + + ze_command_list_handle_t get_copy_list(); + + void init_primitives(cmd_primitives &cmd_primitives); + void get_copy_primitives(const ze_queue_properties_t &queue_props, + cmd_primitives ©_primitives, + init_mode ze_init_mode); + void get_comp_primitives(const ze_queue_properties_t &queue_props, + cmd_primitives &comp_primitives); + + ccl_sched *const sched; + + ccl_comm *comm{}; + int comm_rank{}; + int comm_size{}; + + size_t worker_idx{}; + + bool is_initialized{}; + + ze_device_handle_t device{}; + ze_context_handle_t context{}; + + cmd_primitives comp_primitives{}; + cmd_primitives copy_primitives{}; + + ze_event_pool_desc_t event_pool_desc{}; + ze_event_pool_handle_t event_pool{}; + ze_event_handle_t entry_event{}; + const uint32_t add_event_count; +}; diff --git a/src/sched/entry/gpu/ze_cache.cpp b/src/sched/entry/gpu/ze_cache.cpp new file mode 100644 index 000000000..ea3b55e87 --- /dev/null +++ b/src/sched/entry/gpu/ze_cache.cpp @@ -0,0 +1,411 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "common/global/global.hpp" +#include "sched/entry/gpu/ze_cache.hpp" + +#include + +namespace ccl { +namespace ze { + +template +bool get_from_cache(map_t& cache, typename map_t::mapped_type& object, keys_t... keys) { + bool success{}; + + if (!global_data::env().enable_kernel_cache) + return success; + + typename map_t::key_type key(keys...); + auto key_value = cache.find(key); + if (key_value != cache.end()) { + object = key_value->second; + cache.erase(key_value); + LOG_DEBUG("loaded from cache: object: ", object); + success = true; + } + return success; +} + +template +bool push_to_cache(map_t& cache, const typename map_t::mapped_type& object, keys_t... keys) { + bool success{}; + + if (!global_data::env().enable_kernel_cache) + return success; + + typename map_t::key_type key(keys...); + auto range = cache.equal_range(key); + auto range_len = std::distance(range.first, range.second); + if (range_len > 0) { + LOG_DEBUG("cache already contain ", range_len, " objects with the same key"); + for (auto i = range.first; i != range.second; ++i) { + CCL_THROW_IF_NOT(i->second != object, "trying to push object that already exists"); + } + } + cache.insert({ std::move(key), object }); + LOG_DEBUG("inserted to cache: object: ", object); + success = true; + return success; +} + +// fence_cache +fence_cache::~fence_cache() { + if (!cache.empty()) { + LOG_WARN("fence cache is not empty, size: ", cache.size()); + clear(); + } +} + +void fence_cache::clear() { + LOG_DEBUG("clear fence cache: size: ", cache.size()); + for (auto& key_value : cache) { + ZE_CALL(zeFenceDestroy, (key_value.second)); + } + cache.clear(); +} + +void fence_cache::get(ze_command_queue_handle_t queue, + const ze_fence_desc_t& fence_desc, + ze_fence_handle_t* fence) { + CCL_THROW_IF_NOT(queue); + CCL_THROW_IF_NOT(fence); + if (get_from_cache(cache, *fence, queue)) { + ZE_CALL(zeFenceReset, (*fence)); + } + else { + ZE_CALL(zeFenceCreate, (queue, &fence_desc, fence)); + } +} + +void fence_cache::push(ze_command_queue_handle_t queue, + const ze_fence_desc_t& fence_desc, + ze_fence_handle_t fence) { + CCL_THROW_IF_NOT(queue); + CCL_THROW_IF_NOT(fence); + if (!push_to_cache(cache, fence, queue)) { + zeFenceDestroy(fence); + } +} + +// kernel_cache +kernel_cache::~kernel_cache() { + if (!cache.empty()) { + LOG_WARN("kernel cache is not empty, size: ", cache.size()); + clear(); + } +} + +void kernel_cache::clear() { + LOG_DEBUG("clear kernel cache: size: ", cache.size()); + for (auto& key_value : cache) { + ZE_CALL(zeKernelDestroy, (key_value.second)); + } + cache.clear(); +} + +void kernel_cache::get(ze_module_handle_t module, + const std::string& kernel_name, + ze_kernel_handle_t* kernel) { + CCL_THROW_IF_NOT(module); + CCL_THROW_IF_NOT(!kernel_name.empty()); + CCL_THROW_IF_NOT(kernel); + if (!get_from_cache(cache, *kernel, module, kernel_name)) { + create_kernel(module, kernel_name, kernel); + } +} + +void kernel_cache::push(ze_module_handle_t module, + const std::string& kernel_name, + ze_kernel_handle_t kernel) { + CCL_THROW_IF_NOT(module); + CCL_THROW_IF_NOT(!kernel_name.empty()); + CCL_THROW_IF_NOT(kernel); + if (!push_to_cache(cache, kernel, module, kernel_name)) { + ZE_CALL(zeKernelDestroy, (kernel)); + } +} + +// list_cache +list_cache::~list_cache() { + if (!cache.empty()) { + LOG_WARN("list cache is not empty, size: ", cache.size()); + clear(); + } +} + +void list_cache::clear() { + LOG_DEBUG("clear list cache: size: ", cache.size()); + for (auto& key_value : cache) { + ZE_CALL(zeCommandListDestroy, (key_value.second)); + } + cache.clear(); +} + +void list_cache::get(ze_context_handle_t context, + ze_device_handle_t device, + const ze_command_list_desc_t& list_desc, + ze_command_list_handle_t* list) { + CCL_THROW_IF_NOT(context); + CCL_THROW_IF_NOT(device); + CCL_THROW_IF_NOT(list); + if (get_from_cache( + cache, *list, context, device, list_desc.commandQueueGroupOrdinal, list_desc.flags)) { + ZE_CALL(zeCommandListReset, (*list)); + } + else { + ZE_CALL(zeCommandListCreate, (context, device, &list_desc, list)); + } +} + +void list_cache::push(ze_context_handle_t context, + ze_device_handle_t device, + const ze_command_list_desc_t& list_desc, + ze_command_list_handle_t list) { + CCL_THROW_IF_NOT(context); + CCL_THROW_IF_NOT(device); + CCL_THROW_IF_NOT(list); + if (!push_to_cache( + cache, list, context, device, list_desc.commandQueueGroupOrdinal, list_desc.flags)) { + ZE_CALL(zeCommandListDestroy, (list)); + } +} + +// queue_cache +queue_cache::~queue_cache() { + if (!cache.empty()) { + LOG_WARN("queue cache is not empty, size: ", cache.size()); + clear(); + } +} + +void queue_cache::clear() { + LOG_DEBUG("clear queue cache: size: ", cache.size()); + for (auto& key_value : cache) { + ZE_CALL(zeCommandQueueDestroy, (key_value.second)); + } + cache.clear(); +} + +void queue_cache::get(ze_context_handle_t context, + ze_device_handle_t device, + const ze_command_queue_desc_t& queue_desc, + ze_command_queue_handle_t* queue) { + CCL_THROW_IF_NOT(context); + CCL_THROW_IF_NOT(device); + CCL_THROW_IF_NOT(queue); + if (!get_from_cache(cache, + *queue, + context, + device, + queue_desc.index, + queue_desc.ordinal, + queue_desc.flags, + queue_desc.mode, + queue_desc.priority)) { + ZE_CALL(zeCommandQueueCreate, (context, device, &queue_desc, queue)); + } +} + +void queue_cache::push(ze_context_handle_t context, + ze_device_handle_t device, + const ze_command_queue_desc_t& queue_desc, + ze_command_queue_handle_t queue) { + CCL_THROW_IF_NOT(context); + CCL_THROW_IF_NOT(device); + CCL_THROW_IF_NOT(queue); + if (!push_to_cache(cache, + queue, + context, + device, + queue_desc.index, + queue_desc.ordinal, + queue_desc.flags, + queue_desc.mode, + queue_desc.priority)) { + ZE_CALL(zeCommandQueueDestroy, (queue)); + } +} + +// event_pool_cache +event_pool_cache::~event_pool_cache() { + if (!cache.empty()) { + LOG_WARN("event pool cache is not empty, size: ", cache.size()); + clear(); + } +} + +void event_pool_cache::clear() { + LOG_DEBUG("clear event pool cache: size: ", cache.size()); + for (auto& key_value : cache) { + ZE_CALL(zeEventPoolDestroy, (key_value.second)); + } + cache.clear(); +} + +void event_pool_cache::get(ze_context_handle_t context, + const ze_event_pool_desc_t& pool_desc, + ze_event_pool_handle_t* event_pool) { + CCL_THROW_IF_NOT(context); + CCL_THROW_IF_NOT(event_pool); + // TODO: we can potentially use pool with count >= pool_desc.count + if (!get_from_cache(cache, *event_pool, context, pool_desc.flags, pool_desc.count)) { + ZE_CALL(zeEventPoolCreate, (context, &pool_desc, 0, nullptr, event_pool)); + } +} + +void event_pool_cache::push(ze_context_handle_t context, + const ze_event_pool_desc_t& pool_desc, + ze_event_pool_handle_t event_pool) { + CCL_THROW_IF_NOT(context); + CCL_THROW_IF_NOT(event_pool); + if (!push_to_cache(cache, event_pool, context, pool_desc.flags, pool_desc.count)) { + ZE_CALL(zeEventPoolDestroy, (event_pool)); + } +} + +// device_mem_cache +device_mem_cache::~device_mem_cache() { + if (!cache.empty()) { + LOG_WARN("device memory cache is not empty, size: ", cache.size()); + clear(); + } +} + +void device_mem_cache::clear() { + LOG_DEBUG("clear device memory cache: size: ", cache.size()); + //for (auto& key_value : cache) { + // TODO: there is a segfault on this call, when ~cache is invoked w/ or w/0 api cache. + // But it passes, when CCL_KERNEL_CACHE=0 (calls of zeMemAllocDevice and ZeMemFree happen on every iteration). + // We don't control destroying phase and may be key_value.second (mem_ptr) is already away to free? + // ZE_CALL(zeMemFree, (std::get<0>(key_value.first), key_value.second)) + //} + cache.clear(); +} + +void device_mem_cache::get(ze_context_handle_t context, + ze_device_handle_t device, + const ze_device_mem_alloc_desc_t& device_mem_alloc_desc, + size_t bytes, + size_t alignment, + void** pptr) { + CCL_THROW_IF_NOT(context); + CCL_THROW_IF_NOT(device); + CCL_THROW_IF_NOT(pptr); + if (!get_from_cache(cache, + *pptr, + context, + device, + bytes, + device_mem_alloc_desc.flags, + device_mem_alloc_desc.ordinal)) { + ZE_CALL(zeMemAllocDevice, + (context, &device_mem_alloc_desc, bytes, alignment, device, pptr)); + } +} + +void device_mem_cache::push(ze_context_handle_t context, + ze_device_handle_t device, + const ze_device_mem_alloc_desc_t& device_mem_alloc_desc, + size_t bytes, + size_t alignment, + void* ptr) { + CCL_THROW_IF_NOT(context); + CCL_THROW_IF_NOT(device); + CCL_THROW_IF_NOT(ptr); + if (!push_to_cache(cache, + ptr, + context, + device, + bytes, + device_mem_alloc_desc.flags, + device_mem_alloc_desc.ordinal)) { + ZE_CALL(zeMemFree, (context, ptr)); + } +} + +// module_cache +module_cache::~module_cache() { + if (!cache.empty()) { + LOG_WARN("module cache is not empty, size: ", cache.size()); + clear(); + } +} + +void module_cache::clear() { + LOG_DEBUG("clear module cache: size: ", cache.size()); + std::lock_guard lock(mutex); + for (auto& key_value : cache) { + ZE_CALL(zeModuleDestroy, (key_value.second)); + } + cache.clear(); +} + +void module_cache::get(ze_context_handle_t context, + ze_device_handle_t device, + const std::string& spv_name, + ze_module_handle_t* module) { + CCL_THROW_IF_NOT(context); + CCL_THROW_IF_NOT(device); + CCL_THROW_IF_NOT(!spv_name.empty()); + CCL_THROW_IF_NOT(module); + std::lock_guard lock(mutex); + key_t key(device, spv_name); + auto key_value = cache.find(key); + if (key_value != cache.end()) { + *module = key_value->second; + LOG_DEBUG("loaded from cache: module: ", *module); + } + else { + load(context, device, spv_name, module); + cache.insert({ std::move(key), *module }); + LOG_DEBUG("inserted to cache: module: ", *module); + } +} + +void module_cache::load(ze_context_handle_t context, + ze_device_handle_t device, + const std::string& spv_name, + ze_module_handle_t* module) { + CCL_THROW_IF_NOT(context); + CCL_THROW_IF_NOT(device); + CCL_THROW_IF_NOT(!spv_name.empty()); + CCL_THROW_IF_NOT(module); + std::string modules_dir = global_data::env().kernel_path; + // TODO: remove + if (modules_dir.empty()) { + std::string ccl_root = getenv("CCL_ROOT"); + CCL_THROW_IF_NOT(!ccl_root.empty(), "incorrect comm kernels path, CCL_ROOT not found!"); + modules_dir = ccl_root + "/lib/kernels/"; + } + load_module(modules_dir, spv_name, device, context, module); +} + +// cache +cache::~cache() { + for (size_t i = 0; i < instance_count; ++i) { + fences[i].clear(); + kernels[i].clear(); + lists[i].clear(); + queues[i].clear(); + event_pools[i].clear(); + device_mems[i].clear(); + } + + modules.clear(); +} + +} // namespace ze +} // namespace ccl diff --git a/src/sched/entry/gpu/ze_cache.hpp b/src/sched/entry/gpu/ze_cache.hpp new file mode 100644 index 000000000..1ee76ddb9 --- /dev/null +++ b/src/sched/entry/gpu/ze_cache.hpp @@ -0,0 +1,323 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "common/utils/hash.hpp" +#include "sched/entry/gpu/ze_primitives.hpp" + +#include + +namespace ccl { +namespace ze { + +class fence_cache { +public: + fence_cache() = default; + ~fence_cache(); + + void clear(); + + void get(ze_command_queue_handle_t queue, + const ze_fence_desc_t& fence_desc, + ze_fence_handle_t* fence); + void push(ze_command_queue_handle_t queue, + const ze_fence_desc_t& fence_desc, + ze_fence_handle_t fence); + +private: + using key_t = typename std::tuple; + using value_t = ze_fence_handle_t; + std::unordered_multimap cache; +}; + +class kernel_cache { +public: + kernel_cache() = default; + ~kernel_cache(); + + void clear(); + + void get(ze_module_handle_t module, const std::string& kernel_name, ze_kernel_handle_t* kernel); + void push(ze_module_handle_t module, const std::string& kernel_name, ze_kernel_handle_t kernel); + +private: + using key_t = typename std::tuple; + using value_t = ze_kernel_handle_t; + std::unordered_multimap cache; +}; + +// TODO: need to improve with ability to save list with commands for specific algo +class list_cache { +public: + list_cache() = default; + ~list_cache(); + + void clear(); + + void get(ze_context_handle_t context, + ze_device_handle_t device, + const ze_command_list_desc_t& list_desc, + ze_command_list_handle_t* list); + void push(ze_context_handle_t context, + ze_device_handle_t device, + const ze_command_list_desc_t& list_desc, + ze_command_list_handle_t list); + +private: + using key_t = typename std:: + tuple; + using value_t = ze_command_list_handle_t; + std::unordered_multimap cache; +}; + +class queue_cache { +public: + queue_cache() = default; + ~queue_cache(); + + void clear(); + + void get(ze_context_handle_t context, + ze_device_handle_t device, + const ze_command_queue_desc_t& queue_desc, + ze_command_queue_handle_t* queue); + void push(ze_context_handle_t context, + ze_device_handle_t device, + const ze_command_queue_desc_t& queue_desc, + ze_command_queue_handle_t queue); + +private: + using key_t = typename std::tuple; + using value_t = ze_command_queue_handle_t; + std::unordered_multimap cache; +}; + +class event_pool_cache { +public: + event_pool_cache() = default; + ~event_pool_cache(); + + void clear(); + + void get(ze_context_handle_t context, + const ze_event_pool_desc_t& pool_desc, + ze_event_pool_handle_t* event_pool); + + void push(ze_context_handle_t context, + const ze_event_pool_desc_t& pool_desc, + ze_event_pool_handle_t event_pool); + +private: + using key_t = typename std::tuple; + using value_t = ze_event_pool_handle_t; + std::unordered_multimap cache; +}; + +class device_mem_cache { +public: + device_mem_cache() = default; + ~device_mem_cache(); + + void clear(); + + void get(ze_context_handle_t context, + ze_device_handle_t device, + const ze_device_mem_alloc_desc_t& device_mem_alloc_desc, + size_t bytes, + size_t alignment, + void** pptr); + + void push(ze_context_handle_t context, + ze_device_handle_t device, + const ze_device_mem_alloc_desc_t& device_mem_alloc_desc, + size_t bytes, + size_t alignment, + void* ptr); + +private: + using key_t = typename std::tuple; + using value_t = void*; + std::unordered_multimap cache; +}; + +class module_cache { +public: + module_cache() = default; + ~module_cache(); + + void clear(); + + void get(ze_context_handle_t context, + ze_device_handle_t device, + const std::string& spv_name, + ze_module_handle_t* module); + +private: + using key_t = typename std::tuple; + using value_t = ze_module_handle_t; + std::unordered_multimap cache; + std::mutex mutex; + + void load(ze_context_handle_t context, + ze_device_handle_t device, + const std::string& spv_name, + ze_module_handle_t* module); +}; + +class cache { +public: + cache(size_t instance_count) + : instance_count(instance_count), + fences(instance_count), + kernels(instance_count), + lists(instance_count), + queues(instance_count), + event_pools(instance_count), + device_mems(instance_count) { + LOG_DEBUG("create cache with ", instance_count, " instances"); + } + cache(const cache&) = delete; + cache& operator=(const cache&) = delete; + ~cache(); + + /* get */ + void get(size_t instance_idx, + ze_command_queue_handle_t queue, + const ze_fence_desc_t& fence_desc, + ze_fence_handle_t* fence) { + fences.at(instance_idx).get(queue, fence_desc, fence); + } + + void get(size_t instance_idx, + ze_module_handle_t module, + const std::string& kernel_name, + ze_kernel_handle_t* kernel) { + kernels.at(instance_idx).get(module, kernel_name, kernel); + } + + void get(size_t instance_idx, + ze_context_handle_t context, + ze_device_handle_t device, + const ze_command_list_desc_t& list_desc, + ze_command_list_handle_t* list) { + lists.at(instance_idx).get(context, device, list_desc, list); + } + + void get(size_t instance_idx, + ze_context_handle_t context, + ze_device_handle_t device, + const ze_command_queue_desc_t& queue_desc, + ze_command_queue_handle_t* queue) { + queues.at(instance_idx).get(context, device, queue_desc, queue); + } + + void get(size_t instance_idx, + ze_context_handle_t context, + const ze_event_pool_desc_t& pool_desc, + ze_event_pool_handle_t* event_pool) { + event_pools.at(instance_idx).get(context, pool_desc, event_pool); + } + + void get(size_t instance_idx, + ze_context_handle_t context, + ze_device_handle_t device, + const ze_device_mem_alloc_desc_t& device_mem_alloc_desc, + size_t bytes, + size_t alignment, + void** pptr) { + device_mems.at(instance_idx) + .get(context, device, device_mem_alloc_desc, bytes, alignment, pptr); + } + + void get(ze_context_handle_t context, + ze_device_handle_t device, + const std::string& spv_name, + ze_module_handle_t* module) { + modules.get(context, device, spv_name, module); + } + + /* push */ + void push(size_t instance_idx, + ze_command_queue_handle_t queue, + const ze_fence_desc_t& fence_desc, + ze_fence_handle_t fence) { + fences.at(instance_idx).push(queue, fence_desc, fence); + } + + void push(size_t instance_idx, + ze_module_handle_t module, + const std::string& kernel_name, + ze_kernel_handle_t kernel) { + kernels.at(instance_idx).push(module, kernel_name, kernel); + } + + void push(size_t instance_idx, + ze_context_handle_t context, + ze_device_handle_t device, + const ze_command_list_desc_t& list_desc, + ze_command_list_handle_t list) { + lists.at(instance_idx).push(context, device, list_desc, list); + } + + void push(size_t instance_idx, + ze_context_handle_t context, + ze_device_handle_t device, + const ze_command_queue_desc_t& queue_desc, + ze_command_queue_handle_t queue) { + queues.at(instance_idx).push(context, device, queue_desc, queue); + } + + void push(size_t instance_idx, + ze_context_handle_t context, + const ze_event_pool_desc_t& pool_desc, + ze_event_pool_handle_t event_pool) { + event_pools.at(instance_idx).push(context, pool_desc, event_pool); + } + + void push(size_t instance_idx, + ze_context_handle_t context, + ze_device_handle_t device, + const ze_device_mem_alloc_desc_t& device_mem_alloc_desc, + size_t bytes, + size_t alignment, + void* ptr) { + device_mems.at(instance_idx) + .push(context, device, device_mem_alloc_desc, bytes, alignment, ptr); + } + +private: + const size_t instance_count; + std::vector fences; + std::vector kernels; + std::vector lists; + std::vector queues; + std::vector event_pools; + std::vector device_mems; + module_cache modules{}; +}; + +} // namespace ze +} // namespace ccl diff --git a/src/sched/entry/gpu/ze_call.cpp b/src/sched/entry/gpu/ze_call.cpp new file mode 100644 index 000000000..a1876b451 --- /dev/null +++ b/src/sched/entry/gpu/ze_call.cpp @@ -0,0 +1,64 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "common/global/global.hpp" +#include "common/log/log.hpp" +#include "sched/entry/gpu/ze_call.hpp" + +namespace ccl { +namespace ze { + +std::mutex ze_call::mutex; + +ze_call::ze_call() { + if ((global_data::env().ze_serialize_mode & ze_call::serialize_mode::lock) != 0) { + LOG_DEBUG("ze call is locked"); + mutex.lock(); + } +} + +ze_call::~ze_call() { + if ((global_data::env().ze_serialize_mode & ze_call::serialize_mode::lock) != 0) { + LOG_DEBUG("ze call is unlocked"); + mutex.unlock(); + } +} + +ze_result_t ze_call::do_call(ze_result_t ze_result, const char* ze_name) const { + if (ze_result != ZE_RESULT_SUCCESS) { + CCL_THROW("ze error at ", ze_name, ", code: ", to_string(ze_result)); + } + LOG_DEBUG("call ze function: ", ze_name); + return ze_result; +} + +// provides different level zero synchronize methods +template <> +ze_result_t zeHostSynchronize(ze_event_handle_t handle) { + return zeHostSynchronizeImpl(zeEventHostSynchronize, handle); +} + +template <> +ze_result_t zeHostSynchronize(ze_command_queue_handle_t handle) { + return zeHostSynchronizeImpl(zeCommandQueueSynchronize, handle); +} + +template <> +ze_result_t zeHostSynchronize(ze_fence_handle_t handle) { + return zeHostSynchronizeImpl(zeFenceHostSynchronize, handle); +} + +} // namespace ze +} // namespace ccl diff --git a/src/sched/entry/gpu/ze_call.hpp b/src/sched/entry/gpu/ze_call.hpp new file mode 100644 index 000000000..fdce2af9f --- /dev/null +++ b/src/sched/entry/gpu/ze_call.hpp @@ -0,0 +1,51 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include + +namespace ccl { +namespace ze { + +// class provides the serialization of level zero calls +class ze_call { +public: + // rule level zero calls serialization + enum serialize_mode : int { + none, // no locking or blocking + lock, // locking around each ZE_CALL + block, // blocking ZE calls + }; + + ze_call(); + ~ze_call(); + ze_result_t do_call(ze_result_t ze_result, const char* ze_name) const; + +private: + // mutex that is used for total serialization + static std::mutex mutex; +}; + +//host synchronize primitives +template +ze_result_t zeHostSynchronize(T handle); +template +ze_result_t zeHostSynchronizeImpl(Func sync_func, T handle) { + return sync_func(handle, std::numeric_limits::max()); +} + +} // namespace ze +} // namespace ccl diff --git a/src/sched/entry/gpu/ze_copy_entry.cpp b/src/sched/entry/gpu/ze_copy_entry.cpp new file mode 100644 index 000000000..ce11d0f59 --- /dev/null +++ b/src/sched/entry/gpu/ze_copy_entry.cpp @@ -0,0 +1,99 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "sched/entry/gpu/ze_copy_entry.hpp" + +#include + +using namespace ccl; + +ze_copy_entry::ze_copy_entry(ccl_sched* sched, + ccl_buffer in_buf, + ccl_buffer out_buf, + size_t count, + const ccl_datatype& dtype, + copy_attr attr) + : ze_base_entry(sched), + sched(sched), + in_buf(in_buf), + out_buf(out_buf), + dtype(dtype), + attr(attr), + buf_size_bytes(dtype.size() * count) { + CCL_THROW_IF_NOT(sched, "no sched"); +} + +ze_copy_entry::~ze_copy_entry() { + finalize(); +} + +void ze_copy_entry::init() { + if (ze_base_entry::is_initialized) { + return; + } + + LOG_DEBUG("initialization"); + + ze_base_entry::init(init_mode::copy); + + if (attr.peer_rank != ccl_comm::invalid_rank) { + if (!out_buf) { + sched->get_memory().handle_manager.get( + attr.peer_rank, attr.peer_buf_idx, out_buf, attr.map_comm); + } + + if (!in_buf) { + sched->get_memory().handle_manager.get( + attr.peer_rank, attr.peer_buf_idx, in_buf, attr.map_comm); + } + } + + void* dst = out_buf.get_ptr(); + void* src = static_cast(in_buf.get_ptr()) + attr.in_buf_offset * dtype.size(); + ze_command_list_handle_t list = ze_base_entry::get_copy_list(); + + ZE_CALL(zeCommandListAppendMemoryCopy, + (list, dst, src, buf_size_bytes, ze_base_entry::entry_event, 0, nullptr)); + ZE_CALL(zeCommandListClose, (list)); + + LOG_DEBUG("initialization complete"); +} + +void ze_copy_entry::start() { + init(); + + ze_base_entry::start(); + + status = ccl_sched_entry_status_started; +} + +void ze_copy_entry::update() { + ze_base_entry::update(); + if (status == ccl_sched_entry_status_complete && !sched->coll_attr.to_cache) { + finalize(); + } +} + +void ze_copy_entry::finalize() { + if (!ze_base_entry::is_initialized) { + return; + } + + LOG_DEBUG("finalization"); + + ze_base_entry::finalize(); + + LOG_DEBUG("finalization complete"); +} diff --git a/src/sched/entry/gpu/ze_copy_entry.hpp b/src/sched/entry/gpu/ze_copy_entry.hpp new file mode 100644 index 000000000..6052b3875 --- /dev/null +++ b/src/sched/entry/gpu/ze_copy_entry.hpp @@ -0,0 +1,56 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "sched/entry/copy/copy_helper.hpp" +#include "sched/entry/entry.hpp" +#include "sched/sched.hpp" + +#include "sched/entry/gpu/ze_base_entry.hpp" + +struct copy_attr; + +class ze_copy_entry : public ze_base_entry { +public: + static constexpr const char* class_name() noexcept { + return "ZE_COPY"; + } + + const char* name() const override { + return class_name(); + } + + explicit ze_copy_entry(ccl_sched* sched, + ccl_buffer in_buf, + ccl_buffer out_buf, + size_t count, + const ccl_datatype& dtype, + copy_attr attr = {}); + ~ze_copy_entry(); + + void init(); + void start() override; + void update() override; + void finalize(); + +private: + ccl_sched* const sched; + ccl_buffer in_buf{}; + ccl_buffer out_buf{}; + const ccl_datatype& dtype; + const copy_attr attr; + const size_t buf_size_bytes; +}; diff --git a/src/sched/entry/gpu/ze_event_signal_entry.cpp b/src/sched/entry/gpu/ze_event_signal_entry.cpp new file mode 100644 index 000000000..1ff9c3279 --- /dev/null +++ b/src/sched/entry/gpu/ze_event_signal_entry.cpp @@ -0,0 +1,40 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "sched/entry/gpu/ze_event_signal_entry.hpp" +#include "sched/queue/queue.hpp" +#include "common/utils/sycl_utils.hpp" + +ze_event_signal_entry::ze_event_signal_entry(ccl_sched* sched, ccl_master_sched* master_sched) + : sched_entry(sched), + master_sched(master_sched) { + CCL_THROW_IF_NOT(sched, "no sched"); + CCL_THROW_IF_NOT(master_sched, "no master_sched"); +} + +void ze_event_signal_entry::start() { + LOG_DEBUG("signal event: ", master_sched->get_memory().sync_event); + ZE_CALL(zeEventHostSignal, (master_sched->get_memory().sync_event)); + + status = ccl_sched_entry_status_started; +} + +void ze_event_signal_entry::update() { + if (ccl::utils::is_sycl_event_completed(master_sched->get_native_event()) && + ccl::utils::is_sycl_event_completed(master_sched->get_sync_event())) { + LOG_DEBUG("native and sync events are completed"); + status = ccl_sched_entry_status_complete; + } +} diff --git a/src/sched/entry/gpu/ze_event_signal_entry.hpp b/src/sched/entry/gpu/ze_event_signal_entry.hpp new file mode 100644 index 000000000..4b9b2c4f4 --- /dev/null +++ b/src/sched/entry/gpu/ze_event_signal_entry.hpp @@ -0,0 +1,45 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "sched/entry/entry.hpp" +#include "sched/master_sched.hpp" +#include "sched/sched.hpp" + +class ze_event_signal_entry : public sched_entry { +public: + static constexpr const char* class_name() noexcept { + return "ZE_EVENT_SIGNAL"; + } + + const char* name() const override { + return class_name(); + } + + bool is_strict_order_satisfied() override { + return (status >= ccl_sched_entry_status_complete); + } + + ze_event_signal_entry() = delete; + explicit ze_event_signal_entry(ccl_sched* sched, ccl_master_sched* master_sched); + ze_event_signal_entry(const ze_event_signal_entry&) = delete; + + void start() override; + void update() override; + +private: + ccl_master_sched* const master_sched; +}; diff --git a/src/sched/entry/gpu/ze_event_wait_entry.cpp b/src/sched/entry/gpu/ze_event_wait_entry.cpp new file mode 100644 index 000000000..c317ed81d --- /dev/null +++ b/src/sched/entry/gpu/ze_event_wait_entry.cpp @@ -0,0 +1,50 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "sched/entry/gpu/ze_event_wait_entry.hpp" + +#include + +ze_event_wait_entry::ze_event_wait_entry(ccl_sched* sched, ze_event_handle_t event) + : sched_entry(sched), + event(event) { + CCL_THROW_IF_NOT(sched, "no sched"); + CCL_THROW_IF_NOT(event, "no event"); +} + +void ze_event_wait_entry::check_event_status() { + auto query_status = zeEventQueryStatus(event); + if (query_status == ZE_RESULT_SUCCESS) { + LOG_DEBUG("event complete"); + status = ccl_sched_entry_status_complete; + } + else if (query_status == ZE_RESULT_NOT_READY) { + // just return in case if the kernel is not ready yet, will check again on the next iteration + return; + } + else { + CCL_THROW("error at zeEventQueryStatus"); + } +} + +void ze_event_wait_entry::start() { + LOG_DEBUG("start event waiting"); + status = ccl_sched_entry_status_started; + check_event_status(); +} + +void ze_event_wait_entry::update() { + check_event_status(); +} diff --git a/src/sched/entry/gpu/ze_event_wait_entry.hpp b/src/sched/entry/gpu/ze_event_wait_entry.hpp new file mode 100644 index 000000000..23cbb0092 --- /dev/null +++ b/src/sched/entry/gpu/ze_event_wait_entry.hpp @@ -0,0 +1,44 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "sched/entry/entry.hpp" +#include "sched/sched.hpp" + +class ze_event_wait_entry : public sched_entry { +public: + static constexpr const char* class_name() noexcept { + return "ZE_EVENT_WAIT"; + } + + const char* name() const override { + return class_name(); + } + + bool is_strict_order_satisfied() override { + return (status >= ccl_sched_entry_status_complete); + } + + explicit ze_event_wait_entry(ccl_sched* sched, ze_event_handle_t event); + + void start() override; + void update() override; + +private: + const ze_event_handle_t event; + + void check_event_status(); +}; diff --git a/src/sched/entry/gpu/ze_handle_exchange_entry.cpp b/src/sched/entry/gpu/ze_handle_exchange_entry.cpp new file mode 100644 index 000000000..6911108ca --- /dev/null +++ b/src/sched/entry/gpu/ze_handle_exchange_entry.cpp @@ -0,0 +1,486 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "sched/entry/gpu/ze_handle_exchange_entry.hpp" +#include "sched/queue/queue.hpp" +#include "sched/ze_handle_manager.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void cast_pool_to_mem_handle(ze_ipc_mem_handle_t* mem, + const ze_ipc_event_pool_handle_t* pool) { + static_assert(sizeof(ze_ipc_mem_handle_t) == sizeof(ze_ipc_event_pool_handle_t)); + memcpy(mem, pool, sizeof(*pool)); +} + +ze_handle_exchange_entry::ze_handle_exchange_entry(ccl_sched* sched, + ccl_comm* comm, + const std::vector& in_buffers, + int skip_rank) + : sched_entry(sched), + comm(comm), + in_buffers(in_buffers), + rank(comm->rank()), + comm_size(comm->size()), + skip_rank(skip_rank) { + LOG_DEBUG("initialization"); + CCL_THROW_IF_NOT(sched, "no sched"); + CCL_THROW_IF_NOT(!in_buffers.empty(), "in_buffers should be non empty"); + + poll_fds.reserve(max_pfds); + + handles.resize(comm_size); + for (auto& buffers : handles) { + buffers.resize(in_buffers.size()); + } + LOG_DEBUG("handles size: ", handles.size(), ", in_buffers size: ", in_buffers.size()); + + for (size_t buf_idx = 0; buf_idx < in_buffers.size(); buf_idx++) { + auto mem_ptr = in_buffers[buf_idx].first; + CCL_THROW_IF_NOT(mem_ptr, "memory pointer is nullptr"); + auto mem_type = in_buffers[buf_idx].second; + mem_info_t mem_info{}; + + ze_ipc_mem_handle_t handle{}; + if (rank != skip_rank) { + if (mem_type == ccl::ze::ipc_mem_type::memory) { + // zeMemGetIpcHandle requires the provided pointer to be the base of an allocation. + // We handle this the following way: for an input buffer retrieve its base pointer + // and the offset from this base ptr. The base ptr is used for zeMemGetIpcHandle + // and the offset is sent to the other rank. On that rank the base ptr is retrieved + // and offsetted to get the actual input buffer ptr. + mem_info = get_mem_info(mem_ptr); + sched->get_memory().handle_manager.get_handle(mem_info.first, &handle); + } + else if (mem_type == ccl::ze::ipc_mem_type::pool) { + ze_ipc_event_pool_handle_t pool_handle; + sched->get_memory().handle_manager.get_handle( + static_cast(mem_ptr), &pool_handle); + // since ze_ipc_event_pool_handle_t and ze_ipc_mem_handle_t are similar, + // we cast ze_ipc_event_pool_handle_t to ze_ipc_mem_handle_t, but + // maybe this is not the most correct way + cast_pool_to_mem_handle(&handle, &pool_handle); + } + else { + CCL_THROW("unknown memory type"); + } + } + + handles[rank][buf_idx] = { handle, mem_info.second, mem_type }; + LOG_DEBUG("set IPC handle: { rank: ", + rank, + ", buf_idx: ", + buf_idx, + ", mem_type: ", + to_string(mem_type), + " }"); + } + + std::string unique_tag = std::to_string(sched->get_comm_id()) + "-" + + std::to_string(sched->sched_id) + "-" + + std::to_string(sched->get_op_id()); + right_peer_socket_name = + "/tmp/ccl-handle-" + std::to_string((rank + 1) % comm_size) + "-" + unique_tag; + left_peer_socket_name = "/tmp/ccl-handle-" + std::to_string(rank) + "-" + unique_tag; + + // This is a temporary workaround around to provide uniqueness of socket files created + // in /tmp folder, otherwise this could result in issues in case of parallel runs + // by a single/multiple users. + // Ideally we should use process pid for this, but right now we don't have this information + // available for all the processes, so use this env variable instead. This works with mpiexec + // only(this is why it's the workaround rather than a complete solution) + static const char* mpi_uuid = getenv("I_MPI_HYDRA_UUID"); + if (mpi_uuid) { + right_peer_socket_name += std::string("-") + mpi_uuid; + left_peer_socket_name += std::string("-") + mpi_uuid; + } + + LOG_DEBUG("initialization complete"); +} + +ze_handle_exchange_entry::~ze_handle_exchange_entry() { + close_sockets(); + unlink_sockets(); +} + +void ze_handle_exchange_entry::start() { + start_buf_idx = start_peer_idx = 0; + skip_first_send = false; + status = ccl_sched_entry_status_started; +} + +void ze_handle_exchange_entry::update() { + if (!is_created) { + // server + left_peer_connect_socket = create_server_socket( + left_peer_socket_name, &left_peer_addr, &left_peer_addr_len, comm_size); + + // client + right_peer_socket = + create_client_socket(right_peer_socket_name, &right_peer_addr, &right_peer_addr_len); + + is_created = true; + } + + if (!is_connected) { + if (connect_call( + right_peer_socket, &right_peer_addr, right_peer_addr_len, right_peer_socket_name)) { + return; + } + is_connected = true; + } + + if (!is_accepted) { + if (accept_call(left_peer_connect_socket, + &left_peer_addr, + &left_peer_addr_len, + left_peer_socket_name, + left_peer_socket)) { + return; + } + + struct pollfd poll_fd {}; + poll_fd.fd = left_peer_socket; + poll_fd.events = POLLIN; + poll_fd.revents = 0; + poll_fds.push_back(poll_fd); + + is_accepted = true; + } + + CCL_THROW_IF_NOT(poll_fds.size() == 1, "unexpected poll_fds size: ", poll_fds.size()); + + for (size_t buf_idx = start_buf_idx; buf_idx < in_buffers.size(); buf_idx++) { + for (int peer_idx = start_peer_idx; peer_idx < comm_size - 1; peer_idx++) { + int peer = (comm_size + rank - 1 - peer_idx) % comm_size; + + if ((peer_idx == 0) && !skip_first_send && (rank != skip_rank)) { + int send_fd = 0; + // send own handle to right peer + get_fd_from_handle(&(handles[rank][buf_idx].handle), &send_fd); + sendmsg_call(right_peer_socket, send_fd, handles[rank][buf_idx].offset); + skip_first_send = true; + } + + if (peer == skip_rank) + continue; + + int poll_ret = poll(&poll_fds[0], poll_fds.size(), timeout_ms); + + if (poll_ret == poll_expire_err_code) { + LOG_DEBUG("poll: timeout is expired"); + return; + } + else if (poll_ret == POLL_ERR) { + CCL_THROW("poll: error: ", strerror(errno), ", ret: ", poll_ret); + } + + CCL_THROW_IF_NOT(poll_ret > 0, "unexpected poll ret: ", poll_ret); + + if (poll_fds[0].revents & POLLIN) { + int recv_fd = 0; + ze_ipc_mem_handle_t tmp_handle{}; + + size_t mem_offset = 0; + // recv data from left peer + recvmsg_call(left_peer_socket, recv_fd, mem_offset); + + // invoke get_handle_from_fd to store the handle + get_handle_from_fd(&recv_fd, &tmp_handle); + + // we don't know anything about the memory type on the other side, + // so we take it from our list. This assumes that the lists of types (exactly types) + // on the sending and receiving side are the same in both value and quantity + auto mem_type = in_buffers[buf_idx].second; + handles[peer][buf_idx] = { tmp_handle, mem_offset, mem_type }; + LOG_DEBUG("get IPC handle: { peer: ", + peer, + ", buf_idx: ", + buf_idx, + ", mem_type: ", + to_string(mem_type), + " }"); + + if (peer_idx < (comm_size - 2)) { + // proxy data to right peer + sendmsg_call(right_peer_socket, recv_fd, mem_offset); + } + start_peer_idx++; + } + else if (poll_fds[0].revents & POLLERR) { + CCL_THROW("poll: POLLERR, buf_idx: ", buf_idx, ", peer_idx ", peer_idx); + } + else if (poll_fds[0].revents & POLLHUP) { + CCL_THROW("poll: POLLHUP, buf_idx: ", buf_idx, ", peer_idx ", peer_idx); + } + else { + LOG_TRACE("poll: nothing to receive, buf_idx: ", buf_idx, ", peer_idx ", peer_idx); + // nothing to receive + // continue with the same buf_idx/peer_idx in the next update() call + return; + } + } + start_peer_idx = 0; + skip_first_send = false; + start_buf_idx++; + } + + LOG_DEBUG("handles size: ", handles.size(), ", in_buffers size: ", in_buffers.size()); + + sched->get_memory().handle_manager.set(handles); + + status = ccl_sched_entry_status_complete; + + LOG_DEBUG("completed: ", name()); +} + +int ze_handle_exchange_entry::create_server_socket(const std::string& socket_name, + struct sockaddr_un* socket_addr, + int* addr_len, + int comm_size) { + int ret = 0; + memset(&(*socket_addr), 0, sizeof((*socket_addr))); + + int sock = socket(AF_UNIX, SOCK_STREAM, 0); + if (sock < 0) { + unlink_sockets(); + CCL_THROW("cannot create a server socket: ", + sock, + ", errno: ", + strerror(errno), + ", socket_name: ", + socket_name); + } + + socket_addr->sun_family = AF_UNIX; + strncpy(socket_addr->sun_path, socket_name.c_str(), sizeof(socket_addr->sun_path) - 1); + socket_addr->sun_path[sizeof(socket_addr->sun_path) - 1] = '\0'; + *addr_len = sizeof((*socket_addr)); + + ret = fcntl(sock, F_SETFL, O_NONBLOCK); + CCL_THROW_IF_NOT( + !ret, "fcntl error: ", ret, ", errno: ", strerror(errno), ", socket_name: ", socket_name); + + unlink(socket_name.c_str()); + + ret = bind(sock, ((struct sockaddr*)&(*socket_addr)), *addr_len); + CCL_THROW_IF_NOT( + !ret, "bind error: ", ret, ", errno: ", strerror(errno), ", socket_name: ", socket_name); + + ret = listen(sock, comm_size); + CCL_THROW_IF_NOT( + !ret, "listen error: ", ret, ", errno: ", strerror(errno), ", socket_name: ", socket_name); + + return sock; +} + +int ze_handle_exchange_entry::create_client_socket(const std::string& socket_name, + struct sockaddr_un* socket_addr, + int* addr_len) { + memset(&(*socket_addr), 0, sizeof(*(socket_addr))); + + int sock = socket(AF_UNIX, SOCK_STREAM, 0); + CCL_THROW_IF_NOT( + sock >= 0, "cannot create a client socket: ", sock, ", errno: ", strerror(errno)); + + socket_addr->sun_family = AF_UNIX; + strncpy(socket_addr->sun_path, socket_name.c_str(), sizeof(socket_addr->sun_path) - 1); + socket_addr->sun_path[sizeof(socket_addr->sun_path) - 1] = '\0'; + *addr_len = sizeof((*socket_addr)); + + return sock; +} + +int ze_handle_exchange_entry::accept_call(int connect_socket, + struct sockaddr_un* socket_addr, + int* addr_len, + const std::string& socket_name, + int& sock) { + sock = accept(connect_socket, ((struct sockaddr*)&(*socket_addr)), ((socklen_t*)&(*addr_len))); + if (sock < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + LOG_TRACE("accept eagain: ", strerror(errno), ", socket_name: ", socket_name); + return errno; + } + + if (errno == EMFILE) { + LOG_TRACE("accept no free fd: ", strerror(errno), ", socket_name: ", socket_name); + return errno; + } + + CCL_THROW( + "accept error: ", strerror(errno), " sock: ", sock, ", socket_name: ", socket_name); + } + + LOG_DEBUG("accept from [", comm->rank(), "] (wait) on: ", socket_name); + return 0; +} + +int ze_handle_exchange_entry::connect_call(int sock, + struct sockaddr_un* socket_addr, + int addr_len, + const std::string& socket_name) { + int ret = connect(sock, ((struct sockaddr*)&(*socket_addr)), addr_len); + if (ret < 0) { + if (errno == ECONNREFUSED || errno == ENOENT) { + return errno; + } + CCL_THROW( + "connect error: ", ret, ", errno: ", strerror(errno), ", socket_name: ", socket_name); + } + + LOG_DEBUG("connect from: [", + comm->rank(), + "] to [", + (comm->rank() - 1 + comm->size()) % comm->size(), + "] with: ", + socket_name); + + return 0; +} + +void ze_handle_exchange_entry::sendmsg_fd(int sock, int fd, size_t mem_offset) { + CCL_THROW_IF_NOT(fd > 0, "unexpected fd value"); + + struct iovec iov {}; + iov.iov_base = &mem_offset; + iov.iov_len = sizeof(size_t); + + char ctrl_buf[CMSG_SPACE(sizeof(fd))]{}; + struct msghdr msg {}; + msg.msg_control = ctrl_buf; + msg.msg_controllen = CMSG_SPACE(sizeof(fd)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + auto cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_len = CMSG_LEN(sizeof(fd)); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + *reinterpret_cast(CMSG_DATA(cmsg)) = fd; + + ssize_t send_bytes = sendmsg(sock, &msg, 0); + CCL_THROW_IF_NOT(send_bytes >= 0, + "sendmsg error: ", + send_bytes, + ", socket: ", + sock, + ", fd: ", + fd, + ", from: ", + comm->rank(), + ", errno: ", + strerror(errno)); +} + +void ze_handle_exchange_entry::recvmsg_fd(int sock, int& fd, size_t& mem_offset) { + size_t buf{}; + struct iovec iov {}; + iov.iov_base = &buf; + iov.iov_len = sizeof(size_t); + + char ctrl_buf[CMSG_SPACE(sizeof(int))]{}; + struct msghdr msg {}; + msg.msg_control = ctrl_buf; + msg.msg_controllen = CMSG_SPACE(sizeof(int)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ssize_t recv_bytes = recvmsg(sock, &msg, 0); + CCL_THROW_IF_NOT(recv_bytes >= 0, + "recvmsg error: ", + recv_bytes, + ", socket: ", + sock, + ", fd: ", + fd, + ", from: ", + comm->rank(), + ", errno: ", + strerror(errno)); + + if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) { + CCL_THROW("control message is truncated"); + } + + for (auto cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (cmsg->cmsg_len == CMSG_LEN(sizeof(int)) && cmsg->cmsg_level == SOL_SOCKET && + cmsg->cmsg_type == SCM_RIGHTS) { + memcpy(&fd, CMSG_DATA(cmsg), sizeof(int)); + break; + } + } + + // we assume that the message has a strict format and size, if not this means that something + // is wrong. + if (msg.msg_iovlen != 1 || msg.msg_iov[0].iov_len != sizeof(size_t)) { + CCL_THROW("received data in unexpected format"); + } + + memcpy(&mem_offset, msg.msg_iov[0].iov_base, sizeof(size_t)); +} + +void ze_handle_exchange_entry::sendmsg_call(int sock, int fd, size_t mem_offset) { + sendmsg_fd(sock, fd, mem_offset); + LOG_DEBUG("send: rank[", + comm->rank(), + "], send fd: ", + fd, + ", sock: ", + sock, + ", mem_offset: ", + mem_offset); +} + +void ze_handle_exchange_entry::recvmsg_call(int sock, int& fd, size_t& mem_offset) { + recvmsg_fd(sock, fd, mem_offset); + LOG_DEBUG( + "recv: rank[", rank, "], got fd: ", fd, ", sock: ", sock, ", mem_offset: ", mem_offset); +} + +void ze_handle_exchange_entry::get_fd_from_handle(const ze_ipc_mem_handle_t* handle, + int* fd) noexcept { + memcpy(fd, static_cast(handle), sizeof(*fd)); +} + +void ze_handle_exchange_entry::get_handle_from_fd(const int* fd, + ze_ipc_mem_handle_t* handle) noexcept { + memcpy(handle, static_cast(fd), sizeof(*fd)); +} + +ze_handle_exchange_entry::mem_info_t ze_handle_exchange_entry::get_mem_info(const void* ptr) { + void* base_ptr{}; + size_t alloc_size{}; + sched->get_memory().handle_manager.get_address_range(ptr, &base_ptr, &alloc_size); + return { base_ptr, static_cast(ptr) - static_cast(base_ptr) }; +} + +void ze_handle_exchange_entry::unlink_sockets() { + unlink(left_peer_socket_name.c_str()); +} + +void ze_handle_exchange_entry::close_sockets() { + close(left_peer_connect_socket); + close(left_peer_socket); + close(right_peer_socket); +} diff --git a/src/sched/entry/gpu/ze_handle_exchange_entry.hpp b/src/sched/entry/gpu/ze_handle_exchange_entry.hpp new file mode 100644 index 000000000..9fce2662e --- /dev/null +++ b/src/sched/entry/gpu/ze_handle_exchange_entry.hpp @@ -0,0 +1,140 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "common/comm/comm.hpp" +#include "sched/entry/entry.hpp" +#include "sched/entry/gpu/ze_primitives.hpp" +#include "sched/sched.hpp" +#include "sched/ze_handle_manager.hpp" + +#include +#include +#include + +class ze_handle_exchange_entry : public sched_entry { +public: + using mem_desc_t = typename std::pair; + + static constexpr const char* class_name() noexcept { + return "ZE_HANDLES"; + } + + const char* name() const noexcept override { + return class_name(); + } + + ze_handle_exchange_entry() = delete; + explicit ze_handle_exchange_entry(ccl_sched* sched, + ccl_comm* comm, + const std::vector& in_buffers, + int skip_rank = -1); + ~ze_handle_exchange_entry(); + + void start() override; + void update() override; + + bool is_strict_order_satisfied() noexcept override { + return (status >= ccl_sched_entry_status_complete); + } + +protected: + void dump_detail(std::stringstream& str) const override { + ccl_logger::format(str, + "rank ", + rank, + ", comm_size ", + comm_size, + ", right_peer ", + right_peer_socket_name, + ", left_peer ", + left_peer_socket_name, + ", in_buffers size ", + in_buffers.size(), + ", handles size ", + handles.size(), + "\n"); + } + +private: + static constexpr size_t socket_max_str_len = 100; + static constexpr int poll_expire_err_code = 0; + static constexpr int timeout_ms = 1; + static constexpr size_t max_pfds = 1; + + const ccl_comm* comm; + + std::vector in_buffers; + ccl::ze::ipc_handle_manager::mem_handle_map_t handles; + + const int rank; + const int comm_size; + const int skip_rank; + + int start_buf_idx{}; + int start_peer_idx{}; + + std::vector poll_fds; + + int right_peer_socket{}; + int left_peer_socket{}; + int left_peer_connect_socket{}; + + struct sockaddr_un right_peer_addr { + }, left_peer_addr{}; + int right_peer_addr_len{}, left_peer_addr_len{}; + + std::string right_peer_socket_name; + std::string left_peer_socket_name; + + bool is_created{}; + bool is_connected{}; + bool is_accepted{}; + bool skip_first_send{}; + + void get_fd_from_handle(const ze_ipc_mem_handle_t* handle, int* fd) noexcept; + void get_handle_from_fd(const int* fd, ze_ipc_mem_handle_t* handle) noexcept; + + int create_server_socket(const std::string& socket_name, + struct sockaddr_un* socket_addr, + int* addr_len, + int comm_size); + int create_client_socket(const std::string& left_peer_socket_name, + struct sockaddr_un* sockaddr_cli, + int* len); + + int accept_call(int connect_socket, + struct sockaddr_un* socket_addr, + int* addr_len, + const std::string& socket_name, + int& sock); + int connect_call(int sock, + struct sockaddr_un* socket_addr, + int addr_len, + const std::string& socket_name); + + void sendmsg_fd(int sock, int fd, size_t mem_offset); + void recvmsg_fd(int sock, int& fd, size_t& mem_offset); + + void sendmsg_call(int sock, int fd, size_t mem_offset); + void recvmsg_call(int sock, int& fd, size_t& mem_offset); + + using mem_info_t = typename std::pair; + mem_info_t get_mem_info(const void* ptr); + + void unlink_sockets(); + void close_sockets(); +}; diff --git a/src/sched/entry/gpu/ze_primitives.cpp b/src/sched/entry/gpu/ze_primitives.cpp new file mode 100644 index 000000000..89104fd7d --- /dev/null +++ b/src/sched/entry/gpu/ze_primitives.cpp @@ -0,0 +1,361 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include +#include + +#include "common/global/global.hpp" +#include "common/log/log.hpp" +#include "sched/entry/gpu/ze_primitives.hpp" + +namespace ccl { + +namespace ze { + +void load_module(std::string dir, + std::string file_name, + ze_device_handle_t device, + ze_context_handle_t context, + ze_module_handle_t* module) { + LOG_DEBUG("module loading started: directory: ", dir, ", file: ", file_name); + + if (!dir.empty()) { + if (*dir.rbegin() != '/') { + dir += '/'; + } + } + + std::string file_path = dir + file_name; + std::ifstream file(file_path, std::ios_base::in | std::ios_base::binary); + if (!file.good() || dir.empty() || file_name.empty()) { + CCL_THROW("failed to load module: file: ", file_path); + } + + file.seekg(0, file.end); + size_t filesize = file.tellg(); + file.seekg(0, file.beg); + + std::vector module_data(filesize); + file.read(reinterpret_cast(module_data.data()), filesize); + file.close(); + + ze_module_desc_t desc = {}; + ze_module_format_t format = ZE_MODULE_FORMAT_IL_SPIRV; + desc.format = format; + desc.pInputModule = reinterpret_cast(module_data.data()); + desc.inputSize = module_data.size(); + ZE_CALL(zeModuleCreate, (context, device, &desc, module, nullptr)); + LOG_DEBUG("module loading completed: directory: ", dir, ", file: ", file_name); +} + +void create_kernel(ze_module_handle_t module, std::string kernel_name, ze_kernel_handle_t* kernel) { + ze_kernel_desc_t desc = default_kernel_desc; + // convert to lowercase + std::transform(kernel_name.begin(), kernel_name.end(), kernel_name.begin(), ::tolower); + desc.pKernelName = kernel_name.c_str(); + ze_result_t res = zeKernelCreate(module, &desc, kernel); + if (res != ZE_RESULT_SUCCESS) { + CCL_THROW("error at zeKernelCreate: kernel name: ", kernel_name, " ret: ", to_string(res)); + } +} + +void get_suggested_group_size(ze_kernel_handle_t kernel, + size_t count, + ze_group_size_t* group_size) { + CCL_ASSERT(count > 0, "count == 0"); + ZE_CALL(zeKernelSuggestGroupSize, + (kernel, + count, + 1, + 1, + &group_size->groupSizeX, + &group_size->groupSizeY, + &group_size->groupSizeZ)); + CCL_THROW_IF_NOT(group_size->groupSizeX >= 1, + "wrong group size calculation: group size: ", + to_string(*group_size), + ", count: ", + count); +} + +void get_suggested_group_count(const ze_group_size_t& group_size, + size_t count, + ze_group_count_t* group_count) { + group_count->groupCountX = count / group_size.groupSizeX; + group_count->groupCountY = 1; + group_count->groupCountZ = 1; + + auto rem = count % group_size.groupSizeX; + CCL_THROW_IF_NOT(group_count->groupCountX >= 1 && rem == 0, + "wrong group count calculation: group size: ", + to_string(group_size), + ", group count: ", + to_string(*group_count), + ", count: ", + std::to_string(count)); +} + +void set_kernel_args(ze_kernel_handle_t kernel, const ze_kernel_args_t& kernel_args) { + uint32_t idx = 0; + for (const auto& arg : kernel_args) { + auto res = zeKernelSetArgumentValue(kernel, idx, arg.first, arg.second); + if (res != ZE_RESULT_SUCCESS) { + CCL_THROW("zeKernelSetArgumentValue failed with error ", + to_string(res), + " on idx ", + idx, + " with value ", + *((void**)arg.second)); + } + ++idx; + } +} + +void get_num_queue_groups(ze_device_handle_t device, uint32_t* num) { + *num = 0; + ZE_CALL(zeDeviceGetCommandQueueGroupProperties, (device, num, nullptr)); + CCL_THROW_IF_NOT(*num != 0, "no queue groups found"); +} + +void get_queues_properties(ze_device_handle_t device, + uint32_t num_queue_groups, + ze_queue_properties_t* props) { + props->resize(num_queue_groups); + ZE_CALL(zeDeviceGetCommandQueueGroupProperties, (device, &num_queue_groups, props->data())); +} + +void get_comp_queue_ordinal(ze_device_handle_t device, + const ze_queue_properties_t& props, + uint32_t* ordinal) { + uint32_t comp_ordinal = std::numeric_limits::max(); + + for (uint32_t idx = 0; idx < props.size(); ++idx) { + if (props[idx].flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE) { + comp_ordinal = idx; + break; + } + } + + LOG_DEBUG("find queue: { ordinal: ", + comp_ordinal, + ", queue properties params: ", + to_string(props[comp_ordinal]), + " }"); + + if (comp_ordinal != std::numeric_limits::max()) { + *ordinal = comp_ordinal; + } + else { + LOG_WARN("could not find queue ordinal, ordinal 0 will be used"); + *ordinal = 0; + } +} + +void get_copy_queue_ordinal(ze_device_handle_t device, + const ze_queue_properties_t& props, + uint32_t* ordinal) { + uint32_t copy_ordinal = std::numeric_limits::max(); + + for (uint32_t idx = 0; idx < props.size(); ++idx) { + /* only compute property */ + if ((props[idx].flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE) && + global_data::env().ze_copy_engine == ccl_ze_copy_engine_none) { + copy_ordinal = idx; + break; + } + + /* only copy property */ + if ((props[idx].flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COPY) && + ((props[idx].flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE) == 0)) { + /* main */ + if (props[idx].numQueues == 1 && + global_data::env().ze_copy_engine == ccl_ze_copy_engine_main) { + copy_ordinal = idx; + break; + } + /* link */ + if (props[idx].numQueues > 1 && + global_data::env().ze_copy_engine == ccl_ze_copy_engine_link) { + copy_ordinal = idx; + break; + } + } + } + + LOG_DEBUG("find copy queue: { ordinal: ", + copy_ordinal, + ", queue properties params: ", + to_string(props[copy_ordinal]), + " }"); + + if (copy_ordinal != std::numeric_limits::max()) { + *ordinal = copy_ordinal; + } + else { + LOG_WARN("could not find queue ordinal for copy engine mode: ", + global_data::env().ze_copy_engine, + ", ordinal 0 will be used"); + *ordinal = 0; + } +} + +void get_queue_index(const ze_queue_properties_t& props, + uint32_t ordinal, + int idx, + uint32_t* index) { + CCL_ASSERT(props.size() > ordinal, "props.size() <= ordinal"); + *index = idx % props[ordinal].numQueues; + LOG_DEBUG("set queue index: ", *index); +} + +std::string to_string(const ze_result_t result) { + switch (result) { + case ZE_RESULT_SUCCESS: return "ZE_RESULT_SUCCESS"; + case ZE_RESULT_NOT_READY: return "ZE_RESULT_NOT_READY"; + case ZE_RESULT_ERROR_DEVICE_LOST: return "ZE_RESULT_ERROR_DEVICE_LOST"; + case ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY: return "ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY"; + case ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY: return "ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY"; + case ZE_RESULT_ERROR_MODULE_BUILD_FAILURE: return "ZE_RESULT_ERROR_MODULE_BUILD_FAILURE"; + case ZE_RESULT_ERROR_MODULE_LINK_FAILURE: return "ZE_RESULT_ERROR_MODULE_LINK_FAILURE"; + case ZE_RESULT_ERROR_INSUFFICIENT_PERMISSIONS: + return "ZE_RESULT_ERROR_INSUFFICIENT_PERMISSIONS"; + case ZE_RESULT_ERROR_NOT_AVAILABLE: return "ZE_RESULT_ERROR_NOT_AVAILABLE"; + case ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE: + return "ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE"; + case ZE_RESULT_ERROR_UNINITIALIZED: return "ZE_RESULT_ERROR_UNINITIALIZED"; + case ZE_RESULT_ERROR_UNSUPPORTED_VERSION: return "ZE_RESULT_ERROR_UNSUPPORTED_VERSION"; + case ZE_RESULT_ERROR_UNSUPPORTED_FEATURE: return "ZE_RESULT_ERROR_UNSUPPORTED_FEATURE"; + case ZE_RESULT_ERROR_INVALID_ARGUMENT: return "ZE_RESULT_ERROR_INVALID_ARGUMENT"; + case ZE_RESULT_ERROR_INVALID_NULL_HANDLE: return "ZE_RESULT_ERROR_INVALID_NULL_HANDLE"; + case ZE_RESULT_ERROR_HANDLE_OBJECT_IN_USE: return "ZE_RESULT_ERROR_HANDLE_OBJECT_IN_USE"; + case ZE_RESULT_ERROR_INVALID_NULL_POINTER: return "ZE_RESULT_ERROR_INVALID_NULL_POINTER"; + case ZE_RESULT_ERROR_INVALID_SIZE: return "ZE_RESULT_ERROR_INVALID_SIZE"; + case ZE_RESULT_ERROR_UNSUPPORTED_SIZE: return "ZE_RESULT_ERROR_UNSUPPORTED_SIZE"; + case ZE_RESULT_ERROR_UNSUPPORTED_ALIGNMENT: return "ZE_RESULT_ERROR_UNSUPPORTED_ALIGNMENT"; + case ZE_RESULT_ERROR_INVALID_SYNCHRONIZATION_OBJECT: + return "ZE_RESULT_ERROR_INVALID_SYNCHRONIZATION_OBJECT"; + case ZE_RESULT_ERROR_INVALID_ENUMERATION: return "ZE_RESULT_ERROR_INVALID_ENUMERATION"; + case ZE_RESULT_ERROR_UNSUPPORTED_ENUMERATION: + return "ZE_RESULT_ERROR_UNSUPPORTED_ENUMERATION"; + case ZE_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT: + return "ZE_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT"; + case ZE_RESULT_ERROR_INVALID_NATIVE_BINARY: return "ZE_RESULT_ERROR_INVALID_NATIVE_BINARY"; + case ZE_RESULT_ERROR_INVALID_GLOBAL_NAME: return "ZE_RESULT_ERROR_INVALID_GLOBAL_NAME"; + case ZE_RESULT_ERROR_INVALID_KERNEL_NAME: return "ZE_RESULT_ERROR_INVALID_KERNEL_NAME"; + case ZE_RESULT_ERROR_INVALID_FUNCTION_NAME: return "ZE_RESULT_ERROR_INVALID_FUNCTION_NAME"; + case ZE_RESULT_ERROR_INVALID_GROUP_SIZE_DIMENSION: + return "ZE_RESULT_ERROR_INVALID_GROUP_SIZE_DIMENSION"; + case ZE_RESULT_ERROR_INVALID_GLOBAL_WIDTH_DIMENSION: + return "ZE_RESULT_ERROR_INVALID_GLOBAL_WIDTH_DIMENSION"; + case ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX: + return "ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX"; + case ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE: + return "ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE"; + case ZE_RESULT_ERROR_INVALID_KERNEL_ATTRIBUTE_VALUE: + return "ZE_RESULT_ERROR_INVALID_KERNEL_ATTRIBUTE_VALUE"; + case ZE_RESULT_ERROR_INVALID_MODULE_UNLINKED: + return "ZE_RESULT_ERROR_INVALID_MODULE_UNLINKED"; + case ZE_RESULT_ERROR_INVALID_COMMAND_LIST_TYPE: + return "ZE_RESULT_ERROR_INVALID_COMMAND_LIST_TYPE"; + case ZE_RESULT_ERROR_OVERLAPPING_REGIONS: return "ZE_RESULT_ERROR_OVERLAPPING_REGIONS"; + case ZE_RESULT_ERROR_UNKNOWN: return "ZE_RESULT_ERROR_UNKNOWN"; + case ZE_RESULT_FORCE_UINT32: return "ZE_RESULT_FORCE_UINT32"; + default: return "unknown ze_result_t value: " + std::to_string(static_cast(result)); + } +} + +std::string to_string(const ze_group_size_t& group_size) { + std::stringstream ss; + ss << "{ x: " << group_size.groupSizeX << ", y: " << group_size.groupSizeY + << ", z: " << group_size.groupSizeZ << " }"; + return ss.str(); +} + +std::string to_string(const ze_group_count_t& group_count) { + std::stringstream ss; + ss << "{ x: " << group_count.groupCountX << ", y: " << group_count.groupCountY + << ", z: " << group_count.groupCountZ << " }"; + return ss.str(); +} + +std::string to_string(const ze_kernel_args_t& kernel_args) { + std::stringstream ss; + ss << "{\n"; + size_t idx = 0; + for (const auto& arg : kernel_args) { + // TODO: can we distinguish argument types in order to properly print them instead of printing + // as a void* ptr? + ss << " idx: " << idx << ", { " << arg.first << ", " << *(void**)arg.second << " }\n"; + ++idx; + } + ss << "}"; + return ss.str(); +} + +std::string to_string(const ze_command_queue_group_property_flag_t& flag) { + switch (flag) { + case ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE: + return "ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE"; + case ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COPY: + return "ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COPY"; + case ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COOPERATIVE_KERNELS: + return "ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COOPERATIVE_KERNELS"; + case ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_METRICS: + return "ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_METRICS"; + case ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_FORCE_UINT32: + return "ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_FORCE_UINT32"; + default: + return "unknown ze_command_queue_group_property_flag_t value: " + + std::to_string(static_cast(flag)); + } +} + +std::string to_string(const ze_command_queue_group_properties_t& queue_property) { + std::stringstream ss; + ss << "stype: " << queue_property.stype << ", pNext: " << (void*)queue_property.pNext + << ", flags: " + << flags_to_string(queue_property.flags) + << ", maxMemoryFillPatternSize: " << queue_property.maxMemoryFillPatternSize + << ", numQueues: " << queue_property.numQueues; + return ss.str(); +} + +std::string join_strings(const std::vector& tokens, const std::string& delimeter) { + std::stringstream ss; + for (size_t i = 0; i < tokens.size(); ++i) { + ss << tokens[i]; + if (i < tokens.size() - 1) { + ss << delimeter; + } + } + return ss.str(); +} + +template +std::string flags_to_string(uint32_t flags) { + const size_t bits = 8; + std::vector output; + for (size_t i = 0; i < sizeof(flags) * bits; ++i) { + const size_t mask = 1UL << i; + const auto flag = flags & mask; + if (flag != 0) { + output.emplace_back(to_string(static_cast(flag))); + } + } + return join_strings(output, " | "); +} + +} // namespace ze +} // namespace ccl diff --git a/src/sched/entry/gpu/ze_primitives.hpp b/src/sched/entry/gpu/ze_primitives.hpp new file mode 100644 index 000000000..b87ed9fd6 --- /dev/null +++ b/src/sched/entry/gpu/ze_primitives.hpp @@ -0,0 +1,157 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "sched/entry/gpu/ze_call.hpp" + +#include +#include +#include +#include + +namespace ccl { + +namespace ze { + +#define ZE_CALL(ze_name, ze_args) ccl::ze::ze_call().do_call(ze_name ze_args, #ze_name) + +enum class init_mode : int { + compute = 1, + copy = 2, +}; + +constexpr ze_context_desc_t default_context_desc = { .stype = ZE_STRUCTURE_TYPE_CONTEXT_DESC, + .pNext = nullptr, + .flags = 0 }; + +constexpr ze_fence_desc_t default_fence_desc = { .stype = ZE_STRUCTURE_TYPE_FENCE_DESC, + .pNext = nullptr, + .flags = 0 }; + +constexpr ze_kernel_desc_t default_kernel_desc = { .stype = ZE_STRUCTURE_TYPE_KERNEL_DESC, + .pNext = nullptr, + .flags = 0, + .pKernelName = nullptr }; + +constexpr ze_command_list_desc_t default_cmd_list_desc = { + .stype = ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC, + .pNext = nullptr, + .commandQueueGroupOrdinal = 0, + .flags = 0, +}; + +constexpr ze_command_queue_desc_t default_cmd_queue_desc = { + .stype = ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC, + .pNext = nullptr, + .ordinal = 0, + .index = 0, + .flags = 0, + .mode = ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS, + .priority = ZE_COMMAND_QUEUE_PRIORITY_NORMAL +}; + +constexpr ze_device_mem_alloc_desc_t default_device_mem_alloc_desc = { + .stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, + .pNext = nullptr, + .flags = 0, + .ordinal = 0 +}; + +constexpr ze_memory_allocation_properties_t default_alloc_props = { + .stype = ZE_STRUCTURE_TYPE_MEMORY_ALLOCATION_PROPERTIES, + .pNext = nullptr, + .type = ZE_MEMORY_TYPE_UNKNOWN +}; + +constexpr ze_device_properties_t default_device_props = { .stype = + ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES, + .pNext = nullptr }; + +constexpr ze_event_pool_desc_t default_event_pool_desc = { .stype = + ZE_STRUCTURE_TYPE_EVENT_POOL_DESC, + .pNext = nullptr, + .flags = 0, + .count = 0 }; + +constexpr ze_event_desc_t default_event_desc = { .stype = ZE_STRUCTURE_TYPE_EVENT_DESC, + .pNext = nullptr, + .index = 0, + .signal = 0, + .wait = 0 }; + +inline init_mode operator|(init_mode mode1, init_mode mode2) { + return static_cast(static_cast(mode1) | static_cast(mode2)); +} + +inline bool operator&(init_mode mode1, init_mode mode2) { + return static_cast(mode1) & static_cast(mode2); +} + +void load_module(std::string dir, + std::string file_name, + ze_device_handle_t device, + ze_context_handle_t context, + ze_module_handle_t* module); +void create_kernel(ze_module_handle_t module, std::string kernel_name, ze_kernel_handle_t* kernel); + +// this structure is just to align with ze_group_count_t +// L0 doesn't have ze_group_size_t +struct ze_group_size_t { + uint32_t groupSizeX = 0; + uint32_t groupSizeY = 0; + uint32_t groupSizeZ = 0; +}; + +void get_suggested_group_size(ze_kernel_handle_t kernel, size_t count, ze_group_size_t* group_size); +void get_suggested_group_count(const ze_group_size_t& group_size, + size_t count, + ze_group_count_t* group_count); + +using ze_kernel_arg_t = std::pair; +using ze_kernel_args_t = std::initializer_list; +void set_kernel_args(ze_kernel_handle_t kernel, const ze_kernel_args_t& kernel_args); + +using ze_queue_properties_t = std::vector; + +void get_num_queue_groups(ze_device_handle_t device, uint32_t* num); +void get_queues_properties(ze_device_handle_t device, + uint32_t num_queue_groups, + ze_queue_properties_t* props); +void get_comp_queue_ordinal(ze_device_handle_t device, + const ze_queue_properties_t& props, + uint32_t* ordinal); +void get_copy_queue_ordinal(ze_device_handle_t device, + const ze_queue_properties_t& props, + uint32_t* ordinal); +void get_queue_index(const ze_queue_properties_t& props, + uint32_t ordinal, + int idx, + uint32_t* index); + +std::string to_string(const ze_result_t result); +std::string to_string(const ze_group_size_t& group_size); +std::string to_string(const ze_group_count_t& group_count); +std::string to_string(const ze_kernel_args_t& kernel_args); +std::string to_string(const ze_command_queue_group_property_flag_t& flag); +std::string to_string(const ze_command_queue_group_properties_t& queue_property); + +std::string join_strings(const std::vector& tokens, const std::string& delimeter); + +template +std::string flags_to_string(uint32_t flags); + +} // namespace ze +} // namespace ccl diff --git a/src/sched/entry/gpu/ze_reduce_entry.cpp b/src/sched/entry/gpu/ze_reduce_entry.cpp new file mode 100644 index 000000000..066442898 --- /dev/null +++ b/src/sched/entry/gpu/ze_reduce_entry.cpp @@ -0,0 +1,256 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "common/comm/l0/modules/kernel_utils.hpp" +#include "common/stream/stream.hpp" +#include "sched/entry/gpu/ze_cache.hpp" +#include "sched/entry/gpu/ze_primitives.hpp" +#include "sched/entry/gpu/ze_reduce_entry.hpp" +#include "sched/queue/queue.hpp" + +#include + +using namespace ccl; +using namespace ccl::ze; + +ze_reduce_entry::ze_reduce_entry(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t cnt, + const ccl_datatype& dtype, + reduction op, + int root, + ccl_comm* comm) + : ze_base_entry(sched, comm, 2 /* request additional events */), + send_buf(send_buf), + recv_buf(recv_buf), + cnt(cnt), + dtype(dtype), + op(op), + root(root), + buf_size_bytes(dtype.size() * cnt), + is_initialized(false), + empty_kernel_event(nullptr), + empty_kernel(nullptr), + empty_kernel_name("empty_kernel") {} + +ze_reduce_entry::~ze_reduce_entry() { + finalize(); +} + +void ze_reduce_entry::init() { + if (is_initialized) { + return; + } + + LOG_DEBUG("initialization"); + + init_mode init_mode_type; + if (global_data::env().enable_kernel_1s_copy_ops) { + init_mode_type = (init_mode::copy | init_mode::compute); + } + else { + init_mode_type = init_mode::compute; + } + + CCL_THROW_IF_NOT(comm_rank == root, "unexpected comm_rank ", comm_rank, ", expected ", root); + + ze_base_entry::init(init_mode_type); + + /* create kernels */ + ccl_buffer right_send_buf; + int peer_rank = (comm_rank + 1) % comm_size; + sched->get_memory().handle_manager.get(peer_rank, 0, right_send_buf, comm); + LOG_DEBUG( + "get IPC pointers from ", peer_rank, " by ", root, ", right_send_buf: ", right_send_buf); + + send_buf_ptr = send_buf.get_ptr(); + recv_buf_ptr = recv_buf.get_ptr(); + // TODO: in place case check! diff idx for handle_mngr + + right_send_buf_ptr = right_send_buf.get_ptr(); + + ze_kernel_args_t reduce_local_kernel_args = { { sizeof(comm_rank), &comm_rank }, + { sizeof(comm_size), &comm_size }, + { sizeof(cnt), &cnt }, + { sizeof(send_buf_ptr), &send_buf_ptr }, + { sizeof(tmp_buf_ptr), &tmp_buf_ptr }, + { sizeof(recv_buf_ptr), &recv_buf_ptr } }; + + ccl::global_data::get().ze_cache->get(context, device, "kernels.spv", &module); + + device_mem_alloc_desc = default_device_mem_alloc_desc; + ccl::global_data::get().ze_cache->get(worker_idx, + context, + device, + device_mem_alloc_desc, + buf_size_bytes, + 0, /*alignment*/ + &tmp_buf_ptr); + + main_kernel_name = + "reduce_local_outofplace_kernel_" + to_string(dtype.idx()) + "_" + ccl_reduction_to_str(op); + LOG_DEBUG("get kernel: name: ", main_kernel_name); + ccl::global_data::get().ze_cache->get(worker_idx, module, main_kernel_name, &main_kernel); + + auto& main_kernel_args = reduce_local_kernel_args; + LOG_DEBUG("kernel ", main_kernel, " args:\n", to_string(main_kernel_args)); + set_kernel_args(main_kernel, main_kernel_args); + + ze_group_size_t group_size; + get_suggested_group_size(main_kernel, cnt, &group_size); + LOG_DEBUG("suggested group size: ", to_string(group_size)); + + get_suggested_group_count(group_size, cnt, &group_count); + LOG_DEBUG("suggested group count: ", to_string(group_count)); + + ZE_CALL(zeKernelSetGroupSize, + (main_kernel, group_size.groupSizeX, group_size.groupSizeY, group_size.groupSizeZ)); + + if (ccl::global_data::env().enable_kernel_1s_ipc_wa) { + LOG_DEBUG("get kernel: name: ", empty_kernel_name); + ccl::global_data::get().ze_cache->get(worker_idx, module, empty_kernel_name, &empty_kernel); + CCL_THROW_IF_NOT(empty_kernel, "null empty_kernel"); + /* use allreduce_kernel_args since they have pointers to peer mem */ + set_kernel_args(empty_kernel, main_kernel_args); + } + + ze_event_desc_t event_desc = default_event_desc; + event_desc.signal = ZE_EVENT_SCOPE_FLAG_SUBDEVICE; + event_desc.wait = ZE_EVENT_SCOPE_FLAG_SUBDEVICE; + + uint32_t last_event_idx = 1; // 0 is used to track entry progress + + if (empty_kernel) { + LOG_DEBUG("create event for empty kernel"); + event_desc.index = last_event_idx++; + ZE_CALL(zeEventCreate, (event_pool, &event_desc, &empty_kernel_event)); + } + + event_desc.index = last_event_idx++; + ZE_CALL(zeEventCreate, (event_pool, &event_desc, ©_from_peer_event)); + + LOG_DEBUG("real event count: ", last_event_idx); + + /* do appends */ + if (empty_kernel) { + LOG_DEBUG("append empty kernel"); + ze_group_count_t empty_group_count = { 1, 1, 1 }; + ZE_CALL(zeCommandListAppendLaunchKernel, + (comp_primitives.list, + empty_kernel, + &empty_group_count, + empty_kernel_event, + 0, + nullptr)); + } + + LOG_DEBUG("one-sided multi-phase algorithm"); + + ZE_CALL(zeCommandListAppendMemoryCopy, + (ze_base_entry::get_copy_list(), + tmp_buf_ptr, + right_send_buf_ptr, + buf_size_bytes, + copy_from_peer_event, + (empty_kernel_event) ? 1 : 0, + &empty_kernel_event)); + + ZE_CALL( + zeCommandListAppendLaunchKernel, + (comp_primitives.list, main_kernel, &group_count, entry_event, 1, ©_from_peer_event)); + + ZE_CALL(zeCommandListClose, (comp_primitives.list)); + if (global_data::env().enable_kernel_1s_copy_ops) { + ZE_CALL(zeCommandListClose, (ze_base_entry::copy_primitives.list)); + } + + is_initialized = true; + + LOG_DEBUG("initialization complete"); +} + +void ze_reduce_entry::start() { + init(); + + if (is_initialized && status == ccl_sched_entry_status_not_started) { + reset_sync_objects(); + } + + size_t kernel_counter = 0; + if (ccl::global_data::env().enable_kernel_sync) { + kernel_counter = ccl::global_data::get().kernel_counter++; + } + + if (kernel_counter == 0) { + ze_base_entry::start(); + status = ccl_sched_entry_status_started; + } + else { + ccl::global_data::get().kernel_counter--; + status = ccl_sched_entry_status_again; + } +} + +void ze_reduce_entry::update() { + ze_base_entry::update(); + if (status == ccl_sched_entry_status_complete && !sched->coll_attr.to_cache) { + finalize(); + } + + if (ccl::global_data::env().enable_kernel_sync && ccl::global_data::get().kernel_counter > 0) { + ccl::global_data::get().kernel_counter--; + } +} + +void ze_reduce_entry::finalize() { + if (!is_initialized) { + return; + } + + LOG_DEBUG("finalization"); + + /* events */ + LOG_DEBUG("copy event finalization"); + ZE_CALL(zeEventDestroy, (copy_from_peer_event)); + /* device mem */ + ccl::global_data::get().ze_cache->push(worker_idx, + context, + device, + device_mem_alloc_desc, + buf_size_bytes, + 0, /*alignment*/ + tmp_buf_ptr); + + /* kernels */ + if (empty_kernel_event) { + ZE_CALL(zeEventDestroy, (empty_kernel_event)); + ccl::global_data::get().ze_cache->push(worker_idx, module, empty_kernel_name, empty_kernel); + } + ccl::global_data::get().ze_cache->push(worker_idx, module, main_kernel_name, main_kernel); + + ze_base_entry::finalize(); + + is_initialized = false; + + LOG_DEBUG("finalization complete"); +} + +void ze_reduce_entry::reset_sync_objects() { + if (empty_kernel_event) { + ZE_CALL(zeEventHostReset, (empty_kernel_event)); + } + ZE_CALL(zeEventHostReset, (copy_from_peer_event)); +} diff --git a/src/sched/entry/gpu/ze_reduce_entry.hpp b/src/sched/entry/gpu/ze_reduce_entry.hpp new file mode 100644 index 000000000..48f45eeb6 --- /dev/null +++ b/src/sched/entry/gpu/ze_reduce_entry.hpp @@ -0,0 +1,105 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "common/utils/buffer.hpp" +#include "comp/comp.hpp" +#include "sched/entry/gpu/ze_base_entry.hpp" + +#include +#include + +class ze_reduce_entry : public ze_base_entry { +public: + static constexpr const char* class_name() noexcept { + return "ZE_REDUCE"; + } + + const char* name() const noexcept override { + return class_name(); + } + + ze_reduce_entry() = delete; + explicit ze_reduce_entry(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t cnt, + const ccl_datatype& dtype, + ccl::reduction op, + int root, + ccl_comm* comm); + ~ze_reduce_entry(); + + void init(); + void start() override; + void update() override; + void finalize(); + + void reset_sync_objects(); + + bool is_strict_order_satisfied() override { + return (status >= ccl_sched_entry_status_complete); + } + +protected: + void dump_detail(std::stringstream& str) const override { + ccl_logger::format(str, + "dt ", + ccl::global_data::get().dtypes->name(dtype), + ", cnt ", + cnt, + ", send_buf ", + send_buf, + ", recv_buf ", + recv_buf, + ", op ", + ccl_reduction_to_str(op), + ", comm_id ", + sched->get_comm_id(), + ", context ", + context, + "\n"); + } + +private: + ccl_buffer send_buf; + ccl_buffer recv_buf; + void* send_buf_ptr; + void* recv_buf_ptr; + void* right_send_buf_ptr; + void* tmp_buf_ptr; + const unsigned long cnt; + const ccl_datatype dtype; + const ccl::reduction op; + int root; + const size_t buf_size_bytes; + bool is_initialized; + + ze_event_handle_t empty_kernel_event; + ze_event_handle_t copy_from_peer_event; + + ze_module_handle_t module; + + ze_group_count_t group_count; + + ze_kernel_handle_t main_kernel; + std::string main_kernel_name; + + ze_kernel_handle_t empty_kernel; + std::string empty_kernel_name; + + ze_device_mem_alloc_desc_t device_mem_alloc_desc; +}; diff --git a/src/sched/entry/l0/l0_allgather_handles_entry.hpp b/src/sched/entry/l0/l0_allgather_handles_entry.hpp deleted file mode 100644 index df5a2c7d6..000000000 --- a/src/sched/entry/l0/l0_allgather_handles_entry.hpp +++ /dev/null @@ -1,309 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include -#include -#include "oneapi/ccl/types.hpp" -#include "common/datatype/datatype.hpp" -#include "comp/comp.hpp" -#include "common/comm/l0/devices/devices_declaration.hpp" -#include "sched/entry/coll/direct/base_coll_entry.hpp" - -#include "common/comm/l0/context/device_storage.hpp" -#include "common/comm/host_communicator/host_communicator.hpp" -namespace native { - -template -class l0_allgather_handles_entry : public base_coll_entry { -public: - using dependent_entry = from_entry; - using gpu_comm = typename dependent_entry::gpu_comm; - using processing_type = typename dependent_entry::processing_type; - - friend class ccl_gpu_comm; - - static constexpr const char* class_name() noexcept { - return "L0_ALLGATHER_HANDLES"; - } - - static constexpr ccl_coll_type type() noexcept { - return ccl_coll_allgatherv; //TODO - } - - static constexpr ccl_coll_type dependent_type() noexcept { - return dependent_entry::type(); - } - - static constexpr ccl::group_split_type dependent_topology() { - return dependent_entry::get_topology(); - } - - static constexpr ccl::device_topology_type dependent_topology_class() { - return dependent_entry::get_topology_class(); - } - - l0_allgather_handles_entry() = delete; - - l0_allgather_handles_entry(ccl_sched* sched, - std::shared_ptr comm, - std::shared_ptr ccl_comm, - device_storage& global_device_storage, - ccl_driver_context_ptr in_ctx, - std::vector&& send_data) - : base_coll_entry(sched), - comm_addr( - comm->template get_comm_data()), - ccl_communicator(ccl_comm), - node_device_storage(global_device_storage), - send_handles(std::move(send_data)), - ctx(in_ctx) { - LOG_DEBUG(class_name(), " entry req ", &req, ", rank: ", comm_addr.to_string()); - } - - void start() override { - int comm_size = ccl_communicator->size(); - LOG_INFO(class_name(), " entry req ", &req, ", rank: ", comm_addr.to_string()); - - // serialize data for native allgather algo - plain_send_data.clear(); - constexpr size_t handle_size = - ccl_device::device_ipc_memory_handle::get_size_for_serialize(); - size_t send_bytes = handle_size * send_handles.size() + sizeof(size_t); - plain_send_data.resize(send_bytes); - - // fill send_buf - size_t serialize_offset = 0; - *(reinterpret_cast(plain_send_data.data())) = comm_addr.rank; - serialize_offset += sizeof(size_t); - for (auto& ipc_handle : send_handles) { - serialize_offset += ipc_handle.serialize(plain_send_data, serialize_offset); - } - - CCL_ASSERT(serialize_offset == send_bytes, - "Expected data to send and actually serialized are differ"); - - //prepare recv_buf - plain_recv_data.resize(send_bytes * (comm_size)); //all others and me - recv_bytes.resize(comm_size); - std::fill(recv_bytes.begin(), recv_bytes.end(), send_bytes); - - int step = 0; - offsets.resize(comm_size); - std::generate(offsets.begin(), offsets.end(), [&step, send_bytes] { - int prev = step; - step += send_bytes; - return prev; - }); - - LOG_INFO(class_name(), - " entry req ", - &req, - ", send_bytes ", - send_bytes, - ", waiting recv_bytes: ", - plain_recv_data.size()); - - ccl::stream::impl_value_t empty{}; - event = ccl_communicator->allgatherv_impl((int8_t*)plain_send_data.data(), - send_bytes, - (int8_t*)plain_recv_data.data(), - recv_bytes, - empty, - ccl::default_allgatherv_attr, - {}); - status = ccl_sched_entry_status_started; - - //TODO prepare foreign_device_ipc_mem_storage handles array - } - - void update() override { - if (event.test()) { - LOG_DEBUG(class_name(), - " entry req ", - &req, - ", rank: ", - comm_addr.to_string(), - "gathering completed"); - - //TODO - /* - std::stringstream ss; - std::copy(plain_recv_data.begin(), plain_recv_data.end(), std::ostream_iterator(ss, ",")); - LOG_INFO(class_name(), " recevied: ", ss.str()); - */ - //get ipc handles - size_t recv_data_size = plain_recv_data.size(); - const uint8_t* recv_data_start = plain_recv_data.data(); - - //TODO make preallocation for ipc_memory_container in start or in cnstructor!!! - while (recv_data_size > 0) { - size_t received_rank_idx = *(reinterpret_cast(recv_data_start)); - recv_data_start += sizeof(size_t); - recv_data_size -= sizeof(size_t); - LOG_DEBUG( - "Received IPC rank: ", received_rank_idx, ", on rank: ", comm_addr.to_string()); - - //TODO - size_t num_handles = 0; - while (num_handles < send_handles.size()) { - /* TODO - do not deserilize own rank IPC handles, just skip all - * Current deserizliation just for testing - */ - auto recv_ip_handle = ccl_device::device_ipc_memory_handle::deserialize< - ccl_device::device_ipc_memory_handle>( - &recv_data_start, recv_data_size, ctx, get_platform()); - - std::shared_ptr ipc_mem_owner; - { - auto acc = node_device_storage.get_node_storage(); - auto& device_cont = acc.get(); - auto& ipc_device_cont = - ccl_tuple_get>(device_cont); - - auto ipc_device_cont_it = ipc_device_cont.find(received_rank_idx); - if (ipc_device_cont_it == ipc_device_cont.end()) { - if (received_rank_idx != comm_addr.rank) { - LOG_ERROR("No device owner for ipc handle detected, ipc handle: ", - native::to_string(recv_ip_handle->get()), - ", suggested device handle: ", - *recv_ip_handle->get_owner().lock(), - ", suggested rank: ", - received_rank_idx, - ". Please check your configuration setup"); - - status = ccl_sched_entry_status_failed; - abort(); - return; - } - LOG_INFO("Find own gpu device, skip"); - num_handles++; - continue; - } - - ipc_mem_owner = ipc_device_cont_it->second; - } - - LOG_DEBUG("Find gpu device: ", - ipc_mem_owner->to_string(), - ", IPC handle: ", - native::to_string(recv_ip_handle->get())); - - // create IPC memory object & remember in shared storage - - // TODO: resolve issue to provide ctx correctly - std::shared_ptr ctx; - foreign_device_ipc_mem_storage[ipc_mem_owner].push_back( - ipc_mem_owner->get_device().get_ipc_memory(std::move(recv_ip_handle), ctx)); - - num_handles++; - } - } - - LOG_INFO("All handles deserialized. Start ipc kernel arguments binding", - ", rank: ", - comm_addr.to_string()); - for (auto& dev_handle_pair : foreign_device_ipc_mem_storage) { - auto& ipc_device = dev_handle_pair.first; - ipc_memory_container& handles = dev_handle_pair.second; - - LOG_INFO( - "Bind kernel arguments: ", handles.size(), ", rank: ", comm_addr.to_string()); - CCL_ASSERT(handles.size() == send_handles.size(), - "Received unexpected memory handles count"); - - //Bind - - using kernel_ipc_typed = typename dependent_entry::kernel_ipc_typed; - kernel_ipc_typed& unreach_rank_main_func = - ipc_device->get_gpu_kernel(); - - typename kernel_ipc_typed::tmp_recv_buf_arg_type tmp_recv_buf = - reinterpret_cast( - handles.at(0).get().pointer); - unreach_rank_main_func - .template set_arg(tmp_recv_buf); - - typename kernel_ipc_typed::income_data_flag_arg_type inc = - reinterpret_cast( - handles.at(1).get().pointer); - unreach_rank_main_func - .template set_arg(inc); - - typename kernel_ipc_typed::ready_to_recv_flag_arg_type ready = - reinterpret_cast( - handles.at(2).get().pointer); - unreach_rank_main_func - .template set_arg(ready); - } - - status = ccl_sched_entry_status_complete; - } - else { - LOG_TRACE(class_name(), - " entry req ", - &req, - ", rank: ", - comm_addr.to_string(), - " is not ready yet"); - } - } - - const char* name() const override { - return class_name(); - } - -protected: - void dump_detail(std::stringstream& str) const override { - ccl_logger::format(str, - class_name(), - ", dt ", - ccl::global_data::get().dtypes->name(dtype), - ", cnt ", - cnt, - ", comm_id ", - sched->coll_param.comm->id(), - ", req ", - &req, - "\n"); - } - -private: - topology_addr comm_addr; - std::shared_ptr ccl_communicator; - device_storage& node_device_storage; - - std::vector send_handles; - std::vector plain_send_data; - std::vector plain_recv_data; - std::vector recv_bytes; - std::vector offsets; - - using ipc_memory_container = std::vector; - std::map, ipc_memory_container> - foreign_device_ipc_mem_storage; - size_t cnt; - ccl_datatype dtype; - - ccl::event event; - atl_req_t req{}; - - ccl_driver_context_ptr ctx; -}; -} // namespace native diff --git a/src/sched/entry/l0/l0_allgatherv_typed_entry.hpp b/src/sched/entry/l0/l0_allgatherv_typed_entry.hpp deleted file mode 100644 index 805464937..000000000 --- a/src/sched/entry/l0/l0_allgatherv_typed_entry.hpp +++ /dev/null @@ -1,265 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include -#include - -#include "sched/entry/l0/l0_entry.hpp" - -//TODO L0 Workaround - -namespace native { -template -class l0_allgatherv_typed_entry : public base_gpu_entry { -public: - friend class ccl_gpu_comm; - friend class ccl_virtual_gpu_comm; - - using base = base_gpu_entry; - using base::parent_communicator; - using base::comm_addr; - using base::req; - using base::status; - using base::launch_args; - using base::kernel_router; - using base::get_ctx; - using base::get_local_kernel; - using kernel_main_typed = ring::allgatherv::main_kernel; - using processing_type = void; - - using income_data_flag_gpu_type = - typename std::remove_pointer::type; - using ready_to_recv_flag_gpu_type = - typename std::remove_pointer::type; - using recv_counts_typed_entry_type = typename std::remove_pointer< - typename ring::allgatherv::recv_elem_counts_buf_arg_type>::type; - using recv_offsets_typed_entry_type = typename std::remove_pointer< - typename ring::allgatherv::recv_elem_offsets_buf_arg_type>::type; - - static constexpr const char* class_name() noexcept { - return "L0_ALLGATHERV_TYPED"; - } - - static constexpr ccl_coll_type type() noexcept { - return ccl_coll_allgatherv; - } - - l0_allgatherv_typed_entry() = delete; - l0_allgatherv_typed_entry( - ccl_sched* sched, - std::shared_ptr comm, - specific_indexed_device_storage& available_devices, - ccl_driver_context_ptr in_ctx, - const ccl_buffer send_buf, - size_t send_count, - ccl_buffer recv_buf, - const size_t* recv_counts, - const coll_param_gpu& params, - std::shared_ptr device_stream = std::shared_ptr()) - : base(sched, comm, in_ctx, send_buf, params, device_stream), - // left_wrote_to_me_flag - income_data_flag(this->template alloc_memory_wrap( - typename ring::allgatherv::income_data_flag_arg{}, - parent_communicator, - 1, - get_ctx())), - // ready_to_recv_flag_arg - ready_to_recv_flag(this->template alloc_memory_wrap( - typename ring::allgatherv::ready_to_recv_flag_arg{}, - parent_communicator, - 1, - get_ctx())), - recv_counts_buf(parent_communicator->get_device() - .template alloc_memory( - comm_addr.size, - sizeof(recv_counts_typed_entry_type), - get_ctx())), - - recv_offsets_buf(parent_communicator->get_device() - .template alloc_memory( - comm_addr.size, - sizeof(recv_offsets_typed_entry_type), - get_ctx())) - - { - // copy recv_buf into alloced recv_buf_entry - recv_buf_entry = recv_buf; - cnt_entry = send_count; - // same as parent_communicator->template - // get_comm_data().size; - int local_topology_size = comm_addr.size; - std::vector recv_offsets_v(local_topology_size, 0); - - for (int idx = 0; idx < local_topology_size; idx++) { - if (idx > 0) - recv_offsets_v[idx] += recv_offsets_v[idx - 1] + recv_counts[idx - 1]; - } - - recv_counts_buf.enqueue_write_sync(recv_counts, local_topology_size); - recv_offsets_buf.enqueue_write_sync(recv_offsets_v); - - int next_rank = (comm_addr.rank + 1) % comm_addr.size; - kernel_router = base::template create_kernel_router_for_rank< - l0_allgatherv_typed_entry>( - *this, next_rank, available_devices, base::get_params()); - - ENTRY_LOG_DEBUG("Init phase of current entry for ext_rank:", next_rank); - - // Once we filled our local parameters, we go wait for another entry to set its - // parameters so we can use them - this->set_state(gpu_entry_state::created); - } - - ~l0_allgatherv_typed_entry() { - // TODO: remove the memory once the entry is destroyed if it's not cleared automatically - // TODO: should we destroy handles here? - } - - void start() override { - ENTRY_LOG_DEBUG("Start entry, cnt ", cnt_entry); - //Create base primitives - base::start(); - - auto& main_entry_function = get_local_kernel(); - - auto recv_buf_ptr = reinterpret_cast(recv_buf_entry.get_ptr()); - - //create implementation specified primitives - main_entry_function - .template set_args, - typename ring::allgatherv::recv_elem_counts_buf_arg, - typename ring::allgatherv::recv_elem_offsets_buf_arg, - typename kernel_main_typed::common_entry_buf_size_arg>( - income_data_flag.get(), - ready_to_recv_flag.get(), - recv_buf_ptr, - recv_counts_buf.get(), - recv_offsets_buf.get(), - cnt_entry); - - // Once we filled our local parameters, we go wait for another entry to set its - // parameters so we can use them - this->set_state(gpu_entry_state::wait_for_entry); - - //make sure, that kernel ready for launch - this->submit_for_execution(); - status = ccl_sched_entry_status_started; - } - - const char* name() const override { - return class_name(); - } - - std::vector get_ipc_data() override { - ccl_device& owned_device = parent_communicator->get_device(); - - auto recv_buf_ptr = reinterpret_cast(recv_buf_entry.get_ptr()); - - std::vector ret; - ret.reserve(3); - ret.push_back(owned_device.create_ipc_memory_handle(recv_buf_ptr, get_ctx())); - ret.push_back(owned_device.create_ipc_memory_handle(income_data_flag.get(), get_ctx())); - ret.push_back(owned_device.create_ipc_memory_handle(ready_to_recv_flag.get(), get_ctx())); - return ret; - } - -protected: - void dump_detail(std::stringstream& str) const override { - base::dump_detail(str); - } - -private: - ccl_device::device_memory income_data_flag; - ccl_device::device_memory ready_to_recv_flag; - ccl_buffer recv_buf_entry; - ccl_device::device_memory recv_counts_buf; - ccl_device::device_memory recv_offsets_buf; - size_t cnt_entry; - std::shared_ptr ctx; - -public: - template - bool execute(left_kernel_t& left_kernel, right_kernel_t& right_kernel) { - bool is_right_kernel_ready = - right_kernel - .template test_args, - typename ring::allgatherv::income_data_flag_arg, - typename ring::allgatherv::ready_to_recv_flag_arg>(); - - // Once we're sure that the parameters ready read them from the right kernel - // Note: we not only read the parameters but also reset their 'ready' flag - // (since we're using a destructive-copying policy) meaning that they must be stored - // in order to be read again. - // This is a protection to a case of multiple kernel launches - // (i.e. the collective is ran multiple times) where we might read not up-to-date - // values from the previous run. - - if (is_right_kernel_ready) { - auto right_recv_buf_arg = - right_kernel - .template get_arg>(); - auto right_income_data_flag_arg = - right_kernel.template get_arg(); - auto right_ready_to_recv_flag_arg = - right_kernel.template get_arg(); - - // ENTRY_LOG_DEBUG("Bind right arguments from ", - // right_kernel_t::name(), - // " kernel", - // " to ", - // left_kernel_t::name(), - // " kernel. " - // "Right arguments:\n{ ", - // right_recv_buf_arg.first, - // ", ", - // right_recv_buf_arg.second, - // "}\n", - // "{ ", - // right_income_data_flag_arg.first, - // ", ", - // right_income_data_flag_arg.second, - // "}\n", - // "{ ", - // right_ready_to_recv_flag_arg.first, - // ", ", - // right_ready_to_recv_flag_arg.second, - // "}\n"); - - left_kernel - .template set_args, - typename ring::allgatherv::right_income_data_flag_arg, - typename ring::allgatherv::right_ready_to_recv_flag_arg>( - right_recv_buf_arg.second, - right_income_data_flag_arg.second, - right_ready_to_recv_flag_arg.second); - - ENTRY_LOG_DEBUG("Binding arguments between kernels is complete. ", - "Arguments of the left kernel after binding:\n", - left_kernel.to_string()); - } - return is_right_kernel_ready; - } -}; -} // namespace native diff --git a/src/sched/entry/l0/l0_allreduce_typed_entry.hpp b/src/sched/entry/l0/l0_allreduce_typed_entry.hpp deleted file mode 100644 index f0a1395d2..000000000 --- a/src/sched/entry/l0/l0_allreduce_typed_entry.hpp +++ /dev/null @@ -1,254 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include -#include - -#include "sched/entry/l0/l0_entry.hpp" -#include "common/comm/l0/context/scale/ipc/ipc_ctx_impl.hpp" -#include "kernels/shared.h" - -namespace native { -template -class l0_allreduce_typed_entry : public base_gpu_entry { -public: - friend class ccl_gpu_comm; - friend class ccl_virtual_gpu_comm; - - using base = base_gpu_entry; - using base::parent_communicator; - using base::comm_addr; - using base::req; - using base::status; - using base::launch_args; - using base::kernel_router; - using base::get_ctx; - using base::alloc_memory_wrap; - using base::get_local_kernel; - using kernel_main_typed = ring::allreduce::main_kernel; - - using income_data_flag_gpu_type = - typename std::remove_pointer::type; - using ready_to_recv_flag_gpu_type = - typename std::remove_pointer::type; - using local_barrier_flag_gpu_type = - typename std::remove_pointer::type; - - static constexpr const char* class_name() noexcept { - return "L0_ALLREDUCE_TYPED"; - } - - static constexpr ccl_coll_type type() noexcept { - return ccl_coll_allreduce; - } - - l0_allreduce_typed_entry() = delete; - l0_allreduce_typed_entry( - ccl_sched* sched, - std::shared_ptr comm, - specific_indexed_device_storage& available_devices, - ccl_driver_context_ptr in_ctx, - const ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t cnt, - const coll_param_gpu& params, - std::shared_ptr device_stream = std::shared_ptr()) - : base(sched, comm, in_ctx, send_buf, params, device_stream), - - temp_buffer(this->template alloc_memory_wrap( - typename ring::allreduce::tmp_recv_buf_arg{}, - parent_communicator, - ring_allreduce_get_tmp_buffer_size( - ccl::get_datatype_size(params.get_datatype()) * cnt, - base::comm_addr.size), - get_ctx())), - income_data_flag( - this->template alloc_memory_wrap(typename ring::allreduce::income_data_flag_arg{}, - parent_communicator, - 1, - get_ctx())), - ready_to_recv_flag(this->template alloc_memory_wrap( - typename ring::allreduce::ready_to_recv_flag_arg{}, - parent_communicator, - 1, - get_ctx())), - local_barrier_flag(parent_communicator->get_device() - .template alloc_memory( - 1, - sizeof(local_barrier_flag_gpu_type), - get_ctx())) { - recv_buf_typed_entry = recv_buf; - cnt_entry = cnt; - - int next_rank = (comm_addr.rank + 1) % comm_addr.size; - kernel_router = base::template create_kernel_router_for_rank< - l0_allreduce_typed_entry>( - *this, next_rank, available_devices, base::get_params()); - - ENTRY_LOG_DEBUG("Init phase of current entry for ext_rank:", next_rank); - - this->set_state(gpu_entry_state::created); - } - - ~l0_allreduce_typed_entry() { - // TODO: remove the memory once the entry is destroyed if it's not cleared automatically - // TODO: should we destroy handles here? - } - - void start() override { - ENTRY_LOG_DEBUG("Start entry, cnt ", cnt_entry); - - //Create base primitives - base::start(); - - auto& main_entry_function = get_local_kernel(); - - // TODO: try to remove indirect buffer - void* recv_buf_ptr = recv_buf_typed_entry.get_ptr(); - - //create implementation specified primitives - main_entry_function - .template set_args, - typename ring::allreduce::income_data_flag_arg, - typename ring::allreduce::ready_to_recv_flag_arg, - typename ring::allreduce::local_barrier_flag_arg, - typename ring::allreduce::recv_buf_arg, - typename kernel_main_typed::common_entry_buf_size_arg>( - temp_buffer.get(), - income_data_flag.get(), - ready_to_recv_flag.get(), - local_barrier_flag.get(), - recv_buf_ptr, - cnt_entry); - - // Once we filled our local parameters, we go wait for another entry to set its - // parameters so we can use them - this->set_state(gpu_entry_state::wait_for_entry); - - //make sure, that kernel ready for launch - this->submit_for_execution(); - status = ccl_sched_entry_status_started; - } - - const char* name() const override { - return class_name(); - } - - std::vector get_ipc_data() override { - ccl_device& owned_device = parent_communicator->get_device(); - - std::vector ret; - ret.reserve(3); - ret.push_back(owned_device.create_ipc_memory_handle(temp_buffer.get(), get_ctx())); - ret.push_back(owned_device.create_ipc_memory_handle(income_data_flag.get(), get_ctx())); - ret.push_back(owned_device.create_ipc_memory_handle(ready_to_recv_flag.get(), get_ctx())); - return ret; - } - - observer::invoke_params get_numa_data() override { - observer::producer_description in_params{ - .rank = comm_addr.rank, //TODO unused - .comm_size = comm_addr.size, //TODO unused - .staged_buffer_elem_count = cnt_entry, - .context = get_ctx(), - .device = parent_communicator->get_device(), - .immediate_list = parent_communicator->get_device().create_immediate_cmd_list(get_ctx()) - }; - // TODO: Should get_params() be a part of in_params? - return observer::invoke_params(std::move(in_params), base::get_params()); - } - -protected: - void dump_detail(std::stringstream& str) const override { - base::dump_detail(str); - } - -private: - ccl_device::device_memory temp_buffer; - ccl_device::device_memory income_data_flag; - ccl_device::device_memory ready_to_recv_flag; - ccl_device::device_memory local_barrier_flag; - ccl_buffer recv_buf_typed_entry; - size_t cnt_entry; - -public: - template - bool execute(left_kernel_t& left_kernel, right_kernel_t& right_kernel) { - bool is_right_kernel_ready = - right_kernel.template test_args, - typename ring::allreduce::income_data_flag_arg, - typename ring::allreduce::ready_to_recv_flag_arg>(); - - // Once we're sure that the parameters ready read them from the right kernel - // Note: we not only read the parameters but also reset their 'ready' flag - // (since we're using a destructive-copying policy) meaning that they must be stored - // in order to be read again. - // This is a protection to a case of multiple kernel launches - // (i.e. the collective is ran multiple times) where we might read not up-to-date - // values from the previous run. - - if (is_right_kernel_ready) { - auto right_tmp_recv_buf_arg = - right_kernel.template get_arg>(); - auto right_income_data_flag_arg = - right_kernel.template get_arg(); - auto right_ready_to_recv_flag_arg = - right_kernel.template get_arg(); - - /*ENTRY_LOG_DEBUG("Bind right arguments from ", - right_kernel_t::name(), - " kernel", - " to ", - left_kernel_t::name(), - " kernel. " - "Right arguments:\n{ ", - right_tmp_recv_buf_arg.first, - ", ", - right_tmp_recv_buf_arg.second, - "}\n", - "{ ", - right_income_data_flag_arg.first, - ", ", - right_income_data_flag_arg.second, - "}\n", - "{ ", - right_ready_to_recv_flag_arg.first, - ", ", - right_ready_to_recv_flag_arg.second, - "}\n");*/ - - left_kernel.template set_args, - typename ring::allreduce::right_income_data_flag_arg, - typename ring::allreduce::right_ready_to_recv_flag_arg>( - right_tmp_recv_buf_arg.second, - right_income_data_flag_arg.second, - right_ready_to_recv_flag_arg.second); - - ENTRY_LOG_DEBUG("Binding arguments between kernels is complete. ", - "Arguments of the left kernel after binding:\n", - left_kernel.to_string()); - } - return is_right_kernel_ready; - } -}; -} // namespace native diff --git a/src/sched/entry/l0/l0_alltoallv_typed_entry.hpp b/src/sched/entry/l0/l0_alltoallv_typed_entry.hpp deleted file mode 100644 index 8970e06a8..000000000 --- a/src/sched/entry/l0/l0_alltoallv_typed_entry.hpp +++ /dev/null @@ -1,318 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include -#include - -#include "sched/entry/l0/l0_entry.hpp" - -//TODO L0 Workaround - -namespace native { -template -class l0_alltoallv_typed_entry : public base_gpu_entry { -public: - friend class ccl_gpu_comm; - friend class ccl_virtual_gpu_comm; - - using base = base_gpu_entry; - using base::parent_communicator; - using base::comm_addr; - using base::req; - using base::status; - using base::launch_args; - using base::kernel_router; - using base::get_ctx; - using base::get_local_kernel; - using kernel_main_typed = ring::alltoallv::main_kernel; - - using income_data_flag_gpu_type = - typename std::remove_pointer::type; - using ready_to_recv_flag_gpu_type = - typename std::remove_pointer::type; - - using recv_counts_typed_entry_type = - typename std::remove_pointer::type; - using recv_offsets_typed_entry_type = typename std::remove_pointer< - typename ring::alltoallv::recv_elem_offsets_buf_arg_type>::type; - - using proxy_size_flag_gpu_type = - typename std::remove_pointer::type; - - using send_counts_typed_entry_type = - typename std::remove_pointer::type; - using send_offsets_typed_entry_type = typename std::remove_pointer< - typename ring::alltoallv::send_elem_offsets_buf_arg_type>::type; - - static constexpr const char* class_name() noexcept { - return "L0_ALLTOALLV_TYPED"; - } - - static constexpr ccl_coll_type type() noexcept { - return ccl_coll_alltoallv; - } - - l0_alltoallv_typed_entry() = delete; - l0_alltoallv_typed_entry( - ccl_sched* sched, - std::shared_ptr comm, - specific_indexed_device_storage& available_devices, - ccl_driver_context_ptr in_ctx, - const ccl_buffer send_buf, - const size_t* send_counts, - size_t total_send_counts, - ccl_buffer recv_buf, - const size_t* recv_counts, - size_t total_recv_counts, - const coll_param_gpu& params, - std::shared_ptr device_stream = std::shared_ptr()) - : base(sched, comm, in_ctx, send_buf, params, device_stream), - temp_buffer(this->template alloc_memory_wrap( - typename ring::alltoallv::tmp_recv_buf_arg{}, - parent_communicator, - total_recv_counts, - get_ctx())), - // left_wrote_to_me_flag - income_data_flag( - this->template alloc_memory_wrap(typename ring::alltoallv::income_data_flag_arg{}, - parent_communicator, - 1, - get_ctx())), - // ready_to_recv_flag_arg - ready_to_recv_flag(this->template alloc_memory_wrap( - typename ring::alltoallv::ready_to_recv_flag_arg{}, - parent_communicator, - 1, - get_ctx())), - proxy_size_flag_entry( - this->template alloc_memory_wrap(typename ring::alltoallv::proxy_size_flag_arg{}, - parent_communicator, - 1, - get_ctx())), - recv_counts_buf(parent_communicator->get_device() - .template alloc_memory( - total_recv_counts, - sizeof(recv_counts_typed_entry_type), - get_ctx())), - recv_offsets_buf(parent_communicator->get_device() - .template alloc_memory( - comm_addr.size, - sizeof(recv_offsets_typed_entry_type), - get_ctx())), - send_counts_buf(parent_communicator->get_device() - .template alloc_memory( - total_send_counts, - sizeof(recv_counts_typed_entry_type), - get_ctx())), - send_offsets_buf(parent_communicator->get_device() - .template alloc_memory( - comm_addr.size, - sizeof(send_offsets_typed_entry_type), - get_ctx())) { - // copy recv_buf into recv_buf_entry - recv_buf_entry = recv_buf; - - // same as parent_communicator->template - // get_comm_data().size; - int local_topology_size = comm_addr.size; - std::vector recv_offsets_v(local_topology_size, 0); - - for (int idx = 0; idx < local_topology_size; idx++) { - if (idx > 0) - recv_offsets_v[idx] += recv_offsets_v[idx - 1] + recv_counts[idx - 1]; - } - - std::vector send_offsets_v(local_topology_size, 0); - for (int idx = 0; idx < local_topology_size; idx++) { - if (idx > 0) - send_offsets_v[idx] += send_offsets_v[idx - 1] + send_counts[idx - 1]; - } - // recv - recv_counts_buf.enqueue_write_sync(recv_counts, local_topology_size); - recv_offsets_buf.enqueue_write_sync(recv_offsets_v); - // send - send_counts_buf.enqueue_write_sync(send_counts, local_topology_size); - send_offsets_buf.enqueue_write_sync(send_offsets_v); - // flag - proxy_size_flag_entry.enqueue_write_sync({ (int)0 }); - - int next_rank = (comm_addr.rank + 1) % comm_addr.size; - kernel_router = base::template create_kernel_router_for_rank< - l0_alltoallv_typed_entry>( - *this, next_rank, available_devices, base::get_params()); - - ENTRY_LOG_DEBUG("Init phase of current entry for ext_rank:", next_rank); - - // Once we filled our local parameters, we go wait for another entry to set its - // parameters so we can use them - this->set_state(gpu_entry_state::created); - } - - ~l0_alltoallv_typed_entry() { - // TODO: remove the memory once the entry is destroyed if it's not cleared automatically - // TODO: should we destroy handles here? - } - - void start() override { - LOG_DEBUG(class_name(), " entry req ", &req, ", rank: ", comm_addr.to_string()); - - //Create base primitives - base::start(); - - auto& main_entry_function = get_local_kernel(); - - auto recv_buf_ptr = reinterpret_cast(recv_buf_entry.get_ptr()); - - //create implementation specified primitives - main_entry_function.template set_args, - typename ring::alltoallv::income_data_flag_arg, - typename ring::alltoallv::ready_to_recv_flag_arg, - typename ring::alltoallv::recv_buf_arg, - typename ring::alltoallv::recv_elem_counts_buf_arg, - typename ring::alltoallv::recv_elem_offsets_buf_arg, - typename ring::alltoallv::proxy_size_flag_arg, - typename ring::alltoallv::send_buf_size_arg>( - temp_buffer.get(), - income_data_flag.get(), - ready_to_recv_flag.get(), - recv_buf_ptr, - recv_counts_buf.get(), - recv_offsets_buf.get(), - proxy_size_flag_entry.get(), - send_counts_buf.get()); - - // Once we filled our local parameters, we go wait for another entry to set its - // parameters so we can use them - this->set_state(gpu_entry_state::wait_for_entry); - - //make sure, that kernel ready for launch - this->submit_for_execution(); - status = ccl_sched_entry_status_started; - } - - const char* name() const override { - return class_name(); - } - - std::vector get_ipc_data() override { - ccl_device& owned_device = parent_communicator->get_device(); - - std::vector ret; - ret.reserve(4); - ret.push_back(owned_device.create_ipc_memory_handle(temp_buffer.get(), get_ctx())); - ret.push_back(owned_device.create_ipc_memory_handle(income_data_flag.get(), get_ctx())); - ret.push_back(owned_device.create_ipc_memory_handle(ready_to_recv_flag.get(), get_ctx())); - ret.push_back( - owned_device.create_ipc_memory_handle(proxy_size_flag_entry.get(), get_ctx())); - - return ret; - } - -protected: - void dump_detail(std::stringstream& str) const override { - base::dump_detail(str); - } - -private: - ccl_device::device_memory temp_buffer; - ccl_device::device_memory income_data_flag; - ccl_device::device_memory ready_to_recv_flag; - ccl_device::device_memory proxy_size_flag_entry; - ccl_buffer recv_buf_entry; - ccl_device::device_memory recv_counts_buf; - ccl_device::device_memory recv_offsets_buf; - ccl_device::device_memory send_counts_buf; - ccl_device::device_memory send_offsets_buf; - std::shared_ptr ctx; - -public: - template - bool execute(left_kernel_t& left_kernel, right_kernel_t& right_kernel) { - bool is_right_kernel_ready = - right_kernel.template test_args, - typename ring::alltoallv::income_data_flag_arg, - typename ring::alltoallv::ready_to_recv_flag_arg, - typename ring::alltoallv::proxy_size_flag_arg>(); - - // Once we're sure that the parameters ready read them from the right kernel - // Note: we not only read the parameters but also reset their 'ready' flag - // (since we're using a destructive-copying policy) meaning that they must be stored - // in order to be read again. - // This is a protection to a case of multiple kernel launches - // (i.e. the collective is ran multiple times) where we might read not up-to-date - // values from the previous run. - - if (is_right_kernel_ready) { - auto right_tmp_recv_buf_arg = - right_kernel.template get_arg>(); - auto right_income_data_flag_arg = - right_kernel.template get_arg(); - auto right_ready_to_recv_flag_arg = - right_kernel.template get_arg(); - auto right_proxy_size_flag_arg = - right_kernel.template get_arg(); - - // ENTRY_LOG_DEBUG("Bind right arguments from ", - // right_kernel_t::name(), - // " kernel", - // " to ", - // left_kernel_t::name(), - // " kernel. " - // "Right arguments:\n{ ", - // right_tmp_recv_buf_arg.first, - // ", ", - // right_tmp_recv_buf_arg.second, - // "}\n", - // "{ ", - // right_income_data_flag_arg.first, - // ", ", - // right_income_data_flag_arg.second, - // "}\n", - // "{ ", - // right_ready_to_recv_flag_arg.first, - // ", ", - // right_ready_to_recv_flag_arg.second, - // "}\n", - // "{ ", - // right_proxy_size_flag_arg.first, - // ", ", - // right_proxy_size_flag_arg.second, - // "}\n"); - - left_kernel.template set_args, - typename ring::alltoallv::right_income_data_flag_arg, - typename ring::alltoallv::right_ready_to_recv_flag_arg, - typename ring::alltoallv::right_proxy_size_flag_arg>( - right_tmp_recv_buf_arg.second, - right_income_data_flag_arg.second, - right_ready_to_recv_flag_arg.second, - right_proxy_size_flag_arg.second); - - ENTRY_LOG_DEBUG("Binding arguments between kernels is complete. ", - "Arguments of the left kernel after binding:\n", - left_kernel.to_string()); - } - return is_right_kernel_ready; - } -}; -} // namespace native diff --git a/src/sched/entry/l0/l0_bcast_typed_entry.hpp b/src/sched/entry/l0/l0_bcast_typed_entry.hpp deleted file mode 100644 index c40b7ddd3..000000000 --- a/src/sched/entry/l0/l0_bcast_typed_entry.hpp +++ /dev/null @@ -1,228 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include - -#include "sched/entry/l0/l0_entry.hpp" - -//TODO L0 Workaround - -namespace native { -template -class l0_bcast_typed_entry : public base_gpu_entry { -public: - friend class ccl_gpu_comm; - friend class ccl_virtual_gpu_comm; - - using base = - base_gpu_entry; - using base::parent_communicator; - using base::comm_addr; - using base::req; - using base::status; - using base::launch_args; - using base::kernel_router; - using base::get_ctx; - using base::get_local_kernel; - using kernel_main_typed = ring::bcast::main_kernel; - using processing_type = void; - - using income_data_flag_gpu_type = - typename std::remove_pointer::type; - using ready_to_recv_flag_gpu_type = - typename std::remove_pointer::type; - using local_barrier_flag_gpu_type = - typename std::remove_pointer::type; - - static constexpr const char* class_name() noexcept { - return "L0_BCAST_TYPED"; - } - - static constexpr ccl_coll_type type() noexcept { - return ccl_coll_bcast; - } - - l0_bcast_typed_entry() = delete; - l0_bcast_typed_entry(ccl_sched* sched, - std::shared_ptr comm, - specific_indexed_device_storage& available_devices, - ccl_driver_context_ptr in_ctx, - ccl_buffer buf, - size_t cnt, - int root, - const coll_param_gpu& params, - std::shared_ptr device_stream = std::shared_ptr()) - : base(sched, comm, in_ctx, buf, params, device_stream), - - income_data_flag( - this->template alloc_memory_wrap(typename ring::bcast::income_data_flag_arg{}, - parent_communicator, - 1, - get_ctx())), - ready_to_recv_flag( - this->template alloc_memory_wrap(typename ring::bcast::ready_to_recv_flag_arg{}, - parent_communicator, - 1, - get_ctx())), - local_barrier_flag(parent_communicator->get_device() - .template alloc_memory( - 1, - sizeof(local_barrier_flag_gpu_type), - get_ctx())) { - root_typed_entry = root; - cnt_entry = cnt; - - int next_rank = (comm_addr.rank + 1) % comm_addr.size; - kernel_router = base::template create_kernel_router_for_rank< - l0_bcast_typed_entry>( - *this, next_rank, available_devices, base::get_params()); - - ENTRY_LOG_DEBUG("Init phase of current entry for ext_rank:", next_rank); - - // Once we filled our local parameters, we go wait for another entry to set its - // parameters so we can use them - this->set_state(gpu_entry_state::created); - } - - ~l0_bcast_typed_entry() { - // TODO: remove the memory once the entry is destroyed if it's not cleared automatically - // TODO: should we destroy handles here? - } - - void start() override { - ENTRY_LOG_DEBUG("Start entry, cnt ", cnt_entry); - - //Create base primitives - base::start(); - - auto& main_entry_function = get_local_kernel(); - - //create implementation specified primitives - main_entry_function - .template set_args( - income_data_flag.get(), - ready_to_recv_flag.get(), - local_barrier_flag.get(), - root_typed_entry, - cnt_entry); - - // Once we filled our local parameters, we go wait for another entry to set its - // parameters so we can use them - this->set_state(gpu_entry_state::wait_for_entry); - - //make sure, that kernel ready for launch - this->submit_for_execution(); - status = ccl_sched_entry_status_started; - } - - const char* name() const override { - return class_name(); - } - - std::vector get_ipc_data() override { - ccl_device& owned_device = parent_communicator->get_device(); - - auto recv_buf_ptr = reinterpret_cast(base::send_buf.get_ptr()); - - std::vector ret; - ret.reserve(3); - ret.push_back(owned_device.create_ipc_memory_handle(recv_buf_ptr, get_ctx())); - ret.push_back(owned_device.create_ipc_memory_handle(income_data_flag.get(), get_ctx())); - ret.push_back(owned_device.create_ipc_memory_handle(ready_to_recv_flag.get(), get_ctx())); - return ret; - } - -protected: - void dump_detail(std::stringstream& str) const override { - base::dump_detail(str); - } - -private: - ccl_device::device_memory income_data_flag; - ccl_device::device_memory ready_to_recv_flag; - ccl_device::device_memory local_barrier_flag; - int root_typed_entry; - size_t cnt_entry; - std::shared_ptr ctx; - -public: - template - bool execute(left_kernel_t& left_kernel, right_kernel_t& right_kernel) { - bool is_right_kernel_ready = - right_kernel.template test_args, - typename ring::bcast::income_data_flag_arg, - typename ring::bcast::ready_to_recv_flag_arg>(); - - // Once we're sure that the parameters ready read them from the right kernel - // Note: we not only read the parameters but also reset their 'ready' flag - // (since we're using a destructive-copying policy) meaning that they must be stored - // in order to be read again. - // This is a protection to a case of multiple kernel launches - // (i.e. the collective is ran multiple times) where we might read not up-to-date - // values from the previous run. - - if (is_right_kernel_ready) { - auto right_buf_arg = - right_kernel.template get_arg>(); - auto right_income_data_flag_arg = - right_kernel.template get_arg(); - auto right_ready_to_recv_flag_arg = - right_kernel.template get_arg(); - - // ENTRY_LOG_DEBUG("Bind right arguments from ", - // right_kernel_t::name(), - // " kernel", - // " to ", - // left_kernel_t::name(), - // " kernel. " - // "Right arguments:\n{ ", - // right_buf_arg.first, - // ", ", - // right_buf_arg.second, - // "}\n", - // "{ ", - // right_income_data_flag_arg.first, - // ", ", - // right_income_data_flag_arg.second, - // "}\n", - // "{ ", - // right_ready_to_recv_flag_arg.first, - // ", ", - // right_ready_to_recv_flag_arg.second, - // "}\n"); - - left_kernel.template set_args, - typename ring::bcast::right_income_data_flag_arg, - typename ring::bcast::right_ready_to_recv_flag_arg>( - right_buf_arg.second, - right_income_data_flag_arg.second, - right_ready_to_recv_flag_arg.second); - - ENTRY_LOG_DEBUG("Binding arguments between kernels is complete. ", - "Arguments of the left kernel after binding:\n", - left_kernel.to_string()); - } - return is_right_kernel_ready; - } -}; -} // namespace native diff --git a/src/sched/entry/l0/l0_entry.hpp b/src/sched/entry/l0/l0_entry.hpp deleted file mode 100644 index 7bcb1add3..000000000 --- a/src/sched/entry/l0/l0_entry.hpp +++ /dev/null @@ -1,795 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include "oneapi/ccl/types.hpp" -#include "common/datatype/datatype.hpp" -#include "oneapi/ccl/type_traits.hpp" -#include "oneapi/ccl/native_device_api/l0/primitives.hpp" -#include "common/comm/l0/modules/kernel_functions.hpp" - -#include "oneapi/ccl.hpp" - -#include "comp/comp.hpp" -#include "common/comm/l0/devices/devices_declaration.hpp" -#include "sched/entry/coll/direct/base_coll_entry.hpp" - -#include "common/comm/l0/modules_connector.hpp" -#include "common/global/global.hpp" -#include "common/stream/stream.hpp" - -#include "common/comm/l0/context/scale/ipc/ipc_session_key.hpp" -#include "common/comm/l0/context/scale/base/base_session.hpp" - -//TODO L0 Workaround -#include -static std::mutex global_fence_mutex; - -#define ENTRY_LOG_TRACE(...) \ - if (unlikely(logger.get_log_level() >= ccl_log_level::trace)) { \ - do { \ - std::stringstream ss; \ - this->dump_detail(ss); \ - logger.trace("|TRACE| ", \ - basedir_static(__FILE__), \ - ":", \ - __LINE__, \ - " ", \ - ss.str(), \ - " - ", \ - ##__VA_ARGS__); \ - } while (0); \ - } - -#define ENTRY_LOG_DEBUG(...) \ - if (unlikely(logger.get_log_level() >= ccl_log_level::debug)) { \ - do { \ - std::stringstream ss; \ - this->dump_detail(ss); \ - logger.debug("|DEBUG| ", \ - basedir_static(__FILE__), \ - ":", \ - __LINE__, \ - " ", \ - ss.str(), \ - " - ", \ - ##__VA_ARGS__); \ - } while (0); \ - } - -#define ENTRY_LOG_INFO(...) \ - if (unlikely(logger.get_log_level() >= ccl_log_level::info)) { \ - do { \ - std::stringstream ss; \ - this->dump_detail(ss); \ - logger.info("|INFO| ", \ - basedir_static(__FILE__), \ - ":", \ - __LINE__, \ - " ", \ - ss.str(), \ - " - ", \ - ##__VA_ARGS__); \ - } while (0); \ - } - -#define ENTRY_LOG_WARN(...) \ - if (unlikely(logger.get_log_level() >= ccl_log_level::warn)) { \ - do { \ - std::stringstream ss; \ - this->dump_detail(ss); \ - logger.info("|WARN| ", \ - basedir_static(__FILE__), \ - ":", \ - __LINE__, \ - " ", \ - ss.str(), \ - " - ", \ - ##__VA_ARGS__); \ - } while (0); \ - } - -#define ENTRY_LOG_ERROR(...) \ - if (unlikely(logger.get_log_level() >= ccl_log_level::error)) { \ - do { \ - std::stringstream ss; \ - this->dump_detail(ss); \ - logger.info("|ERROR| ", \ - basedir_static(__FILE__), \ - ":", \ - __LINE__, \ - " ", \ - ss.str(), \ - " - ", \ - ##__VA_ARGS__); \ - } while (0); \ - } - -// This is a internal gpu entry state to keep track of the progress -// filling and submitting a kernel as well as its execution -enum class gpu_entry_state { - // default state - initial, - // Entry is fully constructed - created, - // Local parameters are filled and the entry is waiting for - // parameters from the neighbour entry. No further progress - // until these parameters are received. After that - // the entry appends it's kernel. After that the kernel is - // submited to the queue(only one rank does that in case of - // virtual device. - wait_for_entry, - // The kernel is submited and the entry is waiting for kernel - // completion by checking fence status. - wait_for_completion, - // Execution is done, it's possible to reuse the entry by - // moving the entry to created state - completed, - // Last element in the enum, not used as state - last -}; - -inline std::string to_string(gpu_entry_state state) { - return utils::enum_to_str(gpu_entry_state::last)>{ - "initial", "created", "wait_for_entry", "wait_for_completion", "completed" - } - .choose(state); -} - -namespace native { -template -class base_gpu_entry : public sched_entry { -public: - using gpu_comm = gpu_comm_impl; - using kernel_main_typed = typename gpu_comm::template gpu_kernel_t; - using kernel_ipc_typed = - typename ccl_ipc_gpu_comm::template gpu_kernel_t; - - template - using device_memory = memory; - - friend class ccl_gpu_comm; - friend class ccl_virtual_gpu_comm; - static constexpr const char *class_name() noexcept { - return ccl_coll_type_to_str(type_op); - } - static constexpr ccl_coll_type type() noexcept { - return type_op; - } - - static constexpr ccl::group_split_type get_topology() { - return group_id; - } - - static constexpr ccl::device_topology_type get_topology_class() { - return class_id; - } - - base_gpu_entry() = delete; - base_gpu_entry(ccl_sched *sched, - std::shared_ptr comm, - ccl_driver_context_ptr in_ctx, - const ccl_buffer send_buf, - const coll_param_gpu ¶ms, - std::shared_ptr &stream) - : sched_entry(sched), - parent_communicator(comm), - comm_addr(parent_communicator - ->template get_comm_data()), - send_buf(send_buf), - params(params), - device_stream(stream), - ctx(in_ctx), - entry_state(gpu_entry_state::initial), - queue_descr(init_queue_descr(parent_communicator->get_device())), - list_descr(init_list_descr(parent_communicator->get_device())), - dev_queue(init_default_queue(parent_communicator->get_device())), - dev_cmd_list(init_default_cmd_list()) { - // TODO: remove once all the child entries are refactored to not - // use fence field directly - fence = get_fence(); - } - - kernel_main_typed &get_local_kernel() noexcept { - return parent_communicator - ->template get_gpu_kernel(params); - } - - virtual ~base_gpu_entry() {} - - virtual void start() override { - { - //TODO make check, that device_stream belong to the device - auto &cmd_queue = get_dev_queue(); - - auto fence = parent_communicator->get_fence(cmd_queue, get_ctx()); - - ENTRY_LOG_DEBUG("start base entry initialization, ctx: ", - ctx.get(), - ", queue: ", - cmd_queue.get(), - ", fence: ", - fence.get()); - } - - //set kernel args for main kernel on current device - kernel_main_typed &main_entry_function = - parent_communicator->template register_entry(*this); - - auto send_buf_ptr = send_buf.get_ptr(); - - //bind data - main_entry_function.template set_args( - send_buf_ptr); - - status = ccl_sched_entry_status_started; - ENTRY_LOG_DEBUG("started"); - } - - bool submit_for_execution() { - ready_to_exec = finalize_entry(); - ENTRY_LOG_TRACE("submission result: ", ready_to_exec); - return ready_to_exec; - } - - virtual void update() override { - if (!ready_to_exec) { - // TODO: what if submit_for_execution() return false? - submit_for_execution(); - } - else { - //wait execution - auto &cmd_queue = get_dev_queue(); - - ENTRY_LOG_TRACE(" waiting for finished execution, queue: ", cmd_queue.get()); - - ze_result_t ret; - - // Quering fence doesn't sync kernel output with the host, so if we need this - // we use QuerySyncronize API. - if (ccl::global_data::env().comm_kernels_debug == 0) { - ret = get_fence_impl().query_status(); - } - else { - ret = zeCommandQueueSynchronize(cmd_queue.get(), 0); - } - - ENTRY_LOG_TRACE( - "Fence query status: ", native::to_string(ret), ", queue: ", cmd_queue.get()); - if (ret == ZE_RESULT_SUCCESS) { - this->set_state(gpu_entry_state::completed); - - // Once all the ranks in the group got the notification, reset the state for further launches - reset_state(); - - status = ccl_sched_entry_status_complete; - ENTRY_LOG_DEBUG(" Completed on queue: ", cmd_queue.get()); - } - else if (ret == ZE_RESULT_NOT_READY) { - // Just return in case if the kernel is not ready yet, will check again on the next iteration - return; - } - } - } - - virtual const char *name() const override { - return class_name(); - } - - // getters - ccl_device::device_queue &get_dev_queue() const { - return dev_queue; - } - - ze_fence_handle_t get_fence() { - return get_fence_impl().get(); - } - - ze_command_queue_desc_t &get_queue_descr() { - return queue_descr; - } - - //USE GPU cache binding - virtual std::vector get_ipc_data() = 0; - virtual observer::invoke_params get_numa_data() { - //TODO make pure-virtual - ENTRY_LOG_ERROR("NOT implemented for that collective type"); - abort(); - } - - virtual observer::invoke_params get_scaleout_data() { - //TODO make pure-virtual - ENTRY_LOG_ERROR("NOT implemented for that collective type"); - abort(); - } - - virtual native::ipc_session_key get_ipc_session_key() const { - return native::ipc_session_key{ this }; - } - - virtual native::observer::session_key get_numa_session_key() const { - return native::observer::session_key{ this }; - } - - virtual native::observer::session_key get_scaleout_session_key() const { - return native::observer::session_key{ this }; - } - - const coll_param_gpu &get_params() const { - return params; - } - -protected: - size_t get_work_group_size(size_t buffer_size, ccl_device &device) { - size_t group_size; - size_t val_vector_size; - auto dtype = params.get_datatype(); - - if (ccl::global_data::env().gpu_thread_count != CCL_ENV_SIZET_NOT_SPECIFIED) { - group_size = ccl::global_data::env().gpu_thread_count; - - ENTRY_LOG_DEBUG( - "Set group size for x dimension by CCL_GPU_THREAD_COUNT=", group_size, " by user"); - } - else { - if (dtype == ccl::datatype::bfloat16) { - val_vector_size = 1; - } - else { - // For comm kernels, we have float4 {float x, float y, float z, float w}; - // data type, that's why we set a divider for group_size, wchich equals to 4. - // The vecsize of 4 goes with all data types except bfloat16 - val_vector_size = 4; - } - - group_size = buffer_size / val_vector_size; - - ENTRY_LOG_DEBUG("Set group size for x dimension: ", group_size); - } - if (group_size > device.get_compute_properties().maxGroupSizeX || group_size == 0) { - group_size = device.get_compute_properties().maxGroupSizeX; - ENTRY_LOG_DEBUG( - "Group size is limited by compute_properties.maxGroupSizeX and should NOT equal to 0, set group_size: ", - group_size, - " by default"); - } - - //TODO: remove 'return 1' and retrun 'group_size', when fix small msg sizes issue - return 1; //group_size; - } - - void get_suggested_group_size(ze_kernel_handle_t &kernel, size_t buffer_size) { - // zeKernelSuggestGroupSize ignores the group size that is set using zeKernelSetGroupSize - uint32_t group_size_x = 1u; - uint32_t group_size_y = 1u; - uint32_t group_size_z = 1u; - ze_result_t result = zeKernelSuggestGroupSize( - kernel, buffer_size, 1u, 1u, &group_size_x, &group_size_y, &group_size_z); - if (result != ZE_RESULT_SUCCESS) { - throw std::runtime_error( - std::string(__FUNCTION__) + - " - zeKernelSuggestGroupSize failed. Result: " + native::to_string(result)); - } - ENTRY_LOG_DEBUG("Suggested kernel group sizes, which is based on buffer_size: ", - buffer_size, - ", are: groupSizeX: ", - group_size_x, - " groupSizeY: ", - group_size_y, - " groupSizeZ: ", - group_size_z); - } - - void set_group_size(ze_kernel_handle_t &kernel, size_t buffer_size) { - // setting the group size to control resource consumption - // assuming that group_size_x can be adjusted by changing the value or CCL_GPU_THREAD_COUNT knob - // group_size_y / group_size_z shouldn't be > 1 - uint32_t group_size_x = get_work_group_size(buffer_size, parent_communicator->get_device()); - uint32_t group_size_y = 1u; - uint32_t group_size_z = 1u; - - ze_result_t result = zeKernelSetGroupSize(kernel, group_size_x, group_size_y, group_size_z); - if (result != ZE_RESULT_SUCCESS) { - throw std::runtime_error( - std::string(__FUNCTION__) + - " - zeKernelSetGroupSize failed. Result: " + native::to_string(result) + - " groupSizeX: " + std::to_string(static_cast(group_size_x)) + - " groupSizeY: " + std::to_string(static_cast(group_size_y)) + - " groupSizeZ: " + std::to_string(static_cast(group_size_z))); - } - - ENTRY_LOG_DEBUG("Set kernel group size successfully: groupSizeX: ", - group_size_x, - " groupSizeY: ", - group_size_y, - " groupSizeZ: ", - group_size_z); - } - - bool finalize_entry() { - kernel_main_typed &main_entry_function = get_local_kernel(); - - if (this->get_state() == gpu_entry_state::wait_for_entry) { - if (!(*kernel_router)(main_entry_function)) { - // Parameters are not ready yet, will try again later - return false; - } - } - - ENTRY_LOG_TRACE("Try to finalize"); - - auto &&cmd_list = get_dev_cmd_list(); - - // setting group size - set_group_size(main_entry_function.handle, send_buf.get_size()); - - // get suggested group size for info usage only - get_suggested_group_size(main_entry_function.handle, send_buf.get_size()); - - cmd_list.append_kernel(main_entry_function.handle, &launch_args); - - ENTRY_LOG_DEBUG("Append kernel successfully: ", - main_entry_function.to_string(), - " in list: ", - cmd_list.get()); - - assert(this->get_state() != gpu_entry_state::wait_for_completion); - - if (get_topology() == ccl::group_split_type::cluster) { - // TODO: in case of (vitual device + IPC) we can get the data race here - // How we can detect such case? - // In the case when we use one GPU queue per process, everything should be ok - // throw ccl::exception(std::string(__PRETTY_FUNCTION__) + - // "TODO: implement process communicator case"); - cmd_list.close_and_execute(get_ctx(), this->get_fence()); - } - else { - // TODO: how to ensure that fence update is thread safe? - cmd_list.close_and_execute(get_ctx(), this->get_fence()); - } - - ENTRY_LOG_INFO("List closed:", cmd_list.get(), ", go to submit entry"); - this->set_state(gpu_entry_state::wait_for_completion); - return true; - } - - virtual void dump_detail(std::stringstream &str) const override { - ccl_logger::format(str, "{", name(), ", addr: ", comm_addr.to_string(), "}"); - } - - void reset_state() { - // Reset the state of the used handles - get_fence_impl().reset(); - get_dev_cmd_list().reset(); - } - -protected: - ccl_driver_context_ptr get_ctx() const { - return ctx; - } - - ze_command_list_desc_t get_list_descr() const { - return list_descr; - } - - template - ze_device_mem_alloc_desc_t get_mem_descr(options opt) { - ze_device_mem_alloc_desc_t mem_descr = ccl_device::get_default_mem_alloc_desc(); - // Explicitly reset flags to avoid potential conflicts with the default value - mem_descr.flags = 0; - mem_descr.flags |= (opt.is_uncached() ? ZE_DEVICE_MEM_ALLOC_FLAG_BIAS_UNCACHED - : ZE_DEVICE_MEM_ALLOC_FLAG_BIAS_CACHED); - - return mem_descr; - } - - // Wrapper to handle memory allocation with different options - template ::type> - device_memory alloc_memory_wrap(const kernel_arg &arg, - std::shared_ptr parent_communicator, - size_t cnt, - std::shared_ptr ctx) { - auto mem_descr = get_mem_descr(typename kernel_arg::options_t{}); - auto memory = parent_communicator->get_device().template alloc_memory( - cnt, sizeof(arg_type), ctx, mem_descr); - - LOG_DEBUG("Allocation memory by default: ", - kernel_arg::index, - ", ctx: ", - (void *)ctx.get(), - ", memory: ", - (void *)memory.get(), - ", mem_descr: ", - native::to_string(mem_descr)); - - return memory; - } - - std::shared_ptr parent_communicator; - topology_addr comm_addr; - ccl_buffer send_buf; - coll_param_gpu params; - - // TODO: we don't need dtype anymore? - // ccl::datatype dtype; - atl_req_t req{}; - std::shared_ptr device_stream; - // GPU - bool ready_to_exec = false; - ze_fence_handle_t fence; - - auto get_fence_impl() -> decltype(parent_communicator->get_fence(get_dev_queue(), get_ctx())) { - return parent_communicator->get_fence(get_dev_queue(), get_ctx()); - } - - auto get_dev_cmd_list() - -> decltype(parent_communicator->get_cmd_list(get_ctx(), get_list_descr())) { - return dev_cmd_list; - } - - void set_state(gpu_entry_state new_state) noexcept { - ENTRY_LOG_DEBUG( - "switching entry state from ", to_string(entry_state), " to ", to_string(new_state)); - entry_state = new_state; - } - - gpu_entry_state get_state() const noexcept { - return entry_state; - } - - //TODO - ze_group_count_t launch_args = { 1, 1, 1 }; - - template - static std::unique_ptr> - create_kernel_router_for_rank(executor &exec, - int next_rank, - specific_indexed_device_storage &group_devices, - const coll_param_gpu ¶ms) { - std::unique_ptr> kernel_router; - while (!kernel_router) { - //Gather data from in-process GPU - using right_gpu_type = ccl_gpu_comm; - auto &map_devices = std::get(group_devices); - auto it = map_devices.find(next_rank); - if (it == map_devices.end()) { - break; // not ready yet! - } - - std::shared_ptr gpu = it->second; - using right_kernel_main_type = typename right_gpu_type:: - template gpu_kernel_t; - - right_kernel_main_type &right_main_func = - gpu->get_gpu_kernel(params); - - //communicate with real device - kernel_router.reset( - new kernel_connector( - exec, right_main_func)); - } - - while (!kernel_router) { - //Virtual GPU - using right_gpu_type = ccl_virtual_gpu_comm; - auto &map_devices = std::get(group_devices); - auto it = map_devices.find(next_rank); - if (it == map_devices.end()) { - break; // not ready yet! - } - - std::shared_ptr gpu = it->second; - using right_kernel_main_type = typename right_gpu_type:: - template gpu_kernel_t; - - right_kernel_main_type &right_main_func = - gpu->get_gpu_kernel(params); - kernel_router.reset( - new kernel_connector( - exec, right_main_func)); - } - - while (!kernel_router) { - //concurrent GPU - using right_gpu_type = ccl_thread_comm; - native::indexed_device_container &map_devices = - std::get(group_devices); - auto it = map_devices.find(next_rank); - if (it == map_devices.end()) { - break; // not ready yet! - } - std::shared_ptr gpu = it->second; - using right_kernel_main_type = typename right_gpu_type:: - template gpu_kernel_t; - /*std::shared_ptr gpu = map_devices.find(next_rank); - if(gpu == nullptr) - { - break; // not ready yet! - }*/ - right_kernel_main_type &right_main_func = - gpu->get_gpu_kernel(params); - - //communicate with real device from another thread - kernel_router.reset( - new kernel_connector( - exec, right_main_func)); - } - - while (!kernel_router) { - //concurrent GPU - using right_gpu_type = ccl_thread_comm; - native::indexed_device_container &map_devices = - std::get(group_devices); - auto it = map_devices.find(next_rank); - if (it == map_devices.end()) { - break; // not ready yet! - } - std::shared_ptr gpu = it->second; - using right_kernel_main_type = typename right_gpu_type:: - template gpu_kernel_t; - /* - std::shared_ptr gpu = map_devices.find(next_rank); - if(gpu == nullptr) - { - break; // not ready yet! - }*/ - right_kernel_main_type &right_main_func = - gpu->get_gpu_kernel(params); - - //communicate with virtual device from another thread - kernel_router.reset( - new kernel_connector( - exec, right_main_func)); - } - - while (!kernel_router) { - //ipc-source GPU REAL - using right_gpu_type = ccl_ipc_source_gpu_comm; - native::indexed_device_container &map_devices = - std::get(group_devices); - auto it = map_devices.find(next_rank); - if (it == map_devices.end()) { - break; // not ready yet! - } - std::shared_ptr gpu = it->second; - using right_kernel_main_type = typename right_gpu_type:: - template gpu_kernel_t; - right_kernel_main_type &right_main_func = - gpu->get_gpu_kernel(params); - - //communicate with real device from another thread - kernel_router.reset( - new kernel_connector( - exec, right_main_func)); - } - - while (!kernel_router) { - //ipc-source GPU VIRTUAL - using right_gpu_type = ccl_ipc_source_gpu_comm; - native::indexed_device_container &map_devices = - std::get(group_devices); - auto it = map_devices.find(next_rank); - if (it == map_devices.end()) { - break; // not ready yet! - } - - std::shared_ptr gpu = it->second; - using right_kernel_main_type = typename right_gpu_type:: - template gpu_kernel_t; - right_kernel_main_type &right_main_func = - gpu->get_gpu_kernel(params); - - //communicate with virtual device from another thread - kernel_router.reset( - new kernel_connector( - exec, right_main_func)); - } - - while (!kernel_router) { - //ipc-source GPU VIRTUAL - using right_gpu_type = ccl_ipc_gpu_comm; - native::indexed_device_container &map_devices = - std::get(group_devices); - auto it = map_devices.find(next_rank); - if (it == map_devices.end()) { - break; // not ready yet! - } - std::shared_ptr gpu = it->second; - using right_kernel_main_type = typename right_gpu_type:: - template gpu_kernel_t; - right_kernel_main_type &right_main_func = - gpu->get_gpu_kernel(params); - - //communicate with virtual device from another thread - kernel_router.reset( - new kernel_connector( - exec, right_main_func)); - } - - //sanity - if (!kernel_router) { - LOG_ERROR("Cannot bind communicators in group for next rank: ", next_rank); - } - return kernel_router; - } - - std::unique_ptr> kernel_router; - -private: - ccl_driver_context_ptr ctx; - // Internal gpu entry state to keep track of kernel status, it's not directly related to status field - gpu_entry_state entry_state; - ze_command_queue_desc_t queue_descr; - ze_command_list_desc_t list_descr; - ccl_device::device_queue &dev_queue; - decltype(parent_communicator->get_cmd_list(ctx, list_descr)) dev_cmd_list; - - // initialize - ze_command_queue_desc_t init_queue_descr(ccl_device &device) { - native::ccl_device::queue_group_properties queue_props = device.get_queue_group_prop(); - - queue_descr = device.get_default_queue_desc(); - - // find compute ordinal - uint32_t computeOrdinal = std::numeric_limits::max(); - for (uint32_t i = 0; i < queue_props.size(); i++) { - // Prefer CCS - if (queue_props[i].flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE && - queue_props[i].numQueues > 1) { - queue_descr.ordinal = i; - break; - } - } - // if CCS not found, look for RCS/CCS - if (computeOrdinal == std::numeric_limits::max()) { - for (uint32_t i = 0; i < queue_props.size(); i++) { - if (queue_props[i].flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE) { - queue_descr.ordinal = i; - break; - } - } - } - - //calculate rank (remember it is a local rank) - queue_descr.index = comm_addr.rank % queue_props[queue_descr.ordinal].numQueues; - ENTRY_LOG_DEBUG("Rank to calculate for queue idx:", - comm_addr.rank, - ", queue : ", - to_string(queue_descr)); - return queue_descr; - } - - ze_command_list_desc_t init_list_descr(ccl_device &device) { - list_descr = parent_communicator->get_device().get_default_list_desc(); - return list_descr; - } - - ccl_device::device_queue &init_default_queue(ccl_device &device) { - return device.get_cmd_queue(queue_descr, ctx); - } - - auto init_default_cmd_list() -> decltype(parent_communicator->get_cmd_list(ctx, list_descr)) { - list_descr.commandQueueGroupOrdinal = queue_descr.ordinal; - ENTRY_LOG_DEBUG("cmd_list: ", to_string(list_descr)); - return parent_communicator->get_cmd_list(ctx, list_descr); - } -}; - -} // namespace native diff --git a/src/sched/entry/l0/l0_reduce_scatter_typed_entry.hpp b/src/sched/entry/l0/l0_reduce_scatter_typed_entry.hpp deleted file mode 100644 index 2e5f1b616..000000000 --- a/src/sched/entry/l0/l0_reduce_scatter_typed_entry.hpp +++ /dev/null @@ -1,260 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include -#include - -#include "sched/entry/l0/l0_entry.hpp" - -#include "kernels/shared.h" - -//TODO L0 Workaround - -namespace native { -template -class l0_reduce_scatter_typed_entry : public base_gpu_entry { -public: - friend class ccl_gpu_comm; - friend class ccl_virtual_gpu_comm; - - using base = base_gpu_entry; - using base::parent_communicator; - using base::comm_addr; - using base::req; - using base::status; - using base::launch_args; - using base::kernel_router; - using base::get_ctx; - using base::get_local_kernel; - using kernel_main_typed = ring::reduce_scatter::main_kernel; - - using income_data_flag_gpu_type = typename std::remove_pointer< - typename ring::reduce_scatter::income_data_flag_arg_type>::type; - using ready_to_recv_flag_gpu_type = typename std::remove_pointer< - typename ring::reduce_scatter::ready_to_recv_flag_arg_type>::type; - using local_barrier_flag_gpu_type = typename std::remove_pointer< - typename ring::reduce_scatter::local_barrier_flag_arg_type>::type; - - static constexpr const char* class_name() noexcept { - return "L0_REDUCE_SCATTER_TYPED"; - } - - static constexpr ccl_coll_type type() noexcept { - return ccl_coll_reduce_scatter; - } - - l0_reduce_scatter_typed_entry() = delete; - l0_reduce_scatter_typed_entry( - ccl_sched* sched, - std::shared_ptr comm, - specific_indexed_device_storage& available_devices, - ccl_driver_context_ptr in_ctx, - const ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t cnt, - const coll_param_gpu& params, - std::shared_ptr device_stream = std::shared_ptr()) - : base(sched, comm, in_ctx, send_buf, params, device_stream), - - temp_buffer(this->template alloc_memory_wrap( - typename ring::reduce_scatter::tmp_recv_buf_arg{}, - parent_communicator, - ring_reduce_scatter_tmp_buffer_size(cnt, base::comm_addr.size) * - ccl::get_datatype_size(params.get_datatype()), - get_ctx())), - income_data_flag(this->template alloc_memory_wrap( - typename ring::reduce_scatter::income_data_flag_arg{}, - parent_communicator, - 1, - get_ctx())), - ready_to_recv_flag(this->template alloc_memory_wrap( - typename ring::reduce_scatter::ready_to_recv_flag_arg{}, - parent_communicator, - 1, - get_ctx())), - local_barrier_flag(parent_communicator->get_device() - .template alloc_memory( - 1, - sizeof(local_barrier_flag_gpu_type), - get_ctx())) { - recv_buf_typed_entry = recv_buf; - cnt_entry = cnt; - - int next_rank = (comm_addr.rank + 1) % comm_addr.size; - kernel_router = base::template create_kernel_router_for_rank< - l0_reduce_scatter_typed_entry>( - *this, next_rank, available_devices, base::get_params()); - - ENTRY_LOG_DEBUG("Init phase of current entry for ext_rank:", next_rank); - - // Once we filled our local parameters, we go wait for another entry to set its - // parameters so we can use them - this->set_state(gpu_entry_state::created); - } - - ~l0_reduce_scatter_typed_entry() { - // TODO: remove the memory once the entry is destroyed if it's not cleared automatically - // TODO: should we destroy handles here? - } - - void start() override { - ENTRY_LOG_DEBUG("Start entry, cnt ", cnt_entry); - - //Create base primitives - base::start(); - - auto& main_entry_function = get_local_kernel(); - - auto recv_buf_ptr = reinterpret_cast(recv_buf_typed_entry.get_ptr()); - - //create implementation specified primitives - main_entry_function - .template set_args, - typename ring::reduce_scatter::income_data_flag_arg, - typename ring::reduce_scatter::ready_to_recv_flag_arg, - typename ring::reduce_scatter::local_barrier_flag_arg, - typename ring::reduce_scatter::recv_buf_arg, - typename kernel_main_typed::common_entry_buf_size_arg>( - temp_buffer.get(), - income_data_flag.get(), - ready_to_recv_flag.get(), - local_barrier_flag.get(), - recv_buf_ptr, - cnt_entry); - - // Once we filled our local parameters, we go wait for another entry to set its - // parameters so we can use them - this->set_state(gpu_entry_state::wait_for_entry); - - //make sure, that kernel ready for launch - this->submit_for_execution(); - status = ccl_sched_entry_status_started; - } - - const char* name() const override { - return class_name(); - } - - std::vector get_ipc_data() override { - ccl_device& owned_device = parent_communicator->get_device(); - - auto recv_buf_ptr = reinterpret_cast(recv_buf_typed_entry.get_ptr()); - - std::vector ret; - ret.reserve(4); - ret.push_back(owned_device.create_ipc_memory_handle(recv_buf_ptr, get_ctx())); - ret.push_back(owned_device.create_ipc_memory_handle(temp_buffer.get(), get_ctx())); - ret.push_back(owned_device.create_ipc_memory_handle(income_data_flag.get(), get_ctx())); - ret.push_back(owned_device.create_ipc_memory_handle(ready_to_recv_flag.get(), get_ctx())); - return ret; - } - -protected: - void dump_detail(std::stringstream& str) const override { - base::dump_detail(str); - } - -private: - ccl_device::device_memory<> temp_buffer; - ccl_device::device_memory income_data_flag; - ccl_device::device_memory ready_to_recv_flag; - ccl_device::device_memory local_barrier_flag; - ccl_buffer recv_buf_typed_entry; - size_t cnt_entry; - std::shared_ptr ctx; - -public: - template - bool execute(left_kernel_t& left_kernel, right_kernel_t& right_kernel) { - bool is_right_kernel_ready = - right_kernel - .template test_args, - typename ring::reduce_scatter::tmp_recv_buf_arg, - typename ring::reduce_scatter::income_data_flag_arg, - typename ring::reduce_scatter::ready_to_recv_flag_arg>(); - - // Once we're sure that the parameters ready read them from the right kernel - // Note: we not only read the parameters but also reset their 'ready' flag - // (since we're using a destructive-copying policy) meaning that they must be stored - // in order to be read again. - // This is a protection to a case of multiple kernel launches - // (i.e. the collective is ran multiple times) where we might read not up-to-date - // values from the previous run. - - if (is_right_kernel_ready) { - auto right_output_buf_arg = - right_kernel.template get_arg>(); - auto right_tmp_recv_buf_arg = - right_kernel - .template get_arg>(); - auto right_income_data_flag_arg = - right_kernel - .template get_arg(); - auto right_ready_to_recv_flag_arg = - right_kernel - .template get_arg(); - - // ENTRY_LOG_DEBUG("Bind right arguments from ", - // right_kernel_t::name(), - // " kernel", - // " to ", - // left_kernel_t::name(), - // " kernel. " - // "Right arguments:\n{ ", - // right_output_buf_arg.first, - // ", ", - // right_output_buf_arg.second, - // "}\n", - // "{ ", - // right_tmp_recv_buf_arg.first, - // ", ", - // right_tmp_recv_buf_arg.second, - // "}\n", - // "{ ", - // right_income_data_flag_arg.first, - // ", ", - // right_income_data_flag_arg.second, - // "}\n", - // "{ ", - // right_ready_to_recv_flag_arg.first, - // ", ", - // right_ready_to_recv_flag_arg.second, - // "}\n"); - - left_kernel - .template set_args, - typename ring::reduce_scatter::right_tmp_recv_buf_arg, - typename ring::reduce_scatter::right_income_data_flag_arg, - typename ring::reduce_scatter::right_ready_to_recv_flag_arg>( - right_output_buf_arg.second, - right_tmp_recv_buf_arg.second, - right_income_data_flag_arg.second, - right_ready_to_recv_flag_arg.second); - - ENTRY_LOG_DEBUG("Binding arguments between kernels is complete. ", - "Arguments of the left kernel after binding:\n", - left_kernel.to_string()); - } - return is_right_kernel_ready; - } -}; -} // namespace native diff --git a/src/sched/entry/l0/l0_reduce_typed_entry.hpp b/src/sched/entry/l0/l0_reduce_typed_entry.hpp deleted file mode 100644 index 72ea07031..000000000 --- a/src/sched/entry/l0/l0_reduce_typed_entry.hpp +++ /dev/null @@ -1,246 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once -#include -#include - -#include "sched/entry/l0/l0_entry.hpp" - -//TODO L0 Workaround - -namespace native { -template -class l0_reduce_typed_entry : public base_gpu_entry { -public: - friend class ccl_gpu_comm; - friend class ccl_virtual_gpu_comm; - - using base = - base_gpu_entry; - using base::parent_communicator; - using base::comm_addr; - using base::req; - using base::status; - using base::launch_args; - using base::kernel_router; - using base::get_ctx; - using base::get_local_kernel; - using kernel_main_typed = ring::reduce::main_kernel; - // TODO: fix type - using processing_type = uint8_t; - - using income_data_flag_gpu_type = - typename std::remove_pointer::type; - using ready_to_recv_flag_gpu_type = - typename std::remove_pointer::type; - using local_barrier_flag_gpu_type = - typename std::remove_pointer::type; - - static constexpr const char* class_name() noexcept { - return "L0_REDUCE_TYPED"; - } - - static constexpr ccl_coll_type type() noexcept { - return ccl_coll_reduce; - } - - l0_reduce_typed_entry() = delete; - l0_reduce_typed_entry(ccl_sched* sched, - std::shared_ptr comm, - specific_indexed_device_storage& available_devices, - ccl_driver_context_ptr in_ctx, - const ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t cnt, - ccl::reduction op, - int root, - const coll_param_gpu& params, - std::shared_ptr device_stream = std::shared_ptr()) - : base(sched, comm, in_ctx, send_buf, params, device_stream), - - temp_buffer(this->template alloc_memory_wrap( - typename ring::reduce::tmp_recv_buf_arg{}, - parent_communicator, - ring_reduce_tmp_buffer_size(cnt, comm_addr.size) * - ccl::get_datatype_size(params.get_datatype()), - get_ctx())), - income_data_flag( - this->template alloc_memory_wrap(typename ring::reduce::income_data_flag_arg{}, - parent_communicator, - 1, - get_ctx())), - ready_to_recv_flag( - this->template alloc_memory_wrap(typename ring::reduce::ready_to_recv_flag_arg{}, - parent_communicator, - 1, - get_ctx())), - local_barrier_flag(parent_communicator->get_device() - .template alloc_memory( - 1, - sizeof(local_barrier_flag_gpu_type), - get_ctx())) { - recv_buf_typed_entry = recv_buf; - root_typed_entry = root; - cnt_entry = cnt; - - int next_rank = (comm_addr.rank + 1) % comm_addr.size; - kernel_router = base::template create_kernel_router_for_rank< - l0_reduce_typed_entry>( - *this, next_rank, available_devices, base::get_params()); - - ENTRY_LOG_DEBUG("Init phase of current entry for ext_rank:", next_rank); - - // Once we filled our local parameters, we go wait for another entry to set its - // parameters so we can use them - this->set_state(gpu_entry_state::created); - } - - ~l0_reduce_typed_entry() { - // TODO: remove the memory once the entry is destroyed if it's not cleared automatically - // TODO: should we destroy handles here? - } - - void start() override { - ENTRY_LOG_DEBUG("Start entry, cnt ", cnt_entry); - - //Create base primitives - base::start(); - - auto& main_entry_function = get_local_kernel(); - - auto recv_buf_ptr = reinterpret_cast(recv_buf_typed_entry.get_ptr()); - //create implementation specified primitives - main_entry_function - .template set_args, - typename ring::reduce::income_data_flag_arg, - typename ring::reduce::ready_to_recv_flag_arg, - typename ring::reduce::local_barrier_flag_arg, - typename ring::reduce::recv_buf_arg, - typename ring::reduce::root_arg, - typename kernel_main_typed::common_entry_buf_size_arg>( - temp_buffer.get(), - income_data_flag.get(), - ready_to_recv_flag.get(), - local_barrier_flag.get(), - recv_buf_ptr, - root_typed_entry, - cnt_entry); - - // Once we filled our local parameters, we go wait for another entry to set its - // parameters so we can use them - this->set_state(gpu_entry_state::wait_for_entry); - - //make sure, that kernel ready for launch - // TODO: what if submit_for_execution() return false? - this->submit_for_execution(); - status = ccl_sched_entry_status_started; - } - - const char* name() const override { - return class_name(); - } - - std::vector get_ipc_data() override { - ccl_device& owned_device = parent_communicator->get_device(); - - std::vector ret; - ret.reserve(3); - ret.push_back(owned_device.create_ipc_memory_handle(temp_buffer.get(), get_ctx())); - ret.push_back(owned_device.create_ipc_memory_handle(income_data_flag.get(), get_ctx())); - ret.push_back(owned_device.create_ipc_memory_handle(ready_to_recv_flag.get(), get_ctx())); - return ret; - } - -protected: - void dump_detail(std::stringstream& str) const override { - base::dump_detail(str); - } - -private: - ccl_device::device_memory<> temp_buffer; - ccl_device::device_memory income_data_flag; - ccl_device::device_memory ready_to_recv_flag; - ccl_device::device_memory local_barrier_flag; - ccl_buffer recv_buf_typed_entry; - int root_typed_entry; - size_t cnt_entry; - std::shared_ptr ctx; - -public: - template - bool execute(left_kernel_t& left_kernel, right_kernel_t& right_kernel) { - bool is_right_kernel_ready = - right_kernel - .template test_args, - typename ring::reduce::income_data_flag_arg, - typename ring::reduce::ready_to_recv_flag_arg>(); - - // Once we're sure that the parameters ready read them from the right kernel - // Note: we not only read the parameters but also reset their 'ready' flag - // (since we're using a destructive-copying policy) meaning that they must be stored - // in order to be read again. - // This is a protection to a case of multiple kernel launches - // (i.e. the collective is ran multiple times) where we might read not up-to-date - // values from the previous run. - - if (is_right_kernel_ready) { - auto right_tmp_recv_buf_arg = - right_kernel.template get_arg>(); - auto right_income_data_flag_arg = - right_kernel.template get_arg(); - auto right_ready_to_recv_flag_arg = - right_kernel.template get_arg(); - - // ENTRY_LOG_DEBUG("Bind right arguments from ", - // right_kernel_t::name(), - // " kernel", - // " to ", - // left_kernel_t::name(), - // " kernel. " - // "Right arguments:\n{ ", - // right_tmp_recv_buf_arg.first, - // ", ", - // right_tmp_recv_buf_arg.second, - // "}\n", - // "{ ", - // right_income_data_flag_arg.first, - // ", ", - // right_income_data_flag_arg.second, - // "}\n", - // "{ ", - // right_ready_to_recv_flag_arg.first, - // ", ", - // right_ready_to_recv_flag_arg.second, - // "}\n"); - - left_kernel.template set_args, - typename ring::reduce::right_income_data_flag_arg, - typename ring::reduce::right_ready_to_recv_flag_arg>( - right_tmp_recv_buf_arg.second, - right_income_data_flag_arg.second, - right_ready_to_recv_flag_arg.second); - - ENTRY_LOG_DEBUG("Binding arguments between kernels is complete. ", - "Arguments of the left kernel after binding:\n", - left_kernel.to_string()); - } - return is_right_kernel_ready; - } -}; -} // namespace native diff --git a/src/sched/entry/probe_entry.hpp b/src/sched/entry/probe_entry.hpp index 3cbc402a1..972e680d4 100644 --- a/src/sched/entry/probe_entry.hpp +++ b/src/sched/entry/probe_entry.hpp @@ -35,7 +35,7 @@ class probe_entry : public sched_entry { void start() override { int global_src = comm->get_global_rank(src); atl_tag = comm->atl->tag->create( - sched->get_comm_id(), global_src, sched->sched_id, sched->get_op_id()); + global_src, sched->get_comm_id(), sched->sched_id, sched->get_op_id()); LOG_DEBUG("PROBE entry src ", src, ", tag ", atl_tag); status = ccl_sched_entry_status_started; } diff --git a/src/sched/entry/recv_entry.hpp b/src/sched/entry/recv_entry.hpp index a267dd3e4..4172f5e96 100644 --- a/src/sched/entry/recv_entry.hpp +++ b/src/sched/entry/recv_entry.hpp @@ -55,7 +55,7 @@ class recv_entry : public sched_entry, int global_src = comm->get_global_rank(src); atl_tag = comm->atl->tag->create( - sched->get_comm_id(), global_src, sched->sched_id, sched->get_op_id()); + global_src, sched->get_comm_id(), sched->sched_id, sched->get_op_id()); size_t bytes = cnt * dtype.size(); LOG_DEBUG( diff --git a/src/sched/entry/recv_reduce_entry.hpp b/src/sched/entry/recv_reduce_entry.hpp index 9c36c633b..0239d9d32 100644 --- a/src/sched/entry/recv_reduce_entry.hpp +++ b/src/sched/entry/recv_reduce_entry.hpp @@ -83,7 +83,7 @@ class recv_reduce_entry final : public sched_entry { void start() override { int global_src = comm->get_global_rank(src); atl_tag = comm->atl->tag->create( - sched->get_comm_id(), global_src, sched->sched_id, sched->get_op_id()); + global_src, sched->get_comm_id(), sched->sched_id, sched->get_op_id()); size_t bytes = in_cnt * dtype.size(); LOG_DEBUG("starting RECV in RECV_REDUCE entry, src ", global_src, diff --git a/src/sched/entry/reduce_local_entry.cpp b/src/sched/entry/reduce_local_entry.cpp new file mode 100644 index 000000000..b5f52d77a --- /dev/null +++ b/src/sched/entry/reduce_local_entry.cpp @@ -0,0 +1,134 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "sched/entry/reduce_local_entry.hpp" + +#include "common/comm/l0/modules/kernel_utils.hpp" +#include "common/datatype/datatype.hpp" +#include "common/stream/stream.hpp" +#include "common/utils/sycl_utils.hpp" +#include "sched/entry/gpu/ze_primitives.hpp" +#include "sched/entry/gpu/ze_cache.hpp" +#include "sched/queue/queue.hpp" + +#include + +using namespace ccl; +using namespace ccl::ze; + +void reduce_local_entry::init() { + if (ze_base_entry::is_initialized) { + return; + } + + LOG_DEBUG("initialization"); + + ze_base_entry::init(init_mode::compute); + + ccl::global_data::get().ze_cache->get(context, device, "kernels.spv", &module); + + kernel_name = + "reduce_local_inplace_kernel_" + to_string(dtype.idx()) + "_" + ccl_reduction_to_str(op); + ccl::global_data::get().ze_cache->get(worker_idx, module, kernel_name, &kernel); + LOG_DEBUG("get kernel: name: ", kernel_name); + + ze_group_size_t group_size; + get_suggested_group_size(kernel, in_cnt, &group_size); + LOG_DEBUG("suggested group size: ", to_string(group_size)); + + get_suggested_group_count(group_size, in_cnt, &group_count); + LOG_DEBUG("suggested group count: ", to_string(group_count)); + + ZE_CALL(zeKernelSetGroupSize, + (kernel, group_size.groupSizeX, group_size.groupSizeY, group_size.groupSizeZ)); + + size_t bytes = in_cnt * dtype.size(); + in_buf_ptr = in_buf.get_ptr(bytes); + inout_buf_ptr = inout_buf.get_ptr(bytes); + ze_kernel_args_t kernel_args = { { sizeof(in_cnt), &in_cnt }, + { sizeof(in_buf_ptr), &in_buf_ptr }, + { sizeof(inout_buf_ptr), &inout_buf_ptr } }; + + LOG_DEBUG("kernel ", kernel, " args:\n", to_string(kernel_args)); + set_kernel_args(kernel, kernel_args); + + ZE_CALL(zeCommandListAppendLaunchKernel, + (ze_base_entry::comp_primitives.list, + kernel, + &group_count, + ze_base_entry::entry_event, + 0, + nullptr)); + ZE_CALL(zeCommandListClose, (ze_base_entry::comp_primitives.list)); + + LOG_DEBUG("initialization complete"); +} + +void reduce_local_entry::update() { + CCL_THROW_IF_NOT(use_device); + + ze_base_entry::update(); + if (status == ccl_sched_entry_status_complete && !sched->coll_attr.to_cache) { + finalize(); + } +} + +void reduce_local_entry::check_use_device() { + use_device = false; + ccl_stream* stream = (ccl_stream*)sched->coll_param.stream; + if (fn || !stream) + return; + + size_t bytes = in_cnt * dtype.size(); + sycl::queue* q = stream->get_native_stream(sched->queue->get_idx()); + CCL_THROW_IF_NOT(q, "null sycl queue"); + auto in_ptr_type = sycl::get_pointer_type(in_buf.get_ptr(bytes), q->get_context()); + auto inout_ptr_type = sycl::get_pointer_type(inout_buf.get_ptr(bytes), q->get_context()); + + LOG_DEBUG("in_ptr_type: ", + ccl::utils::usm_type_to_str(in_ptr_type), + ", inout_ptr_type: ", + ccl::utils::usm_type_to_str(inout_ptr_type), + ", native_stream: ", + stream->to_string(), + ", in_count: ", + in_cnt) + + if ((in_ptr_type == sycl::usm::alloc::device) && (inout_ptr_type == sycl::usm::alloc::device)) { + use_device = true; + } +} + +void reduce_local_entry::start_on_device() { + init(); + + ze_base_entry::start(); + status = ccl_sched_entry_status_started; +} + +void reduce_local_entry::finalize() { + if (!ze_base_entry::is_initialized) { + return; + } + + LOG_DEBUG("finalization"); + + // kernel cache + ccl::global_data::get().ze_cache->push(worker_idx, module, kernel_name, kernel); + + ze_base_entry::finalize(); + + LOG_DEBUG("finalization complete"); +} diff --git a/src/sched/entry/reduce_local_entry.hpp b/src/sched/entry/reduce_local_entry.hpp index 0a7b58a74..2a6686296 100644 --- a/src/sched/entry/reduce_local_entry.hpp +++ b/src/sched/entry/reduce_local_entry.hpp @@ -14,11 +14,19 @@ limitations under the License. */ #pragma once - +#include "common/global/global.hpp" #include "comp/comp.hpp" #include "sched/entry/entry.hpp" +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) +#include "sched/entry/gpu/ze_base_entry.hpp" +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT + +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) +class reduce_local_entry : public ze_base_entry { +#else class reduce_local_entry : public sched_entry { +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT public: static constexpr const char* class_name() noexcept { return "REDUCE_LOCAL"; @@ -32,19 +40,38 @@ class reduce_local_entry : public sched_entry { size_t* out_cnt, const ccl_datatype& dtype, ccl::reduction reduction_op) - : sched_entry(sched), + : +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + ze_base_entry(sched), +#else // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT + sched_entry(sched), +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT in_buf(in_buf), in_cnt(in_cnt), inout_buf(inout_buf), out_cnt(out_cnt), dtype(dtype), op(reduction_op), - fn(sched->coll_attr.reduction_fn) { + fn(sched->coll_attr.reduction_fn), + use_device(false) { CCL_THROW_IF_NOT(op != ccl::reduction::custom || fn, "custom reduction requires user provided callback"); } - void start() override { +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + ~reduce_local_entry() override { + finalize(); + } + void init(); + void finalize(); + void update() override; + void check_use_device(); + void start_on_device(); +#else // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT + void check_use_device() {} + void start_on_device() {} +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT + void start_on_host() { size_t bytes = in_cnt * dtype.size(); size_t offset = inout_buf.get_offset(); const ccl::fn_context context = { sched->coll_attr.match_id.c_str(), offset }; @@ -62,7 +89,19 @@ class reduce_local_entry : public sched_entry { status = ccl_sched_entry_status_complete; } - const char* name() const override { + void start() override { + check_use_device(); + if (use_device) { + LOG_DEBUG("start on device"); + start_on_device(); + } + else { + LOG_DEBUG("start on host"); + start_on_host(); + } + } + + const char* name() const noexcept override { return class_name(); } @@ -94,4 +133,15 @@ class reduce_local_entry : public sched_entry { ccl_datatype dtype; ccl::reduction op; ccl::reduction_fn fn; + void* in_buf_ptr; + void* inout_buf_ptr; + + bool use_device; + +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + ze_module_handle_t module; + ze_kernel_handle_t kernel; + std::string kernel_name; + ze_group_count_t group_count; +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT }; diff --git a/src/sched/entry/send_entry.hpp b/src/sched/entry/send_entry.hpp index 5250a8c4f..5e3ddd6c2 100644 --- a/src/sched/entry/send_entry.hpp +++ b/src/sched/entry/send_entry.hpp @@ -18,6 +18,7 @@ #include "common/global/global.hpp" #include "sched/entry/entry.hpp" #include "sched/queue/queue.hpp" +#include "sched/entry/copy/copy_entry.hpp" class send_entry : public sched_entry, public postponed_fieldsget_global_rank(dst); int global_rank = comm->get_global_rank(comm->rank()); atl_tag = comm->atl->tag->create( - sched->get_comm_id(), global_rank, sched->sched_id, sched->get_op_id()); + global_rank, sched->get_comm_id(), sched->sched_id, sched->get_op_id()); size_t bytes = cnt * dtype.size(); LOG_DEBUG( "SEND entry dst ", global_dst, ", tag ", atl_tag, ", req ", &req, ", bytes ", bytes); atl_status_t atl_status = comm->atl->atl_ep_send( - sched->bin->get_atl_ep(), buf.get_ptr(bytes), bytes, global_dst, atl_tag, &req); + sched->bin->get_atl_ep(), send_buf.get_ptr(bytes), bytes, global_dst, atl_tag, &req); update_status(atl_status); } + void reset(size_t idx) override { + sched_entry::reset(idx); +#ifdef CCL_ENABLE_SYCL + if (proxy_copy_entry) { + proxy_copy_entry->reset(idx); + } +#endif // CCL_ENABLE_SYCL + } + + void start() override { + update_fields(); + + send_buf = buf; + +#ifdef CCL_ENABLE_SYCL + if (sched->coll_param.stream && cnt && + (ccl::global_data::env().atl_send_proxy != ccl_atl_send_proxy_none) && + (proxy_mode == proxy_copy_mode::unknown)) { + sycl::usm::alloc ptr_type = sycl::usm::alloc::unknown; + if (sched->coll_param.stream->get_type() == stream_type::gpu) { + auto sycl_queue = sched->coll_param.stream->get_native_stream(); + ptr_type = sycl::get_pointer_type(buf.get_ptr(), sycl_queue.get_context()); + } + proxy_mode = (ptr_type == sycl::usm::alloc::device) ? proxy_copy_mode::enabled + : proxy_copy_mode::disabled; + } + + if (proxy_mode == proxy_copy_mode::enabled) { + if (!proxy_buf) { + ccl_sched_buf_type buf_type = + (ccl::global_data::env().atl_send_proxy == ccl_atl_send_proxy_regular) + ? ccl_sched_buf_system + : ccl_sched_buf_runtime; + send_buf = proxy_buf = sched->alloc_buffer(cnt * dtype.size(), buf_type); + } + if (!proxy_copy_entry) { + proxy_copy_entry = + std::shared_ptr(new copy_entry(sched, buf, proxy_buf, cnt, dtype)); + } + + proxy_copy_entry->do_progress(); + + if (proxy_copy_entry->get_status() != ccl_sched_entry_status_complete) { + status = ccl_sched_entry_status_again; + return; + } + } +#endif // CCL_ENABLE_SYCL + + start_send(); + } + void update() override { int req_status; + atl_status_t atl_status = comm->atl->atl_ep_check(sched->bin->get_atl_ep(), &req_status, &req); @@ -116,4 +168,13 @@ class send_entry : public sched_entry, ccl_comm* comm; uint64_t atl_tag = 0; atl_req_t req{}; + + ccl_buffer send_buf; + +#ifdef CCL_ENABLE_SYCL + enum class proxy_copy_mode { unknown, enabled, disabled }; + proxy_copy_mode proxy_mode = proxy_copy_mode::unknown; + std::shared_ptr proxy_copy_entry; + ccl_buffer proxy_buf{}; +#endif // CCL_ENABLE_SYCL }; diff --git a/src/sched/entry/sync_entry.hpp b/src/sched/entry/sync_entry.hpp index 180573923..780b16559 100644 --- a/src/sched/entry/sync_entry.hpp +++ b/src/sched/entry/sync_entry.hpp @@ -49,7 +49,7 @@ class sync_entry : public sched_entry { status = ccl_sched_entry_status_complete; } else { - LOG_TRACE("waiting SYNC entry cnt ", counter); + // LOG_TRACE("waiting SYNC entry cnt ", counter); ccl_yield(ccl::global_data::env().yield_type); } } diff --git a/src/sched/extra_sched.cpp b/src/sched/extra_sched.cpp index 7a0d7e673..7ec71bbb5 100644 --- a/src/sched/extra_sched.cpp +++ b/src/sched/extra_sched.cpp @@ -36,16 +36,6 @@ void ccl_extra_sched::dump(std::ostream& out) const { entries[i]->dump(msg, i); } out << msg.str(); -#ifdef ENABLE_TIMERS - ccl_logger::format( - out, - "\nlife time [us] ", - std::setw(5), - std::setbase(10), - std::chrono::duration_cast(exec_complete_time - exec_start_time) - .count(), - "\n"); -#endif ccl_logger::format(out, "--------------------------------\n"); } diff --git a/src/sched/master_sched.cpp b/src/sched/master_sched.cpp index 61afd234f..19d9fac93 100644 --- a/src/sched/master_sched.cpp +++ b/src/sched/master_sched.cpp @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "coll/coll_check.hpp" #include "common/global/global.hpp" #include "common/utils/sync_object.hpp" +#include "common/utils/sycl_utils.hpp" #include "parallelizer/parallelizer.hpp" #include "sched/cache/cache.hpp" #include "sched/cache/key.hpp" @@ -23,13 +25,92 @@ #include "sched/master_sched.hpp" #include "sched/queue/queue.hpp" +#ifdef CCL_ENABLE_SYCL +#include +#include + +#ifdef MULTI_GPU_SUPPORT +#include "sched/entry/gpu/ze_cache.hpp" +#include "sched/entry/gpu/ze_primitives.hpp" +#endif // MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL + +#ifdef CCL_ENABLE_SYCL +constexpr ze_event_pool_desc_t get_event_pool_desc() { + auto desc = ccl::ze::default_event_pool_desc; + + desc.count = 1; + desc.flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE; + + return desc; +} +#endif + +ccl_master_sched::ccl_master_sched(const ccl_coll_param& coll_param) + : ccl_sched_base(coll_param), + ccl_request(), + partial_scheds() { +#ifdef ENABLE_DEBUG + set_dump_callback([this](std::ostream& out) { + dump(out); + }); +#endif + +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + if (ccl::utils::should_use_sycl_output_event(coll_param.stream)) { + auto l0_context = coll_param.stream->get_native_stream() + .get_context() + .template get_native(); + + auto pool_desc = get_event_pool_desc(); + + ccl::global_data::get().ze_cache->get(0, l0_context, pool_desc, &get_memory().sync_pool); + + ze_event_desc_t event_desc = ccl::ze::default_event_desc; + event_desc.signal = ZE_EVENT_SCOPE_FLAG_HOST; + event_desc.wait = ZE_EVENT_SCOPE_FLAG_HOST; + event_desc.index = 0; + + ZE_CALL(zeEventCreate, (get_memory().sync_pool, &event_desc, &get_memory().sync_event)); + LOG_DEBUG("created sync event: ", get_memory().sync_event); + } + else { + LOG_DEBUG("skip sync event creation"); + } +#endif +} + ccl_master_sched::~ccl_master_sched() { for (auto& part_sched : partial_scheds) { part_sched.reset(); } + if (!memory.mr_list.empty()) + LOG_WARN("memory region list should be empty for master sched"); + +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + if (ccl::utils::should_use_sycl_output_event(coll_param.stream)) { + auto l0_context = coll_param.stream->get_native_stream() + .get_context() + .template get_native(); + + // Sycl event might call wait on destruction meaning that it should be valid at that time + // The problem is that the sync event is stored in request, which descrutor is called + // after ccl_master_sched, which means its underlying l0 event will be already destroyed + // by that time. As a workaround, reset the event, essentially calling its destructor before + // destroying the corresponding l0 event + set_sync_event(sycl::event()); + + LOG_DEBUG("destroying sync event: ", get_memory().sync_event); + ZE_CALL(zeEventDestroy, (get_memory().sync_event)); - CCL_ASSERT(memory.mr_list.empty(), "memory list is not empty"); - free_buffers(); + auto pool_desc = get_event_pool_desc(); + + ccl::global_data::get().ze_cache->push(0, l0_context, pool_desc, get_memory().sync_pool); + } + else { + LOG_DEBUG("skip sync event destruction"); + } +#endif } void ccl_master_sched::commit(ccl_parallelizer* parallelizer) { @@ -63,6 +144,19 @@ void ccl_master_sched::commit(ccl_parallelizer* parallelizer) { partial_scheds.size()); } +void ccl_master_sched::reset_state() { + reset_request(); + +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + if (ccl::utils::should_use_sycl_output_event(coll_param.stream)) { + // Reset sycl event while it's in complete state, similar case to destruction in ~ccl_master_sched + set_sync_event(sycl::event()); + LOG_DEBUG("reset sync event: ", get_memory().sync_event); + ZE_CALL(zeEventHostReset, (get_memory().sync_event)); + } +#endif +} + ccl_request* ccl_master_sched::start(ccl_executor* exec, bool reset_sched) { /* sanity check the schedule */ CCL_ASSERT(coll_param.comm); @@ -72,14 +166,35 @@ ccl_request* ccl_master_sched::start(ccl_executor* exec, bool reset_sched) { prepare_partial_scheds(); if (reset_sched) { - reset_request(); + reset_state(); } if (ccl::global_data::env().sched_dump) { std::stringstream ostream; dump(ostream); - LOG_INFO(ostream.str()); + logger.info(ostream.str()); + } + +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + if (ccl::utils::should_use_sycl_output_event(coll_param.stream)) { + LOG_DEBUG("convert L0 event: ", + get_memory().sync_event, + "into a SYCL event and submit a barrier"); + auto q = coll_param.stream->get_native_stream(); + auto context = q.get_context(); +#ifdef CCL_ENABLE_SYCL_INTEROP_EVENT + auto e = sycl::level_zero::make( + context, get_memory().sync_event, sycl::level_zero::ownership::keep); + set_sync_event(e); + + set_native_event(q.submit_barrier({ e })); +#else + CCL_THROW( + "interop event functionality is not available with current configuration, please rebuild oneCCL using ENABLE_SYCL_INTEROP_EVENT option" + "and a DPCPP compiler that supports that feature"); +#endif } +#endif exec->start(this); return this; @@ -90,7 +205,7 @@ ccl_request* ccl_master_sched::reset_request() { return this; } -void ccl_master_sched::add_partial_sched(ccl_coll_param& coll_param) { +void ccl_master_sched::add_partial_sched(const ccl_coll_param& coll_param) { partial_scheds.emplace_back(std::make_shared(coll_param, this)); partial_scheds.back()->internal_type = internal_type; } @@ -160,70 +275,11 @@ void ccl_master_sched::dump(std::ostream& out) const { sched->dump(out); } -#ifdef ENABLE_TIMERS - ccl_logger::format( - out, - "\nlife time [us] ", - std::setw(5), - std::setbase(10), - std::chrono::duration_cast(exec_complete_time - exec_start_time) - .count(), - "\n"); -#endif - ccl_logger::format(out, "--------------------------------\n"); } ccl_master_sched::ccl_master_sched_ptr ccl_master_sched::create(const ccl_coll_param& param, const ccl_coll_attr& attr) { - /* check contracts at first */ - - CCL_THROW_IF_NOT(ccl::global_data::env().atl_transport == ccl_atl_ofi || !(attr.reduction_fn), - "custom reduction is supported for OFI transport only"); - - CCL_THROW_IF_NOT(ccl_datatype_storage::is_predefined_datatype(param.dtype.idx()) || - ccl::global_data::env().atl_transport == ccl_atl_ofi, - "custom datatype is supported for OFI transport only"); - - CCL_THROW_IF_NOT((param.ctype != ccl_coll_allreduce && param.ctype != ccl_coll_reduce && - param.ctype != ccl_coll_sparse_allreduce) || - ccl_datatype_storage::is_predefined_datatype(param.dtype.idx()) || - attr.reduction_fn, - "custom datatype requires custom reduction"); - - CCL_THROW_IF_NOT(param.ctype == ccl_coll_allreduce || - !(attr.prologue_fn || attr.epilogue_fn || attr.reduction_fn), - "prologue/epilogue/custom reduction is supported for allreduce only"); - - CCL_THROW_IF_NOT(param.ctype == ccl_coll_allgatherv || !(attr.vector_buf), - "vector buffer is supported for allgatherv only"); - - if (param.ctype == ccl_coll_sparse_allreduce) { - CCL_THROW_IF_NOT( - ccl::global_data::env().sparse_allreduce_algo_raw != "mask" || !(attr.reduction_fn), - "mask algorithm for sparse_allreduce does not support custom reduction"); - - CCL_THROW_IF_NOT( - (attr.sparse_allreduce_completion_fn || attr.sparse_allreduce_alloc_fn) && - !(reinterpret_cast(attr.sparse_allreduce_completion_fn) & - reinterpret_cast(attr.sparse_allreduce_alloc_fn)), - "sparse_allreduce requires completion callback only or allocation callback only"); - } - - if (param.dtype.idx() == ccl::datatype::float16) { - CCL_THROW_IF_NOT(ccl::global_data::env().fp16_impl_type != ccl_fp16_no_compiler_support, - "FP16 datatype is requested but not supported by CCL compiler"); - CCL_THROW_IF_NOT(ccl::global_data::env().fp16_impl_type != ccl_fp16_no_hardware_support, - "FP16 datatype is requested but not supported by hardware"); - } - - if (param.dtype.idx() == ccl::datatype::bfloat16) { - CCL_THROW_IF_NOT(ccl::global_data::env().bf16_impl_type != ccl_bf16_no_compiler_support, - "BF16 datatype is requested but not supported by CCL compiler"); - CCL_THROW_IF_NOT(ccl::global_data::env().bf16_impl_type != ccl_bf16_no_hardware_support, - "BF16 datatype is requested but not supported by hardware"); - } - ccl_sched_key key; ccl_master_sched_ptr sched; bool is_created = false; diff --git a/src/sched/master_sched.hpp b/src/sched/master_sched.hpp index d3c2e45a5..dbcbf6c44 100644 --- a/src/sched/master_sched.hpp +++ b/src/sched/master_sched.hpp @@ -26,22 +26,13 @@ class ccl_master_sched : public ccl_sched_base, public ccl_request { return "master_sched"; } - ccl_master_sched(const ccl_coll_param& coll_param) - : ccl_sched_base(coll_param), - ccl_request(), - partial_scheds() { -#ifdef ENABLE_DEBUG - set_dump_callback([this](std::ostream& out) { - dump(out); - }); -#endif - } + ccl_master_sched(const ccl_coll_param& coll_param); ccl_master_sched(const ccl_master_sched& src) = delete; ~ccl_master_sched() override; - void add_partial_sched(ccl_coll_param& param); + void add_partial_sched(const ccl_coll_param& param); void commit(ccl_parallelizer* parallelizer = nullptr); ccl_request* start(ccl_executor* exec, bool reset_sched = true); @@ -56,13 +47,14 @@ class ccl_master_sched : public ccl_sched_base, public ccl_request { void sync_partial_scheds(); void dump(std::ostream& out) const; - //TODO encapsulate it in private. + // TODO encapsulate it in private. std::vector> partial_scheds; - //factory method (TODO: wrap into smart-pointer) + // TODO: wrap into smart-pointer using ccl_master_sched_ptr = ccl_master_sched*; static ccl_master_sched_ptr create(const ccl_coll_param& param, const ccl_coll_attr& attr); private: + void reset_state(); void prepare_partial_scheds(); }; diff --git a/src/sched/queue/flow_control.cpp b/src/sched/queue/flow_control.cpp index 1c67546b6..174ca1214 100644 --- a/src/sched/queue/flow_control.cpp +++ b/src/sched/queue/flow_control.cpp @@ -43,7 +43,7 @@ bool flow_control::take_credit() { if (credits) { credits--; CCL_THROW_IF_NOT( - credits >= 0, "unexpected credits ", credits, ", max_credits ", max_credits); + credits <= max_credits, "unexpected credits ", credits, ", max_credits ", max_credits); min_credits = std::min(min_credits, credits); return true; } diff --git a/src/sched/queue/queue.cpp b/src/sched/queue/queue.cpp index 8654e1470..2dbe8f12b 100644 --- a/src/sched/queue/queue.cpp +++ b/src/sched/queue/queue.cpp @@ -77,14 +77,24 @@ ccl_sched_queue::ccl_sched_queue(size_t idx, std::vector atl_eps) } ccl_sched_queue::~ccl_sched_queue() { - if (!bins.empty()) - LOG_WARN("unexpected bins size ", bins.size(), ", expected 0"); + size_t expected_max_priority = 0; + ccl_sched_bin* expected_cached_max_priority_bin = nullptr; + + if (bins.size() >= 1) { + ccl_sched_bin* bin = &(bins.begin()->second); + expected_max_priority = bin->priority; + expected_cached_max_priority_bin = bin; + if (bins.size() > 1) + LOG_WARN("unexpected bins size ", bins.size(), ", expected <= 1"); + } - if (max_priority != 0) - LOG_WARN("unexpected max_priority ", max_priority, ", expected 0"); + if (max_priority != expected_max_priority) + LOG_WARN("unexpected max_priority ", max_priority, ", expected ", expected_max_priority); - if (cached_max_priority_bin) + if (cached_max_priority_bin != expected_cached_max_priority_bin) LOG_WARN("unexpected cached_max_priority_bin"); + + clear(); } void ccl_sched_queue::add(ccl_sched* sched) { @@ -156,7 +166,7 @@ size_t ccl_sched_queue::erase(ccl_sched_bin* bin, size_t idx) { std::lock_guard lock{ bins_guard }; { // no need to lock 'bin' here, because all adding are under bins_guard protection - if (bin->sched_list.elems.empty()) { + if (bin->sched_list.elems.empty() /* && (bins.size() > 1)*/) { bins.erase(bin_priority); // change priority diff --git a/src/sched/queue/queue.hpp b/src/sched/queue/queue.hpp index e25ed9e71..a155ebfc4 100644 --- a/src/sched/queue/queue.hpp +++ b/src/sched/queue/queue.hpp @@ -29,7 +29,7 @@ using sched_bin_list_t = std::unordered_map; // key - pri using sched_queue_lock_t = ccl_spinlock; /* ATL EP is limited resource, each priority bucket consumes single ATL EP and uses it for all bins in bucket */ -#define CCL_PRIORITY_BUCKET_COUNT (4) +#define CCL_PRIORITY_BUCKET_COUNT (1) /* the size of priority bucket, each bin in bucket use the same ATL EP although bins have different priorities */ #define CCL_PRIORITY_BUCKET_SIZE (8) @@ -52,13 +52,13 @@ class ccl_sched_list { } ~ccl_sched_list() { - if (elems.size() != 0 && !ccl::global_data::get().is_ft_enabled) { + if (!elems.empty()) { LOG_WARN("unexpected elem_count ", elems.size(), ", expected 0"); } + clear(); + } - for (size_t i = 0; i < elems.size(); i++) { - elems[i]->clear(); - } + void clear() { elems.clear(); } @@ -88,7 +88,7 @@ class ccl_sched_list { } } - size_t size() { + size_t size() const { { std::lock_guard lock(elem_guard); return elems.size(); @@ -127,9 +127,18 @@ class ccl_sched_list { void dump(std::ostream& out) const { { + auto sched_dump = ccl::global_data::env().sched_dump; std::lock_guard lock(elem_guard); - for (auto& e : elems) { - e->dump(out); + if (sched_dump) { + for (auto& e : elems) { + e->dump(out); + } + } + else { + for (size_t idx = 0; idx < elems.size(); idx++) { + out << " [" << idx + << "]: " << ccl_coll_type_to_str(elems[idx]->coll_param.ctype) << "\n"; + } } } } @@ -154,14 +163,17 @@ class ccl_sched_bin { sched->queue = queue; } - ~ccl_sched_bin() = default; + ~ccl_sched_bin() { + sched_list.clear(); + } + ccl_sched_bin() = delete; ccl_sched_bin& operator=(const ccl_sched_bin& other) = delete; ccl_sched_bin(ccl_sched_bin&& src) = default; ccl_sched_bin& operator=(ccl_sched_bin&& other) = default; - size_t size() { + size_t size() const { return sched_list.size(); } size_t get_priority() { @@ -218,14 +230,15 @@ class ccl_sched_queue { void dump(std::ostream& out) const { { std::lock_guard lock(bins_guard); - if (bins.empty()) { - out << "empty sched_queue"; - } - else { - for (auto& b : bins) { - b.second.dump(out); - } + out << "{\n"; + out << " sched_queue: idx: " << idx << " size: " << bins.size() << "\n"; + size_t idx = 0; + for (auto& bin : bins) { + out << " bin: idx: " << idx << " priority: " << bin.first + << " size: " << bin.second.size() << "\n"; + bin.second.dump(out); } + out << "}\n"; } } diff --git a/src/sched/sched.cpp b/src/sched/sched.cpp index 2c6080867..834c438d3 100644 --- a/src/sched/sched.cpp +++ b/src/sched/sched.cpp @@ -16,7 +16,6 @@ #include "common/global/global.hpp" #include "common/utils/sync_object.hpp" #include "parallelizer/parallelizer.hpp" -#include "sched/entry/factory/entry_factory.hpp" #include "sched/extra_sched.hpp" #include "sched/queue/queue.hpp" #include "sched/sched.hpp" @@ -34,35 +33,6 @@ ccl_sched::~ccl_sched() { if (finalize_fn) { finalize_fn(this, finalize_fn_ctx); } - - if (!memory.mr_list.empty()) { - /* perform deregistration in worker thread */ - { - ccl_coll_param param{}; - param.ctype = ccl_coll_internal; - param.comm = coll_param.comm; - std::unique_ptr dereg_sched(new ccl_extra_sched(param, sched_id)); - entry_factory::make_entry( - dereg_sched.get(), memory.mr_list, param.comm); - if (ccl::global_data::get().is_worker_thread || - !ccl::global_data::env().worker_offload) { - dereg_sched->do_progress(); - } - else { - /* release ownership, because ccl_wait_impl use delete inside */ - ccl_wait_impl(ccl::global_data::get().executor.get(), - start_subsched(dereg_sched.release())); - } - } - - if (!memory.mr_list.empty()) { - LOG_ERROR("memory list is not empty"); - } - - CCL_ASSERT(memory.mr_list.empty()); - } - - free_buffers(); } void ccl_sched::do_progress() { @@ -125,13 +95,37 @@ bool ccl_sched::is_strict_order_satisfied() { } void ccl_sched::complete() { -#ifdef ENABLE_TIMERS - exec_complete_time = timer_type::now(); - if (ccl::global_data::env().sched_dump) { - dump(std::cout); - } -#endif CCL_ASSERT(req, "ccl_sched must have req"); + + if (ccl::global_data::env().sched_profile) { + timer.stop(); + if (entries.size() > 0) { + std::stringstream ss; + ss << "\ncoll:"; + + ccl_coll_param* profile_param = &(static_cast(req)->coll_param); + ss << ccl_coll_type_to_str(profile_param->ctype); + + /* TODO: tmp check, replace ccl_coll_entry_param by ccl_coll_param */ + if (!profile_param->send_counts.empty()) { + ss << " count:" << profile_param->get_send_count(); + } + + ss << " time(uses):\ntotal: " << timer.str() << "\n"; + for (size_t idx = 0; idx < entries.size(); ++idx) { + ss << "[" << idx << "] " << entries[idx]->name() << ": " + << entries[idx]->timer.str() << "\n"; + } + ss << "-----------------------------"; + logger.info(ss.str()); + } + } + + if (!coll_attr.to_cache) { + /* don't wait sched dtor to free memory */ + free_memory(); + } + req->complete(); } @@ -139,10 +133,9 @@ void ccl_sched::renew(bool need_update_id /* = false*/) { if (need_update_id) { update_id(); } -#ifdef ENABLE_TIMERS - exec_start_time = timer_type::now(); - exec_complete_time = exec_start_time; -#endif + if (ccl::global_data::env().sched_profile) { + timer.start(); + } start_idx = 0; for (size_t idx = 0; idx < entries.size(); idx++) { entries[idx].get()->reset(idx); @@ -171,12 +164,13 @@ ccl_request* ccl_sched::start_subsched(ccl_extra_sched* subsched) { ccl::global_data::get().executor->update_wait_condition( queue->get_idx(), ccl_base_thread::wait_data::update_type::increment, 1); + queue->add(subsched); if (ccl::global_data::env().sched_dump) { std::stringstream ostream; subsched->dump(ostream); - LOG_INFO(ostream.str()); + logger.info(ostream.str()); } return subsched->req; diff --git a/src/sched/sched.hpp b/src/sched/sched.hpp index f390bc6bd..f190cfcaf 100644 --- a/src/sched/sched.hpp +++ b/src/sched/sched.hpp @@ -16,6 +16,7 @@ #pragma once #include "sched/sched_base.hpp" +#include "sched/sched_timer.hpp" #include "sched/queue/flow_control.hpp" #include "internal_types.hpp" @@ -59,10 +60,6 @@ class alignas(CACHELINE_SIZE) ccl_sched : public ccl_sched_base { virtual void complete(); - void clear() { - entries.clear(); - } - size_t get_start_idx() const { return start_idx; } @@ -147,7 +144,7 @@ class alignas(CACHELINE_SIZE) ccl_sched : public ccl_sched_base { ccl_sched_queue* queue = nullptr; /* cached pointer to queue, valid even after execution */ size_t start_idx = 0; /* index to start */ - /* + /* used for unique ATL tag creation in algorithms with multiple parallel sub-schedules set once and then used for all entries */ @@ -161,10 +158,10 @@ class alignas(CACHELINE_SIZE) ccl_sched : public ccl_sched_base { /* whether sched should be executed in the same order as in user code */ /* currently applicable for start phase only */ - bool strict_order; + bool strict_order = false; /* - limits number of active entries + limits number of active entries mostly makes sense for ATL entries */ ccl::flow_control flow_control; @@ -181,9 +178,5 @@ class alignas(CACHELINE_SIZE) ccl_sched : public ccl_sched_base { ccl_sched_finalize_fn_t finalize_fn = nullptr; void* finalize_fn_ctx = nullptr; -#ifdef ENABLE_TIMERS - using timer_type = std::chrono::system_clock; - timer_type::time_point exec_start_time{}; - timer_type::time_point exec_complete_time{}; -#endif + ccl::sched_timer timer; }; diff --git a/src/sched/sched_base.cpp b/src/sched/sched_base.cpp index 6803ea257..ff13786c8 100644 --- a/src/sched/sched_base.cpp +++ b/src/sched/sched_base.cpp @@ -17,9 +17,23 @@ #include "coll/algorithms/algorithms_enum.hpp" #include "coll/coll_param.hpp" +#include "coll/selection/selection.hpp" #include "common/global/global.hpp" +#include "common/comm/comm.hpp" +#include "common/comm/host_communicator/host_communicator.hpp" +#include "sched/buffer_cache.hpp" +#include "sched/entry/factory/entry_factory.hpp" #include "sched/sched_base.hpp" +ccl_sched_base::ccl_sched_base(const ccl_coll_param& coll_param) : coll_param(coll_param) { +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) + if (coll_param.stream) { + ccl_comm* node_comm = + coll_param.comm->get_host_comm()->get_node_comm().get()->get_ccl_comm().get(); + memory.handle_manager.init(node_comm, coll_param.stream); + } +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +} std::string to_string(ccl_sched_add_mode mode) { switch (mode) { case ccl_sched_add_front: return "FRONT"; @@ -29,6 +43,10 @@ std::string to_string(ccl_sched_add_mode mode) { return "DEFAULT"; } +ccl_sched_base::~ccl_sched_base() { + free_memory(); +} + void ccl_sched_base::set_coll_attr(const ccl_coll_attr& attr) { coll_attr = attr; } @@ -36,49 +54,46 @@ void ccl_sched_base::set_coll_attr(const ccl_coll_attr& attr) { void ccl_sched_base::update_coll_param_and_attr(const ccl_coll_param& param, const ccl_coll_attr& attr) { #ifdef CCL_ENABLE_SYCL - copy_deps(param.deps, coll_param.deps); - if (param.stream && param.stream->is_sycl_device_stream()) { - /* update device buffers only if they are already non-null - i.e. were set on previous call */ - if (coll_param.device_send_buf) { - coll_param.device_send_buf = static_cast((void*)param.send_buf); - } - if (coll_param.device_recv_buf) { - coll_param.device_recv_buf = static_cast(param.recv_buf); - } + coll_param.sync_deps(param.stream, param.deps); +#endif // CCL_ENABLE_SYCL + + bool has_pre_post_copies = + (!coll_param.device_send_bufs.empty() || !coll_param.device_recv_bufs.empty()) ? true + : false; + + if (has_pre_post_copies) { + CCL_THROW_IF_NOT(coll_param.device_send_bufs.size() == param.send_bufs.size(), + "send_bufs sizes mismatch"); + CCL_THROW_IF_NOT(coll_param.device_recv_bufs.size() == param.recv_bufs.size(), + "recv_bufs sizes mismatch"); + coll_param.device_send_bufs = param.send_bufs; + coll_param.device_recv_bufs = param.recv_bufs; } else { -#endif /* CCL_ENABLE_SYCL */ - coll_param.send_buf = param.send_buf; - coll_param.recv_buf = param.recv_buf; -#ifdef CCL_ENABLE_SYCL + CCL_THROW_IF_NOT(coll_param.send_bufs.size() == param.send_bufs.size(), + "send_bufs sizes mismatch"); + CCL_THROW_IF_NOT(coll_param.recv_bufs.size() == param.recv_bufs.size(), + "recv_bufs sizes mismatch"); + coll_param.send_bufs = param.send_bufs; + coll_param.recv_bufs = param.recv_bufs; } -#endif /* CCL_ENABLE_SYCL */ + + int comm_size = coll_param.comm->size(); if (coll_param.ctype == ccl_coll_allgatherv) { - coll_param.recv_counts = param.recv_counts; - CCL_THROW_IF_NOT((int)coll_param_copy.ag_recv_counts.size() == coll_param.comm->size()); - coll_param_copy.ag_recv_counts.assign((size_t*)param.recv_counts, - (size_t*)param.recv_counts + coll_param.comm->size()); - - if (coll_attr.vector_buf) { - CCL_THROW_IF_NOT((int)coll_param_copy.ag_recv_bufs.size() == coll_param.comm->size()); - coll_param_copy.ag_recv_bufs.assign((void**)param.recv_buf, - (void**)param.recv_buf + coll_param.comm->size()); - } + if (coll_attr.is_vector_buf) + CCL_THROW_IF_NOT(static_cast(coll_param.recv_bufs.size()) == comm_size); + CCL_THROW_IF_NOT(static_cast(coll_param.recv_counts.size()) == comm_size); } if (coll_param.ctype == ccl_coll_alltoallv) { - coll_param.send_counts = param.send_counts; - coll_param.recv_counts = param.recv_counts; + if (coll_attr.is_vector_buf) + CCL_THROW_IF_NOT(static_cast(coll_param.send_bufs.size()) == comm_size); + CCL_THROW_IF_NOT(static_cast(coll_param.send_counts.size()) == comm_size); - CCL_THROW_IF_NOT((int)coll_param_copy.a2av_send_counts.size() == coll_param.comm->size()); - CCL_THROW_IF_NOT((int)coll_param_copy.a2av_recv_counts.size() == coll_param.comm->size()); - - coll_param_copy.a2av_send_counts.assign( - (size_t*)param.send_counts, (size_t*)param.send_counts + coll_param.comm->size()); - coll_param_copy.a2av_recv_counts.assign( - (size_t*)param.recv_counts, (size_t*)param.recv_counts + coll_param.comm->size()); + if (coll_attr.is_vector_buf) + CCL_THROW_IF_NOT(static_cast(coll_param.recv_bufs.size()) == comm_size); + CCL_THROW_IF_NOT(static_cast(coll_param.recv_counts.size()) == comm_size); } if (coll_param.ctype == ccl_coll_sparse_allreduce) { @@ -110,60 +125,108 @@ size_t ccl_sched_base::get_priority() const { return priority; } -ccl_buffer ccl_sched_base::alloc_buffer(size_t bytes) { +void* ccl_sched_base::alloc_buffer_unmanaged(size_t bytes, ccl_sched_buf_type buf_type) { LOG_DEBUG("try to allocate buffer size: ", bytes); CCL_THROW_IF_NOT(bytes > 0, "incorrect buffer size: ", bytes); - ccl_buffer buffer = - ccl_buffer(CCL_MALLOC(bytes, "sched_buffer"), bytes, 0, ccl_buffer_type::DIRECT); - memory.buf_list.emplace_back(buffer, bytes); - CCL_THROW_IF_NOT(buffer.get_ptr(), "null ptr"); + void* ptr = nullptr; + if (buf_type == ccl_sched_buf_system) { + ccl::global_data::get().buffer_cache->get(sched_id, bytes, &ptr); + } +#ifdef CCL_ENABLE_SYCL + else if (buf_type == ccl_sched_buf_runtime) { + CCL_THROW_IF_NOT(coll_param.stream, "null stream"); + sycl::context ctx = coll_param.stream->get_native_stream().get_context(); + ccl::global_data::get().buffer_cache->get(sched_id, bytes, ctx, &ptr); + } +#endif // CCL_ENABLE_SYCL + else { + CCL_THROW("unexpected buf_type ", buf_type); + } - LOG_DEBUG("allocated buffer ptr: ", buffer.get_ptr(), ", size: ", buffer.get_size()); - return buffer; + LOG_DEBUG("allocated buffer: ", ptr, ", size: ", bytes); + return ptr; } +void ccl_sched_base::free_buffer_unmanaged(void* ptr, size_t bytes, ccl_sched_buf_type buf_type) { + LOG_DEBUG("free buffer: ", ptr, ", buf_type: ", buf_type); + + if (buf_type == ccl_sched_buf_system) { + ccl::global_data::get().buffer_cache->push(sched_id, bytes, ptr); + } #ifdef CCL_ENABLE_SYCL + else if (buf_type == ccl_sched_buf_runtime) { + CCL_THROW_IF_NOT(coll_param.stream, "null stream"); + sycl::context ctx = coll_param.stream->get_native_stream().get_context(); + ccl::global_data::get().buffer_cache->push(sched_id, bytes, ctx, ptr); + } +#endif // CCL_ENABLE_SYCL + else { + CCL_THROW("unexpected buf_type ", buf_type); + } +} -ccl_buffer ccl_sched_base::alloc_staging_buffer(size_t bytes) { - LOG_DEBUG("try to allocate usm host buffer size: ", bytes); - CCL_THROW_IF_NOT(bytes > 0, "incorrect buffer size: ", bytes); +ccl_buffer ccl_sched_base::alloc_buffer(size_t bytes, ccl_sched_buf_type buf_type) { + ccl_buffer buffer = + ccl_buffer(alloc_buffer_unmanaged(bytes, buf_type), bytes, 0, ccl_buffer_type::DIRECT); - ccl_buffer buffer; - if (ccl::global_data::env().staging_buffer == ccl_staging_usm) { - CCL_ASSERT(coll_param.stream); + if (buf_type == ccl_sched_buf_system) { + memory.buf_list.emplace_back(buffer, bytes); + } +#ifdef CCL_ENABLE_SYCL + else if (buf_type == ccl_sched_buf_runtime) { + CCL_THROW_IF_NOT(coll_param.stream, "null stream"); sycl::context ctx = coll_param.stream->get_native_stream().get_context(); - buffer = ccl_buffer(aligned_alloc_host(64, bytes, ctx), bytes, 0, ccl_buffer_type::DIRECT); memory.sycl_buf_list.emplace_back(buffer, bytes, ctx); LOG_DEBUG( "allocated host usm buffer ptr: ", buffer.get_ptr(), ", size: ", buffer.get_size()); } - else { - buffer = alloc_buffer(bytes); +#endif // CCL_ENABLE_SYCL + + CCL_THROW_IF_NOT(buffer.get_ptr(), "null ptr"); + return buffer; +} + +#ifdef CCL_ENABLE_SYCL +ccl_buffer ccl_sched_base::alloc_staging_buffer(size_t bytes) { + LOG_DEBUG("try to allocate usm host buffer size: ", bytes); + CCL_THROW_IF_NOT(bytes > 0, "incorrect buffer size: ", bytes); + + ccl_sched_buf_type buf_type = ccl_sched_buf_system; + if (ccl::global_data::env().staging_buffer == ccl_staging_usm) { + buf_type = ccl_sched_buf_runtime; } + ccl_buffer buffer = alloc_buffer(bytes, buf_type); CCL_THROW_IF_NOT(buffer.get_ptr(), "null ptr"); return buffer; } -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL -void ccl_sched_base::free_buffers() { +void ccl_sched_base::free_memory() { std::list::iterator it; for (it = memory.buf_list.begin(); it != memory.buf_list.end(); it++) { - LOG_DEBUG("free ", it->buffer.get_ptr()); - CCL_FREE(it->buffer.get_ptr()); + free_buffer_unmanaged(it->buffer.get_ptr(), it->size, ccl_sched_buf_system); } memory.buf_list.clear(); + free_memory_regions(); + #ifdef CCL_ENABLE_SYCL std::list::iterator sycl_it; for (sycl_it = memory.sycl_buf_list.begin(); sycl_it != memory.sycl_buf_list.end(); sycl_it++) { LOG_DEBUG("free host usm ", sycl_it->buffer.get_ptr()); - free(sycl_it->buffer.get_ptr(), sycl_it->ctx); + ccl::global_data::get().buffer_cache->push( + sched_id, sycl_it->size, sycl_it->ctx, sycl_it->buffer.get_ptr()); } memory.sycl_buf_list.clear(); -#endif /* CCL_ENABLE_SYCL */ + +#ifdef MULTI_GPU_SUPPORT + memory.handle_manager.clear(); +#endif // MULTI_GPU_SUPPORT + +#endif // CCL_ENABLE_SYCL } ccl_buffer ccl_sched_base::update_buffer(ccl_buffer buffer, size_t new_size) { @@ -261,72 +324,95 @@ void ccl_sched_base::add_memory_region(atl_mr_t* mr) { memory.mr_list.emplace_back(mr); } -void ccl_sched_base::alloc_buffers_for_pre_post_copy() { -#ifdef CCL_ENABLE_SYCL - ccl_coll_param& param = coll_param; - param.device_send_buf = param.device_recv_buf = nullptr; - - if (!param.stream || (!param.stream->is_sycl_device_stream())) +void ccl_sched_base::free_memory_regions() { + if (memory.mr_list.empty()) { return; + } - // check both recv and send buffers, for some algorithms(i.e. alltoallv) one of them is allowed to - // be invalid(i.e. unknown return type) as long as the corresponding count is 0 so we won't dereference it. - // TODO: should we add a special handling for case when both buffers are invalid? - auto send_ptr_type = sycl::get_pointer_type((void*)param.send_buf, - param.stream->get_native_stream().get_context()); - auto recv_ptr_type = - sycl::get_pointer_type(param.recv_buf, param.stream->get_native_stream().get_context()); + /* perform deregistration in worker thread */ - // TODO: we currently don't correctly handle cases when there are 2 different types at the same time - // i.e. device memory for send buffer and shared memory for recv buffer - bool should_alloc_buffers = true; - if ((send_ptr_type == sycl::usm::alloc::shared || recv_ptr_type == sycl::usm::alloc::shared) || - ((send_ptr_type == sycl::usm::alloc::device || recv_ptr_type == sycl::usm::alloc::device) && - atl_wrapper::attr.out.enable_device_buf)) { - should_alloc_buffers = false; + ccl_coll_param param{}; + param.ctype = ccl_coll_internal; + param.comm = coll_param.comm; + std::unique_ptr dereg_sched(new ccl_extra_sched(param, sched_id)); + entry_factory::make_entry(dereg_sched.get(), memory.mr_list, param.comm); + + if (ccl::global_data::get().is_worker_thread || !ccl::global_data::env().worker_offload) { + dereg_sched->do_progress(); + } + else { + CCL_THROW("unsupported path"); + /* release ownership, because ccl_wait_impl use delete inside */ + // ccl_wait_impl( + // ccl::global_data::get().executor.get(), + // start_subsched(dereg_sched.release())); } - if (!should_alloc_buffers) { - return; + if (!memory.mr_list.empty()) { + LOG_ERROR("memory region list is not empty after deregister_entry completion"); } +} - param.device_send_buf = static_cast((void*)param.send_buf); - param.device_recv_buf = static_cast(param.recv_buf); +void ccl_sched_base::get_pre_post_copy_counts(std::vector& d2h_counts, + std::vector& h2d_counts, + bool& reuse_buffers) { + ccl_coll_param& param = coll_param; - param.send_buf = param.recv_buf = nullptr; + d2h_counts.clear(); + h2d_counts.clear(); + reuse_buffers = false; - size_t send_alloc_count = 0, recv_alloc_count = 0; switch (param.ctype) { case ccl_coll_allgatherv: - send_alloc_count = param.send_count; - recv_alloc_count = - std::accumulate(param.recv_counts, param.recv_counts + param.comm->size(), 0); + d2h_counts.push_back(param.get_send_count()); + if (param.recv_bufs.size() > 1) { + h2d_counts.insert( + h2d_counts.end(), param.recv_counts.begin(), param.recv_counts.end()); + } + else { + h2d_counts.push_back( + std::accumulate(param.recv_counts.begin(), param.recv_counts.end(), 0)); + } break; case ccl_coll_allreduce: + d2h_counts.push_back(param.get_send_count()); + h2d_counts.push_back(param.get_recv_count()); /* use in-place to avoid allocation of extra staging buffer*/ - send_alloc_count = 0; - recv_alloc_count = param.count; + reuse_buffers = true; break; case ccl_coll_alltoall: - send_alloc_count = recv_alloc_count = param.count * param.comm->size(); + d2h_counts.push_back(param.get_send_count() * param.comm->size()); + h2d_counts.push_back(param.get_recv_count() * param.comm->size()); break; case ccl_coll_alltoallv: - send_alloc_count = - std::accumulate(param.send_counts, param.send_counts + param.comm->size(), 0); - recv_alloc_count = - std::accumulate(param.recv_counts, param.recv_counts + param.comm->size(), 0); + if (param.recv_bufs.size() > 1) { + /* expect that is_vector_buf is enabled for send/recv both */ + d2h_counts.insert( + d2h_counts.end(), param.send_counts.begin(), param.send_counts.end()); + h2d_counts.insert( + h2d_counts.end(), param.recv_counts.begin(), param.recv_counts.end()); + } + else { + d2h_counts.push_back( + std::accumulate(param.send_counts.begin(), param.send_counts.end(), 0)); + h2d_counts.push_back( + std::accumulate(param.recv_counts.begin(), param.recv_counts.end(), 0)); + } break; case ccl_coll_bcast: - send_alloc_count = 0; - recv_alloc_count = param.count; + if (param.comm->rank() == param.root) + d2h_counts.push_back(param.get_send_count()); + h2d_counts.push_back(param.get_recv_count()); + reuse_buffers = true; break; case ccl_coll_reduce: - send_alloc_count = param.count; - recv_alloc_count = (param.comm->rank() == param.root) ? param.count : 0; + d2h_counts.push_back(param.get_send_count()); + if (param.comm->rank() == param.root) + h2d_counts.push_back(param.get_recv_count()); break; case ccl_coll_reduce_scatter: - send_alloc_count = param.count * param.comm->size(); - recv_alloc_count = param.count; + d2h_counts.push_back(param.get_send_count()); + h2d_counts.push_back(param.get_recv_count()); break; case ccl_coll_sparse_allreduce: CCL_FATAL("SYCL stream is not supported for sparse_allreduce yet"); @@ -334,6 +420,61 @@ void ccl_sched_base::alloc_buffers_for_pre_post_copy() { break; default: break; } +} + +void ccl_sched_base::alloc_buffers_for_pre_post_copy() { +#ifdef CCL_ENABLE_SYCL + + ccl_coll_param& param = coll_param; + + param.device_send_bufs.clear(); + param.device_recv_bufs.clear(); + + // TODO: WA skip sycl pre_post_copy for allreduce gpu algo + ccl_selector_param selector_param; + selector_param.ctype = param.ctype; + selector_param.count = param.get_send_count(); + selector_param.dtype = param.dtype; + selector_param.comm = param.comm; + selector_param.stream = param.stream; + selector_param.is_sycl_buf = coll_attr.is_sycl_buf; + + if (!param.stream || !param.stream->is_sycl_device_stream() || + ccl_is_topo_ring_algo(selector_param)) { + return; + } + + bool should_alloc_buffers = true; + + if (!coll_attr.is_sycl_buf) { + auto bufs = param.get_all_non_zero_bufs(); + if (!bufs.empty()) { + auto usm_type = + sycl::get_pointer_type(bufs[0], param.stream->get_native_stream().get_context()); + if ((usm_type == sycl::usm::alloc::host) || (usm_type == sycl::usm::alloc::shared) || + ((usm_type == sycl::usm::alloc::device) && atl_wrapper::attr.out.enable_hmem)) { + should_alloc_buffers = false; + } + } + } + + LOG_DEBUG("coll_type ", param.ctype, ", should_alloc_buffers ", should_alloc_buffers); + + if (!should_alloc_buffers) { + return; + } + + /* + move user-supplied pointers into device_* fields + they will be used further for pre-post copies + */ + param.device_send_bufs = param.send_bufs; + param.device_recv_bufs = param.recv_bufs; + + std::vector d2h_counts; + std::vector h2d_counts; + bool reuse_buffers; + get_pre_post_copy_counts(d2h_counts, h2d_counts, reuse_buffers); LOG_DEBUG("alloc tmp buffers for D2H and H2D copies, coll_type ", ccl_coll_type_to_str(param.ctype), @@ -341,21 +482,53 @@ void ccl_sched_base::alloc_buffers_for_pre_post_copy() { param.dtype.size(), ", comm_size ", param.comm->size(), - ", count ", - param.count); + ", d2h_counts_size ", + d2h_counts.size(), + ", h2d_counts_size ", + h2d_counts.size(), + ", reuse_buffers ", + reuse_buffers); + + if (reuse_buffers) { + /* keep only single vector with counts */ + if (d2h_counts.size() < h2d_counts.size()) + d2h_counts = h2d_counts; + h2d_counts.clear(); + } - if (send_alloc_count) { - param.send_buf = alloc_staging_buffer(send_alloc_count * param.dtype.size()).get_ptr(); + for (size_t idx = 0; idx < d2h_counts.size(); idx++) { + if (d2h_counts[idx]) + param.send_bufs[idx] = + alloc_staging_buffer(d2h_counts[idx] * param.dtype.size()).get_ptr(); + else + param.send_bufs[idx] = nullptr; } - if (recv_alloc_count) { - param.recv_buf = alloc_staging_buffer(recv_alloc_count * param.dtype.size()).get_ptr(); + for (size_t idx = 0; idx < h2d_counts.size(); idx++) { + if (h2d_counts[idx]) + param.recv_bufs[idx] = + alloc_staging_buffer(h2d_counts[idx] * param.dtype.size()).get_ptr(); + else + param.recv_bufs[idx] = nullptr; + } - if (param.ctype == ccl_coll_allreduce || param.ctype == ccl_coll_bcast) { - param.send_buf = param.recv_buf; - } + if (reuse_buffers) { + param.recv_bufs = param.send_bufs; } -#endif /* CCL_ENABLE_SYCL */ + + CCL_THROW_IF_NOT(param.send_bufs.size() == param.device_send_bufs.size(), + "send_bufs.size() mismatch: ", + param.send_bufs.size(), + " vs ", + param.device_send_bufs.size()); + + CCL_THROW_IF_NOT(param.recv_bufs.size() == param.device_recv_bufs.size(), + "recv_bufs.size() mismatch: ", + param.recv_bufs.size(), + " vs ", + param.device_recv_bufs.size()); + +#endif // CCL_ENABLE_SYCL } void ccl_sched_base::update_id() { diff --git a/src/sched/sched_base.hpp b/src/sched/sched_base.hpp index 80edcb737..8775a9f84 100644 --- a/src/sched/sched_base.hpp +++ b/src/sched/sched_base.hpp @@ -26,6 +26,10 @@ #include "common/utils/buffer.hpp" #include "sched/entry/entry.hpp" +#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) +#include "sched/ze_handle_manager.hpp" +#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT + class ccl_sched_queue; class ccl_sched_bin; class ccl_request; @@ -45,6 +49,13 @@ enum ccl_sched_add_mode { ccl_sched_add_mode_last_value }; +enum ccl_sched_buf_type { + ccl_sched_buf_system, + ccl_sched_buf_runtime, + + ccl_sched_buf_last_value +}; + std::string to_string(ccl_sched_add_mode mode); struct ccl_sched_buffer_handler { @@ -62,7 +73,7 @@ struct ccl_sched_sycl_buffer_handler : public ccl_sched_buffer_handler { : ccl_sched_buffer_handler(buffer, size), ctx(ctx) {} }; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL struct ccl_sched_memory { std::list buf_list; @@ -70,7 +81,16 @@ struct ccl_sched_memory { #ifdef CCL_ENABLE_SYCL std::list sycl_buf_list; -#endif /* CCL_ENABLE_SYCL */ +#ifdef MULTI_GPU_SUPPORT + ccl::ze::ipc_handle_manager handle_manager; + // sync event which we use to signal to the user about collective completion + // and the pool it's created from(need to keep it to know what to return to the cache) + // TODO: this is not the best place for these objects, think about moving them + // to ccl_master_sched where they actually used + ze_event_handle_t sync_event; + ze_event_pool_handle_t sync_pool; +#endif // MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL }; static size_t lifo_priority = 0; @@ -89,18 +109,28 @@ struct ccl_sched_base { size_t get_priority() const; - ccl_buffer alloc_buffer(size_t bytes); + void* alloc_buffer_unmanaged(size_t bytes, ccl_sched_buf_type buf_type = ccl_sched_buf_system); + void free_buffer_unmanaged(void* ptr, + size_t bytes, + ccl_sched_buf_type buf_type = ccl_sched_buf_system); + + ccl_buffer alloc_buffer(size_t bytes, ccl_sched_buf_type buf_type = ccl_sched_buf_system); #ifdef CCL_ENABLE_SYCL ccl_buffer alloc_staging_buffer(size_t bytes); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL - void free_buffers(); + void add_memory_region(atl_mr_t* mr); + void free_memory_regions(); + + void free_memory(); ccl_buffer update_buffer(ccl_buffer buffer, size_t new_size); ccl_buffer find_and_realloc_buffer(void* buffer, size_t new_size, size_t expected_size = 0); - void add_memory_region(atl_mr_t* mr); + void get_pre_post_copy_counts(std::vector& d2h_counts, + std::vector& h2d_counts, + bool& reuse_buffers); void alloc_buffers_for_pre_post_copy(); @@ -116,9 +146,15 @@ struct ccl_sched_base { add_mode = mode; } + ccl_sched_memory& get_memory() { + return memory; + } + ccl_coll_param coll_param{}; ccl_coll_attr coll_attr{}; - ccl_coll_param_copy coll_param_copy{}; + + /* TODO: schedule doesn't necessarily map on single algo */ + ccl_coll_algo hint_algo{}; /* sequence number of the schedule in the communicator */ ccl_sched_id_t sched_id = 0; @@ -131,13 +167,13 @@ struct ccl_sched_base { } protected: - ~ccl_sched_base() = default; + ~ccl_sched_base(); ccl_sched_base() { CCL_THROW("unsupported"); } - ccl_sched_base(const ccl_coll_param& coll_param) : coll_param(coll_param) {} + ccl_sched_base(const ccl_coll_param& coll_param); void update_id(); diff --git a/src/sched/sched_timer.cpp b/src/sched/sched_timer.cpp new file mode 100644 index 000000000..94141ef45 --- /dev/null +++ b/src/sched/sched_timer.cpp @@ -0,0 +1,54 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include +#include +#include + +#include "common/log/log.hpp" +#include "sched_timer.hpp" + +namespace ccl { + +void sched_timer::start() noexcept { + start_time = std::chrono::high_resolution_clock::now(); +} + +void sched_timer::stop() { + auto stop_time = std::chrono::high_resolution_clock::now(); + std::chrono::duration time_span = stop_time - start_time; + time_usec = time_span.count(); +} + +std::string sched_timer::str() const { + std::stringstream ss; + ss.precision(2); + ss << std::fixed << get_time(); + return ss.str(); +} + +void sched_timer::print(std::string title) const { + logger.info(title, ": ", this->str()); +} + +void sched_timer::reset() noexcept { + time_usec = 0; +} + +long double sched_timer::get_time() const noexcept { + return time_usec; +} + +} // namespace ccl diff --git a/src/kernels/event_declaration.h b/src/sched/sched_timer.hpp similarity index 57% rename from src/kernels/event_declaration.h rename to src/sched/sched_timer.hpp index b468d8a76..532a3240e 100644 --- a/src/kernels/event_declaration.h +++ b/src/sched/sched_timer.hpp @@ -1,41 +1,39 @@ /* Copyright 2016-2020 Intel Corporation - + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef HOST_CTX -#define __global - -#include -using namespace ccl; +#pragma once -template -struct shared_event_traits {}; +#include +#include -#else -typedef ushort bfloat16; -#endif +namespace ccl { -typedef struct __attribute__((packed)) shared_event_float { - __global int* produced_bytes; - __global float* mem_chunk; -} shared_event_float; +class sched_timer { +public: + sched_timer() = default; + void start() noexcept; + void stop(); + std::string str() const; + void print(std::string title = {}) const; + void reset() noexcept; -#ifdef HOST_CTX +private: + long double time_usec; + std::chrono::high_resolution_clock::time_point start_time{}; -template <> -struct shared_event_traits { - using impl_t = shared_event_float; + long double get_time() const noexcept; }; -#endif +} //namespace ccl diff --git a/src/sched/ze_handle_manager.cpp b/src/sched/ze_handle_manager.cpp new file mode 100644 index 000000000..31d958a55 --- /dev/null +++ b/src/sched/ze_handle_manager.cpp @@ -0,0 +1,259 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "common/comm/comm.hpp" +#include "sched/entry/gpu/ze_call.hpp" +#include "sched/ze_handle_manager.hpp" + +#include + +namespace ccl { + +namespace ze { + +static void cast_mem_to_pool_handle(ze_ipc_event_pool_handle_t* pool, + const ze_ipc_mem_handle_t* mem) { + static_assert(sizeof(ze_ipc_event_pool_handle_t) == sizeof(ze_ipc_mem_handle_t)); + memcpy(pool, mem, sizeof(*mem)); +} + +std::string to_string(ipc_mem_type type) { + switch (type) { + case ipc_mem_type::memory: return "buffer"; + case ipc_mem_type::pool: return "pool"; + default: return "unknown"; + } +} + +ipc_handle_info::ipc_handle_info() { + memset(&handle, 0, sizeof(handle)); +} + +ipc_handle_info::ipc_handle_info(const ze_ipc_mem_handle_t& handle, + size_t offset, + ipc_mem_type type) + : handle(handle), + offset(offset), + type(type) {} + +ipc_handle_manager::~ipc_handle_manager() { + clear(); +} + +void ipc_handle_manager::init(const ccl_comm* init_comm, const ccl_stream* init_stream) { + LOG_DEBUG("initialization"); + CCL_THROW_IF_NOT(init_comm, "no comm"); + CCL_THROW_IF_NOT(init_stream, "no stream"); + + comm = const_cast(init_comm); + + for (int idx = 0; idx < comm->size(); idx++) { + rank_map.insert({ comm->get_global_rank(idx), idx }); + } + + auto sycl_device = init_stream->get_native_stream().get_device(); + auto sycl_context = init_stream->get_native_stream().get_context(); + + device = sycl_device.template get_native(); + context = sycl_context.template get_native(); + + CCL_THROW_IF_NOT(device, "device is not valid"); + CCL_THROW_IF_NOT(context, "context is not valid"); + + LOG_DEBUG("initialization completed"); +} + +void ipc_handle_manager::clear() { + for (int rank = 0; rank < static_cast(handles.size()); rank++) { + for (size_t buf_idx = 0; buf_idx < handles[rank].size(); buf_idx++) { + const auto& handle_info = handles[rank][buf_idx]; + ze_ipc_mem_handle_t handle = handle_info.handle; + auto mem_ptr = handle_info.ptr; + auto mem_type = handle_info.type; + size_t mem_offset = handle_info.offset; + + LOG_DEBUG("close handle: { base_ptr: ", + mem_ptr, + ", offset: ", + mem_offset, + ", fd: ", + *(int*)handle.data, + ", rank: ", + rank, + ", buf_idx: ", + buf_idx, + " }"); + + // when closing the handle we need to take care of pointers that points to the + // same level zero allocation. They're simply offsetted from some base pointer + // although represented by different FDs. If we close this base pointer, + // all the derived pointers are closed(unmapped) as well. To handle this case + // we ignore the result of close function which would fail if we close a pointer + // which is already closed. The function has semantic of free() call, so the result + // is not much useful anyway. + if (mem_ptr) { + ze_result_t res{}; + if (mem_type == ipc_mem_type::memory) { + res = zeMemCloseIpcHandle(context, mem_ptr); + } + else if (mem_type == ipc_mem_type::pool) { + res = zeEventPoolCloseIpcHandle((ze_event_pool_handle_t)mem_ptr); + } + else { + CCL_THROW("unknown memory type"); + } + + if (res != ZE_RESULT_SUCCESS) { + LOG_TRACE("unable to close memory handle: ", + "level-zero res: ", + to_string(res), + ", rank: ", + rank, + ", buf_idx: ", + buf_idx, + ", ptr: ", + mem_ptr); + } + } + + // TODO: remove, when the fix arrives from L0 side: XDEPS-2302 + int fd; + memcpy(&fd, handle.data, sizeof(fd)); + close(fd); + } + } + + if (!handles.empty()) { + LOG_DEBUG("handles are cleared successfully"); + } + + handles.clear(); +} + +void ipc_handle_manager::set(const mem_handle_map_t& handles_arg) { + CCL_THROW_IF_NOT(!handles_arg.empty(), "handles_arg argument is empty"); + CCL_THROW_IF_NOT(handles_arg.size() == static_cast(comm->size()), + "handles_arg and comm sizes should be equal"); + CCL_THROW_IF_NOT(handles.empty(), "handles should be empty before set"); + + handles = handles_arg; + LOG_DEBUG("handles are set successfully, size of handles: ", handles.size()); +} + +void ipc_handle_manager::get(int rank, size_t buf_idx, ccl_buffer& buf, ccl_comm* map_comm) { + check_rank(rank, (map_comm) ? map_comm : comm); + if (map_comm && (map_comm->id() != comm->id())) { + int old_rank = rank; + rank = map_comm->get_global_rank(rank); + auto rank_it = rank_map.find(rank); + if (rank_it == rank_map.end()) { + CCL_THROW("handle manager can not handle global rank ", rank); + } + rank = rank_it->second; + LOG_DEBUG("convert rank: old_rank: ", + old_rank, + " old_comm: id: ", + map_comm->id(), + ", size: ", + map_comm->size(), + ", new_rank: ", + rank, + " new_comm: id: ", + comm->id(), + ", size: ", + comm->size()); + check_rank(rank, comm); + } + CCL_THROW_IF_NOT(buf_idx < handles[rank].size(), "buf_idx is not valid value: ", buf_idx); + + const auto& handle_info = handles[rank][buf_idx]; + auto handle = handle_info.handle; + auto mem_ptr = handle_info.ptr; + auto mem_type = handle_info.type; + + LOG_DEBUG("context: ", context, ", device: ", device, ", rank: ", rank, ", buf_idx: ", buf_idx); + if (mem_ptr == nullptr) { + if (mem_type == ccl::ze::ipc_mem_type::memory) { + open_handle(handle, &mem_ptr); + } + else if (mem_type == ccl::ze::ipc_mem_type::pool) { + ze_ipc_event_pool_handle_t pool_handle; + cast_mem_to_pool_handle(&pool_handle, &handle); + open_handle(pool_handle, (ze_event_pool_handle_t*)&mem_ptr); + } + else { + CCL_THROW("unknown memory type"); + } + } + + LOG_DEBUG("get handle: { mem_ptr: ", + mem_ptr, + ", fd: ", + *(int*)handle.data, + ", rank: ", + rank, + ", buf_idx: ", + buf_idx, + " }"); + + // add offset that we received along with the handle + size_t mem_offset = handle_info.offset; + void* final_ptr = static_cast(static_cast(mem_ptr) + mem_offset); + buf.set(final_ptr); +} + +void ipc_handle_manager::get_handle(const void* ptr, ze_ipc_mem_handle_t* handle) { + CCL_THROW_IF_NOT(ptr, "no mem pointer"); + ZE_CALL(zeMemGetIpcHandle, (context, ptr, handle)); +} + +void ipc_handle_manager::get_handle(ze_event_pool_handle_t pool, + ze_ipc_event_pool_handle_t* handle) { + CCL_THROW_IF_NOT(pool, "no pool"); + ZE_CALL(zeEventPoolGetIpcHandle, (pool, handle)); +} + +void ipc_handle_manager::open_handle(const ze_ipc_mem_handle_t& handle, void** ptr) { + ZE_CALL(zeMemOpenIpcHandle, (context, device, handle, 0 /* cache allocation */, ptr)); +} + +void ipc_handle_manager::open_handle(const ze_ipc_event_pool_handle_t& handle, + ze_event_pool_handle_t* pool) { + ZE_CALL(zeEventPoolOpenIpcHandle, (context, handle, pool)); +} + +void ipc_handle_manager::get_address_range(const void* ptr, void** base_ptr, size_t* size) { + ZE_CALL(zeMemGetAddressRange, (context, ptr, base_ptr, size)); + LOG_DEBUG("zeMemGetAddressRange: ptr: ", + ptr, + ", base ptr: ", + *base_ptr, + ", offset: ", + ccl_get_ptr_diff(*base_ptr, ptr), + ", size: ", + *size); +} + +void ipc_handle_manager::check_rank(int rank, ccl_comm* check_comm) { + CCL_THROW_IF_NOT( + (rank >= 0) && (rank < static_cast(handles.size())) && (rank < check_comm->size()), + "rank is not valid value: ", + rank); + CCL_THROW_IF_NOT( + rank != check_comm->rank(), "don't expect to open handle for own rank: ", rank); +} + +} // namespace ze +} // namespace ccl diff --git a/src/sched/ze_handle_manager.hpp b/src/sched/ze_handle_manager.hpp new file mode 100644 index 000000000..2920e392d --- /dev/null +++ b/src/sched/ze_handle_manager.hpp @@ -0,0 +1,80 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "common/log/log.hpp" +#include "common/stream/stream.hpp" +#include "common/utils/buffer.hpp" +#include "sched/entry/gpu/ze_primitives.hpp" + +#include +#include + +class ccl_comm; + +namespace ccl { + +namespace ze { + +enum class ipc_mem_type : int { unknown = 0, memory, pool }; + +std::string to_string(ipc_mem_type type); + +struct ipc_handle_info { + ze_ipc_mem_handle_t handle{}; + size_t offset{}; + void* ptr{}; + ipc_mem_type type{}; + + ipc_handle_info(); + ipc_handle_info(const ze_ipc_mem_handle_t& handle, size_t offset, ipc_mem_type type); + ipc_handle_info& operator=(const ipc_handle_info&) = default; +}; + +class ipc_handle_manager { +public: + using mem_handle_map_t = typename std::vector>; + + ipc_handle_manager() = default; + ipc_handle_manager(const ipc_handle_manager&) = delete; + ipc_handle_manager& operator=(const ipc_handle_manager&) = delete; + ~ipc_handle_manager(); + + void init(const ccl_comm* comm, const ccl_stream* stream); + void clear(); + + void set(const mem_handle_map_t& handles_arg); + void get(int rank, size_t buf_idx, ccl_buffer& buf, ccl_comm* map_comm = nullptr); + + void get_handle(const void* buffer, ze_ipc_mem_handle_t* handle); + void get_handle(ze_event_pool_handle_t pool, ze_ipc_event_pool_handle_t* handle); + void open_handle(const ze_ipc_mem_handle_t& handle, void** ptr); + void open_handle(const ze_ipc_event_pool_handle_t& handle, ze_event_pool_handle_t* pool); + + void get_address_range(const void* ptr, void** base_ptr, size_t* size); + +private: + ze_context_handle_t context{}; + ze_device_handle_t device{}; + ccl_comm* comm{}; + std::unordered_map rank_map{}; + mem_handle_map_t handles; + + void check_rank(int rank, ccl_comm* check_comm); +}; + +} // namespace ze +} // namespace ccl diff --git a/src/stream_impl.hpp b/src/stream_impl.hpp index c2fd61067..17d5c9443 100644 --- a/src/stream_impl.hpp +++ b/src/stream_impl.hpp @@ -29,44 +29,12 @@ namespace ccl { namespace v1 { -/* TODO temporary function for UT compilation: would be part of ccl::detail::environment in final*/ -template -stream stream::create_stream_from_attr(typename unified_device_type::ccl_native_t device, - attr_val_type&&... avs) { - auto version = utils::get_library_version(); - - stream str{ stream_provider_dispatcher::create(device, version) }; - int expander[]{ (str.template set(avs.val()), 0)... }; - (void)expander; - str.build_from_params(); - return str; -} - -template -stream stream::create_stream_from_attr(typename unified_device_type::ccl_native_t device, - typename unified_context_type::ccl_native_t context, - attr_val_type&&... avs) { - auto version = utils::get_library_version(); - - stream str{ stream_provider_dispatcher::create(device, context, version) }; - int expander[]{ (str.template set(avs.val()), 0)... }; - (void)expander; - str.build_from_params(); - return str; -} - template stream stream::create_stream(native_stream_type& native_stream) { auto version = utils::get_library_version(); return { stream_provider_dispatcher::create(native_stream, version) }; } -template -stream stream::create_stream(device_type& device, native_context_type& native_ctx) { - auto version = utils::get_library_version(); - return { stream_provider_dispatcher::create(device, native_ctx, version) }; -} - template CCL_API const typename detail::ccl_api_type_attr_traits::return_type& stream::get() const { @@ -83,12 +51,6 @@ CCL_API typename detail::ccl_api_type_attr_traits::retur v, detail::ccl_api_type_attr_traits{}); } -/* -stream::stream(const typename detail::ccl_api_type_attr_traits::type& version) : - base_t(stream_provider_dispatcher::create(version)) -{ -}*/ - } // namespace v1 } // namespace ccl @@ -97,10 +59,6 @@ stream::stream(const typename detail::ccl_api_type_attr_traits::return_type \ diff --git a/src/supported_topologies.hpp b/src/supported_topologies.hpp index 0244b550c..8c308d3c9 100644 --- a/src/supported_topologies.hpp +++ b/src/supported_topologies.hpp @@ -30,7 +30,7 @@ namespace ccl { } // namespace ccl using device_group_split_type_names = - utils::enum_to_str::type>( + ::utils::enum_to_str::type>( ccl::group_split_type::last_value)>; inline std::string to_string(ccl::group_split_type type) { return device_group_split_type_names({ @@ -41,7 +41,8 @@ inline std::string to_string(ccl::group_split_type type) { .choose(type, "INVALID_VALUE"); } -using device_topology_type_names = utils::enum_to_str; +using device_topology_type_names = + ::utils::enum_to_str; inline std::string to_string(ccl::device_topology_type class_value) { return device_topology_type_names({ "RING_CLASS", "A2A_CLASS" }) .choose(class_value, "INVALID_VALUE"); diff --git a/tests/functional/CMakeLists.txt b/tests/functional/CMakeLists.txt index fb6416133..d6ba5177b 100644 --- a/tests/functional/CMakeLists.txt +++ b/tests/functional/CMakeLists.txt @@ -98,6 +98,8 @@ if (COMPUTE_BACKEND) set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${COMPUTE_BACKEND_LIBRARIES}") endif() +link_directories(${EXAMPLES_LIB_DIRS}) + foreach(src ${sources}) get_filename_component(executable ${src} NAME_WE) add_executable(${executable} ${src} ${SERVICE_SRC}) @@ -107,6 +109,7 @@ foreach(src ${sources}) if (${executable} MATCHES ".*bcast.*") target_compile_definitions(${executable} PRIVATE TEST_CCL_BCAST) endif() + target_include_directories(${executable} PRIVATE ${EXAMPLES_INC_DIRS}) target_link_libraries(${executable} PUBLIC gtest_main) target_link_libraries(${executable} PUBLIC gtest) target_link_libraries(${executable} PRIVATE ccl) @@ -129,8 +132,19 @@ endforeach() add_test (NAME allreduce_fusion CONFIGURATIONS allreduce_fusion COMMAND mpiexec.hydra -l -n 2 -ppn 1 ${CCL_INSTALL_TESTS}/allreduce_test --gtest_output=xml:${CCL_INSTALL_TESTS}/allreduce_fusion_report.junit.xml) -foreach(algo direct; rabenseifner; starlike; ring; ring_rma; double_tree; recursive_doubling; 2d) -add_test (NAME allreduce_${algo} CONFIGURATIONS allreduce_${algo} COMMAND mpiexec.hydra -l -n 2 -ppn 1 ${CCL_INSTALL_TESTS}/allreduce_test --gtest_output=xml:${CCL_INSTALL_TESTS}/allreduce_${algo}_report.junit.xml) +foreach(ppn 1; 2) + foreach(algo direct; rabenseifner; starlike; ring; ring_rma; double_tree; recursive_doubling; 2d; topo_ring) + add_test (NAME allreduce_${algo}_${ppn} CONFIGURATIONS allreduce_${algo}_${ppn} COMMAND mpiexec.hydra -l -n 2 -ppn ${ppn} ${CCL_INSTALL_TESTS}/allreduce_test --gtest_output=xml:${CCL_INSTALL_TESTS}/allreduce_${algo}_${ppn}_report.junit.xml) + endforeach() + + foreach(algo direct; ring; double_tree; naive; topo_ring) + add_test (NAME bcast_${algo}_${ppn} CONFIGURATIONS bcast_${algo}_${ppn} COMMAND mpiexec.hydra -l -n 2 -ppn ${ppn} ${CCL_INSTALL_TESTS}/bcast_test --gtest_output=xml:${CCL_INSTALL_TESTS}/bcast_${algo}_${ppn}_report.junit.xml) + endforeach() + + foreach(algo direct; rabenseifner; tree; double_tree; topo_ring) + add_test (NAME reduce_${algo}_${ppn} CONFIGURATIONS reduce_${algo}_${ppn} COMMAND mpiexec.hydra -l -n 2 -ppn ${ppn} ${CCL_INSTALL_TESTS}/reduce_test --gtest_output=xml:${CCL_INSTALL_TESTS}/reduce_${algo}_${ppn}_report.junit.xml) + endforeach() + endforeach() foreach(algo starlike; ring; 2d) @@ -153,19 +167,6 @@ foreach(algo scatter_barrier) add_test (NAME alltoallv_${algo}_chunked CONFIGURATIONS alltoallv_${algo}_chunked COMMAND mpiexec.hydra -l -n 2 -ppn 1 ${CCL_INSTALL_TESTS}/alltoallv_test --gtest_output=xml:${CCL_INSTALL_TESTS}/alltoallv_${algo}_chunked_report.junit.xml) endforeach() -foreach(algo direct; ring; double_tree; naive) -add_test (NAME bcast_${algo} CONFIGURATIONS bcast_${algo} COMMAND mpiexec.hydra -l -n 2 -ppn 1 ${CCL_INSTALL_TESTS}/bcast_test --gtest_output=xml:${CCL_INSTALL_TESTS}/bcast_${algo}_report.junit.xml) -endforeach() - -foreach(algo direct; rabenseifner; tree; double_tree) -add_test (NAME reduce_${algo} CONFIGURATIONS reduce_${algo} COMMAND mpiexec.hydra -l -n 2 -ppn 1 ${CCL_INSTALL_TESTS}/reduce_test --gtest_output=xml:${CCL_INSTALL_TESTS}/reduce_${algo}_report.junit.xml) -endforeach() - foreach(algo direct; ring) add_test (NAME reduce_scatter_${algo} CONFIGURATIONS reduce_scatter_${algo} COMMAND mpiexec.hydra -l -n 2 -ppn 1 ${CCL_INSTALL_TESTS}/reduce_scatter_test --gtest_output=xml:${CCL_INSTALL_TESTS}/reduce_scatter_${algo}_report.junit.xml) endforeach() - -if (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") - # right now all regression tests require dpcpp, might be changed in the future - add_subdirectory(regression) -endif() diff --git a/tests/functional/lp.cpp b/tests/functional/lp.cpp index 8e7b28e6e..2d07bf156 100644 --- a/tests/functional/lp.cpp +++ b/tests/functional/lp.cpp @@ -82,14 +82,14 @@ void convert_fp16_to_fp32(const void* src, void* dst) { // _mm512_storeu_si512(dst, (__m512i)(_mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)src)))); _mm256_storeu_si256((__m256i*)dst, (__m256i)(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src)))); } -#else /* CCL_FP16_COMPILER */ +#else // CCL_FP16_COMPILER void convert_fp32_to_fp16(const void* src, void* dst) { ASSERT(0, "FP16 is unsupported"); } void convert_fp16_to_fp32(const void* src, void* dst) { ASSERT(0, "FP16 is unsupported"); } -#endif /* CCL_FP16_COMPILER */ +#endif // CCL_FP16_COMPILER #ifdef CCL_BF16_COMPILER void convert_fp32_to_bf16(const void* src, void* dst) { @@ -105,14 +105,14 @@ void convert_bf16_to_fp32(const void* src, void* dst) { __m512i y = _mm512_cvtepu16_epi32(_mm256_loadu_si256((__m256i const*)src)); _mm512_storeu_si512(dst, _mm512_bslli_epi128(y, 2)); } -#else /* CCL_BF16_COMPILER */ +#else // CCL_BF16_COMPILER void convert_fp32_to_bf16(const void* src, void* dst) { ASSERT(0, "BF16 is unsupported"); } void convert_bf16_to_fp32(const void* src, void* dst) { ASSERT(0, "BF16 is unsupported"); } -#endif /* CCL_BF16_COMPILER */ +#endif // CCL_BF16_COMPILER void convert_lp_to_fp32(const void* src, void* dst, ccl_data_type dtype) { if (dtype == DATATYPE_FLOAT16) { diff --git a/tests/functional/regression/CMakeLists.txt b/tests/functional/regression/CMakeLists.txt deleted file mode 100644 index c17eac8a3..000000000 --- a/tests/functional/regression/CMakeLists.txt +++ /dev/null @@ -1,30 +0,0 @@ -# -# Copyright 2016-2020 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -set(sources alltoallv_empty_count.cpp) - -set(CCL_INSTALL_TESTS "${CMAKE_CURRENT_BINARY_DIR}") - -message(WARNING $"TEST DIR: ${CCL_INSTALL_TESTS}") - -foreach(src ${sources}) - get_filename_component(executable ${src} NAME_WE) - add_executable(${executable} ${src}) - target_link_libraries(${executable} PRIVATE ccl gtest_main gtest mpi) - - install(TARGETS ${executable} RUNTIME DESTINATION ${CCL_INSTALL_TESTS} OPTIONAL) - add_test (NAME ${executable} CONFIGURATIONS regression COMMAND mpiexec.hydra -l -n 3 -ppn 1 ${CCL_INSTALL_TESTS}/${executable} --gtest_output=xml:${CCL_INSTALL_TESTS}/${executable}_default_report.junit.xml) - -endforeach(src ${sources}) diff --git a/tests/functional/regression/alltoallv_empty_count.cpp b/tests/functional/regression/alltoallv_empty_count.cpp deleted file mode 100644 index 23e8409d8..000000000 --- a/tests/functional/regression/alltoallv_empty_count.cpp +++ /dev/null @@ -1,163 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include -#include -#include - -#include "oneapi/ccl.hpp" -#include "gtest/gtest.h" -#include "mpi.h" - -class alltoallv_test : public ::testing::Test { -protected: - void SetUp() override { - ccl::init(); - - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - } - - void TearDown() override { - // Don't do finalize if the case has failed, this - // could lead to a deadlock due to inconsistent state. - if (HasFatalFailure()) { - return; - } - - int is_finalized = 0; - MPI_Finalized(&is_finalized); - - if (!is_finalized) - MPI_Finalize(); - } - - int size; - int rank; -}; - -// there are 3 ranks, rank 0 is able to send and receive data to/from others(its send and receive total count > 0) -// rank 1 only sends data but not receives it(its recv_count == 0 for all ranks), and rank 2 only receives data but -// not sends it. -// also rank 1 sets its recv_buf to nullptr(it's not used anyway due to 0 recv count), the same is done on rank 2 for send buf -// in the testcase we simply run alltoallv with these parameters and after that check that both rank 0 and rank 2 received -// the correct data. -// TODO: once we add more tests, move some common parts out of this test -TEST_F(alltoallv_test, alltoallv_empty_recv_count) { - const size_t count = 1000; - - int i = 0; - - ASSERT_EQ(size, 3) << "Test expects 3 ranks"; - - sycl::queue q; - ASSERT_TRUE(q.get_device().is_gpu()) - << "Test expects gpu device, please use SYCL_DEVICE_FILTER accordingly"; - - /* create communicator */ - auto dev = ccl::create_device(q.get_device()); - auto ctx = ccl::create_context(q.get_context()); - auto comm = ccl::create_communicator(size, rank, dev, ctx, /*kvs*/ {}); - - /* create stream */ - auto stream = ccl::create_stream(q); - - // TODO: find a proper way to choose between shared and device pointers(i.e. env variable) - /* create buffers */ - auto send_buf = sycl::malloc_device(count * size, q); - auto recv_buf = sycl::malloc_device(count * size, q); - - // we have 2 ranks in total: rank 1 doesn't receive anything, rank 2 - doesn't send anything - int empty_recv_rank = 1; - int empty_send_rank = 2; - - std::vector send_counts(size, count); - std::vector recv_counts(size, count); - - // update counts so the corresponding rank doesn't receive anything and others doesn't send anything to it - send_counts[empty_recv_rank] = 0; - if (rank == empty_recv_rank) { - std::fill(recv_counts.begin(), recv_counts.end(), 0); - } - - recv_counts[empty_send_rank] = 0; - if (rank == empty_send_rank) { - std::fill(send_counts.begin(), send_counts.end(), 0); - } - q.memset(recv_buf, 0, count * size).wait(); - - std::vector events; - size_t offset = 0; - for (int i = 0; i < send_counts.size(); ++i) { - auto e = q.submit([&](auto& h) { - h.parallel_for(send_counts[i], [=](auto id) { - send_buf[id + offset] = i + 1; - }); - }); - offset += send_counts[i]; - events.push_back(e); - } - - // do not wait completion of kernel and provide it as dependency for operation - std::vector deps; - for (auto e : events) { - deps.push_back(ccl::create_event(e)); - } - - // invoke alltoall - auto attr = ccl::create_operation_attr(); - int* invalid_ptr = (int*)0x00ffff; - // pass an invalid pointer to make sure it's correctly handled and not dereferenced due to 0 count - if (rank == empty_recv_rank) { - recv_buf = invalid_ptr; - } - else if (rank == empty_send_rank) { - send_buf = invalid_ptr; - } - - ccl::alltoallv(send_buf, send_counts, recv_buf, recv_counts, comm, stream, attr, deps).wait(); - - // if our rank is the one that didn't receive anything, than just exit and don't do any checking - if (rank == empty_recv_rank) - return; - - size_t total_recv = std::accumulate(recv_counts.begin(), recv_counts.end(), 0); - - sycl::buffer check_buf(count * size); - q.submit([&](auto& h) { - sycl::accessor check_buf_acc(check_buf, h, sycl::write_only); - h.parallel_for(total_recv, [=, rnk = rank](auto id) { - // we expect that size - 1 chunks are properly filled with data and the last one is - // unchanged as we have one rank that doesn't send anything - if (recv_buf[id] != rnk + 1) { - check_buf_acc[id] = -1; - } - else { - check_buf_acc[id] = 0; - } - }); - }).wait_and_throw(); - - /* print out the result of the test on the host side */ - { - sycl::host_accessor check_buf_acc(check_buf, sycl::read_only); - for (i = 0; i < total_recv; i++) { - ASSERT_NE(check_buf_acc[i], -1) << "Check failed for receive buffer"; - } - } - - return; -} diff --git a/tests/functional/test.hpp b/tests/functional/test.hpp index c8c9e01c1..8adbf3422 100644 --- a/tests/functional/test.hpp +++ b/tests/functional/test.hpp @@ -48,7 +48,7 @@ struct test_operation { #ifdef CCL_ENABLE_SYCL std::vector device_send_bufs; std::vector device_recv_bufs; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL std::vector events; ccl::string_class match_id; @@ -83,17 +83,17 @@ struct test_operation { void* get_send_buf(size_t buf_idx) { #ifdef CCL_ENABLE_SYCL return device_send_bufs[buf_idx]; -#else /* CCL_ENABLE_SYCL */ +#else // CCL_ENABLE_SYCL return send_bufs[buf_idx].data(); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL } void* get_recv_buf(size_t buf_idx) { #ifdef CCL_ENABLE_SYCL return device_recv_bufs[buf_idx]; -#else /* CCL_ENABLE_SYCL */ +#else // CCL_ENABLE_SYCL return recv_bufs[buf_idx].data(); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL } size_t get_check_step(size_t elem_idx) { @@ -129,6 +129,7 @@ class base_test { } base_test(); + virtual ~base_test() = default; void alloc_buffers_base(test_operation& op); virtual void alloc_buffers(test_operation& op); @@ -141,7 +142,7 @@ class base_test { #ifdef CCL_ENABLE_SYCL void copy_to_device_send_buffers(test_operation& op); void copy_from_device_recv_buffers(test_operation& op); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL virtual T calculate_reduce_value(test_operation& op, size_t buf_idx, size_t elem_idx); diff --git a/tests/functional/test_impl.hpp b/tests/functional/test_impl.hpp index b37a85d0e..e3b3b68d9 100644 --- a/tests/functional/test_impl.hpp +++ b/tests/functional/test_impl.hpp @@ -42,7 +42,7 @@ void copy_buffer(void* dst, void* src, size_t bytes) { void fill_buffer(void* ptr, int value, size_t bytes) { transport_data::instance().get_stream().get_native().memset(ptr, value, bytes).wait(); } -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL template template @@ -215,7 +215,7 @@ void base_test::free_buffers(test_operation& op) { free_buffer(op.device_send_bufs[buf_idx]); free_buffer(op.device_recv_bufs[buf_idx]); } -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL } template @@ -234,7 +234,7 @@ void base_test::alloc_buffers_base(test_operation& op) { op.device_send_bufs[buf_idx] = alloc_buffer(op.elem_count * sizeof(T) * op.comm_size); op.device_recv_bufs[buf_idx] = alloc_buffer(op.elem_count * sizeof(T) * op.comm_size); } -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL } template @@ -364,7 +364,7 @@ void base_test::change_buffers(test_operation& op) { void* new_device_recv_buf = op.device_recv_bufs[0]; ASSERT(device_send_buf != new_device_send_buf, "device send buffers should differ"); ASSERT(device_recv_buf != new_device_recv_buf, "device recv buffers should differ"); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL } } @@ -376,12 +376,12 @@ void base_test::copy_to_device_send_buffers(test_operation& op) { #ifdef TEST_CCL_BCAST void* host_buf = op.recv_bufs[buf_idx].data(); void* device_buf = op.device_recv_bufs[buf_idx]; -#else /* TEST_CCL_BCAST */ +#else // TEST_CCL_BCAST void* host_buf = (op.param.place_type == PLACE_IN) ? op.recv_bufs[buf_idx].data() : op.send_bufs[buf_idx].data(); void* device_buf = (op.param.place_type == PLACE_IN) ? op.device_recv_bufs[buf_idx] : op.device_send_bufs[buf_idx]; -#endif /* TEST_CCL_BCAST */ +#endif // TEST_CCL_BCAST size_t bytes = op.send_bufs[buf_idx].size() * sizeof(T); copy_buffer(device_buf, host_buf, bytes); } @@ -395,10 +395,11 @@ void base_test::copy_from_device_recv_buffers(test_operation& op) { op.recv_bufs[buf_idx].size() * sizeof(T)); } } -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL template int base_test::run(test_operation& op) { + static int run_counter = 0; size_t iter = 0, result = 0; char* algo = getenv(ALGO_SELECTION_ENV); @@ -441,7 +442,7 @@ int base_test::run(test_operation& op) { #ifdef CCL_ENABLE_SYCL copy_to_device_send_buffers(op); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL op.define_start_order(rand_engine); run_derived(op); @@ -449,14 +450,18 @@ int base_test::run(test_operation& op) { #ifdef CCL_ENABLE_SYCL copy_from_device_recv_buffers(op); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL if (is_lp_datatype(op.param.datatype)) { make_lp_epilogue(op, op.comm_size * op.elem_count); } result += check(op); + if ((run_counter % 10) == 0) { + ccl::barrier(transport_data::instance().get_service_comm()); + } } + run_counter++; free_buffers(op); } catch (const std::exception& ex) { diff --git a/tests/functional/transport.cpp b/tests/functional/transport.cpp index b6b29a9bd..6c9b04b03 100644 --- a/tests/functional/transport.cpp +++ b/tests/functional/transport.cpp @@ -17,7 +17,7 @@ #ifdef CCL_ENABLE_SYCL #include -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL #include "transport.hpp" @@ -79,7 +79,8 @@ void transport_data::init_comms() { #ifdef CCL_ENABLE_SYCL auto sycl_queues = create_sycl_queues("gpu", local_ranks); ASSERT(!sycl_queues.empty(), "queues should contain at least one queue"); - ASSERT(ranks_per_proc == sycl_queues.size(), "ranks and queues sizes should match"); + ASSERT(static_cast(ranks_per_proc) == sycl_queues.size(), + "ranks and queues sizes should match"); auto sycl_context = sycl_queues[0].get_context(); context = ccl::create_context(sycl_context); @@ -89,12 +90,12 @@ void transport_data::init_comms() { devices.push_back(ccl::create_device(sycl_queues[idx].get_device())); allocators.push_back(buf_allocator(streams[0].get_native())); } -#else /* CCL_ENABLE_SYCL */ +#else // CCL_ENABLE_SYCL for (int idx = 0; idx < ranks_per_proc; idx++) { streams.push_back(ccl::create_stream()); devices.push_back(ccl::create_device()); } -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL for (int idx = 0; idx < ranks_per_proc; idx++) { r2d_map.emplace(local_ranks[idx], devices[idx]); @@ -109,6 +110,7 @@ void transport_data::init_comms() { } void transport_data::reset_comms() { + ccl::barrier(get_service_comm()); comms.clear(); service_comms.clear(); } @@ -141,4 +143,4 @@ ccl::stream& transport_data::get_stream() { buf_allocator& transport_data::get_allocator() { return allocators[0]; } -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL diff --git a/tests/functional/transport.hpp b/tests/functional/transport.hpp index f7287064e..12b9ef478 100644 --- a/tests/functional/transport.hpp +++ b/tests/functional/transport.hpp @@ -22,7 +22,7 @@ #include "oneapi/ccl.hpp" #ifdef CCL_ENABLE_SYCL #include "sycl_base.hpp" -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL class transport_data { public: @@ -41,7 +41,7 @@ class transport_data { #ifdef CCL_ENABLE_SYCL buf_allocator& get_allocator(); -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL private: transport_data(); @@ -60,7 +60,7 @@ class transport_data { #ifdef CCL_ENABLE_SYCL std::vector> allocators; -#endif /* CCL_ENABLE_SYCL */ +#endif // CCL_ENABLE_SYCL const int ranks_per_proc = 1; }; diff --git a/tests/functional/utils.hpp b/tests/functional/utils.hpp index 494381865..e4885eeeb 100644 --- a/tests/functional/utils.hpp +++ b/tests/functional/utils.hpp @@ -50,7 +50,7 @@ printf("\n(%ld): %s: " fmt "\n", GETTID(), __FUNCTION__, ##__VA_ARGS__); \ fflush(stdout); \ } while (0) -#endif /* PRINT */ +#endif // PRINT #define OUTPUT_NAME_ARG "--gtest_output=" #define PATCH_OUTPUT_NAME_ARG(argc, argv, comm) \ diff --git a/third-party-programs.txt b/third-party-programs.txt index 274a06f11..27c2c65f6 100644 --- a/third-party-programs.txt +++ b/third-party-programs.txt @@ -1,5 +1,5 @@ Intel(R) oneAPI Collective Communications Library (oneCCL) -2021.3.0 Third Party Programs File +2021.4.0 Third Party Programs File This file is the "third-party-programs.txt" file specified in the associated Intel end user license agreement for the Intel software you are licensing.